summaryrefslogtreecommitdiffstats
path: root/vendor/golang.org/x/net/http2/transport_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/net/http2/transport_test.go')
-rw-r--r--vendor/golang.org/x/net/http2/transport_test.go177
1 files changed, 173 insertions, 4 deletions
diff --git a/vendor/golang.org/x/net/http2/transport_test.go b/vendor/golang.org/x/net/http2/transport_test.go
index f9287e575..8ef4f3388 100644
--- a/vendor/golang.org/x/net/http2/transport_test.go
+++ b/vendor/golang.org/x/net/http2/transport_test.go
@@ -2073,10 +2073,11 @@ func TestTransportHandlerBodyClose(t *testing.T) {
// https://golang.org/issue/15930
func TestTransportFlowControl(t *testing.T) {
- const (
- total = 100 << 20 // 100MB
- bufLen = 1 << 16
- )
+ const bufLen = 64 << 10
+ var total int64 = 100 << 20 // 100MB
+ if testing.Short() {
+ total = 10 << 20
+ }
var wrote int64 // updated atomically
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
@@ -2745,3 +2746,171 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
t.Errorf("Got = %q; want %q", slurp, msg)
}
}
+
+func TestTransportRetryAfterGOAWAY(t *testing.T) {
+ var dialer struct {
+ sync.Mutex
+ count int
+ }
+ ct1 := make(chan *clientTester)
+ ct2 := make(chan *clientTester)
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ }
+ tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ dialer.Lock()
+ defer dialer.Unlock()
+ dialer.count++
+ if dialer.count == 3 {
+ return nil, errors.New("unexpected number of dials")
+ }
+ cc, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ return nil, fmt.Errorf("dial error: %v", err)
+ }
+ sc, err := ln.Accept()
+ if err != nil {
+ return nil, fmt.Errorf("accept error: %v", err)
+ }
+ ct := &clientTester{
+ t: t,
+ tr: tr,
+ cc: cc,
+ sc: sc,
+ fr: NewFramer(sc, sc),
+ }
+ switch dialer.count {
+ case 1:
+ ct1 <- ct
+ case 2:
+ ct2 <- ct
+ }
+ return cc, nil
+ }
+
+ errs := make(chan error, 3)
+ done := make(chan struct{})
+ defer close(done)
+
+ // Client.
+ go func() {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := tr.RoundTrip(req)
+ if res != nil {
+ res.Body.Close()
+ if got := res.Header.Get("Foo"); got != "bar" {
+ err = fmt.Errorf("foo header = %q; want bar", got)
+ }
+ }
+ if err != nil {
+ err = fmt.Errorf("RoundTrip: %v", err)
+ }
+ errs <- err
+ }()
+
+ connToClose := make(chan io.Closer, 2)
+
+ // Server for the first request.
+ go func() {
+ var ct *clientTester
+ select {
+ case ct = <-ct1:
+ case <-done:
+ return
+ }
+
+ connToClose <- ct.cc
+ ct.greet()
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
+ return
+ }
+ t.Logf("server1 got %v", hf)
+ if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
+ errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
+ return
+ }
+ errs <- nil
+ }()
+
+ // Server for the second request.
+ go func() {
+ var ct *clientTester
+ select {
+ case ct = <-ct2:
+ case <-done:
+ return
+ }
+
+ connToClose <- ct.cc
+ ct.greet()
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
+ return
+ }
+ t.Logf("server2 got %v", hf)
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
+ err = ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ if err != nil {
+ errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
+ } else {
+ errs <- nil
+ }
+ }()
+
+ for k := 0; k < 3; k++ {
+ select {
+ case err := <-errs:
+ if err != nil {
+ t.Error(err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Errorf("timed out")
+ }
+ }
+
+ for {
+ select {
+ case c := <-connToClose:
+ c.Close()
+ default:
+ return
+ }
+ }
+}
+
+func TestAuthorityAddr(t *testing.T) {
+ tests := []struct {
+ scheme, authority string
+ want string
+ }{
+ {"http", "foo.com", "foo.com:80"},
+ {"https", "foo.com", "foo.com:443"},
+ {"https", "foo.com:1234", "foo.com:1234"},
+ {"https", "1.2.3.4:1234", "1.2.3.4:1234"},
+ {"https", "1.2.3.4", "1.2.3.4:443"},
+ {"https", "[::1]:1234", "[::1]:1234"},
+ {"https", "[::1]", "[::1]:443"},
+ }
+ for _, tt := range tests {
+ got := authorityAddr(tt.scheme, tt.authority)
+ if got != tt.want {
+ t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
+ }
+ }
+}