summaryrefslogtreecommitdiffstats
path: root/vendor/golang.org/x/net/http2/transport_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/net/http2/transport_test.go')
-rw-r--r--vendor/golang.org/x/net/http2/transport_test.go283
1 files changed, 281 insertions, 2 deletions
diff --git a/vendor/golang.org/x/net/http2/transport_test.go b/vendor/golang.org/x/net/http2/transport_test.go
index 15dfa0739..ac4661f48 100644
--- a/vendor/golang.org/x/net/http2/transport_test.go
+++ b/vendor/golang.org/x/net/http2/transport_test.go
@@ -685,7 +685,7 @@ func newLocalListener(t *testing.T) net.Listener {
return ln
}
-func (ct *clientTester) greet() {
+func (ct *clientTester) greet(settings ...Setting) {
buf := make([]byte, len(ClientPreface))
_, err := io.ReadFull(ct.sc, buf)
if err != nil {
@@ -699,7 +699,7 @@ func (ct *clientTester) greet() {
ct.t.Fatalf("Wanted client settings frame; got %v", f)
_ = sf // stash it away?
}
- if err := ct.fr.WriteSettings(); err != nil {
+ if err := ct.fr.WriteSettings(settings...); err != nil {
ct.t.Fatal(err)
}
if err := ct.fr.WriteSettingsAck(); err != nil {
@@ -2926,6 +2926,285 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
}
}
+func TestTransportRetryAfterRefusedStream(t *testing.T) {
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ resp, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != 204 {
+ return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ nreq := 0
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ nreq++
+ if nreq == 1 {
+ ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
+ } else {
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+
+func TestTransportRetryHasLimit(t *testing.T) {
+ // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
+ if testing.Short() {
+ t.Skip("skipping long test in short mode")
+ }
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ resp, err := ct.tr.RoundTrip(req)
+ if err == nil {
+ return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
+ }
+ t.Logf("expected error, got: %v", err)
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+
+func TestTransportRequestsStallAtServerLimit(t *testing.T) {
+ const maxConcurrent = 2
+
+ greet := make(chan struct{}) // server sends initial SETTINGS frame
+ gotRequest := make(chan struct{}) // server received a request
+ clientDone := make(chan struct{})
+
+ // Collect errors from goroutines.
+ var wg sync.WaitGroup
+ errs := make(chan error, 100)
+ defer func() {
+ wg.Wait()
+ close(errs)
+ for err := range errs {
+ t.Error(err)
+ }
+ }()
+
+ // We will send maxConcurrent+2 requests. This checker goroutine waits for the
+ // following stages:
+ // 1. The first maxConcurrent requests are received by the server.
+ // 2. The client will cancel the next request
+ // 3. The server is unblocked so it can service the first maxConcurrent requests
+ // 4. The client will send the final request
+ wg.Add(1)
+ unblockClient := make(chan struct{})
+ clientRequestCancelled := make(chan struct{})
+ unblockServer := make(chan struct{})
+ go func() {
+ defer wg.Done()
+ // Stage 1.
+ for k := 0; k < maxConcurrent; k++ {
+ <-gotRequest
+ }
+ // Stage 2.
+ close(unblockClient)
+ <-clientRequestCancelled
+ // Stage 3: give some time for the final RoundTrip call to be scheduled and
+ // verify that the final request is not sent.
+ time.Sleep(50 * time.Millisecond)
+ select {
+ case <-gotRequest:
+ errs <- errors.New("last request did not stall")
+ close(unblockServer)
+ return
+ default:
+ }
+ close(unblockServer)
+ // Stage 4.
+ <-gotRequest
+ }()
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(clientDone)
+ ct.cc.(*net.TCPConn).CloseWrite()
+ }()
+ for k := 0; k < maxConcurrent+2; k++ {
+ wg.Add(1)
+ go func(k int) {
+ defer wg.Done()
+ // Don't send the second request until after receiving SETTINGS from the server
+ // to avoid a race where we use the default SettingMaxConcurrentStreams, which
+ // is much larger than maxConcurrent. We have to send the first request before
+ // waiting because the first request triggers the dial and greet.
+ if k > 0 {
+ <-greet
+ }
+ // Block until maxConcurrent requests are sent before sending any more.
+ if k >= maxConcurrent {
+ <-unblockClient
+ }
+ req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+ if k == maxConcurrent {
+ // This request will be canceled.
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+ close(cancel)
+ _, err := ct.tr.RoundTrip(req)
+ close(clientRequestCancelled)
+ if err == nil {
+ errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
+ return
+ }
+ } else {
+ resp, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
+ return
+ }
+ ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if resp.StatusCode != 204 {
+ errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+ return
+ }
+ }
+ }(k)
+ }
+ return nil
+ }
+
+ ct.server = func() error {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
+ // Server write loop.
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ writeResp := make(chan uint32, maxConcurrent+1)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-unblockServer
+ for id := range writeResp {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: id,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ }()
+
+ // Server read loop.
+ var nreq int
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it will have reported any errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame:
+ case *SettingsFrame:
+ // Wait for the client SETTINGS ack until ending the greet.
+ close(greet)
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ gotRequest <- struct{}{}
+ nreq++
+ writeResp <- f.StreamID
+ if nreq == maxConcurrent+1 {
+ close(writeResp)
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+
+ ct.run()
+}
+
func TestAuthorityAddr(t *testing.T) {
tests := []struct {
scheme, authority string