summaryrefslogtreecommitdiffstats
path: root/vendor/google.golang.org/grpc/transport
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/transport')
-rw-r--r--vendor/google.golang.org/grpc/transport/controlbuf.go101
-rw-r--r--vendor/google.golang.org/grpc/transport/flowcontrol.go10
-rw-r--r--vendor/google.golang.org/grpc/transport/http2_client.go36
-rw-r--r--vendor/google.golang.org/grpc/transport/http2_server.go98
-rw-r--r--vendor/google.golang.org/grpc/transport/http_util.go54
-rw-r--r--vendor/google.golang.org/grpc/transport/transport.go37
6 files changed, 216 insertions, 120 deletions
diff --git a/vendor/google.golang.org/grpc/transport/controlbuf.go b/vendor/google.golang.org/grpc/transport/controlbuf.go
index e147cd51b..5c5891a11 100644
--- a/vendor/google.golang.org/grpc/transport/controlbuf.go
+++ b/vendor/google.golang.org/grpc/transport/controlbuf.go
@@ -28,6 +28,10 @@ import (
"golang.org/x/net/http2/hpack"
)
+var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
+ e.SetMaxDynamicTableSizeLimit(v)
+}
+
type itemNode struct {
it interface{}
next *itemNode
@@ -80,6 +84,13 @@ func (il *itemList) isEmpty() bool {
// the control buffer of transport. They represent different aspects of
// control tasks, e.g., flow control, settings, streaming resetting, etc.
+// registerStream is used to register an incoming stream with loopy writer.
+type registerStream struct {
+ streamID uint32
+ wq *writeQuota
+}
+
+// headerFrame is also used to register stream on the client-side.
type headerFrame struct {
streamID uint32
hf []hpack.HeaderField
@@ -361,44 +372,47 @@ func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimato
const minBatchSize = 1000
// run should be run in a separate goroutine.
-func (l *loopyWriter) run() {
- var (
- it interface{}
- err error
- isEmpty bool
- )
+func (l *loopyWriter) run() (err error) {
defer func() {
- errorf("transport: loopyWriter.run returning. Err: %v", err)
+ if err == ErrConnClosing {
+ // Don't log ErrConnClosing as error since it happens
+ // 1. When the connection is closed by some other known issue.
+ // 2. User closed the connection.
+ // 3. A graceful close of connection.
+ infof("transport: loopyWriter.run returning. %v", err)
+ err = nil
+ }
}()
for {
- it, err = l.cbuf.get(true)
+ it, err := l.cbuf.get(true)
if err != nil {
- return
+ return err
}
if err = l.handle(it); err != nil {
- return
+ return err
}
if _, err = l.processData(); err != nil {
- return
+ return err
}
gosched := true
hasdata:
for {
- it, err = l.cbuf.get(false)
+ it, err := l.cbuf.get(false)
if err != nil {
- return
+ return err
}
if it != nil {
if err = l.handle(it); err != nil {
- return
+ return err
}
if _, err = l.processData(); err != nil {
- return
+ return err
}
continue hasdata
}
- if isEmpty, err = l.processData(); err != nil {
- return
+ isEmpty, err := l.processData()
+ if err != nil {
+ return err
}
if !isEmpty {
continue hasdata
@@ -450,30 +464,39 @@ func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error {
return l.framer.fr.WriteSettingsAck()
}
+func (l *loopyWriter) registerStreamHandler(h *registerStream) error {
+ str := &outStream{
+ id: h.streamID,
+ state: empty,
+ itl: &itemList{},
+ wq: h.wq,
+ }
+ l.estdStreams[h.streamID] = str
+ return nil
+}
+
func (l *loopyWriter) headerHandler(h *headerFrame) error {
if l.side == serverSide {
- if h.endStream { // Case 1.A: Server wants to close stream.
- // Make sure it's not a trailers only response.
- if str, ok := l.estdStreams[h.streamID]; ok {
- if str.state != empty { // either active or waiting on stream quota.
- // add it str's list of items.
- str.itl.enqueue(h)
- return nil
- }
- }
- if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil {
- return err
- }
- return l.cleanupStreamHandler(h.cleanup)
+ str, ok := l.estdStreams[h.streamID]
+ if !ok {
+ warningf("transport: loopy doesn't recognize the stream: %d", h.streamID)
+ return nil
+ }
+ // Case 1.A: Server is responding back with headers.
+ if !h.endStream {
+ return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite)
}
- // Case 1.B: Server is responding back with headers.
- str := &outStream{
- state: empty,
- itl: &itemList{},
- wq: h.wq,
+ // else: Case 1.B: Server wants to close stream.
+
+ if str.state != empty { // either active or waiting on stream quota.
+ // add it str's list of items.
+ str.itl.enqueue(h)
+ return nil
+ }
+ if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil {
+ return err
}
- l.estdStreams[h.streamID] = str
- return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite)
+ return l.cleanupStreamHandler(h.cleanup)
}
// Case 2: Client wants to originate stream.
str := &outStream{
@@ -632,6 +655,8 @@ func (l *loopyWriter) handle(i interface{}) error {
return l.outgoingSettingsHandler(i)
case *headerFrame:
return l.headerHandler(i)
+ case *registerStream:
+ return l.registerStreamHandler(i)
case *cleanupStream:
return l.cleanupStreamHandler(i)
case *incomingGoAway:
@@ -664,6 +689,8 @@ func (l *loopyWriter) applySettings(ss []http2.Setting) error {
}
}
}
+ case http2.SettingHeaderTableSize:
+ updateHeaderTblSize(l.hEnc, s.Val)
}
}
return nil
diff --git a/vendor/google.golang.org/grpc/transport/flowcontrol.go b/vendor/google.golang.org/grpc/transport/flowcontrol.go
index 378f5c450..bbf98b6f5 100644
--- a/vendor/google.golang.org/grpc/transport/flowcontrol.go
+++ b/vendor/google.golang.org/grpc/transport/flowcontrol.go
@@ -58,14 +58,20 @@ type writeQuota struct {
ch chan struct{}
// done is triggered in error case.
done <-chan struct{}
+ // replenish is called by loopyWriter to give quota back to.
+ // It is implemented as a field so that it can be updated
+ // by tests.
+ replenish func(n int)
}
func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota {
- return &writeQuota{
+ w := &writeQuota{
quota: sz,
ch: make(chan struct{}, 1),
done: done,
}
+ w.replenish = w.realReplenish
+ return w
}
func (w *writeQuota) get(sz int32) error {
@@ -83,7 +89,7 @@ func (w *writeQuota) get(sz int32) error {
}
}
-func (w *writeQuota) replenish(n int) {
+func (w *writeQuota) realReplenish(n int) {
sz := int32(n)
a := atomic.AddInt32(&w.quota, sz)
b := a - sz
diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go
index 1fdabd954..eaf007eb0 100644
--- a/vendor/google.golang.org/grpc/transport/http2_client.go
+++ b/vendor/google.golang.org/grpc/transport/http2_client.go
@@ -31,9 +31,9 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
- "google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@@ -76,8 +76,9 @@ type http2Client struct {
// Boolean to keep track of reading activity on transport.
// 1 is true and 0 is false.
- activity uint32 // Accessed atomically.
- kp keepalive.ClientParameters
+ activity uint32 // Accessed atomically.
+ kp keepalive.ClientParameters
+ keepaliveEnabled bool
statsHandler stats.Handler
@@ -259,6 +260,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, "")
}
+ if t.kp.Time != infinity {
+ t.keepaliveEnabled = true
+ go t.keepalive()
+ }
// Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it
// dispatches the frame to the corresponding stream entity.
@@ -295,13 +300,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
t.framer.writer.Flush()
go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst)
- t.loopy.run()
- t.conn.Close()
+ err := t.loopy.run()
+ if err != nil {
+ errorf("transport: loopyWriter.run returning. Err: %v", err)
+ }
+ // If it's a connection error, let reader goroutine handle it
+ // since there might be data in the buffers.
+ if _, ok := err.(net.Error); !ok {
+ t.conn.Close()
+ }
close(t.writerDone)
}()
- if t.kp.Time != infinity {
- go t.keepalive()
- }
return t, nil
}
@@ -537,7 +546,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
var sendPing bool
// If the number of active streams change from 0 to 1, then check if keepalive
// has gone dormant. If so, wake it up.
- if len(t.activeStreams) == 1 {
+ if len(t.activeStreams) == 1 && t.keepaliveEnabled {
select {
case t.awakenKeepalive <- struct{}{}:
sendPing = true
@@ -735,6 +744,7 @@ func (t *http2Client) GracefulClose() error {
if active == 0 {
return t.Close()
}
+ t.controlBuf.put(&incomingGoAway{})
return nil
}
@@ -1109,7 +1119,9 @@ func (t *http2Client) reader() {
t.Close()
return
}
- atomic.CompareAndSwapUint32(&t.activity, 0, 1)
+ if t.keepaliveEnabled {
+ atomic.CompareAndSwapUint32(&t.activity, 0, 1)
+ }
sf, ok := frame.(*http2.SettingsFrame)
if !ok {
t.Close()
@@ -1121,7 +1133,9 @@ func (t *http2Client) reader() {
// loop to keep reading incoming messages on this transport.
for {
frame, err := t.framer.fr.ReadFrame()
- atomic.CompareAndSwapUint32(&t.activity, 0, 1)
+ if t.keepaliveEnabled {
+ atomic.CompareAndSwapUint32(&t.activity, 0, 1)
+ }
if err != nil {
// Abort an active stream if the http2.Framer returns a
// http2.StreamError. This can happen only if the server's response
diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go
index 8b93e222e..19acedb2b 100644
--- a/vendor/google.golang.org/grpc/transport/http2_server.go
+++ b/vendor/google.golang.org/grpc/transport/http2_server.go
@@ -24,7 +24,6 @@ import (
"fmt"
"io"
"math"
- "math/rand"
"net"
"strconv"
"sync"
@@ -36,9 +35,11 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
- "google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/grpclog"
+ "google.golang.org/grpc/internal/channelz"
+ "google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@@ -273,7 +274,9 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
- t.loopy.run()
+ if err := t.loopy.run(); err != nil {
+ errorf("transport: loopyWriter.run returning. Err: %v", err)
+ }
t.conn.Close()
close(t.writerDone)
}()
@@ -413,6 +416,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.updateWindow(s, uint32(n))
},
}
+ // Register the stream with loopy.
+ t.controlBuf.put(&registerStream{
+ streamID: s.id,
+ wq: s.wq,
+ })
handle(s)
return
}
@@ -682,12 +690,25 @@ func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) {
})
}
+func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) []hpack.HeaderField {
+ for k, vv := range md {
+ if isReservedHeader(k) {
+ // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
+ continue
+ }
+ for _, v := range vv {
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ }
+ }
+ return headerFields
+}
+
// WriteHeader sends the header metedata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
- if s.headerOk || s.getState() == streamDone {
+ if s.updateHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
- s.headerOk = true
+ s.hdrMu.Lock()
if md.Len() > 0 {
if s.header.Len() > 0 {
s.header = metadata.Join(s.header, md)
@@ -695,7 +716,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
s.header = md
}
}
- md = s.header
+ t.writeHeaderLocked(s)
+ s.hdrMu.Unlock()
+ return nil
+}
+
+func (t *http2Server) writeHeaderLocked(s *Stream) {
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
@@ -704,15 +730,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
if s.sendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
- for k, vv := range md {
- if isReservedHeader(k) {
- // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
- continue
- }
- for _, v := range vv {
- headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
- }
- }
+ headerFields = appendHeaderFieldsFromMD(headerFields, s.header)
t.controlBuf.put(&headerFrame{
streamID: s.id,
hf: headerFields,
@@ -720,7 +738,6 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
onWrite: func() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
- wq: s.wq,
})
if t.stats != nil {
// Note: WireLength is not set in outHeader.
@@ -728,7 +745,6 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
outHeader := &stats.OutHeader{}
t.stats.HandleRPC(s.Context(), outHeader)
}
- return nil
}
// WriteStatus sends stream status to the client and terminates the stream.
@@ -736,21 +752,20 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
- if !s.headerOk && s.header.Len() > 0 {
- if err := t.WriteHeader(s, nil); err != nil {
- return err
- }
- } else {
- if s.getState() == streamDone {
- return nil
- }
+ if s.getState() == streamDone {
+ return nil
}
+ s.hdrMu.Lock()
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
- if !s.headerOk {
- headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
- headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
+ if !s.updateHeaderSent() { // No headers have been sent.
+ if len(s.header) > 0 { // Send a separate header frame.
+ t.writeHeaderLocked(s)
+ } else { // Send a trailer only response.
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
+ }
}
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
@@ -759,23 +774,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
stBytes, err := proto.Marshal(p)
if err != nil {
// TODO: return error instead, when callers are able to handle it.
- panic(err)
+ grpclog.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err)
+ } else {
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
}
-
- headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
}
// Attach the trailer metadata.
- for k, vv := range s.trailer {
- // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
- if isReservedHeader(k) {
- continue
- }
- for _, v := range vv {
- headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
- }
- }
- trailer := &headerFrame{
+ headerFields = appendHeaderFieldsFromMD(headerFields, s.trailer)
+ trailingHeader := &headerFrame{
streamID: s.id,
hf: headerFields,
endStream: true,
@@ -783,7 +790,8 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
}
- t.closeStream(s, false, 0, trailer, true)
+ s.hdrMu.Unlock()
+ t.closeStream(s, false, 0, trailingHeader, true)
if t.stats != nil {
t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
@@ -793,7 +801,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
- if !s.headerOk { // Headers haven't been written yet.
+ if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil {
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
return streamErrorf(codes.Internal, "transport: %v", err)
@@ -1123,14 +1131,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
}
}
-var rgen = rand.New(rand.NewSource(time.Now().UnixNano()))
-
func getJitter(v time.Duration) time.Duration {
if v == infinity {
return 0
}
// Generate a jitter between +/- 10% of the value.
r := int64(v / 10)
- j := rgen.Int63n(2*r) - r
+ j := grpcrand.Int63n(2*r) - r
return time.Duration(j)
}
diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go
index a35586608..7d15c7d74 100644
--- a/vendor/google.golang.org/grpc/transport/http_util.go
+++ b/vendor/google.golang.org/grpc/transport/http_util.go
@@ -28,6 +28,7 @@ import (
"strconv"
"strings"
"time"
+ "unicode/utf8"
"github.com/golang/protobuf/proto"
"golang.org/x/net/http2"
@@ -437,16 +438,17 @@ func decodeTimeout(s string) (time.Duration, error) {
const (
spaceByte = ' '
- tildaByte = '~'
+ tildeByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
-// "grpc-message".
-// It checks to see if each individual byte in msg is an
-// allowable byte, and then either percent encoding or passing it through.
-// When percent encoding, the byte is converted into hexadecimal notation
-// with a '%' prepended.
+// "grpc-message". It does percent encoding and also replaces invalid utf-8
+// characters with Unicode replacement character.
+//
+// It checks to see if each individual byte in msg is an allowable byte, and
+// then either percent encoding or passing it through. When percent encoding,
+// the byte is converted into hexadecimal notation with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
@@ -454,7 +456,7 @@ func encodeGrpcMessage(msg string) string {
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
- if !(c >= spaceByte && c < tildaByte && c != percentByte) {
+ if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
@@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string {
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
- lenMsg := len(msg)
- for i := 0; i < lenMsg; i++ {
- c := msg[i]
- if c >= spaceByte && c < tildaByte && c != percentByte {
- buf.WriteByte(c)
- } else {
- buf.WriteString(fmt.Sprintf("%%%02X", c))
+ for len(msg) > 0 {
+ r, size := utf8.DecodeRuneInString(msg)
+ for _, b := range []byte(string(r)) {
+ if size > 1 {
+ // If size > 1, r is not ascii. Always do percent encoding.
+ buf.WriteString(fmt.Sprintf("%%%02X", b))
+ continue
+ }
+
+ // The for loop is necessary even if size == 1. r could be
+ // utf8.RuneError.
+ //
+ // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
+ if b >= spaceByte && b <= tildeByte && b != percentByte {
+ buf.WriteByte(b)
+ } else {
+ buf.WriteString(fmt.Sprintf("%%%02X", b))
+ }
}
+ msg = msg[size:]
}
return buf.String()
}
@@ -531,10 +545,14 @@ func (w *bufWriter) Write(b []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
- n = copy(w.buf[w.offset:], b)
- w.offset += n
- if w.offset >= w.batchSize {
- err = w.Flush()
+ for len(b) > 0 {
+ nn := copy(w.buf[w.offset:], b)
+ b = b[nn:]
+ w.offset += nn
+ n += nn
+ if w.offset >= w.batchSize {
+ err = w.Flush()
+ }
}
return n, err
}
diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go
index 2f643a3d0..f51f87888 100644
--- a/vendor/google.golang.org/grpc/transport/transport.go
+++ b/vendor/google.golang.org/grpc/transport/transport.go
@@ -185,13 +185,20 @@ type Stream struct {
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
- header metadata.MD // the received header metadata.
- trailer metadata.MD // the key-value map of trailer metadata.
- headerOk bool // becomes true from the first header is about to send
- state streamState
+ // hdrMu protects header and trailer metadata on the server-side.
+ hdrMu sync.Mutex
+ header metadata.MD // the received header metadata.
+ trailer metadata.MD // the key-value map of trailer metadata.
- status *status.Status // the status error received from the server
+ // On the server-side, headerSent is atomically set to 1 when the headers are sent out.
+ headerSent uint32
+
+ state streamState
+
+ // On client-side it is the status error received from the server.
+ // On server-side it is unused.
+ status *status.Status
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
@@ -201,6 +208,17 @@ type Stream struct {
contentSubtype string
}
+// isHeaderSent is only valid on the server-side.
+func (s *Stream) isHeaderSent() bool {
+ return atomic.LoadUint32(&s.headerSent) == 1
+}
+
+// updateHeaderSent updates headerSent and returns true
+// if it was alreay set. It is valid only on server-side.
+func (s *Stream) updateHeaderSent() bool {
+ return atomic.SwapUint32(&s.headerSent, 1) == 1
+}
+
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
@@ -313,10 +331,12 @@ func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
- if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
+ if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
+ s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
+ s.hdrMu.Unlock()
return nil
}
@@ -335,7 +355,12 @@ func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
+ if s.getState() == streamDone {
+ return ErrIllegalHeaderWrite
+ }
+ s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
+ s.hdrMu.Unlock()
return nil
}