summaryrefslogtreecommitdiffstats
path: root/vendor/golang.org/x/net/nettest/conntest_test.go
blob: 9f9453fb562b613b5ecf0ae0c0671becc5e8b44a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build go1.8

package nettest

import (
	"net"
	"os"
	"runtime"
	"testing"

	"golang.org/x/net/internal/nettest"
)

func TestTestConn(t *testing.T) {
	tests := []struct{ name, network string }{
		{"TCP", "tcp"},
		{"UnixPipe", "unix"},
		{"UnixPacketPipe", "unixpacket"},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if !nettest.TestableNetwork(tt.network) {
				t.Skipf("not supported on %s", runtime.GOOS)
			}

			mp := func() (c1, c2 net.Conn, stop func(), err error) {
				ln, err := nettest.NewLocalListener(tt.network)
				if err != nil {
					return nil, nil, nil, err
				}

				// Start a connection between two endpoints.
				var err1, err2 error
				done := make(chan bool)
				go func() {
					c2, err2 = ln.Accept()
					close(done)
				}()
				c1, err1 = net.Dial(ln.Addr().Network(), ln.Addr().String())
				<-done

				stop = func() {
					if err1 == nil {
						c1.Close()
					}
					if err2 == nil {
						c2.Close()
					}
					ln.Close()
					switch tt.network {
					case "unix", "unixpacket":
						os.Remove(ln.Addr().String())
					}
				}

				switch {
				case err1 != nil:
					stop()
					return nil, nil, nil, err1
				case err2 != nil:
					stop()
					return nil, nil, nil, err2
				default:
					return c1, c2, stop, nil
				}
			}

			TestConn(t, mp)
		})
	}
}