diff options
Diffstat (limited to 'vendor/golang.org/x/net/http2/transport_test.go')
-rw-r--r-- | vendor/golang.org/x/net/http2/transport_test.go | 177 |
1 files changed, 173 insertions, 4 deletions
diff --git a/vendor/golang.org/x/net/http2/transport_test.go b/vendor/golang.org/x/net/http2/transport_test.go index f9287e575..8ef4f3388 100644 --- a/vendor/golang.org/x/net/http2/transport_test.go +++ b/vendor/golang.org/x/net/http2/transport_test.go @@ -2073,10 +2073,11 @@ func TestTransportHandlerBodyClose(t *testing.T) { // https://golang.org/issue/15930 func TestTransportFlowControl(t *testing.T) { - const ( - total = 100 << 20 // 100MB - bufLen = 1 << 16 - ) + const bufLen = 64 << 10 + var total int64 = 100 << 20 // 100MB + if testing.Short() { + total = 10 << 20 + } var wrote int64 // updated atomically st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -2745,3 +2746,171 @@ func TestTransportCancelDataResponseRace(t *testing.T) { t.Errorf("Got = %q; want %q", slurp, msg) } } + +func TestTransportRetryAfterGOAWAY(t *testing.T) { + var dialer struct { + sync.Mutex + count int + } + ct1 := make(chan *clientTester) + ct2 := make(chan *clientTester) + + ln := newLocalListener(t) + defer ln.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + } + tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialer.Lock() + defer dialer.Unlock() + dialer.count++ + if dialer.count == 3 { + return nil, errors.New("unexpected number of dials") + } + cc, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return nil, fmt.Errorf("dial error: %v", err) + } + sc, err := ln.Accept() + if err != nil { + return nil, fmt.Errorf("accept error: %v", err) + } + ct := &clientTester{ + t: t, + tr: tr, + cc: cc, + sc: sc, + fr: NewFramer(sc, sc), + } + switch dialer.count { + case 1: + ct1 <- ct + case 2: + ct2 <- ct + } + return cc, nil + } + + errs := make(chan error, 3) + done := make(chan struct{}) + defer close(done) + + // Client. + go func() { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := tr.RoundTrip(req) + if res != nil { + res.Body.Close() + if got := res.Header.Get("Foo"); got != "bar" { + err = fmt.Errorf("foo header = %q; want bar", got) + } + } + if err != nil { + err = fmt.Errorf("RoundTrip: %v", err) + } + errs <- err + }() + + connToClose := make(chan io.Closer, 2) + + // Server for the first request. + go func() { + var ct *clientTester + select { + case ct = <-ct1: + case <-done: + return + } + + connToClose <- ct.cc + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err) + return + } + t.Logf("server1 got %v", hf) + if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { + errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) + return + } + errs <- nil + }() + + // Server for the second request. + go func() { + var ct *clientTester + select { + case ct = <-ct2: + case <-done: + return + } + + connToClose <- ct.cc + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err) + return + } + t.Logf("server2 got %v", hf) + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + err = ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + if err != nil { + errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err) + } else { + errs <- nil + } + }() + + for k := 0; k < 3; k++ { + select { + case err := <-errs: + if err != nil { + t.Error(err) + } + case <-time.After(1 * time.Second): + t.Errorf("timed out") + } + } + + for { + select { + case c := <-connToClose: + c.Close() + default: + return + } + } +} + +func TestAuthorityAddr(t *testing.T) { + tests := []struct { + scheme, authority string + want string + }{ + {"http", "foo.com", "foo.com:80"}, + {"https", "foo.com", "foo.com:443"}, + {"https", "foo.com:1234", "foo.com:1234"}, + {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, + {"https", "1.2.3.4", "1.2.3.4:443"}, + {"https", "[::1]:1234", "[::1]:1234"}, + {"https", "[::1]", "[::1]:443"}, + } + for _, tt := range tests { + got := authorityAddr(tt.scheme, tt.authority) + if got != tt.want { + t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) + } + } +} |