summaryrefslogtreecommitdiffstats
path: root/vendor/google.golang.org/grpc/transport/http_util.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/transport/http_util.go')
-rw-r--r--vendor/google.golang.org/grpc/transport/http_util.go54
1 files changed, 36 insertions, 18 deletions
diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go
index a35586608..7d15c7d74 100644
--- a/vendor/google.golang.org/grpc/transport/http_util.go
+++ b/vendor/google.golang.org/grpc/transport/http_util.go
@@ -28,6 +28,7 @@ import (
"strconv"
"strings"
"time"
+ "unicode/utf8"
"github.com/golang/protobuf/proto"
"golang.org/x/net/http2"
@@ -437,16 +438,17 @@ func decodeTimeout(s string) (time.Duration, error) {
const (
spaceByte = ' '
- tildaByte = '~'
+ tildeByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
-// "grpc-message".
-// It checks to see if each individual byte in msg is an
-// allowable byte, and then either percent encoding or passing it through.
-// When percent encoding, the byte is converted into hexadecimal notation
-// with a '%' prepended.
+// "grpc-message". It does percent encoding and also replaces invalid utf-8
+// characters with Unicode replacement character.
+//
+// It checks to see if each individual byte in msg is an allowable byte, and
+// then either percent encoding or passing it through. When percent encoding,
+// the byte is converted into hexadecimal notation with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
@@ -454,7 +456,7 @@ func encodeGrpcMessage(msg string) string {
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
- if !(c >= spaceByte && c < tildaByte && c != percentByte) {
+ if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
@@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string {
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
- lenMsg := len(msg)
- for i := 0; i < lenMsg; i++ {
- c := msg[i]
- if c >= spaceByte && c < tildaByte && c != percentByte {
- buf.WriteByte(c)
- } else {
- buf.WriteString(fmt.Sprintf("%%%02X", c))
+ for len(msg) > 0 {
+ r, size := utf8.DecodeRuneInString(msg)
+ for _, b := range []byte(string(r)) {
+ if size > 1 {
+ // If size > 1, r is not ascii. Always do percent encoding.
+ buf.WriteString(fmt.Sprintf("%%%02X", b))
+ continue
+ }
+
+ // The for loop is necessary even if size == 1. r could be
+ // utf8.RuneError.
+ //
+ // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
+ if b >= spaceByte && b <= tildeByte && b != percentByte {
+ buf.WriteByte(b)
+ } else {
+ buf.WriteString(fmt.Sprintf("%%%02X", b))
+ }
}
+ msg = msg[size:]
}
return buf.String()
}
@@ -531,10 +545,14 @@ func (w *bufWriter) Write(b []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
- n = copy(w.buf[w.offset:], b)
- w.offset += n
- if w.offset >= w.batchSize {
- err = w.Flush()
+ for len(b) > 0 {
+ nn := copy(w.buf[w.offset:], b)
+ b = b[nn:]
+ w.offset += nn
+ n += nn
+ if w.offset >= w.batchSize {
+ err = w.Flush()
+ }
}
return n, err
}