diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/transport')
-rw-r--r-- | vendor/google.golang.org/grpc/transport/bdp_estimator.go | 140 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/controlbuf.go | 769 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/flowcontrol.go | 236 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/go16.go | 51 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/go17.go | 52 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/handler_server.go | 451 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/http2_client.go | 1284 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/http2_server.go | 1136 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/http_util.go | 574 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/log.go | 50 | ||||
-rw-r--r-- | vendor/google.golang.org/grpc/transport/transport.go | 683 |
11 files changed, 5426 insertions, 0 deletions
diff --git a/vendor/google.golang.org/grpc/transport/bdp_estimator.go b/vendor/google.golang.org/grpc/transport/bdp_estimator.go new file mode 100644 index 000000000..63cd2627c --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/bdp_estimator.go @@ -0,0 +1,140 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "sync" + "time" +) + +const ( + // bdpLimit is the maximum value the flow control windows + // will be increased to. + bdpLimit = (1 << 20) * 4 + // alpha is a constant factor used to keep a moving average + // of RTTs. + alpha = 0.9 + // If the current bdp sample is greater than or equal to + // our beta * our estimated bdp and the current bandwidth + // sample is the maximum bandwidth observed so far, we + // increase our bbp estimate by a factor of gamma. + beta = 0.66 + // To put our bdp to be smaller than or equal to twice the real BDP, + // we should multiply our current sample with 4/3, however to round things out + // we use 2 as the multiplication factor. + gamma = 2 +) + +// Adding arbitrary data to ping so that its ack can be identified. +// Easter-egg: what does the ping message say? +var bdpPing = &ping{data: [8]byte{2, 4, 16, 16, 9, 14, 7, 7}} + +type bdpEstimator struct { + // sentAt is the time when the ping was sent. + sentAt time.Time + + mu sync.Mutex + // bdp is the current bdp estimate. + bdp uint32 + // sample is the number of bytes received in one measurement cycle. + sample uint32 + // bwMax is the maximum bandwidth noted so far (bytes/sec). + bwMax float64 + // bool to keep track of the beginning of a new measurement cycle. + isSent bool + // Callback to update the window sizes. + updateFlowControl func(n uint32) + // sampleCount is the number of samples taken so far. + sampleCount uint64 + // round trip time (seconds) + rtt float64 +} + +// timesnap registers the time bdp ping was sent out so that +// network rtt can be calculated when its ack is received. +// It is called (by controller) when the bdpPing is +// being written on the wire. +func (b *bdpEstimator) timesnap(d [8]byte) { + if bdpPing.data != d { + return + } + b.sentAt = time.Now() +} + +// add adds bytes to the current sample for calculating bdp. +// It returns true only if a ping must be sent. This can be used +// by the caller (handleData) to make decision about batching +// a window update with it. +func (b *bdpEstimator) add(n uint32) bool { + b.mu.Lock() + defer b.mu.Unlock() + if b.bdp == bdpLimit { + return false + } + if !b.isSent { + b.isSent = true + b.sample = n + b.sentAt = time.Time{} + b.sampleCount++ + return true + } + b.sample += n + return false +} + +// calculate is called when an ack for a bdp ping is received. +// Here we calculate the current bdp and bandwidth sample and +// decide if the flow control windows should go up. +func (b *bdpEstimator) calculate(d [8]byte) { + // Check if the ping acked for was the bdp ping. + if bdpPing.data != d { + return + } + b.mu.Lock() + rttSample := time.Since(b.sentAt).Seconds() + if b.sampleCount < 10 { + // Bootstrap rtt with an average of first 10 rtt samples. + b.rtt += (rttSample - b.rtt) / float64(b.sampleCount) + } else { + // Heed to the recent past more. + b.rtt += (rttSample - b.rtt) * float64(alpha) + } + b.isSent = false + // The number of bytes accumulated so far in the sample is smaller + // than or equal to 1.5 times the real BDP on a saturated connection. + bwCurrent := float64(b.sample) / (b.rtt * float64(1.5)) + if bwCurrent > b.bwMax { + b.bwMax = bwCurrent + } + // If the current sample (which is smaller than or equal to the 1.5 times the real BDP) is + // greater than or equal to 2/3rd our perceived bdp AND this is the maximum bandwidth seen so far, we + // should update our perception of the network BDP. + if float64(b.sample) >= beta*float64(b.bdp) && bwCurrent == b.bwMax && b.bdp != bdpLimit { + sampleFloat := float64(b.sample) + b.bdp = uint32(gamma * sampleFloat) + if b.bdp > bdpLimit { + b.bdp = bdpLimit + } + bdp := b.bdp + b.mu.Unlock() + b.updateFlowControl(bdp) + return + } + b.mu.Unlock() +} diff --git a/vendor/google.golang.org/grpc/transport/controlbuf.go b/vendor/google.golang.org/grpc/transport/controlbuf.go new file mode 100644 index 000000000..e147cd51b --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/controlbuf.go @@ -0,0 +1,769 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bytes" + "fmt" + "runtime" + "sync" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +type itemNode struct { + it interface{} + next *itemNode +} + +type itemList struct { + head *itemNode + tail *itemNode +} + +func (il *itemList) enqueue(i interface{}) { + n := &itemNode{it: i} + if il.tail == nil { + il.head, il.tail = n, n + return + } + il.tail.next = n + il.tail = n +} + +// peek returns the first item in the list without removing it from the +// list. +func (il *itemList) peek() interface{} { + return il.head.it +} + +func (il *itemList) dequeue() interface{} { + if il.head == nil { + return nil + } + i := il.head.it + il.head = il.head.next + if il.head == nil { + il.tail = nil + } + return i +} + +func (il *itemList) dequeueAll() *itemNode { + h := il.head + il.head, il.tail = nil, nil + return h +} + +func (il *itemList) isEmpty() bool { + return il.head == nil +} + +// The following defines various control items which could flow through +// the control buffer of transport. They represent different aspects of +// control tasks, e.g., flow control, settings, streaming resetting, etc. + +type headerFrame struct { + streamID uint32 + hf []hpack.HeaderField + endStream bool // Valid on server side. + initStream func(uint32) (bool, error) // Used only on the client side. + onWrite func() + wq *writeQuota // write quota for the stream created. + cleanup *cleanupStream // Valid on the server side. + onOrphaned func(error) // Valid on client-side +} + +type cleanupStream struct { + streamID uint32 + idPtr *uint32 + rst bool + rstCode http2.ErrCode + onWrite func() +} + +type dataFrame struct { + streamID uint32 + endStream bool + h []byte + d []byte + // onEachWrite is called every time + // a part of d is written out. + onEachWrite func() +} + +type incomingWindowUpdate struct { + streamID uint32 + increment uint32 +} + +type outgoingWindowUpdate struct { + streamID uint32 + increment uint32 +} + +type incomingSettings struct { + ss []http2.Setting +} + +type outgoingSettings struct { + ss []http2.Setting +} + +type settingsAck struct { +} + +type incomingGoAway struct { +} + +type goAway struct { + code http2.ErrCode + debugData []byte + headsUp bool + closeConn bool +} + +type ping struct { + ack bool + data [8]byte +} + +type outFlowControlSizeRequest struct { + resp chan uint32 +} + +type outStreamState int + +const ( + active outStreamState = iota + empty + waitingOnStreamQuota +) + +type outStream struct { + id uint32 + state outStreamState + itl *itemList + bytesOutStanding int + wq *writeQuota + + next *outStream + prev *outStream +} + +func (s *outStream) deleteSelf() { + if s.prev != nil { + s.prev.next = s.next + } + if s.next != nil { + s.next.prev = s.prev + } + s.next, s.prev = nil, nil +} + +type outStreamList struct { + // Following are sentinel objects that mark the + // beginning and end of the list. They do not + // contain any item lists. All valid objects are + // inserted in between them. + // This is needed so that an outStream object can + // deleteSelf() in O(1) time without knowing which + // list it belongs to. + head *outStream + tail *outStream +} + +func newOutStreamList() *outStreamList { + head, tail := new(outStream), new(outStream) + head.next = tail + tail.prev = head + return &outStreamList{ + head: head, + tail: tail, + } +} + +func (l *outStreamList) enqueue(s *outStream) { + e := l.tail.prev + e.next = s + s.prev = e + s.next = l.tail + l.tail.prev = s +} + +// remove from the beginning of the list. +func (l *outStreamList) dequeue() *outStream { + b := l.head.next + if b == l.tail { + return nil + } + b.deleteSelf() + return b +} + +type controlBuffer struct { + ch chan struct{} + done <-chan struct{} + mu sync.Mutex + consumerWaiting bool + list *itemList + err error +} + +func newControlBuffer(done <-chan struct{}) *controlBuffer { + return &controlBuffer{ + ch: make(chan struct{}, 1), + list: &itemList{}, + done: done, + } +} + +func (c *controlBuffer) put(it interface{}) error { + _, err := c.executeAndPut(nil, it) + return err +} + +func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) { + var wakeUp bool + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return false, c.err + } + if f != nil { + if !f(it) { // f wasn't successful + c.mu.Unlock() + return false, nil + } + } + if c.consumerWaiting { + wakeUp = true + c.consumerWaiting = false + } + c.list.enqueue(it) + c.mu.Unlock() + if wakeUp { + select { + case c.ch <- struct{}{}: + default: + } + } + return true, nil +} + +func (c *controlBuffer) get(block bool) (interface{}, error) { + for { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return nil, c.err + } + if !c.list.isEmpty() { + h := c.list.dequeue() + c.mu.Unlock() + return h, nil + } + if !block { + c.mu.Unlock() + return nil, nil + } + c.consumerWaiting = true + c.mu.Unlock() + select { + case <-c.ch: + case <-c.done: + c.finish() + return nil, ErrConnClosing + } + } +} + +func (c *controlBuffer) finish() { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return + } + c.err = ErrConnClosing + // There may be headers for streams in the control buffer. + // These streams need to be cleaned out since the transport + // is still not aware of these yet. + for head := c.list.dequeueAll(); head != nil; head = head.next { + hdr, ok := head.it.(*headerFrame) + if !ok { + continue + } + if hdr.onOrphaned != nil { // It will be nil on the server-side. + hdr.onOrphaned(ErrConnClosing) + } + } + c.mu.Unlock() +} + +type side int + +const ( + clientSide side = iota + serverSide +) + +type loopyWriter struct { + side side + cbuf *controlBuffer + sendQuota uint32 + oiws uint32 // outbound initial window size. + estdStreams map[uint32]*outStream // Established streams. + activeStreams *outStreamList // Streams that are sending data. + framer *framer + hBuf *bytes.Buffer // The buffer for HPACK encoding. + hEnc *hpack.Encoder // HPACK encoder. + bdpEst *bdpEstimator + draining bool + + // Side-specific handlers + ssGoAwayHandler func(*goAway) (bool, error) +} + +func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator) *loopyWriter { + var buf bytes.Buffer + l := &loopyWriter{ + side: s, + cbuf: cbuf, + sendQuota: defaultWindowSize, + oiws: defaultWindowSize, + estdStreams: make(map[uint32]*outStream), + activeStreams: newOutStreamList(), + framer: fr, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), + bdpEst: bdpEst, + } + return l +} + +const minBatchSize = 1000 + +// run should be run in a separate goroutine. +func (l *loopyWriter) run() { + var ( + it interface{} + err error + isEmpty bool + ) + defer func() { + errorf("transport: loopyWriter.run returning. Err: %v", err) + }() + for { + it, err = l.cbuf.get(true) + if err != nil { + return + } + if err = l.handle(it); err != nil { + return + } + if _, err = l.processData(); err != nil { + return + } + gosched := true + hasdata: + for { + it, err = l.cbuf.get(false) + if err != nil { + return + } + if it != nil { + if err = l.handle(it); err != nil { + return + } + if _, err = l.processData(); err != nil { + return + } + continue hasdata + } + if isEmpty, err = l.processData(); err != nil { + return + } + if !isEmpty { + continue hasdata + } + if gosched { + gosched = false + if l.framer.writer.offset < minBatchSize { + runtime.Gosched() + continue hasdata + } + } + l.framer.writer.Flush() + break hasdata + + } + } +} + +func (l *loopyWriter) outgoingWindowUpdateHandler(w *outgoingWindowUpdate) error { + return l.framer.fr.WriteWindowUpdate(w.streamID, w.increment) +} + +func (l *loopyWriter) incomingWindowUpdateHandler(w *incomingWindowUpdate) error { + // Otherwise update the quota. + if w.streamID == 0 { + l.sendQuota += w.increment + return nil + } + // Find the stream and update it. + if str, ok := l.estdStreams[w.streamID]; ok { + str.bytesOutStanding -= int(w.increment) + if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota > 0 && str.state == waitingOnStreamQuota { + str.state = active + l.activeStreams.enqueue(str) + return nil + } + } + return nil +} + +func (l *loopyWriter) outgoingSettingsHandler(s *outgoingSettings) error { + return l.framer.fr.WriteSettings(s.ss...) +} + +func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error { + if err := l.applySettings(s.ss); err != nil { + return err + } + return l.framer.fr.WriteSettingsAck() +} + +func (l *loopyWriter) headerHandler(h *headerFrame) error { + if l.side == serverSide { + if h.endStream { // Case 1.A: Server wants to close stream. + // Make sure it's not a trailers only response. + if str, ok := l.estdStreams[h.streamID]; ok { + if str.state != empty { // either active or waiting on stream quota. + // add it str's list of items. + str.itl.enqueue(h) + return nil + } + } + if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil { + return err + } + return l.cleanupStreamHandler(h.cleanup) + } + // Case 1.B: Server is responding back with headers. + str := &outStream{ + state: empty, + itl: &itemList{}, + wq: h.wq, + } + l.estdStreams[h.streamID] = str + return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite) + } + // Case 2: Client wants to originate stream. + str := &outStream{ + id: h.streamID, + state: empty, + itl: &itemList{}, + wq: h.wq, + } + str.itl.enqueue(h) + return l.originateStream(str) +} + +func (l *loopyWriter) originateStream(str *outStream) error { + hdr := str.itl.dequeue().(*headerFrame) + sendPing, err := hdr.initStream(str.id) + if err != nil { + if err == ErrConnClosing { + return err + } + // Other errors(errStreamDrain) need not close transport. + return nil + } + if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { + return err + } + l.estdStreams[str.id] = str + if sendPing { + return l.pingHandler(&ping{data: [8]byte{}}) + } + return nil +} + +func (l *loopyWriter) writeHeader(streamID uint32, endStream bool, hf []hpack.HeaderField, onWrite func()) error { + if onWrite != nil { + onWrite() + } + l.hBuf.Reset() + for _, f := range hf { + if err := l.hEnc.WriteField(f); err != nil { + warningf("transport: loopyWriter.writeHeader encountered error while encoding headers:", err) + } + } + var ( + err error + endHeaders, first bool + ) + first = true + for !endHeaders { + size := l.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + first = false + err = l.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: streamID, + BlockFragment: l.hBuf.Next(size), + EndStream: endStream, + EndHeaders: endHeaders, + }) + } else { + err = l.framer.fr.WriteContinuation( + streamID, + endHeaders, + l.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } + return nil +} + +func (l *loopyWriter) preprocessData(df *dataFrame) error { + str, ok := l.estdStreams[df.streamID] + if !ok { + return nil + } + // If we got data for a stream it means that + // stream was originated and the headers were sent out. + str.itl.enqueue(df) + if str.state == empty { + str.state = active + l.activeStreams.enqueue(str) + } + return nil +} + +func (l *loopyWriter) pingHandler(p *ping) error { + if !p.ack { + l.bdpEst.timesnap(p.data) + } + return l.framer.fr.WritePing(p.ack, p.data) + +} + +func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequest) error { + o.resp <- l.sendQuota + return nil +} + +func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { + c.onWrite() + if str, ok := l.estdStreams[c.streamID]; ok { + // On the server side it could be a trailers-only response or + // a RST_STREAM before stream initialization thus the stream might + // not be established yet. + delete(l.estdStreams, c.streamID) + str.deleteSelf() + } + if c.rst { // If RST_STREAM needs to be sent. + if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil { + return err + } + } + if l.side == clientSide && l.draining && len(l.estdStreams) == 0 { + return ErrConnClosing + } + return nil +} + +func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error { + if l.side == clientSide { + l.draining = true + if len(l.estdStreams) == 0 { + return ErrConnClosing + } + } + return nil +} + +func (l *loopyWriter) goAwayHandler(g *goAway) error { + // Handling of outgoing GoAway is very specific to side. + if l.ssGoAwayHandler != nil { + draining, err := l.ssGoAwayHandler(g) + if err != nil { + return err + } + l.draining = draining + } + return nil +} + +func (l *loopyWriter) handle(i interface{}) error { + switch i := i.(type) { + case *incomingWindowUpdate: + return l.incomingWindowUpdateHandler(i) + case *outgoingWindowUpdate: + return l.outgoingWindowUpdateHandler(i) + case *incomingSettings: + return l.incomingSettingsHandler(i) + case *outgoingSettings: + return l.outgoingSettingsHandler(i) + case *headerFrame: + return l.headerHandler(i) + case *cleanupStream: + return l.cleanupStreamHandler(i) + case *incomingGoAway: + return l.incomingGoAwayHandler(i) + case *dataFrame: + return l.preprocessData(i) + case *ping: + return l.pingHandler(i) + case *goAway: + return l.goAwayHandler(i) + case *outFlowControlSizeRequest: + return l.outFlowControlSizeRequestHandler(i) + default: + return fmt.Errorf("transport: unknown control message type %T", i) + } +} + +func (l *loopyWriter) applySettings(ss []http2.Setting) error { + for _, s := range ss { + switch s.ID { + case http2.SettingInitialWindowSize: + o := l.oiws + l.oiws = s.Val + if o < l.oiws { + // If the new limit is greater make all depleted streams active. + for _, stream := range l.estdStreams { + if stream.state == waitingOnStreamQuota { + stream.state = active + l.activeStreams.enqueue(stream) + } + } + } + } + } + return nil +} + +func (l *loopyWriter) processData() (bool, error) { + if l.sendQuota == 0 { + return true, nil + } + str := l.activeStreams.dequeue() + if str == nil { + return true, nil + } + dataItem := str.itl.peek().(*dataFrame) + if len(dataItem.h) == 0 && len(dataItem.d) == 0 { + // Client sends out empty data frame with endStream = true + if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { + return false, err + } + str.itl.dequeue() + if str.itl.isEmpty() { + str.state = empty + } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. + if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil { + return false, err + } + if err := l.cleanupStreamHandler(trailer.cleanup); err != nil { + return false, nil + } + } else { + l.activeStreams.enqueue(str) + } + return false, nil + } + var ( + idx int + buf []byte + ) + if len(dataItem.h) != 0 { // data header has not been written out yet. + buf = dataItem.h + } else { + idx = 1 + buf = dataItem.d + } + size := http2MaxFrameLen + if len(buf) < size { + size = len(buf) + } + if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { + str.state = waitingOnStreamQuota + return false, nil + } else if strQuota < size { + size = strQuota + } + + if l.sendQuota < uint32(size) { + size = int(l.sendQuota) + } + // Now that outgoing flow controls are checked we can replenish str's write quota + str.wq.replenish(size) + var endStream bool + // This last data message on this stream and all + // of it can be written in this go. + if dataItem.endStream && size == len(buf) { + // buf contains either data or it contains header but data is empty. + if idx == 1 || len(dataItem.d) == 0 { + endStream = true + } + } + if dataItem.onEachWrite != nil { + dataItem.onEachWrite() + } + if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil { + return false, err + } + buf = buf[size:] + str.bytesOutStanding += size + l.sendQuota -= uint32(size) + if idx == 0 { + dataItem.h = buf + } else { + dataItem.d = buf + } + + if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out. + str.itl.dequeue() + } + if str.itl.isEmpty() { + str.state = empty + } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // The next item is trailers. + if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil { + return false, err + } + if err := l.cleanupStreamHandler(trailer.cleanup); err != nil { + return false, err + } + } else if int(l.oiws)-str.bytesOutStanding <= 0 { // Ran out of stream quota. + str.state = waitingOnStreamQuota + } else { // Otherwise add it back to the list of active streams. + l.activeStreams.enqueue(str) + } + return false, nil +} diff --git a/vendor/google.golang.org/grpc/transport/flowcontrol.go b/vendor/google.golang.org/grpc/transport/flowcontrol.go new file mode 100644 index 000000000..378f5c450 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/flowcontrol.go @@ -0,0 +1,236 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "fmt" + "math" + "sync" + "sync/atomic" + "time" +) + +const ( + // The default value of flow control window size in HTTP2 spec. + defaultWindowSize = 65535 + // The initial window size for flow control. + initialWindowSize = defaultWindowSize // for an RPC + infinity = time.Duration(math.MaxInt64) + defaultClientKeepaliveTime = infinity + defaultClientKeepaliveTimeout = 20 * time.Second + defaultMaxStreamsClient = 100 + defaultMaxConnectionIdle = infinity + defaultMaxConnectionAge = infinity + defaultMaxConnectionAgeGrace = infinity + defaultServerKeepaliveTime = 2 * time.Hour + defaultServerKeepaliveTimeout = 20 * time.Second + defaultKeepalivePolicyMinTime = 5 * time.Minute + // max window limit set by HTTP2 Specs. + maxWindowSize = math.MaxInt32 + // defaultWriteQuota is the default value for number of data + // bytes that each stream can schedule before some of it being + // flushed out. + defaultWriteQuota = 64 * 1024 +) + +// writeQuota is a soft limit on the amount of data a stream can +// schedule before some of it is written out. +type writeQuota struct { + quota int32 + // get waits on read from when quota goes less than or equal to zero. + // replenish writes on it when quota goes positive again. + ch chan struct{} + // done is triggered in error case. + done <-chan struct{} +} + +func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota { + return &writeQuota{ + quota: sz, + ch: make(chan struct{}, 1), + done: done, + } +} + +func (w *writeQuota) get(sz int32) error { + for { + if atomic.LoadInt32(&w.quota) > 0 { + atomic.AddInt32(&w.quota, -sz) + return nil + } + select { + case <-w.ch: + continue + case <-w.done: + return errStreamDone + } + } +} + +func (w *writeQuota) replenish(n int) { + sz := int32(n) + a := atomic.AddInt32(&w.quota, sz) + b := a - sz + if b <= 0 && a > 0 { + select { + case w.ch <- struct{}{}: + default: + } + } +} + +type trInFlow struct { + limit uint32 + unacked uint32 + effectiveWindowSize uint32 +} + +func (f *trInFlow) newLimit(n uint32) uint32 { + d := n - f.limit + f.limit = n + f.updateEffectiveWindowSize() + return d +} + +func (f *trInFlow) onData(n uint32) uint32 { + f.unacked += n + if f.unacked >= f.limit/4 { + w := f.unacked + f.unacked = 0 + f.updateEffectiveWindowSize() + return w + } + f.updateEffectiveWindowSize() + return 0 +} + +func (f *trInFlow) reset() uint32 { + w := f.unacked + f.unacked = 0 + f.updateEffectiveWindowSize() + return w +} + +func (f *trInFlow) updateEffectiveWindowSize() { + atomic.StoreUint32(&f.effectiveWindowSize, f.limit-f.unacked) +} + +func (f *trInFlow) getSize() uint32 { + return atomic.LoadUint32(&f.effectiveWindowSize) +} + +// TODO(mmukhi): Simplify this code. +// inFlow deals with inbound flow control +type inFlow struct { + mu sync.Mutex + // The inbound flow control limit for pending data. + limit uint32 + // pendingData is the overall data which have been received but not been + // consumed by applications. + pendingData uint32 + // The amount of data the application has consumed but grpc has not sent + // window update for them. Used to reduce window update frequency. + pendingUpdate uint32 + // delta is the extra window update given by receiver when an application + // is reading data bigger in size than the inFlow limit. + delta uint32 +} + +// newLimit updates the inflow window to a new value n. +// It assumes that n is always greater than the old limit. +func (f *inFlow) newLimit(n uint32) uint32 { + f.mu.Lock() + d := n - f.limit + f.limit = n + f.mu.Unlock() + return d +} + +func (f *inFlow) maybeAdjust(n uint32) uint32 { + if n > uint32(math.MaxInt32) { + n = uint32(math.MaxInt32) + } + f.mu.Lock() + // estSenderQuota is the receiver's view of the maximum number of bytes the sender + // can send without a window update. + estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate)) + // estUntransmittedData is the maximum number of bytes the sends might not have put + // on the wire yet. A value of 0 or less means that we have already received all or + // more bytes than the application is requesting to read. + estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative. + // This implies that unless we send a window update, the sender won't be able to send all the bytes + // for this message. Therefore we must send an update over the limit since there's an active read + // request from the application. + if estUntransmittedData > estSenderQuota { + // Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec. + if f.limit+n > maxWindowSize { + f.delta = maxWindowSize - f.limit + } else { + // Send a window update for the whole message and not just the difference between + // estUntransmittedData and estSenderQuota. This will be helpful in case the message + // is padded; We will fallback on the current available window(at least a 1/4th of the limit). + f.delta = n + } + f.mu.Unlock() + return f.delta + } + f.mu.Unlock() + return 0 +} + +// onData is invoked when some data frame is received. It updates pendingData. +func (f *inFlow) onData(n uint32) error { + f.mu.Lock() + f.pendingData += n + if f.pendingData+f.pendingUpdate > f.limit+f.delta { + limit := f.limit + rcvd := f.pendingData + f.pendingUpdate + f.mu.Unlock() + return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit) + } + f.mu.Unlock() + return nil +} + +// onRead is invoked when the application reads the data. It returns the window size +// to be sent to the peer. +func (f *inFlow) onRead(n uint32) uint32 { + f.mu.Lock() + if f.pendingData == 0 { + f.mu.Unlock() + return 0 + } + f.pendingData -= n + if n > f.delta { + n -= f.delta + f.delta = 0 + } else { + f.delta -= n + n = 0 + } + f.pendingUpdate += n + if f.pendingUpdate >= f.limit/4 { + wu := f.pendingUpdate + f.pendingUpdate = 0 + f.mu.Unlock() + return wu + } + f.mu.Unlock() + return 0 +} diff --git a/vendor/google.golang.org/grpc/transport/go16.go b/vendor/google.golang.org/grpc/transport/go16.go new file mode 100644 index 000000000..5babcf9b8 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/go16.go @@ -0,0 +1,51 @@ +// +build go1.6,!go1.7 + +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "net" + "net/http" + + "google.golang.org/grpc/codes" + + "golang.org/x/net/context" +) + +// dialContext connects to the address on the named network. +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) +} + +// ContextErr converts the error from context package into a StreamError. +func ContextErr(err error) StreamError { + switch err { + case context.DeadlineExceeded: + return streamErrorf(codes.DeadlineExceeded, "%v", err) + case context.Canceled: + return streamErrorf(codes.Canceled, "%v", err) + } + return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) +} + +// contextFromRequest returns a background context. +func contextFromRequest(r *http.Request) context.Context { + return context.Background() +} diff --git a/vendor/google.golang.org/grpc/transport/go17.go b/vendor/google.golang.org/grpc/transport/go17.go new file mode 100644 index 000000000..b7fa6bdb9 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/go17.go @@ -0,0 +1,52 @@ +// +build go1.7 + +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "context" + "net" + "net/http" + + "google.golang.org/grpc/codes" + + netctx "golang.org/x/net/context" +) + +// dialContext connects to the address on the named network. +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, address) +} + +// ContextErr converts the error from context package into a StreamError. +func ContextErr(err error) StreamError { + switch err { + case context.DeadlineExceeded, netctx.DeadlineExceeded: + return streamErrorf(codes.DeadlineExceeded, "%v", err) + case context.Canceled, netctx.Canceled: + return streamErrorf(codes.Canceled, "%v", err) + } + return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) +} + +// contextFromRequest returns a context from the HTTP Request. +func contextFromRequest(r *http.Request) context.Context { + return r.Context() +} diff --git a/vendor/google.golang.org/grpc/transport/handler_server.go b/vendor/google.golang.org/grpc/transport/handler_server.go new file mode 100644 index 000000000..f71b74821 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/handler_server.go @@ -0,0 +1,451 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This file is the implementation of a gRPC server using HTTP/2 which +// uses the standard Go http2 Server implementation (via the +// http.Handler interface), rather than speaking low-level HTTP/2 +// frames itself. It is the implementation of *grpc.Server.ServeHTTP. + +package transport + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "golang.org/x/net/http2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" +) + +// NewServerHandlerTransport returns a ServerTransport handling gRPC +// from inside an http.Handler. It requires that the http Server +// supports HTTP/2. +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) { + if r.ProtoMajor != 2 { + return nil, errors.New("gRPC requires HTTP/2") + } + if r.Method != "POST" { + return nil, errors.New("invalid gRPC request method") + } + contentType := r.Header.Get("Content-Type") + // TODO: do we assume contentType is lowercase? we did before + contentSubtype, validContentType := contentSubtype(contentType) + if !validContentType { + return nil, errors.New("invalid gRPC request content-type") + } + if _, ok := w.(http.Flusher); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher") + } + if _, ok := w.(http.CloseNotifier); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier") + } + + st := &serverHandlerTransport{ + rw: w, + req: r, + closedCh: make(chan struct{}), + writes: make(chan func()), + contentType: contentType, + contentSubtype: contentSubtype, + stats: stats, + } + + if v := r.Header.Get("grpc-timeout"); v != "" { + to, err := decodeTimeout(v) + if err != nil { + return nil, streamErrorf(codes.Internal, "malformed time-out: %v", err) + } + st.timeoutSet = true + st.timeout = to + } + + metakv := []string{"content-type", contentType} + if r.Host != "" { + metakv = append(metakv, ":authority", r.Host) + } + for k, vv := range r.Header { + k = strings.ToLower(k) + if isReservedHeader(k) && !isWhitelistedHeader(k) { + continue + } + for _, v := range vv { + v, err := decodeMetadataHeader(k, v) + if err != nil { + return nil, streamErrorf(codes.Internal, "malformed binary metadata: %v", err) + } + metakv = append(metakv, k, v) + } + } + st.headerMD = metadata.Pairs(metakv...) + + return st, nil +} + +// serverHandlerTransport is an implementation of ServerTransport +// which replies to exactly one gRPC request (exactly one HTTP request), +// using the net/http.Handler interface. This http.Handler is guaranteed +// at this point to be speaking over HTTP/2, so it's able to speak valid +// gRPC. +type serverHandlerTransport struct { + rw http.ResponseWriter + req *http.Request + timeoutSet bool + timeout time.Duration + didCommonHeaders bool + + headerMD metadata.MD + + closeOnce sync.Once + closedCh chan struct{} // closed on Close + + // writes is a channel of code to run serialized in the + // ServeHTTP (HandleStreams) goroutine. The channel is closed + // when WriteStatus is called. + writes chan func() + + // block concurrent WriteStatus calls + // e.g. grpc/(*serverStream).SendMsg/RecvMsg + writeStatusMu sync.Mutex + + // we just mirror the request content-type + contentType string + // we store both contentType and contentSubtype so we don't keep recreating them + // TODO make sure this is consistent across handler_server and http2_server + contentSubtype string + + stats stats.Handler +} + +func (ht *serverHandlerTransport) Close() error { + ht.closeOnce.Do(ht.closeCloseChanOnce) + return nil +} + +func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } + +func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) } + +// strAddr is a net.Addr backed by either a TCP "ip:port" string, or +// the empty string if unknown. +type strAddr string + +func (a strAddr) Network() string { + if a != "" { + // Per the documentation on net/http.Request.RemoteAddr, if this is + // set, it's set to the IP:port of the peer (hence, TCP): + // https://golang.org/pkg/net/http/#Request + // + // If we want to support Unix sockets later, we can + // add our own grpc-specific convention within the + // grpc codebase to set RemoteAddr to a different + // format, or probably better: we can attach it to the + // context and use that from serverHandlerTransport.RemoteAddr. + return "tcp" + } + return "" +} + +func (a strAddr) String() string { return string(a) } + +// do runs fn in the ServeHTTP goroutine. +func (ht *serverHandlerTransport) do(fn func()) error { + // Avoid a panic writing to closed channel. Imperfect but maybe good enough. + select { + case <-ht.closedCh: + return ErrConnClosing + default: + select { + case ht.writes <- fn: + return nil + case <-ht.closedCh: + return ErrConnClosing + } + } +} + +func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { + ht.writeStatusMu.Lock() + defer ht.writeStatusMu.Unlock() + + err := ht.do(func() { + ht.writeCommonHeaders(s) + + // And flush, in case no header or body has been sent yet. + // This forces a separation of headers and trailers if this is the + // first call (for example, in end2end tests's TestNoService). + ht.rw.(http.Flusher).Flush() + + h := ht.rw.Header() + h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code())) + if m := st.Message(); m != "" { + h.Set("Grpc-Message", encodeGrpcMessage(m)) + } + + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + panic(err) + } + + h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes)) + } + + if md := s.Trailer(); len(md) > 0 { + for k, vv := range md { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } + for _, v := range vv { + // http2 ResponseWriter mechanism to send undeclared Trailers after + // the headers have possibly been written. + h.Add(http2.TrailerPrefix+k, encodeMetadataHeader(k, v)) + } + } + } + }) + + if err == nil { // transport has not been closed + if ht.stats != nil { + ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) + } + ht.Close() + close(ht.writes) + } + return err +} + +// writeCommonHeaders sets common headers on the first write +// call (Write, WriteHeader, or WriteStatus). +func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { + if ht.didCommonHeaders { + return + } + ht.didCommonHeaders = true + + h := ht.rw.Header() + h["Date"] = nil // suppress Date to make tests happy; TODO: restore + h.Set("Content-Type", ht.contentType) + + // Predeclare trailers we'll set later in WriteStatus (after the body). + // This is a SHOULD in the HTTP RFC, and the way you add (known) + // Trailers per the net/http.ResponseWriter contract. + // See https://golang.org/pkg/net/http/#ResponseWriter + // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers + h.Add("Trailer", "Grpc-Status") + h.Add("Trailer", "Grpc-Message") + h.Add("Trailer", "Grpc-Status-Details-Bin") + + if s.sendCompress != "" { + h.Set("Grpc-Encoding", s.sendCompress) + } +} + +func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + return ht.do(func() { + ht.writeCommonHeaders(s) + ht.rw.Write(hdr) + ht.rw.Write(data) + if !opts.Delay { + ht.rw.(http.Flusher).Flush() + } + }) +} + +func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { + err := ht.do(func() { + ht.writeCommonHeaders(s) + h := ht.rw.Header() + for k, vv := range md { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } + for _, v := range vv { + v = encodeMetadataHeader(k, v) + h.Add(k, v) + } + } + ht.rw.WriteHeader(200) + ht.rw.(http.Flusher).Flush() + }) + + if err == nil { + if ht.stats != nil { + ht.stats.HandleRPC(s.Context(), &stats.OutHeader{}) + } + } + return err +} + +func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { + // With this transport type there will be exactly 1 stream: this HTTP request. + + ctx := contextFromRequest(ht.req) + var cancel context.CancelFunc + if ht.timeoutSet { + ctx, cancel = context.WithTimeout(ctx, ht.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + + // requestOver is closed when either the request's context is done + // or the status has been written via WriteStatus. + requestOver := make(chan struct{}) + + // clientGone receives a single value if peer is gone, either + // because the underlying connection is dead or because the + // peer sends an http2 RST_STREAM. + clientGone := ht.rw.(http.CloseNotifier).CloseNotify() + go func() { + select { + case <-requestOver: + return + case <-ht.closedCh: + case <-clientGone: + } + cancel() + }() + + req := ht.req + + s := &Stream{ + id: 0, // irrelevant + requestRead: func(int) {}, + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: req.URL.Path, + recvCompress: req.Header.Get("grpc-encoding"), + contentSubtype: ht.contentSubtype, + } + pr := &peer.Peer{ + Addr: ht.RemoteAddr(), + } + if req.TLS != nil { + pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} + } + ctx = metadata.NewIncomingContext(ctx, ht.headerMD) + s.ctx = peer.NewContext(ctx, pr) + if ht.stats != nil { + s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: ht.RemoteAddr(), + Compression: s.recvCompress, + } + ht.stats.HandleRPC(s.ctx, inHeader) + } + s.trReader = &transportReader{ + reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, + windowHandler: func(int) {}, + } + + // readerDone is closed when the Body.Read-ing goroutine exits. + readerDone := make(chan struct{}) + go func() { + defer close(readerDone) + + // TODO: minimize garbage, optimize recvBuffer code/ownership + const readSize = 8196 + for buf := make([]byte, readSize); ; { + n, err := req.Body.Read(buf) + if n > 0 { + s.buf.put(recvMsg{data: buf[:n:n]}) + buf = buf[n:] + } + if err != nil { + s.buf.put(recvMsg{err: mapRecvMsgError(err)}) + return + } + if len(buf) == 0 { + buf = make([]byte, readSize) + } + } + }() + + // startStream is provided by the *grpc.Server's serveStreams. + // It starts a goroutine serving s and exits immediately. + // The goroutine that is started is the one that then calls + // into ht, calling WriteHeader, Write, WriteStatus, Close, etc. + startStream(s) + + ht.runStream() + close(requestOver) + + // Wait for reading goroutine to finish. + req.Body.Close() + <-readerDone +} + +func (ht *serverHandlerTransport) runStream() { + for { + select { + case fn, ok := <-ht.writes: + if !ok { + return + } + fn() + case <-ht.closedCh: + return + } + } +} + +func (ht *serverHandlerTransport) IncrMsgSent() {} + +func (ht *serverHandlerTransport) IncrMsgRecv() {} + +func (ht *serverHandlerTransport) Drain() { + panic("Drain() is not implemented") +} + +// mapRecvMsgError returns the non-nil err into the appropriate +// error value as expected by callers of *grpc.parser.recvMsg. +// In particular, in can only be: +// * io.EOF +// * io.ErrUnexpectedEOF +// * of type transport.ConnectionError +// * of type transport.StreamError +func mapRecvMsgError(err error) error { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return err + } + if se, ok := err.(http2.StreamError); ok { + if code, ok := http2ErrConvTab[se.Code]; ok { + return StreamError{ + Code: code, + Desc: se.Error(), + } + } + } + return connectionErrorf(true, err, err.Error()) +} diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go new file mode 100644 index 000000000..1fdabd954 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/http2_client.go @@ -0,0 +1,1284 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "io" + "math" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + + "google.golang.org/grpc/channelz" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" +) + +// http2Client implements the ClientTransport interface with HTTP2. +type http2Client struct { + ctx context.Context + cancel context.CancelFunc + ctxDone <-chan struct{} // Cache the ctx.Done() chan. + userAgent string + md interface{} + conn net.Conn // underlying communication channel + loopy *loopyWriter + remoteAddr net.Addr + localAddr net.Addr + authInfo credentials.AuthInfo // auth info about the connection + + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. + // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) + // that the server sent GoAway on this transport. + goAway chan struct{} + // awakenKeepalive is used to wake up keepalive when after it has gone dormant. + awakenKeepalive chan struct{} + + framer *framer + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *controlBuffer + fc *trInFlow + // The scheme used: https if TLS is on, http otherwise. + scheme string + + isSecure bool + + creds []credentials.PerRPCCredentials + + // Boolean to keep track of reading activity on transport. + // 1 is true and 0 is false. + activity uint32 // Accessed atomically. + kp keepalive.ClientParameters + + statsHandler stats.Handler + + initialWindowSize int32 + + bdpEst *bdpEstimator + // onSuccess is a callback that client transport calls upon + // receiving server preface to signal that a succefull HTTP2 + // connection was established. + onSuccess func() + + maxConcurrentStreams uint32 + streamQuota int64 + streamsQuotaAvailable chan struct{} + waitingStreams uint32 + nextID uint32 + + mu sync.Mutex // guard the following variables + state transportState + activeStreams map[uint32]*Stream + // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. + prevGoAwayID uint32 + // goAwayReason records the http2.ErrCode and debug data received with the + // GoAway frame. + goAwayReason GoAwayReason + + // Fields below are for channelz metric collection. + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + kpCount int64 + // The number of streams that have started, including already finished ones. + streamsStarted int64 + // The number of streams that have ended successfully by receiving EoS bit set + // frame from server. + streamsSucceeded int64 + streamsFailed int64 + lastStreamCreated time.Time + msgSent int64 + msgRecv int64 + lastMsgSent time.Time + lastMsgRecv time.Time +} + +func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { + if fn != nil { + return fn(ctx, addr) + } + return dialContext(ctx, "tcp", addr) +} + +func isTemporary(err error) bool { + switch err := err.(type) { + case interface { + Temporary() bool + }: + return err.Temporary() + case interface { + Timeout() bool + }: + // Timeouts may be resolved upon retry, and are thus treated as + // temporary. + return err.Timeout() + } + return true +} + +// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 +// and starts to receive messages on it. Non-nil error returns if construction +// fails. +func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onSuccess func()) (_ ClientTransport, err error) { + scheme := "http" + ctx, cancel := context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + + conn, err := dial(connectCtx, opts.Dialer, addr.Addr) + if err != nil { + if opts.FailOnNonTempDialError { + return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) + } + return nil, connectionErrorf(true, err, "transport: Error while dialing %v", err) + } + // Any further errors will close the underlying connection + defer func(conn net.Conn) { + if err != nil { + conn.Close() + } + }(conn) + var ( + isSecure bool + authInfo credentials.AuthInfo + ) + if creds := opts.TransportCredentials; creds != nil { + scheme = "https" + conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn) + if err != nil { + return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) + } + isSecure = true + } + kp := opts.KeepaliveParams + // Validate keepalive parameters. + if kp.Time == 0 { + kp.Time = defaultClientKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultClientKeepaliveTimeout + } + dynamicWindow := true + icwz := int32(initialWindowSize) + if opts.InitialConnWindowSize >= defaultWindowSize { + icwz = opts.InitialConnWindowSize + dynamicWindow = false + } + writeBufSize := defaultWriteBufSize + if opts.WriteBufferSize > 0 { + writeBufSize = opts.WriteBufferSize + } + readBufSize := defaultReadBufSize + if opts.ReadBufferSize > 0 { + readBufSize = opts.ReadBufferSize + } + t := &http2Client{ + ctx: ctx, + ctxDone: ctx.Done(), // Cache Done chan. + cancel: cancel, + userAgent: opts.UserAgent, + md: addr.Metadata, + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: authInfo, + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), + goAway: make(chan struct{}), + awakenKeepalive: make(chan struct{}, 1), + framer: newFramer(conn, writeBufSize, readBufSize), + fc: &trInFlow{limit: uint32(icwz)}, + scheme: scheme, + activeStreams: make(map[uint32]*Stream), + isSecure: isSecure, + creds: opts.PerRPCCredentials, + kp: kp, + statsHandler: opts.StatsHandler, + initialWindowSize: initialWindowSize, + onSuccess: onSuccess, + nextID: 1, + maxConcurrentStreams: defaultMaxStreamsClient, + streamQuota: defaultMaxStreamsClient, + streamsQuotaAvailable: make(chan struct{}, 1), + } + t.controlBuf = newControlBuffer(t.ctxDone) + if opts.InitialWindowSize >= defaultWindowSize { + t.initialWindowSize = opts.InitialWindowSize + dynamicWindow = false + } + if dynamicWindow { + t.bdpEst = &bdpEstimator{ + bdp: initialWindowSize, + updateFlowControl: t.updateFlowControl, + } + } + // Make sure awakenKeepalive can't be written upon. + // keepalive routine will make it writable, if need be. + t.awakenKeepalive <- struct{}{} + if t.statsHandler != nil { + t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{ + Client: true, + } + t.statsHandler.HandleConn(t.ctx, connBegin) + } + if channelz.IsOn() { + t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, "") + } + // Start the reader goroutine for incoming message. Each transport has + // a dedicated goroutine which reads HTTP2 frame from network. Then it + // dispatches the frame to the corresponding stream entity. + go t.reader() + // Send connection preface to server. + n, err := t.conn.Write(clientPreface) + if err != nil { + t.Close() + return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + } + if n != len(clientPreface) { + t.Close() + return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + } + if t.initialWindowSize != defaultWindowSize { + err = t.framer.fr.WriteSettings(http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(t.initialWindowSize), + }) + } else { + err = t.framer.fr.WriteSettings() + } + if err != nil { + t.Close() + return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(icwz - defaultWindowSize); delta > 0 { + if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil { + t.Close() + return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) + } + } + t.framer.writer.Flush() + go func() { + t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst) + t.loopy.run() + t.conn.Close() + close(t.writerDone) + }() + if t.kp.Time != infinity { + go t.keepalive() + } + return t, nil +} + +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { + // TODO(zhaoq): Handle uint32 overflow of Stream.id. + s := &Stream{ + done: make(chan struct{}), + method: callHdr.Method, + sendCompress: callHdr.SendCompress, + buf: newRecvBuffer(), + headerChan: make(chan struct{}), + contentSubtype: callHdr.ContentSubtype, + } + s.wq = newWriteQuota(defaultWriteQuota, s.done) + s.requestRead = func(n int) { + t.adjustWindow(s, uint32(n)) + } + // The client side stream context should have exactly the same life cycle with the user provided context. + // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. + // So we use the original context here instead of creating a copy. + s.ctx = ctx + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + ctxDone: s.ctx.Done(), + recv: s.buf, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } + return s +} + +func (t *http2Client) getPeer() *peer.Peer { + pr := &peer.Peer{ + Addr: t.remoteAddr, + } + // Attach Auth info if there is any. + if t.authInfo != nil { + pr.AuthInfo = t.authInfo + } + return pr +} + +func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) { + aud := t.createAudience(callHdr) + authData, err := t.getTrAuthData(ctx, aud) + if err != nil { + return nil, err + } + callAuthData, err := t.getCallAuthData(ctx, aud, callHdr) + if err != nil { + return nil, err + } + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + // Make the slice of certain predictable size to reduce allocations made by append. + hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te + hfLen += len(authData) + len(callAuthData) + headerFields := make([]hpack.HeaderField, 0, hfLen) + headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) + headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) + + if callHdr.SendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) + } + if dl, ok := ctx.Deadline(); ok { + // Send out timeout regardless its value. The server can detect timeout context by itself. + // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire. + timeout := dl.Sub(time.Now()) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) + } + for k, v := range authData { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + for k, v := range callAuthData { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + if b := stats.OutgoingTags(ctx); b != nil { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)}) + } + if b := stats.OutgoingTrace(ctx); b != nil { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) + } + + if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok { + var k string + for _, vv := range added { + for i, v := range vv { + if i%2 == 0 { + k = v + continue + } + // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. + if isReservedHeader(k) { + continue + } + headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)}) + } + } + for k, vv := range md { + // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. + if isReservedHeader(k) { + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + } + if md, ok := t.md.(*metadata.MD); ok { + for k, vv := range *md { + if isReservedHeader(k) { + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + } + return headerFields, nil +} + +func (t *http2Client) createAudience(callHdr *CallHdr) string { + // Create an audience string only if needed. + if len(t.creds) == 0 && callHdr.Creds == nil { + return "" + } + // Construct URI required to get auth request metadata. + // Omit port if it is the default one. + host := strings.TrimSuffix(callHdr.Host, ":443") + pos := strings.LastIndex(callHdr.Method, "/") + if pos == -1 { + pos = len(callHdr.Method) + } + return "https://" + host + callHdr.Method[:pos] +} + +func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) { + authData := map[string]string{} + for _, c := range t.creds { + data, err := c.GetRequestMetadata(ctx, audience) + if err != nil { + if _, ok := status.FromError(err); ok { + return nil, err + } + + return nil, streamErrorf(codes.Unauthenticated, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2. + k = strings.ToLower(k) + authData[k] = v + } + } + return authData, nil +} + +func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) { + callAuthData := map[string]string{} + // Check if credentials.PerRPCCredentials were provided via call options. + // Note: if these credentials are provided both via dial options and call + // options, then both sets of credentials will be applied. + if callCreds := callHdr.Creds; callCreds != nil { + if !t.isSecure && callCreds.RequireTransportSecurity() { + return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } + data, err := callCreds.GetRequestMetadata(ctx, audience) + if err != nil { + return nil, streamErrorf(codes.Internal, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2 + k = strings.ToLower(k) + callAuthData[k] = v + } + } + return callAuthData, nil +} + +// NewStream creates a stream and registers it into the transport as "active" +// streams. +func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { + ctx = peer.NewContext(ctx, t.getPeer()) + headerFields, err := t.createHeaderFields(ctx, callHdr) + if err != nil { + return nil, err + } + s := t.newStream(ctx, callHdr) + cleanup := func(err error) { + if s.swapState(streamDone) == streamDone { + // If it was already done, return. + return + } + // The stream was unprocessed by the server. + atomic.StoreUint32(&s.unprocessed, 1) + s.write(recvMsg{err: err}) + close(s.done) + // If headerChan isn't closed, then close it. + if atomic.SwapUint32(&s.headerDone, 1) == 0 { + close(s.headerChan) + } + + } + hdr := &headerFrame{ + hf: headerFields, + endStream: false, + initStream: func(id uint32) (bool, error) { + t.mu.Lock() + if state := t.state; state != reachable { + t.mu.Unlock() + // Do a quick cleanup. + err := error(errStreamDrain) + if state == closing { + err = ErrConnClosing + } + cleanup(err) + return false, err + } + t.activeStreams[id] = s + if channelz.IsOn() { + t.czmu.Lock() + t.streamsStarted++ + t.lastStreamCreated = time.Now() + t.czmu.Unlock() + } + var sendPing bool + // If the number of active streams change from 0 to 1, then check if keepalive + // has gone dormant. If so, wake it up. + if len(t.activeStreams) == 1 { + select { + case t.awakenKeepalive <- struct{}{}: + sendPing = true + // Fill the awakenKeepalive channel again as this channel must be + // kept non-writable except at the point that the keepalive() + // goroutine is waiting either to be awaken or shutdown. + t.awakenKeepalive <- struct{}{} + default: + } + } + t.mu.Unlock() + return sendPing, nil + }, + onOrphaned: cleanup, + wq: s.wq, + } + firstTry := true + var ch chan struct{} + checkForStreamQuota := func(it interface{}) bool { + if t.streamQuota <= 0 { // Can go negative if server decreases it. + if firstTry { + t.waitingStreams++ + } + ch = t.streamsQuotaAvailable + return false + } + if !firstTry { + t.waitingStreams-- + } + t.streamQuota-- + h := it.(*headerFrame) + h.streamID = t.nextID + t.nextID += 2 + s.id = h.streamID + s.fc = &inFlow{limit: uint32(t.initialWindowSize)} + if t.streamQuota > 0 && t.waitingStreams > 0 { + select { + case t.streamsQuotaAvailable <- struct{}{}: + default: + } + } + return true + } + for { + success, err := t.controlBuf.executeAndPut(checkForStreamQuota, hdr) + if err != nil { + return nil, err + } + if success { + break + } + firstTry = false + select { + case <-ch: + case <-s.ctx.Done(): + return nil, ContextErr(s.ctx.Err()) + case <-t.goAway: + return nil, errStreamDrain + case <-t.ctx.Done(): + return nil, ErrConnClosing + } + } + if t.statsHandler != nil { + outHeader := &stats.OutHeader{ + Client: true, + FullMethod: callHdr.Method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: callHdr.SendCompress, + } + t.statsHandler.HandleRPC(s.ctx, outHeader) + } + return s, nil +} + +// CloseStream clears the footprint of a stream when the stream is not needed any more. +// This must not be executed in reader's goroutine. +func (t *http2Client) CloseStream(s *Stream, err error) { + var ( + rst bool + rstCode http2.ErrCode + ) + if err != nil { + rst = true + rstCode = http2.ErrCodeCancel + } + t.closeStream(s, err, rst, rstCode, nil, nil, false) +} + +func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { + // Set stream status to done. + if s.swapState(streamDone) == streamDone { + // If it was already done, return. + return + } + // status and trailers can be updated here without any synchronization because the stream goroutine will + // only read it after it sees an io.EOF error from read or write and we'll write those errors + // only after updating this. + s.status = st + if len(mdata) > 0 { + s.trailer = mdata + } + if err != nil { + // This will unblock reads eventually. + s.write(recvMsg{err: err}) + } + // This will unblock write. + close(s.done) + // If headerChan isn't closed, then close it. + if atomic.SwapUint32(&s.headerDone, 1) == 0 { + close(s.headerChan) + } + cleanup := &cleanupStream{ + streamID: s.id, + onWrite: func() { + t.mu.Lock() + if t.activeStreams != nil { + delete(t.activeStreams, s.id) + } + t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + if eosReceived { + t.streamsSucceeded++ + } else { + t.streamsFailed++ + } + t.czmu.Unlock() + } + }, + rst: rst, + rstCode: rstCode, + } + addBackStreamQuota := func(interface{}) bool { + t.streamQuota++ + if t.streamQuota > 0 && t.waitingStreams > 0 { + select { + case t.streamsQuotaAvailable <- struct{}{}: + default: + } + } + return true + } + t.controlBuf.executeAndPut(addBackStreamQuota, cleanup) +} + +// Close kicks off the shutdown process of the transport. This should be called +// only once on a transport. Once it is called, the transport should not be +// accessed any more. +func (t *http2Client) Close() error { + t.mu.Lock() + // Make sure we only Close once. + if t.state == closing { + t.mu.Unlock() + return nil + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + t.controlBuf.finish() + t.cancel() + err := t.conn.Close() + if channelz.IsOn() { + channelz.RemoveEntry(t.channelzID) + } + // Notify all active streams. + for _, s := range streams { + t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, nil, nil, false) + } + if t.statsHandler != nil { + connEnd := &stats.ConnEnd{ + Client: true, + } + t.statsHandler.HandleConn(t.ctx, connEnd) + } + return err +} + +// GracefulClose sets the state to draining, which prevents new streams from +// being created and causes the transport to be closed when the last active +// stream is closed. If there are no active streams, the transport is closed +// immediately. This does nothing if the transport is already draining or +// closing. +func (t *http2Client) GracefulClose() error { + t.mu.Lock() + // Make sure we move to draining only from active. + if t.state == draining || t.state == closing { + t.mu.Unlock() + return nil + } + t.state = draining + active := len(t.activeStreams) + t.mu.Unlock() + if active == 0 { + return t.Close() + } + return nil +} + +// Write formats the data into HTTP2 data frame(s) and sends it out. The caller +// should proceed only if Write returns nil. +func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + if opts.Last { + // If it's the last message, update stream state. + if !s.compareAndSwapState(streamActive, streamWriteDone) { + return errStreamDone + } + } else if s.getState() != streamActive { + return errStreamDone + } + df := &dataFrame{ + streamID: s.id, + endStream: opts.Last, + } + if hdr != nil || data != nil { // If it's not an empty data frame. + // Add some data to grpc message header so that we can equally + // distribute bytes across frames. + emptyLen := http2MaxFrameLen - len(hdr) + if emptyLen > len(data) { + emptyLen = len(data) + } + hdr = append(hdr, data[:emptyLen]...) + data = data[emptyLen:] + df.h, df.d = hdr, data + // TODO(mmukhi): The above logic in this if can be moved to loopyWriter's data handler. + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + return err + } + } + return t.controlBuf.put(df) +} + +func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { + t.mu.Lock() + defer t.mu.Unlock() + s, ok := t.activeStreams[f.Header().StreamID] + return s, ok +} + +// adjustWindow sends out extra window update over the initial window size +// of stream if the application is requesting data larger in size than +// the window. +func (t *http2Client) adjustWindow(s *Stream, n uint32) { + if w := s.fc.maybeAdjust(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } +} + +// updateWindow adjusts the inbound quota for the stream. +// Window updates will be sent out when the cumulative quota +// exceeds the corresponding threshold. +func (t *http2Client) updateWindow(s *Stream, n uint32) { + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } +} + +// updateFlowControl updates the incoming flow control windows +// for the transport and the stream based on the current bdp +// estimation. +func (t *http2Client) updateFlowControl(n uint32) { + t.mu.Lock() + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.mu.Unlock() + updateIWS := func(interface{}) bool { + t.initialWindowSize = int32(n) + return true + } + t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)}) + t.controlBuf.put(&outgoingSettings{ + ss: []http2.Setting{ + { + ID: http2.SettingInitialWindowSize, + Val: n, + }, + }, + }) +} + +func (t *http2Client) handleData(f *http2.DataFrame) { + size := f.Header().Length + var sendBDPPing bool + if t.bdpEst != nil { + sendBDPPing = t.bdpEst.add(size) + } + // Decouple connection's flow control from application's read. + // An update on connection's flow control should not depend on + // whether user application has read the data or not. Such a + // restriction is already imposed on the stream's flow control, + // and therefore the sender will be blocked anyways. + // Decoupling the connection flow control will prevent other + // active(fast) streams from starving in presence of slow or + // inactive streams. + // + if w := t.fc.onData(size); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + if sendBDPPing { + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + + t.controlBuf.put(bdpPing) + } + // Select the right stream to dispatch. + s, ok := t.getStream(f) + if !ok { + return + } + if size > 0 { + if err := s.fc.onData(size); err != nil { + t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) + return + } + if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) + } + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + if len(f.Data()) > 0 { + data := make([]byte, len(f.Data())) + copy(data, f.Data()) + s.write(recvMsg{data: data}) + } + } + // The server has closed the stream without sending trailers. Record that + // the read direction is closed, and set the status appropriately. + if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) + } +} + +func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { + s, ok := t.getStream(f) + if !ok { + return + } + if f.ErrCode == http2.ErrCodeRefusedStream { + // The stream was unprocessed by the server. + atomic.StoreUint32(&s.unprocessed, 1) + } + statusCode, ok := http2ErrConvTab[f.ErrCode] + if !ok { + warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) + statusCode = codes.Unknown + } + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) +} + +func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { + if f.IsAck() { + return + } + var maxStreams *uint32 + var ss []http2.Setting + f.ForeachSetting(func(s http2.Setting) error { + if s.ID == http2.SettingMaxConcurrentStreams { + maxStreams = new(uint32) + *maxStreams = s.Val + return nil + } + ss = append(ss, s) + return nil + }) + if isFirst && maxStreams == nil { + maxStreams = new(uint32) + *maxStreams = math.MaxUint32 + } + sf := &incomingSettings{ + ss: ss, + } + if maxStreams == nil { + t.controlBuf.put(sf) + return + } + updateStreamQuota := func(interface{}) bool { + delta := int64(*maxStreams) - int64(t.maxConcurrentStreams) + t.maxConcurrentStreams = *maxStreams + t.streamQuota += delta + if delta > 0 && t.waitingStreams > 0 { + close(t.streamsQuotaAvailable) // wake all of them up. + t.streamsQuotaAvailable = make(chan struct{}, 1) + } + return true + } + t.controlBuf.executeAndPut(updateStreamQuota, sf) +} + +func (t *http2Client) handlePing(f *http2.PingFrame) { + if f.IsAck() { + // Maybe it's a BDP ping. + if t.bdpEst != nil { + t.bdpEst.calculate(f.Data) + } + return + } + pingAck := &ping{ack: true} + copy(pingAck.data[:], f.Data[:]) + t.controlBuf.put(pingAck) +} + +func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return + } + if f.ErrCode == http2.ErrCodeEnhanceYourCalm { + infof("Client received GoAway with http2.ErrCodeEnhanceYourCalm.") + } + id := f.LastStreamID + if id > 0 && id%2 != 1 { + t.mu.Unlock() + t.Close() + return + } + // A client can receive multiple GoAways from the server (see + // https://github.com/grpc/grpc-go/issues/1387). The idea is that the first + // GoAway will be sent with an ID of MaxInt32 and the second GoAway will be + // sent after an RTT delay with the ID of the last stream the server will + // process. + // + // Therefore, when we get the first GoAway we don't necessarily close any + // streams. While in case of second GoAway we close all streams created after + // the GoAwayId. This way streams that were in-flight while the GoAway from + // server was being sent don't get killed. + select { + case <-t.goAway: // t.goAway has been closed (i.e.,multiple GoAways). + // If there are multiple GoAways the first one should always have an ID greater than the following ones. + if id > t.prevGoAwayID { + t.mu.Unlock() + t.Close() + return + } + default: + t.setGoAwayReason(f) + close(t.goAway) + t.state = draining + t.controlBuf.put(&incomingGoAway{}) + } + // All streams with IDs greater than the GoAwayId + // and smaller than the previous GoAway ID should be killed. + upperLimit := t.prevGoAwayID + if upperLimit == 0 { // This is the first GoAway Frame. + upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. + } + for streamID, stream := range t.activeStreams { + if streamID > id && streamID <= upperLimit { + // The stream was unprocessed by the server. + atomic.StoreUint32(&stream.unprocessed, 1) + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) + } + } + t.prevGoAwayID = id + active := len(t.activeStreams) + t.mu.Unlock() + if active == 0 { + t.Close() + } +} + +// setGoAwayReason sets the value of t.goAwayReason based +// on the GoAway frame received. +// It expects a lock on transport's mutext to be held by +// the caller. +func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { + t.goAwayReason = GoAwayNoReason + switch f.ErrCode { + case http2.ErrCodeEnhanceYourCalm: + if string(f.DebugData()) == "too_many_pings" { + t.goAwayReason = GoAwayTooManyPings + } + } +} + +func (t *http2Client) GetGoAwayReason() GoAwayReason { + t.mu.Lock() + defer t.mu.Unlock() + return t.goAwayReason +} + +func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) +} + +// operateHeaders takes action on the decoded headers. +func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { + s, ok := t.getStream(frame) + if !ok { + return + } + atomic.StoreUint32(&s.bytesReceived, 1) + var state decodeState + if err := state.decodeResponseHeader(frame); err != nil { + t.closeStream(s, err, true, http2.ErrCodeProtocol, nil, nil, false) + // Something wrong. Stops reading even when there is remaining. + return + } + + endStream := frame.StreamEnded() + var isHeader bool + defer func() { + if t.statsHandler != nil { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + } + t.statsHandler.HandleRPC(s.ctx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + } + t.statsHandler.HandleRPC(s.ctx, inTrailer) + } + } + }() + // If headers haven't been received yet. + if atomic.SwapUint32(&s.headerDone, 1) == 0 { + if !endStream { + // Headers frame is not actually a trailers-only frame. + isHeader = true + // These values can be set without any synchronization because + // stream goroutine will read it only after seeing a closed + // headerChan which we'll close after setting this. + s.recvCompress = state.encoding + if len(state.mdata) > 0 { + s.header = state.mdata + } + } + close(s.headerChan) + } + if !endStream { + return + } + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, state.status(), state.mdata, true) +} + +// reader runs as a separate goroutine in charge of reading data from network +// connection. +// +// TODO(zhaoq): currently one reader per transport. Investigate whether this is +// optimal. +// TODO(zhaoq): Check the validity of the incoming frame sequence. +func (t *http2Client) reader() { + defer close(t.readerDone) + // Check the validity of server preface. + frame, err := t.framer.fr.ReadFrame() + if err != nil { + t.Close() + return + } + atomic.CompareAndSwapUint32(&t.activity, 0, 1) + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + t.Close() + return + } + t.onSuccess() + t.handleSettings(sf, true) + + // loop to keep reading incoming messages on this transport. + for { + frame, err := t.framer.fr.ReadFrame() + atomic.CompareAndSwapUint32(&t.activity, 0, 1) + if err != nil { + // Abort an active stream if the http2.Framer returns a + // http2.StreamError. This can happen only if the server's response + // is malformed http2. + if se, ok := err.(http2.StreamError); ok { + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + // use error detail to provide better err message + t.closeStream(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()), true, http2.ErrCodeProtocol, nil, nil, false) + } + continue + } else { + // Transport error. + t.Close() + return + } + } + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + t.operateHeaders(frame) + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame, false) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.GoAwayFrame: + t.handleGoAway(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + default: + errorf("transport: http2Client.reader got unhandled frame type %v.", frame) + } + } +} + +// keepalive running in a separate goroutune makes sure the connection is alive by sending pings. +func (t *http2Client) keepalive() { + p := &ping{data: [8]byte{}} + timer := time.NewTimer(t.kp.Time) + for { + select { + case <-timer.C: + if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { + timer.Reset(t.kp.Time) + continue + } + // Check if keepalive should go dormant. + t.mu.Lock() + if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream { + // Make awakenKeepalive writable. + <-t.awakenKeepalive + t.mu.Unlock() + select { + case <-t.awakenKeepalive: + // If the control gets here a ping has been sent + // need to reset the timer with keepalive.Timeout. + case <-t.ctx.Done(): + return + } + } else { + t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + t.kpCount++ + t.czmu.Unlock() + } + // Send ping. + t.controlBuf.put(p) + } + + // By the time control gets here a ping has been sent one way or the other. + timer.Reset(t.kp.Timeout) + select { + case <-timer.C: + if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { + timer.Reset(t.kp.Time) + continue + } + t.Close() + return + case <-t.ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + case <-t.ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +func (t *http2Client) Error() <-chan struct{} { + return t.ctx.Done() +} + +func (t *http2Client) GoAway() <-chan struct{} { + return t.goAway +} + +func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric { + t.czmu.RLock() + s := channelz.SocketInternalMetric{ + StreamsStarted: t.streamsStarted, + StreamsSucceeded: t.streamsSucceeded, + StreamsFailed: t.streamsFailed, + MessagesSent: t.msgSent, + MessagesReceived: t.msgRecv, + KeepAlivesSent: t.kpCount, + LastLocalStreamCreatedTimestamp: t.lastStreamCreated, + LastMessageSentTimestamp: t.lastMsgSent, + LastMessageReceivedTimestamp: t.lastMsgRecv, + LocalFlowControlWindow: int64(t.fc.getSize()), + //socket options + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // Security + // RemoteName : + } + t.czmu.RUnlock() + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s +} + +func (t *http2Client) IncrMsgSent() { + t.czmu.Lock() + t.msgSent++ + t.lastMsgSent = time.Now() + t.czmu.Unlock() +} + +func (t *http2Client) IncrMsgRecv() { + t.czmu.Lock() + t.msgRecv++ + t.lastMsgRecv = time.Now() + t.czmu.Unlock() +} + +func (t *http2Client) getOutFlowWindow() int64 { + resp := make(chan uint32, 1) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.ctxDone: + return -1 + case <-timer.C: + return -2 + } +} diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go new file mode 100644 index 000000000..8b93e222e --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/http2_server.go @@ -0,0 +1,1136 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bytes" + "errors" + "fmt" + "io" + "math" + "math/rand" + "net" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + + "google.golang.org/grpc/channelz" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" +) + +// ErrIllegalHeaderWrite indicates that setting header is illegal because of +// the stream's state. +var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + +// http2Server implements the ServerTransport interface with HTTP2. +type http2Server struct { + ctx context.Context + ctxDone <-chan struct{} // Cache the context.Done() chan + cancel context.CancelFunc + conn net.Conn + loopy *loopyWriter + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. + remoteAddr net.Addr + localAddr net.Addr + maxStreamID uint32 // max stream ID ever seen + authInfo credentials.AuthInfo // auth info about the connection + inTapHandle tap.ServerInHandle + framer *framer + // The max number of concurrent streams. + maxStreams uint32 + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *controlBuffer + fc *trInFlow + stats stats.Handler + // Flag to keep track of reading activity on transport. + // 1 is true and 0 is false. + activity uint32 // Accessed atomically. + // Keepalive and max-age parameters for the server. + kp keepalive.ServerParameters + + // Keepalive enforcement policy. + kep keepalive.EnforcementPolicy + // The time instance last ping was received. + lastPingAt time.Time + // Number of times the client has violated keepalive ping policy so far. + pingStrikes uint8 + // Flag to signify that number of ping strikes should be reset to 0. + // This is set whenever data or header frames are sent. + // 1 means yes. + resetPingStrikes uint32 // Accessed atomically. + initialWindowSize int32 + bdpEst *bdpEstimator + + mu sync.Mutex // guard the following + + // drainChan is initialized when drain(...) is called the first time. + // After which the server writes out the first GoAway(with ID 2^31-1) frame. + // Then an independent goroutine will be launched to later send the second GoAway. + // During this time we don't want to write another first GoAway(with ID 2^31 -1) frame. + // Thus call to drain(...) will be a no-op if drainChan is already initialized since draining is + // already underway. + drainChan chan struct{} + state transportState + activeStreams map[uint32]*Stream + // idle is the time instant when the connection went idle. + // This is either the beginning of the connection or when the number of + // RPCs go down to 0. + // When the connection is busy, this value is set to 0. + idle time.Time + + // Fields below are for channelz metric collection. + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + kpCount int64 + // The number of streams that have started, including already finished ones. + streamsStarted int64 + // The number of streams that have ended successfully by sending frame with + // EoS bit set. + streamsSucceeded int64 + streamsFailed int64 + lastStreamCreated time.Time + msgSent int64 + msgRecv int64 + lastMsgSent time.Time + lastMsgRecv time.Time +} + +// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is +// returned if something goes wrong. +func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { + writeBufSize := defaultWriteBufSize + if config.WriteBufferSize > 0 { + writeBufSize = config.WriteBufferSize + } + readBufSize := defaultReadBufSize + if config.ReadBufferSize > 0 { + readBufSize = config.ReadBufferSize + } + framer := newFramer(conn, writeBufSize, readBufSize) + // Send initial settings as connection preface to client. + var isettings []http2.Setting + // TODO(zhaoq): Have a better way to signal "no limit" because 0 is + // permitted in the HTTP2 spec. + maxStreams := config.MaxStreams + if maxStreams == 0 { + maxStreams = math.MaxUint32 + } else { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: maxStreams, + }) + } + dynamicWindow := true + iwz := int32(initialWindowSize) + if config.InitialWindowSize >= defaultWindowSize { + iwz = config.InitialWindowSize + dynamicWindow = false + } + icwz := int32(initialWindowSize) + if config.InitialConnWindowSize >= defaultWindowSize { + icwz = config.InitialConnWindowSize + dynamicWindow = false + } + if iwz != defaultWindowSize { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(iwz)}) + } + if err := framer.fr.WriteSettings(isettings...); err != nil { + return nil, connectionErrorf(false, err, "transport: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(icwz - defaultWindowSize); delta > 0 { + if err := framer.fr.WriteWindowUpdate(0, delta); err != nil { + return nil, connectionErrorf(false, err, "transport: %v", err) + } + } + kp := config.KeepaliveParams + if kp.MaxConnectionIdle == 0 { + kp.MaxConnectionIdle = defaultMaxConnectionIdle + } + if kp.MaxConnectionAge == 0 { + kp.MaxConnectionAge = defaultMaxConnectionAge + } + // Add a jitter to MaxConnectionAge. + kp.MaxConnectionAge += getJitter(kp.MaxConnectionAge) + if kp.MaxConnectionAgeGrace == 0 { + kp.MaxConnectionAgeGrace = defaultMaxConnectionAgeGrace + } + if kp.Time == 0 { + kp.Time = defaultServerKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultServerKeepaliveTimeout + } + kep := config.KeepalivePolicy + if kep.MinTime == 0 { + kep.MinTime = defaultKeepalivePolicyMinTime + } + ctx, cancel := context.WithCancel(context.Background()) + t := &http2Server{ + ctx: ctx, + cancel: cancel, + ctxDone: ctx.Done(), + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: config.AuthInfo, + framer: framer, + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), + maxStreams: maxStreams, + inTapHandle: config.InTapHandle, + fc: &trInFlow{limit: uint32(icwz)}, + state: reachable, + activeStreams: make(map[uint32]*Stream), + stats: config.StatsHandler, + kp: kp, + idle: time.Now(), + kep: kep, + initialWindowSize: iwz, + } + t.controlBuf = newControlBuffer(t.ctxDone) + if dynamicWindow { + t.bdpEst = &bdpEstimator{ + bdp: initialWindowSize, + updateFlowControl: t.updateFlowControl, + } + } + if t.stats != nil { + t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{} + t.stats.HandleConn(t.ctx, connBegin) + } + if channelz.IsOn() { + t.channelzID = channelz.RegisterNormalSocket(t, config.ChannelzParentID, "") + } + t.framer.writer.Flush() + + defer func() { + if err != nil { + t.Close() + } + }() + + // Check the validity of client preface. + preface := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(t.conn, preface); err != nil { + return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) + } + if !bytes.Equal(preface, clientPreface) { + return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams received bogus greeting from client: %q", preface) + } + + frame, err := t.framer.fr.ReadFrame() + if err == io.EOF || err == io.ErrUnexpectedEOF { + return nil, err + } + if err != nil { + return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) + } + atomic.StoreUint32(&t.activity, 1) + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) + } + t.handleSettings(sf) + + go func() { + t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst) + t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler + t.loopy.run() + t.conn.Close() + close(t.writerDone) + }() + go t.keepalive() + return t, nil +} + +// operateHeader takes action on the decoded headers. +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { + streamID := frame.Header().StreamID + var state decodeState + for _, hf := range frame.Fields { + if err := state.processHeaderField(hf); err != nil { + if se, ok := err.(StreamError); ok { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: statusCodeConvTab[se.Code], + onWrite: func() {}, + }) + } + return + } + } + + buf := newRecvBuffer() + s := &Stream{ + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + recvCompress: state.encoding, + method: state.method, + contentSubtype: state.contentSubtype, + } + if frame.StreamEnded() { + // s is just created by the caller. No lock needed. + s.state = streamReadDone + } + if state.timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) + } else { + s.ctx, s.cancel = context.WithCancel(t.ctx) + } + pr := &peer.Peer{ + Addr: t.remoteAddr, + } + // Attach Auth info if there is any. + if t.authInfo != nil { + pr.AuthInfo = t.authInfo + } + s.ctx = peer.NewContext(s.ctx, pr) + // Attach the received metadata to the context. + if len(state.mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) + } + if state.statsTags != nil { + s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags) + } + if state.statsTrace != nil { + s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace) + } + if t.inTapHandle != nil { + var err error + info := &tap.Info{ + FullMethodName: state.method, + } + s.ctx, err = t.inTapHandle(s.ctx, info) + if err != nil { + warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) + return + } + } + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + return + } + if uint32(len(t.activeStreams)) >= t.maxStreams { + t.mu.Unlock() + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) + return + } + if streamID%2 != 1 || streamID <= t.maxStreamID { + t.mu.Unlock() + // illegal gRPC stream id. + errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", streamID) + return true + } + t.maxStreamID = streamID + t.activeStreams[streamID] = s + if len(t.activeStreams) == 1 { + t.idle = time.Time{} + } + t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + t.streamsStarted++ + t.lastStreamCreated = time.Now() + t.czmu.Unlock() + } + s.requestRead = func(n int) { + t.adjustWindow(s, uint32(n)) + } + s.ctx = traceCtx(s.ctx, s.method) + if t.stats != nil { + s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: s.recvCompress, + WireLength: int(frame.Header().Length), + } + t.stats.HandleRPC(s.ctx, inHeader) + } + s.ctxDone = s.ctx.Done() + s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + ctxDone: s.ctxDone, + recv: s.buf, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } + handle(s) + return +} + +// HandleStreams receives incoming streams using the given handler. This is +// typically run in a separate goroutine. +// traceCtx attaches trace to ctx and returns the new context. +func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { + defer close(t.readerDone) + for { + frame, err := t.framer.fr.ReadFrame() + atomic.StoreUint32(&t.activity, 1) + if err != nil { + if se, ok := err.(http2.StreamError); ok { + warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se) + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + t.closeStream(s, true, se.Code, nil, false) + } else { + t.controlBuf.put(&cleanupStream{ + streamID: se.StreamID, + rst: true, + rstCode: se.Code, + onWrite: func() {}, + }) + } + continue + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + t.Close() + return + } + warningf("transport: http2Server.HandleStreams failed to read frame: %v", err) + t.Close() + return + } + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + if t.operateHeaders(frame, handle, traceCtx) { + t.Close() + break + } + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + case *http2.GoAwayFrame: + // TODO: Handle GoAway from the client appropriately. + default: + errorf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) + } + } +} + +func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { + t.mu.Lock() + defer t.mu.Unlock() + if t.activeStreams == nil { + // The transport is closing. + return nil, false + } + s, ok := t.activeStreams[f.Header().StreamID] + if !ok { + // The stream is already done. + return nil, false + } + return s, true +} + +// adjustWindow sends out extra window update over the initial window size +// of stream if the application is requesting data larger in size than +// the window. +func (t *http2Server) adjustWindow(s *Stream, n uint32) { + if w := s.fc.maybeAdjust(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } + +} + +// updateWindow adjusts the inbound quota for the stream and the transport. +// Window updates will deliver to the controller for sending when +// the cumulative quota exceeds the corresponding threshold. +func (t *http2Server) updateWindow(s *Stream, n uint32) { + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, + increment: w, + }) + } +} + +// updateFlowControl updates the incoming flow control windows +// for the transport and the stream based on the current bdp +// estimation. +func (t *http2Server) updateFlowControl(n uint32) { + t.mu.Lock() + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.initialWindowSize = int32(n) + t.mu.Unlock() + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: t.fc.newLimit(n), + }) + t.controlBuf.put(&outgoingSettings{ + ss: []http2.Setting{ + { + ID: http2.SettingInitialWindowSize, + Val: n, + }, + }, + }) + +} + +func (t *http2Server) handleData(f *http2.DataFrame) { + size := f.Header().Length + var sendBDPPing bool + if t.bdpEst != nil { + sendBDPPing = t.bdpEst.add(size) + } + // Decouple connection's flow control from application's read. + // An update on connection's flow control should not depend on + // whether user application has read the data or not. Such a + // restriction is already imposed on the stream's flow control, + // and therefore the sender will be blocked anyways. + // Decoupling the connection flow control will prevent other + // active(fast) streams from starving in presence of slow or + // inactive streams. + if w := t.fc.onData(size); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + if sendBDPPing { + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + t.controlBuf.put(bdpPing) + } + // Select the right stream to dispatch. + s, ok := t.getStream(f) + if !ok { + return + } + if size > 0 { + if err := s.fc.onData(size); err != nil { + t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false) + return + } + if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) + } + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + if len(f.Data()) > 0 { + data := make([]byte, len(f.Data())) + copy(data, f.Data()) + s.write(recvMsg{data: data}) + } + } + if f.Header().Flags.Has(http2.FlagDataEndStream) { + // Received the end of stream from the client. + s.compareAndSwapState(streamActive, streamReadDone) + s.write(recvMsg{err: io.EOF}) + } +} + +func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { + s, ok := t.getStream(f) + if !ok { + return + } + t.closeStream(s, false, 0, nil, false) +} + +func (t *http2Server) handleSettings(f *http2.SettingsFrame) { + if f.IsAck() { + return + } + var ss []http2.Setting + f.ForeachSetting(func(s http2.Setting) error { + ss = append(ss, s) + return nil + }) + t.controlBuf.put(&incomingSettings{ + ss: ss, + }) +} + +const ( + maxPingStrikes = 2 + defaultPingTimeout = 2 * time.Hour +) + +func (t *http2Server) handlePing(f *http2.PingFrame) { + if f.IsAck() { + if f.Data == goAwayPing.data && t.drainChan != nil { + close(t.drainChan) + return + } + // Maybe it's a BDP ping. + if t.bdpEst != nil { + t.bdpEst.calculate(f.Data) + } + return + } + pingAck := &ping{ack: true} + copy(pingAck.data[:], f.Data[:]) + t.controlBuf.put(pingAck) + + now := time.Now() + defer func() { + t.lastPingAt = now + }() + // A reset ping strikes means that we don't need to check for policy + // violation for this ping and the pingStrikes counter should be set + // to 0. + if atomic.CompareAndSwapUint32(&t.resetPingStrikes, 1, 0) { + t.pingStrikes = 0 + return + } + t.mu.Lock() + ns := len(t.activeStreams) + t.mu.Unlock() + if ns < 1 && !t.kep.PermitWithoutStream { + // Keepalive shouldn't be active thus, this new ping should + // have come after at least defaultPingTimeout. + if t.lastPingAt.Add(defaultPingTimeout).After(now) { + t.pingStrikes++ + } + } else { + // Check if keepalive policy is respected. + if t.lastPingAt.Add(t.kep.MinTime).After(now) { + t.pingStrikes++ + } + } + + if t.pingStrikes > maxPingStrikes { + // Send goaway and close the connection. + errorf("transport: Got too many pings from the client, closing the connection.") + t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings"), closeConn: true}) + } +} + +func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) +} + +// WriteHeader sends the header metedata md back to the client. +func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { + if s.headerOk || s.getState() == streamDone { + return ErrIllegalHeaderWrite + } + s.headerOk = true + if md.Len() > 0 { + if s.header.Len() > 0 { + s.header = metadata.Join(s.header, md) + } else { + s.header = md + } + } + md = s.header + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + if s.sendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + } + for k, vv := range md { + if isReservedHeader(k) { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + t.controlBuf.put(&headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + onWrite: func() { + atomic.StoreUint32(&t.resetPingStrikes, 1) + }, + wq: s.wq, + }) + if t.stats != nil { + // Note: WireLength is not set in outHeader. + // TODO(mmukhi): Revisit this later, if needed. + outHeader := &stats.OutHeader{} + t.stats.HandleRPC(s.Context(), outHeader) + } + return nil +} + +// WriteStatus sends stream status to the client and terminates the stream. +// There is no further I/O operations being able to perform on this stream. +// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early +// OK is adopted. +func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { + if !s.headerOk && s.header.Len() > 0 { + if err := t.WriteHeader(s, nil); err != nil { + return err + } + } else { + if s.getState() == streamDone { + return nil + } + } + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. + if !s.headerOk { + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + } + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) + + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + panic(err) + } + + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) + } + + // Attach the trailer metadata. + for k, vv := range s.trailer { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + trailer := &headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: true, + onWrite: func() { + atomic.StoreUint32(&t.resetPingStrikes, 1) + }, + } + t.closeStream(s, false, 0, trailer, true) + if t.stats != nil { + t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) + } + return nil +} + +// Write converts the data into HTTP2 data frame and sends it out. Non-nil error +// is returns if it fails (e.g., framing error, transport error). +func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + if !s.headerOk { // Headers haven't been written yet. + if err := t.WriteHeader(s, nil); err != nil { + // TODO(mmukhi, dfawley): Make sure this is the right code to return. + return streamErrorf(codes.Internal, "transport: %v", err) + } + } else { + // Writing headers checks for this condition. + if s.getState() == streamDone { + // TODO(mmukhi, dfawley): Should the server write also return io.EOF? + s.cancel() + select { + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) + } + } + // Add some data to header frame so that we can equally distribute bytes across frames. + emptyLen := http2MaxFrameLen - len(hdr) + if emptyLen > len(data) { + emptyLen = len(data) + } + hdr = append(hdr, data[:emptyLen]...) + data = data[emptyLen:] + df := &dataFrame{ + streamID: s.id, + h: hdr, + d: data, + onEachWrite: func() { + atomic.StoreUint32(&t.resetPingStrikes, 1) + }, + } + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + select { + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) + } + return t.controlBuf.put(df) +} + +// keepalive running in a separate goroutine does the following: +// 1. Gracefully closes an idle connection after a duration of keepalive.MaxConnectionIdle. +// 2. Gracefully closes any connection after a duration of keepalive.MaxConnectionAge. +// 3. Forcibly closes a connection after an additive period of keepalive.MaxConnectionAgeGrace over keepalive.MaxConnectionAge. +// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-responsive connection +// after an additional duration of keepalive.Timeout. +func (t *http2Server) keepalive() { + p := &ping{} + var pingSent bool + maxIdle := time.NewTimer(t.kp.MaxConnectionIdle) + maxAge := time.NewTimer(t.kp.MaxConnectionAge) + keepalive := time.NewTimer(t.kp.Time) + // NOTE: All exit paths of this function should reset their + // respective timers. A failure to do so will cause the + // following clean-up to deadlock and eventually leak. + defer func() { + if !maxIdle.Stop() { + <-maxIdle.C + } + if !maxAge.Stop() { + <-maxAge.C + } + if !keepalive.Stop() { + <-keepalive.C + } + }() + for { + select { + case <-maxIdle.C: + t.mu.Lock() + idle := t.idle + if idle.IsZero() { // The connection is non-idle. + t.mu.Unlock() + maxIdle.Reset(t.kp.MaxConnectionIdle) + continue + } + val := t.kp.MaxConnectionIdle - time.Since(idle) + t.mu.Unlock() + if val <= 0 { + // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. + // Gracefully close the connection. + t.drain(http2.ErrCodeNo, []byte{}) + // Resetting the timer so that the clean-up doesn't deadlock. + maxIdle.Reset(infinity) + return + } + maxIdle.Reset(val) + case <-maxAge.C: + t.drain(http2.ErrCodeNo, []byte{}) + maxAge.Reset(t.kp.MaxConnectionAgeGrace) + select { + case <-maxAge.C: + // Close the connection after grace period. + t.Close() + // Resetting the timer so that the clean-up doesn't deadlock. + maxAge.Reset(infinity) + case <-t.ctx.Done(): + } + return + case <-keepalive.C: + if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { + pingSent = false + keepalive.Reset(t.kp.Time) + continue + } + if pingSent { + t.Close() + // Resetting the timer so that the clean-up doesn't deadlock. + keepalive.Reset(infinity) + return + } + pingSent = true + if channelz.IsOn() { + t.czmu.Lock() + t.kpCount++ + t.czmu.Unlock() + } + t.controlBuf.put(p) + keepalive.Reset(t.kp.Timeout) + case <-t.ctx.Done(): + return + } + } +} + +// Close starts shutting down the http2Server transport. +// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This +// could cause some resource issue. Revisit this later. +func (t *http2Server) Close() error { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Close() was already called") + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + t.controlBuf.finish() + t.cancel() + err := t.conn.Close() + if channelz.IsOn() { + channelz.RemoveEntry(t.channelzID) + } + // Cancel all active streams. + for _, s := range streams { + s.cancel() + } + if t.stats != nil { + connEnd := &stats.ConnEnd{} + t.stats.HandleConn(t.ctx, connEnd) + } + return err +} + +// closeStream clears the footprint of a stream when the stream is not needed +// any more. +func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) { + if s.swapState(streamDone) == streamDone { + // If the stream was already done, return. + return + } + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel() + cleanup := &cleanupStream{ + streamID: s.id, + rst: rst, + rstCode: rstCode, + onWrite: func() { + t.mu.Lock() + if t.activeStreams != nil { + delete(t.activeStreams, s.id) + if len(t.activeStreams) == 0 { + t.idle = time.Now() + } + } + t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + if eosReceived { + t.streamsSucceeded++ + } else { + t.streamsFailed++ + } + t.czmu.Unlock() + } + }, + } + if hdr != nil { + hdr.cleanup = cleanup + t.controlBuf.put(hdr) + } else { + t.controlBuf.put(cleanup) + } +} + +func (t *http2Server) RemoteAddr() net.Addr { + return t.remoteAddr +} + +func (t *http2Server) Drain() { + t.drain(http2.ErrCodeNo, []byte{}) +} + +func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { + t.mu.Lock() + defer t.mu.Unlock() + if t.drainChan != nil { + return + } + t.drainChan = make(chan struct{}) + t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) +} + +var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} + +// Handles outgoing GoAway and returns true if loopy needs to put itself +// in draining mode. +func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { + t.mu.Lock() + if t.state == closing { // TODO(mmukhi): This seems unnecessary. + t.mu.Unlock() + // The transport is closing. + return false, ErrConnClosing + } + sid := t.maxStreamID + if !g.headsUp { + // Stop accepting more streams now. + t.state = draining + if len(t.activeStreams) == 0 { + g.closeConn = true + } + t.mu.Unlock() + if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil { + return false, err + } + if g.closeConn { + // Abruptly close the connection following the GoAway (via + // loopywriter). But flush out what's inside the buffer first. + t.framer.writer.Flush() + return false, fmt.Errorf("transport: Connection closing") + } + return true, nil + } + t.mu.Unlock() + // For a graceful close, send out a GoAway with stream ID of MaxUInt32, + // Follow that with a ping and wait for the ack to come back or a timer + // to expire. During this time accept new streams since they might have + // originated before the GoAway reaches the client. + // After getting the ack or timer expiration send out another GoAway this + // time with an ID of the max stream server intends to process. + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + return false, err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return false, err + } + go func() { + timer := time.NewTimer(time.Minute) + defer timer.Stop() + select { + case <-t.drainChan: + case <-timer.C: + case <-t.ctx.Done(): + return + } + t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData}) + }() + return false, nil +} + +func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric { + t.czmu.RLock() + s := channelz.SocketInternalMetric{ + StreamsStarted: t.streamsStarted, + StreamsSucceeded: t.streamsSucceeded, + StreamsFailed: t.streamsFailed, + MessagesSent: t.msgSent, + MessagesReceived: t.msgRecv, + KeepAlivesSent: t.kpCount, + LastRemoteStreamCreatedTimestamp: t.lastStreamCreated, + LastMessageSentTimestamp: t.lastMsgSent, + LastMessageReceivedTimestamp: t.lastMsgRecv, + LocalFlowControlWindow: int64(t.fc.getSize()), + //socket options + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // Security + // RemoteName : + } + t.czmu.RUnlock() + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s +} + +func (t *http2Server) IncrMsgSent() { + t.czmu.Lock() + t.msgSent++ + t.lastMsgSent = time.Now() + t.czmu.Unlock() +} + +func (t *http2Server) IncrMsgRecv() { + t.czmu.Lock() + t.msgRecv++ + t.lastMsgRecv = time.Now() + t.czmu.Unlock() +} + +func (t *http2Server) getOutFlowWindow() int64 { + resp := make(chan uint32) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.ctxDone: + return -1 + case <-timer.C: + return -2 + } +} + +var rgen = rand.New(rand.NewSource(time.Now().UnixNano())) + +func getJitter(v time.Duration) time.Duration { + if v == infinity { + return 0 + } + // Generate a jitter between +/- 10% of the value. + r := int64(v / 10) + j := rgen.Int63n(2*r) - r + return time.Duration(j) +} diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go new file mode 100644 index 000000000..a35586608 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/http_util.go @@ -0,0 +1,574 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bufio" + "bytes" + "encoding/base64" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + // http2MaxFrameLen specifies the max length of a HTTP2 frame. + http2MaxFrameLen = 16384 // 16KB frame + // http://http2.github.io/http2-spec/#SettingValues + http2InitHeaderTableSize = 4096 + // http2IOBufSize specifies the buffer size for sending frames. + defaultWriteBufSize = 32 * 1024 + defaultReadBufSize = 32 * 1024 + // baseContentType is the base content-type for gRPC. This is a valid + // content-type on it's own, but can also include a content-subtype such as + // "proto" as a suffix after "+" or ";". See + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + // for more details. + baseContentType = "application/grpc" +) + +var ( + clientPreface = []byte(http2.ClientPreface) + http2ErrConvTab = map[http2.ErrCode]codes.Code{ + http2.ErrCodeNo: codes.Internal, + http2.ErrCodeProtocol: codes.Internal, + http2.ErrCodeInternal: codes.Internal, + http2.ErrCodeFlowControl: codes.ResourceExhausted, + http2.ErrCodeSettingsTimeout: codes.Internal, + http2.ErrCodeStreamClosed: codes.Internal, + http2.ErrCodeFrameSize: codes.Internal, + http2.ErrCodeRefusedStream: codes.Unavailable, + http2.ErrCodeCancel: codes.Canceled, + http2.ErrCodeCompression: codes.Internal, + http2.ErrCodeConnect: codes.Internal, + http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, + http2.ErrCodeInadequateSecurity: codes.PermissionDenied, + http2.ErrCodeHTTP11Required: codes.Internal, + } + statusCodeConvTab = map[codes.Code]http2.ErrCode{ + codes.Internal: http2.ErrCodeInternal, + codes.Canceled: http2.ErrCodeCancel, + codes.Unavailable: http2.ErrCodeRefusedStream, + codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, + codes.PermissionDenied: http2.ErrCodeInadequateSecurity, + } + httpStatusConvTab = map[int]codes.Code{ + // 400 Bad Request - INTERNAL. + http.StatusBadRequest: codes.Internal, + // 401 Unauthorized - UNAUTHENTICATED. + http.StatusUnauthorized: codes.Unauthenticated, + // 403 Forbidden - PERMISSION_DENIED. + http.StatusForbidden: codes.PermissionDenied, + // 404 Not Found - UNIMPLEMENTED. + http.StatusNotFound: codes.Unimplemented, + // 429 Too Many Requests - UNAVAILABLE. + http.StatusTooManyRequests: codes.Unavailable, + // 502 Bad Gateway - UNAVAILABLE. + http.StatusBadGateway: codes.Unavailable, + // 503 Service Unavailable - UNAVAILABLE. + http.StatusServiceUnavailable: codes.Unavailable, + // 504 Gateway timeout - UNAVAILABLE. + http.StatusGatewayTimeout: codes.Unavailable, + } +) + +// Records the states during HPACK decoding. Must be reset once the +// decoding of the entire headers are finished. +type decodeState struct { + encoding string + // statusGen caches the stream status received from the trailer the server + // sent. Client side only. Do not access directly. After all trailers are + // parsed, use the status method to retrieve the status. + statusGen *status.Status + // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not + // intended for direct access outside of parsing. + rawStatusCode *int + rawStatusMsg string + httpStatus *int + // Server side only fields. + timeoutSet bool + timeout time.Duration + method string + // key-value metadata map from the peer. + mdata map[string][]string + statsTags []byte + statsTrace []byte + contentSubtype string +} + +// isReservedHeader checks whether hdr belongs to HTTP2 headers +// reserved by gRPC protocol. Any other headers are classified as the +// user-specified metadata. +func isReservedHeader(hdr string) bool { + if hdr != "" && hdr[0] == ':' { + return true + } + switch hdr { + case "content-type", + "user-agent", + "grpc-message-type", + "grpc-encoding", + "grpc-message", + "grpc-status", + "grpc-timeout", + "grpc-status-details-bin", + "te": + return true + default: + return false + } +} + +// isWhitelistedHeader checks whether hdr should be propagated +// into metadata visible to users. +func isWhitelistedHeader(hdr string) bool { + switch hdr { + case ":authority", "user-agent": + return true + default: + return false + } +} + +// contentSubtype returns the content-subtype for the given content-type. The +// given content-type must be a valid content-type that starts with +// "application/grpc". A content-subtype will follow "application/grpc" after a +// "+" or ";". See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +// +// If contentType is not a valid content-type for gRPC, the boolean +// will be false, otherwise true. If content-type == "application/grpc", +// "application/grpc+", or "application/grpc;", the boolean will be true, +// but no content-subtype will be returned. +// +// contentType is assumed to be lowercase already. +func contentSubtype(contentType string) (string, bool) { + if contentType == baseContentType { + return "", true + } + if !strings.HasPrefix(contentType, baseContentType) { + return "", false + } + // guaranteed since != baseContentType and has baseContentType prefix + switch contentType[len(baseContentType)] { + case '+', ';': + // this will return true for "application/grpc+" or "application/grpc;" + // which the previous validContentType function tested to be valid, so we + // just say that no content-subtype is specified in this case + return contentType[len(baseContentType)+1:], true + default: + return "", false + } +} + +// contentSubtype is assumed to be lowercase +func contentType(contentSubtype string) string { + if contentSubtype == "" { + return baseContentType + } + return baseContentType + "+" + contentSubtype +} + +func (d *decodeState) status() *status.Status { + if d.statusGen == nil { + // No status-details were provided; generate status using code/msg. + d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg) + } + return d.statusGen +} + +const binHdrSuffix = "-bin" + +func encodeBinHeader(v []byte) string { + return base64.RawStdEncoding.EncodeToString(v) +} + +func decodeBinHeader(v string) ([]byte, error) { + if len(v)%4 == 0 { + // Input was padded, or padding was not necessary. + return base64.StdEncoding.DecodeString(v) + } + return base64.RawStdEncoding.DecodeString(v) +} + +func encodeMetadataHeader(k, v string) string { + if strings.HasSuffix(k, binHdrSuffix) { + return encodeBinHeader(([]byte)(v)) + } + return v +} + +func decodeMetadataHeader(k, v string) (string, error) { + if strings.HasSuffix(k, binHdrSuffix) { + b, err := decodeBinHeader(v) + return string(b), err + } + return v, nil +} + +func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error { + for _, hf := range frame.Fields { + if err := d.processHeaderField(hf); err != nil { + return err + } + } + + // If grpc status exists, no need to check further. + if d.rawStatusCode != nil || d.statusGen != nil { + return nil + } + + // If grpc status doesn't exist and http status doesn't exist, + // then it's a malformed header. + if d.httpStatus == nil { + return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)") + } + + if *(d.httpStatus) != http.StatusOK { + code, ok := httpStatusConvTab[*(d.httpStatus)] + if !ok { + code = codes.Unknown + } + return streamErrorf(code, http.StatusText(*(d.httpStatus))) + } + + // gRPC status doesn't exist and http status is OK. + // Set rawStatusCode to be unknown and return nil error. + // So that, if the stream has ended this Unknown status + // will be propagated to the user. + // Otherwise, it will be ignored. In which case, status from + // a later trailer, that has StreamEnded flag set, is propagated. + code := int(codes.Unknown) + d.rawStatusCode = &code + return nil + +} + +func (d *decodeState) addMetadata(k, v string) { + if d.mdata == nil { + d.mdata = make(map[string][]string) + } + d.mdata[k] = append(d.mdata[k], v) +} + +func (d *decodeState) processHeaderField(f hpack.HeaderField) error { + switch f.Name { + case "content-type": + contentSubtype, validContentType := contentSubtype(f.Value) + if !validContentType { + return streamErrorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value) + } + d.contentSubtype = contentSubtype + // TODO: do we want to propagate the whole content-type in the metadata, + // or come up with a way to just propagate the content-subtype if it was set? + // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} + // in the metadata? + d.addMetadata(f.Name, f.Value) + case "grpc-encoding": + d.encoding = f.Value + case "grpc-status": + code, err := strconv.Atoi(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) + } + d.rawStatusCode = &code + case "grpc-message": + d.rawStatusMsg = decodeGrpcMessage(f.Value) + case "grpc-status-details-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + } + s := &spb.Status{} + if err := proto.Unmarshal(v, s); err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + } + d.statusGen = status.FromProto(s) + case "grpc-timeout": + d.timeoutSet = true + var err error + if d.timeout, err = decodeTimeout(f.Value); err != nil { + return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err) + } + case ":path": + d.method = f.Value + case ":status": + code, err := strconv.Atoi(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err) + } + d.httpStatus = &code + case "grpc-tags-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + } + d.statsTags = v + d.addMetadata(f.Name, string(v)) + case "grpc-trace-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + } + d.statsTrace = v + d.addMetadata(f.Name, string(v)) + default: + if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { + break + } + v, err := decodeMetadataHeader(f.Name, f.Value) + if err != nil { + errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) + return nil + } + d.addMetadata(f.Name, v) + } + return nil +} + +type timeoutUnit uint8 + +const ( + hour timeoutUnit = 'H' + minute timeoutUnit = 'M' + second timeoutUnit = 'S' + millisecond timeoutUnit = 'm' + microsecond timeoutUnit = 'u' + nanosecond timeoutUnit = 'n' +) + +func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) { + switch u { + case hour: + return time.Hour, true + case minute: + return time.Minute, true + case second: + return time.Second, true + case millisecond: + return time.Millisecond, true + case microsecond: + return time.Microsecond, true + case nanosecond: + return time.Nanosecond, true + default: + } + return +} + +const maxTimeoutValue int64 = 100000000 - 1 + +// div does integer division and round-up the result. Note that this is +// equivalent to (d+r-1)/r but has less chance to overflow. +func div(d, r time.Duration) int64 { + if m := d % r; m > 0 { + return int64(d/r + 1) + } + return int64(d / r) +} + +// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. +func encodeTimeout(t time.Duration) string { + if t <= 0 { + return "0n" + } + if d := div(t, time.Nanosecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "n" + } + if d := div(t, time.Microsecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "u" + } + if d := div(t, time.Millisecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "m" + } + if d := div(t, time.Second); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "S" + } + if d := div(t, time.Minute); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "M" + } + // Note that maxTimeoutValue * time.Hour > MaxInt64. + return strconv.FormatInt(div(t, time.Hour), 10) + "H" +} + +func decodeTimeout(s string) (time.Duration, error) { + size := len(s) + if size < 2 { + return 0, fmt.Errorf("transport: timeout string is too short: %q", s) + } + unit := timeoutUnit(s[size-1]) + d, ok := timeoutUnitToDuration(unit) + if !ok { + return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) + } + t, err := strconv.ParseInt(s[:size-1], 10, 64) + if err != nil { + return 0, err + } + return d * time.Duration(t), nil +} + +const ( + spaceByte = ' ' + tildaByte = '~' + percentByte = '%' +) + +// encodeGrpcMessage is used to encode status code in header field +// "grpc-message". +// It checks to see if each individual byte in msg is an +// allowable byte, and then either percent encoding or passing it through. +// When percent encoding, the byte is converted into hexadecimal notation +// with a '%' prepended. +func encodeGrpcMessage(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if !(c >= spaceByte && c < tildaByte && c != percentByte) { + return encodeGrpcMessageUnchecked(msg) + } + } + return msg +} + +func encodeGrpcMessageUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c >= spaceByte && c < tildaByte && c != percentByte { + buf.WriteByte(c) + } else { + buf.WriteString(fmt.Sprintf("%%%02X", c)) + } + } + return buf.String() +} + +// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage. +func decodeGrpcMessage(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + if msg[i] == percentByte && i+2 < lenMsg { + return decodeGrpcMessageUnchecked(msg) + } + } + return msg +} + +func decodeGrpcMessageUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c == percentByte && i+2 < lenMsg { + parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8) + if err != nil { + buf.WriteByte(c) + } else { + buf.WriteByte(byte(parsed)) + i += 2 + } + } else { + buf.WriteByte(c) + } + } + return buf.String() +} + +type bufWriter struct { + buf []byte + offset int + batchSize int + conn net.Conn + err error + + onFlush func() +} + +func newBufWriter(conn net.Conn, batchSize int) *bufWriter { + return &bufWriter{ + buf: make([]byte, batchSize*2), + batchSize: batchSize, + conn: conn, + } +} + +func (w *bufWriter) Write(b []byte) (n int, err error) { + if w.err != nil { + return 0, w.err + } + n = copy(w.buf[w.offset:], b) + w.offset += n + if w.offset >= w.batchSize { + err = w.Flush() + } + return n, err +} + +func (w *bufWriter) Flush() error { + if w.err != nil { + return w.err + } + if w.offset == 0 { + return nil + } + if w.onFlush != nil { + w.onFlush() + } + _, w.err = w.conn.Write(w.buf[:w.offset]) + w.offset = 0 + return w.err +} + +type framer struct { + writer *bufWriter + fr *http2.Framer +} + +func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer { + r := bufio.NewReaderSize(conn, readBufferSize) + w := newBufWriter(conn, writeBufferSize) + f := &framer{ + writer: w, + fr: http2.NewFramer(w, r), + } + // Opt-in to Frame reuse API on framer to reduce garbage. + // Frames aren't safe to read from after a subsequent call to ReadFrame. + f.fr.SetReuseFrames() + f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) + return f +} diff --git a/vendor/google.golang.org/grpc/transport/log.go b/vendor/google.golang.org/grpc/transport/log.go new file mode 100644 index 000000000..ac8e358c5 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/log.go @@ -0,0 +1,50 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This file contains wrappers for grpclog functions. +// The transport package only logs to verbose level 2 by default. + +package transport + +import "google.golang.org/grpc/grpclog" + +const logLevel = 2 + +func infof(format string, args ...interface{}) { + if grpclog.V(logLevel) { + grpclog.Infof(format, args...) + } +} + +func warningf(format string, args ...interface{}) { + if grpclog.V(logLevel) { + grpclog.Warningf(format, args...) + } +} + +func errorf(format string, args ...interface{}) { + if grpclog.V(logLevel) { + grpclog.Errorf(format, args...) + } +} + +func fatalf(format string, args ...interface{}) { + if grpclog.V(logLevel) { + grpclog.Fatalf(format, args...) + } +} diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go new file mode 100644 index 000000000..2f643a3d0 --- /dev/null +++ b/vendor/google.golang.org/grpc/transport/transport.go @@ -0,0 +1,683 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package transport defines and implements message oriented communication +// channel to complete various transactions (e.g., an RPC). It is meant for +// grpc-internal usage and is not intended to be imported directly by users. +package transport // externally used as import "google.golang.org/grpc/transport" + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" +) + +// recvMsg represents the received msg from the transport. All transport +// protocol specific info has been removed. +type recvMsg struct { + data []byte + // nil: received some data + // io.EOF: stream is completed. data is nil. + // other non-nil error: transport failure. data is nil. + err error +} + +// recvBuffer is an unbounded channel of recvMsg structs. +// Note recvBuffer differs from controlBuffer only in that recvBuffer +// holds a channel of only recvMsg structs instead of objects implementing "item" interface. +// recvBuffer is written to much more often than +// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put" +type recvBuffer struct { + c chan recvMsg + mu sync.Mutex + backlog []recvMsg + err error +} + +func newRecvBuffer() *recvBuffer { + b := &recvBuffer{ + c: make(chan recvMsg, 1), + } + return b +} + +func (b *recvBuffer) put(r recvMsg) { + b.mu.Lock() + if b.err != nil { + b.mu.Unlock() + // An error had occurred earlier, don't accept more + // data or errors. + return + } + b.err = r.err + if len(b.backlog) == 0 { + select { + case b.c <- r: + b.mu.Unlock() + return + default: + } + } + b.backlog = append(b.backlog, r) + b.mu.Unlock() +} + +func (b *recvBuffer) load() { + b.mu.Lock() + if len(b.backlog) > 0 { + select { + case b.c <- b.backlog[0]: + b.backlog[0] = recvMsg{} + b.backlog = b.backlog[1:] + default: + } + } + b.mu.Unlock() +} + +// get returns the channel that receives a recvMsg in the buffer. +// +// Upon receipt of a recvMsg, the caller should call load to send another +// recvMsg onto the channel if there is any. +func (b *recvBuffer) get() <-chan recvMsg { + return b.c +} + +// +// recvBufferReader implements io.Reader interface to read the data from +// recvBuffer. +type recvBufferReader struct { + ctx context.Context + ctxDone <-chan struct{} // cache of ctx.Done() (for performance). + recv *recvBuffer + last []byte // Stores the remaining data in the previous calls. + err error +} + +// Read reads the next len(p) bytes from last. If last is drained, it tries to +// read additional data from recv. It blocks if there no additional data available +// in recv. If Read returns any non-nil error, it will continue to return that error. +func (r *recvBufferReader) Read(p []byte) (n int, err error) { + if r.err != nil { + return 0, r.err + } + n, r.err = r.read(p) + return n, r.err +} + +func (r *recvBufferReader) read(p []byte) (n int, err error) { + if r.last != nil && len(r.last) > 0 { + // Read remaining data left in last call. + copied := copy(p, r.last) + r.last = r.last[copied:] + return copied, nil + } + select { + case <-r.ctxDone: + return 0, ContextErr(r.ctx.Err()) + case m := <-r.recv.get(): + r.recv.load() + if m.err != nil { + return 0, m.err + } + copied := copy(p, m.data) + r.last = m.data[copied:] + return copied, nil + } +} + +type streamState uint32 + +const ( + streamActive streamState = iota + streamWriteDone // EndStream sent + streamReadDone // EndStream received + streamDone // the entire stream is finished. +) + +// Stream represents an RPC in the transport layer. +type Stream struct { + id uint32 + st ServerTransport // nil for client side Stream + ctx context.Context // the associated context of the stream + cancel context.CancelFunc // always nil for client side Stream + done chan struct{} // closed at the end of stream to unblock writers. On the client side. + ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) + method string // the associated RPC method of the stream + recvCompress string + sendCompress string + buf *recvBuffer + trReader io.Reader + fc *inFlow + recvQuota uint32 + wq *writeQuota + + // Callback to state application's intentions to read data. This + // is used to adjust flow control, if needed. + requestRead func(int) + + headerChan chan struct{} // closed to indicate the end of header metadata. + headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + header metadata.MD // the received header metadata. + trailer metadata.MD // the key-value map of trailer metadata. + + headerOk bool // becomes true from the first header is about to send + state streamState + + status *status.Status // the status error received from the server + + bytesReceived uint32 // indicates whether any bytes have been received on this stream + unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream + + // contentSubtype is the content-subtype for requests. + // this must be lowercase or the behavior is undefined. + contentSubtype string +} + +func (s *Stream) swapState(st streamState) streamState { + return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) +} + +func (s *Stream) compareAndSwapState(oldState, newState streamState) bool { + return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState)) +} + +func (s *Stream) getState() streamState { + return streamState(atomic.LoadUint32((*uint32)(&s.state))) +} + +func (s *Stream) waitOnHeader() error { + if s.headerChan == nil { + // On the server headerChan is always nil since a stream originates + // only after having received headers. + return nil + } + select { + case <-s.ctx.Done(): + return ContextErr(s.ctx.Err()) + case <-s.headerChan: + return nil + } +} + +// RecvCompress returns the compression algorithm applied to the inbound +// message. It is empty string if there is no compression applied. +func (s *Stream) RecvCompress() string { + if err := s.waitOnHeader(); err != nil { + return "" + } + return s.recvCompress +} + +// SetSendCompress sets the compression algorithm to the stream. +func (s *Stream) SetSendCompress(str string) { + s.sendCompress = str +} + +// Done returns a chanel which is closed when it receives the final status +// from the server. +func (s *Stream) Done() <-chan struct{} { + return s.done +} + +// Header acquires the key-value pairs of header metadata once it +// is available. It blocks until i) the metadata is ready or ii) there is no +// header metadata or iii) the stream is canceled/expired. +func (s *Stream) Header() (metadata.MD, error) { + err := s.waitOnHeader() + // Even if the stream is closed, header is returned if available. + select { + case <-s.headerChan: + if s.header == nil { + return nil, nil + } + return s.header.Copy(), nil + default: + } + return nil, err +} + +// Trailer returns the cached trailer metedata. Note that if it is not called +// after the entire stream is done, it could return an empty MD. Client +// side only. +// It can be safely read only after stream has ended that is either read +// or write have returned io.EOF. +func (s *Stream) Trailer() metadata.MD { + c := s.trailer.Copy() + return c +} + +// ServerTransport returns the underlying ServerTransport for the stream. +// The client side stream always returns nil. +func (s *Stream) ServerTransport() ServerTransport { + return s.st +} + +// ContentSubtype returns the content-subtype for a request. For example, a +// content-subtype of "proto" will result in a content-type of +// "application/grpc+proto". This will always be lowercase. See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +func (s *Stream) ContentSubtype() string { + return s.contentSubtype +} + +// Context returns the context of the stream. +func (s *Stream) Context() context.Context { + return s.ctx +} + +// Method returns the method for the stream. +func (s *Stream) Method() string { + return s.method +} + +// Status returns the status received from the server. +// Status can be read safely only after the stream has ended, +// that is, read or write has returned io.EOF. +func (s *Stream) Status() *status.Status { + return s.status +} + +// SetHeader sets the header metadata. This can be called multiple times. +// Server side only. +// This should not be called in parallel to other data writes. +func (s *Stream) SetHeader(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) { + return ErrIllegalHeaderWrite + } + s.header = metadata.Join(s.header, md) + return nil +} + +// SendHeader sends the given header metadata. The given metadata is +// combined with any metadata set by previous calls to SetHeader and +// then written to the transport stream. +func (s *Stream) SendHeader(md metadata.MD) error { + t := s.ServerTransport() + return t.WriteHeader(s, md) +} + +// SetTrailer sets the trailer metadata which will be sent with the RPC status +// by the server. This can be called multiple times. Server side only. +// This should not be called parallel to other data writes. +func (s *Stream) SetTrailer(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + s.trailer = metadata.Join(s.trailer, md) + return nil +} + +func (s *Stream) write(m recvMsg) { + s.buf.put(m) +} + +// Read reads all p bytes from the wire for this stream. +func (s *Stream) Read(p []byte) (n int, err error) { + // Don't request a read if there was an error earlier + if er := s.trReader.(*transportReader).er; er != nil { + return 0, er + } + s.requestRead(len(p)) + return io.ReadFull(s.trReader, p) +} + +// tranportReader reads all the data available for this Stream from the transport and +// passes them into the decoder, which converts them into a gRPC message stream. +// The error is io.EOF when the stream is done or another non-nil error if +// the stream broke. +type transportReader struct { + reader io.Reader + // The handler to control the window update procedure for both this + // particular stream and the associated transport. + windowHandler func(int) + er error +} + +func (t *transportReader) Read(p []byte) (n int, err error) { + n, err = t.reader.Read(p) + if err != nil { + t.er = err + return + } + t.windowHandler(n) + return +} + +// BytesReceived indicates whether any bytes have been received on this stream. +func (s *Stream) BytesReceived() bool { + return atomic.LoadUint32(&s.bytesReceived) == 1 +} + +// Unprocessed indicates whether the server did not process this stream -- +// i.e. it sent a refused stream or GOAWAY including this stream ID. +func (s *Stream) Unprocessed() bool { + return atomic.LoadUint32(&s.unprocessed) == 1 +} + +// GoString is implemented by Stream so context.String() won't +// race when printing %#v. +func (s *Stream) GoString() string { + return fmt.Sprintf("<stream: %p, %v>", s, s.method) +} + +// state of transport +type transportState int + +const ( + reachable transportState = iota + closing + draining +) + +// ServerConfig consists of all the configurations to establish a server transport. +type ServerConfig struct { + MaxStreams uint32 + AuthInfo credentials.AuthInfo + InTapHandle tap.ServerInHandle + StatsHandler stats.Handler + KeepaliveParams keepalive.ServerParameters + KeepalivePolicy keepalive.EnforcementPolicy + InitialWindowSize int32 + InitialConnWindowSize int32 + WriteBufferSize int + ReadBufferSize int + ChannelzParentID int64 +} + +// NewServerTransport creates a ServerTransport with conn or non-nil error +// if it fails. +func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) { + return newHTTP2Server(conn, config) +} + +// ConnectOptions covers all relevant options for communicating with the server. +type ConnectOptions struct { + // UserAgent is the application user agent. + UserAgent string + // Authority is the :authority pseudo-header to use. This field has no effect if + // TransportCredentials is set. + Authority string + // Dialer specifies how to dial a network address. + Dialer func(context.Context, string) (net.Conn, error) + // FailOnNonTempDialError specifies if gRPC fails on non-temporary dial errors. + FailOnNonTempDialError bool + // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. + PerRPCCredentials []credentials.PerRPCCredentials + // TransportCredentials stores the Authenticator required to setup a client connection. + TransportCredentials credentials.TransportCredentials + // KeepaliveParams stores the keepalive parameters. + KeepaliveParams keepalive.ClientParameters + // StatsHandler stores the handler for stats. + StatsHandler stats.Handler + // InitialWindowSize sets the initial window size for a stream. + InitialWindowSize int32 + // InitialConnWindowSize sets the initial window size for a connection. + InitialConnWindowSize int32 + // WriteBufferSize sets the size of write buffer which in turn determines how much data can be batched before it's written on the wire. + WriteBufferSize int + // ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall. + ReadBufferSize int + // ChannelzParentID sets the addrConn id which initiate the creation of this client transport. + ChannelzParentID int64 +} + +// TargetInfo contains the information of the target such as network address and metadata. +type TargetInfo struct { + Addr string + Metadata interface{} + Authority string +} + +// NewClientTransport establishes the transport with the required ConnectOptions +// and returns it to the caller. +func NewClientTransport(connectCtx, ctx context.Context, target TargetInfo, opts ConnectOptions, onSuccess func()) (ClientTransport, error) { + return newHTTP2Client(connectCtx, ctx, target, opts, onSuccess) +} + +// Options provides additional hints and information for message +// transmission. +type Options struct { + // Last indicates whether this write is the last piece for + // this stream. + Last bool + + // Delay is a hint to the transport implementation for whether + // the data could be buffered for a batching write. The + // transport implementation may ignore the hint. + Delay bool +} + +// CallHdr carries the information of a particular RPC. +type CallHdr struct { + // Host specifies the peer's host. + Host string + + // Method specifies the operation to perform. + Method string + + // SendCompress specifies the compression algorithm applied on + // outbound message. + SendCompress string + + // Creds specifies credentials.PerRPCCredentials for a call. + Creds credentials.PerRPCCredentials + + // Flush indicates whether a new stream command should be sent + // to the peer without waiting for the first data. This is + // only a hint. + // If it's true, the transport may modify the flush decision + // for performance purposes. + // If it's false, new stream will never be flushed. + Flush bool + + // ContentSubtype specifies the content-subtype for a request. For example, a + // content-subtype of "proto" will result in a content-type of + // "application/grpc+proto". The value of ContentSubtype must be all + // lowercase, otherwise the behavior is undefined. See + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + // for more details. + ContentSubtype string +} + +// ClientTransport is the common interface for all gRPC client-side transport +// implementations. +type ClientTransport interface { + // Close tears down this transport. Once it returns, the transport + // should not be accessed any more. The caller must make sure this + // is called only once. + Close() error + + // GracefulClose starts to tear down the transport. It stops accepting + // new RPCs and wait the completion of the pending RPCs. + GracefulClose() error + + // Write sends the data for the given stream. A nil stream indicates + // the write is to be performed on the transport as a whole. + Write(s *Stream, hdr []byte, data []byte, opts *Options) error + + // NewStream creates a Stream for an RPC. + NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) + + // CloseStream clears the footprint of a stream when the stream is + // not needed any more. The err indicates the error incurred when + // CloseStream is called. Must be called when a stream is finished + // unless the associated transport is closing. + CloseStream(stream *Stream, err error) + + // Error returns a channel that is closed when some I/O error + // happens. Typically the caller should have a goroutine to monitor + // this in order to take action (e.g., close the current transport + // and create a new one) in error case. It should not return nil + // once the transport is initiated. + Error() <-chan struct{} + + // GoAway returns a channel that is closed when ClientTransport + // receives the draining signal from the server (e.g., GOAWAY frame in + // HTTP/2). + GoAway() <-chan struct{} + + // GetGoAwayReason returns the reason why GoAway frame was received. + GetGoAwayReason() GoAwayReason + + // IncrMsgSent increments the number of message sent through this transport. + IncrMsgSent() + + // IncrMsgRecv increments the number of message received through this transport. + IncrMsgRecv() +} + +// ServerTransport is the common interface for all gRPC server-side transport +// implementations. +// +// Methods may be called concurrently from multiple goroutines, but +// Write methods for a given Stream will be called serially. +type ServerTransport interface { + // HandleStreams receives incoming streams using the given handler. + HandleStreams(func(*Stream), func(context.Context, string) context.Context) + + // WriteHeader sends the header metadata for the given stream. + // WriteHeader may not be called on all streams. + WriteHeader(s *Stream, md metadata.MD) error + + // Write sends the data for the given stream. + // Write may not be called on all streams. + Write(s *Stream, hdr []byte, data []byte, opts *Options) error + + // WriteStatus sends the status of a stream to the client. WriteStatus is + // the final call made on a stream and always occurs. + WriteStatus(s *Stream, st *status.Status) error + + // Close tears down the transport. Once it is called, the transport + // should not be accessed any more. All the pending streams and their + // handlers will be terminated asynchronously. + Close() error + + // RemoteAddr returns the remote network address. + RemoteAddr() net.Addr + + // Drain notifies the client this ServerTransport stops accepting new RPCs. + Drain() + + // IncrMsgSent increments the number of message sent through this transport. + IncrMsgSent() + + // IncrMsgRecv increments the number of message received through this transport. + IncrMsgRecv() +} + +// streamErrorf creates an StreamError with the specified error code and description. +func streamErrorf(c codes.Code, format string, a ...interface{}) StreamError { + return StreamError{ + Code: c, + Desc: fmt.Sprintf(format, a...), + } +} + +// connectionErrorf creates an ConnectionError with the specified error description. +func connectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { + return ConnectionError{ + Desc: fmt.Sprintf(format, a...), + temp: temp, + err: e, + } +} + +// ConnectionError is an error that results in the termination of the +// entire connection and the retry of all the active streams. +type ConnectionError struct { + Desc string + temp bool + err error +} + +func (e ConnectionError) Error() string { + return fmt.Sprintf("connection error: desc = %q", e.Desc) +} + +// Temporary indicates if this connection error is temporary or fatal. +func (e ConnectionError) Temporary() bool { + return e.temp +} + +// Origin returns the original error of this connection error. +func (e ConnectionError) Origin() error { + // Never return nil error here. + // If the original error is nil, return itself. + if e.err == nil { + return e + } + return e.err +} + +var ( + // ErrConnClosing indicates that the transport is closing. + ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + // errStreamDrain indicates that the stream is rejected because the + // connection is draining. This could be caused by goaway or balancer + // removing the address. + errStreamDrain = streamErrorf(codes.Unavailable, "the connection is draining") + // errStreamDone is returned from write at the client side to indiacte application + // layer of an error. + errStreamDone = errors.New("the stream is done") + // StatusGoAway indicates that the server sent a GOAWAY that included this + // stream's ID in unprocessed RPCs. + statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection") +) + +// TODO: See if we can replace StreamError with status package errors. + +// StreamError is an error that only affects one stream within a connection. +type StreamError struct { + Code codes.Code + Desc string +} + +func (e StreamError) Error() string { + return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) +} + +// GoAwayReason contains the reason for the GoAway frame received. +type GoAwayReason uint8 + +const ( + // GoAwayInvalid indicates that no GoAway frame is received. + GoAwayInvalid GoAwayReason = 0 + // GoAwayNoReason is the default value when GoAway frame is received. + GoAwayNoReason GoAwayReason = 1 + // GoAwayTooManyPings indicates that a GoAway frame with + // ErrCodeEnhanceYourCalm was received and that the debug data said + // "too_many_pings". + GoAwayTooManyPings GoAwayReason = 2 +) |