1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package transport
20
21import (
22	"bufio"
23	"bytes"
24	"encoding/base64"
25	"fmt"
26	"net"
27	"net/http"
28	"strconv"
29	"strings"
30	"time"
31
32	"github.com/golang/protobuf/proto"
33	"golang.org/x/net/http2"
34	"golang.org/x/net/http2/hpack"
35	spb "google.golang.org/genproto/googleapis/rpc/status"
36	"google.golang.org/grpc/codes"
37	"google.golang.org/grpc/status"
38)
39
40const (
41	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
42	http2MaxFrameLen = 16384 // 16KB frame
43	// http://http2.github.io/http2-spec/#SettingValues
44	http2InitHeaderTableSize = 4096
45	// http2IOBufSize specifies the buffer size for sending frames.
46	defaultWriteBufSize = 32 * 1024
47	defaultReadBufSize  = 32 * 1024
48	// baseContentType is the base content-type for gRPC.  This is a valid
49	// content-type on it's own, but can also include a content-subtype such as
50	// "proto" as a suffix after "+" or ";".  See
51	// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
52	// for more details.
53	baseContentType = "application/grpc"
54)
55
56var (
57	clientPreface   = []byte(http2.ClientPreface)
58	http2ErrConvTab = map[http2.ErrCode]codes.Code{
59		http2.ErrCodeNo:                 codes.Internal,
60		http2.ErrCodeProtocol:           codes.Internal,
61		http2.ErrCodeInternal:           codes.Internal,
62		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
63		http2.ErrCodeSettingsTimeout:    codes.Internal,
64		http2.ErrCodeStreamClosed:       codes.Internal,
65		http2.ErrCodeFrameSize:          codes.Internal,
66		http2.ErrCodeRefusedStream:      codes.Unavailable,
67		http2.ErrCodeCancel:             codes.Canceled,
68		http2.ErrCodeCompression:        codes.Internal,
69		http2.ErrCodeConnect:            codes.Internal,
70		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
71		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
72		http2.ErrCodeHTTP11Required:     codes.Internal,
73	}
74	statusCodeConvTab = map[codes.Code]http2.ErrCode{
75		codes.Internal:          http2.ErrCodeInternal,
76		codes.Canceled:          http2.ErrCodeCancel,
77		codes.Unavailable:       http2.ErrCodeRefusedStream,
78		codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
79		codes.PermissionDenied:  http2.ErrCodeInadequateSecurity,
80	}
81	httpStatusConvTab = map[int]codes.Code{
82		// 400 Bad Request - INTERNAL.
83		http.StatusBadRequest: codes.Internal,
84		// 401 Unauthorized  - UNAUTHENTICATED.
85		http.StatusUnauthorized: codes.Unauthenticated,
86		// 403 Forbidden - PERMISSION_DENIED.
87		http.StatusForbidden: codes.PermissionDenied,
88		// 404 Not Found - UNIMPLEMENTED.
89		http.StatusNotFound: codes.Unimplemented,
90		// 429 Too Many Requests - UNAVAILABLE.
91		http.StatusTooManyRequests: codes.Unavailable,
92		// 502 Bad Gateway - UNAVAILABLE.
93		http.StatusBadGateway: codes.Unavailable,
94		// 503 Service Unavailable - UNAVAILABLE.
95		http.StatusServiceUnavailable: codes.Unavailable,
96		// 504 Gateway timeout - UNAVAILABLE.
97		http.StatusGatewayTimeout: codes.Unavailable,
98	}
99)
100
101// Records the states during HPACK decoding. Must be reset once the
102// decoding of the entire headers are finished.
103type decodeState struct {
104	encoding string
105	// statusGen caches the stream status received from the trailer the server
106	// sent.  Client side only.  Do not access directly.  After all trailers are
107	// parsed, use the status method to retrieve the status.
108	statusGen *status.Status
109	// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
110	// intended for direct access outside of parsing.
111	rawStatusCode *int
112	rawStatusMsg  string
113	httpStatus    *int
114	// Server side only fields.
115	timeoutSet bool
116	timeout    time.Duration
117	method     string
118	// key-value metadata map from the peer.
119	mdata          map[string][]string
120	statsTags      []byte
121	statsTrace     []byte
122	contentSubtype string
123}
124
125// isReservedHeader checks whether hdr belongs to HTTP2 headers
126// reserved by gRPC protocol. Any other headers are classified as the
127// user-specified metadata.
128func isReservedHeader(hdr string) bool {
129	if hdr != "" && hdr[0] == ':' {
130		return true
131	}
132	switch hdr {
133	case "content-type",
134		"user-agent",
135		"grpc-message-type",
136		"grpc-encoding",
137		"grpc-message",
138		"grpc-status",
139		"grpc-timeout",
140		"grpc-status-details-bin",
141		"te":
142		return true
143	default:
144		return false
145	}
146}
147
148// isWhitelistedHeader checks whether hdr should be propagated
149// into metadata visible to users.
150func isWhitelistedHeader(hdr string) bool {
151	switch hdr {
152	case ":authority", "user-agent":
153		return true
154	default:
155		return false
156	}
157}
158
159// contentSubtype returns the content-subtype for the given content-type.  The
160// given content-type must be a valid content-type that starts with
161// "application/grpc". A content-subtype will follow "application/grpc" after a
162// "+" or ";". See
163// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
164// more details.
165//
166// If contentType is not a valid content-type for gRPC, the boolean
167// will be false, otherwise true. If content-type == "application/grpc",
168// "application/grpc+", or "application/grpc;", the boolean will be true,
169// but no content-subtype will be returned.
170//
171// contentType is assumed to be lowercase already.
172func contentSubtype(contentType string) (string, bool) {
173	if contentType == baseContentType {
174		return "", true
175	}
176	if !strings.HasPrefix(contentType, baseContentType) {
177		return "", false
178	}
179	// guaranteed since != baseContentType and has baseContentType prefix
180	switch contentType[len(baseContentType)] {
181	case '+', ';':
182		// this will return true for "application/grpc+" or "application/grpc;"
183		// which the previous validContentType function tested to be valid, so we
184		// just say that no content-subtype is specified in this case
185		return contentType[len(baseContentType)+1:], true
186	default:
187		return "", false
188	}
189}
190
191// contentSubtype is assumed to be lowercase
192func contentType(contentSubtype string) string {
193	if contentSubtype == "" {
194		return baseContentType
195	}
196	return baseContentType + "+" + contentSubtype
197}
198
199func (d *decodeState) status() *status.Status {
200	if d.statusGen == nil {
201		// No status-details were provided; generate status using code/msg.
202		d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
203	}
204	return d.statusGen
205}
206
207const binHdrSuffix = "-bin"
208
209func encodeBinHeader(v []byte) string {
210	return base64.RawStdEncoding.EncodeToString(v)
211}
212
213func decodeBinHeader(v string) ([]byte, error) {
214	if len(v)%4 == 0 {
215		// Input was padded, or padding was not necessary.
216		return base64.StdEncoding.DecodeString(v)
217	}
218	return base64.RawStdEncoding.DecodeString(v)
219}
220
221func encodeMetadataHeader(k, v string) string {
222	if strings.HasSuffix(k, binHdrSuffix) {
223		return encodeBinHeader(([]byte)(v))
224	}
225	return v
226}
227
228func decodeMetadataHeader(k, v string) (string, error) {
229	if strings.HasSuffix(k, binHdrSuffix) {
230		b, err := decodeBinHeader(v)
231		return string(b), err
232	}
233	return v, nil
234}
235
236func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error {
237	for _, hf := range frame.Fields {
238		if err := d.processHeaderField(hf); err != nil {
239			return err
240		}
241	}
242
243	// If grpc status exists, no need to check further.
244	if d.rawStatusCode != nil || d.statusGen != nil {
245		return nil
246	}
247
248	// If grpc status doesn't exist and http status doesn't exist,
249	// then it's a malformed header.
250	if d.httpStatus == nil {
251		return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
252	}
253
254	if *(d.httpStatus) != http.StatusOK {
255		code, ok := httpStatusConvTab[*(d.httpStatus)]
256		if !ok {
257			code = codes.Unknown
258		}
259		return streamErrorf(code, http.StatusText(*(d.httpStatus)))
260	}
261
262	// gRPC status doesn't exist and http status is OK.
263	// Set rawStatusCode to be unknown and return nil error.
264	// So that, if the stream has ended this Unknown status
265	// will be propagated to the user.
266	// Otherwise, it will be ignored. In which case, status from
267	// a later trailer, that has StreamEnded flag set, is propagated.
268	code := int(codes.Unknown)
269	d.rawStatusCode = &code
270	return nil
271
272}
273
274func (d *decodeState) addMetadata(k, v string) {
275	if d.mdata == nil {
276		d.mdata = make(map[string][]string)
277	}
278	d.mdata[k] = append(d.mdata[k], v)
279}
280
281func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
282	switch f.Name {
283	case "content-type":
284		contentSubtype, validContentType := contentSubtype(f.Value)
285		if !validContentType {
286			return streamErrorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
287		}
288		d.contentSubtype = contentSubtype
289		// TODO: do we want to propagate the whole content-type in the metadata,
290		// or come up with a way to just propagate the content-subtype if it was set?
291		// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
292		// in the metadata?
293		d.addMetadata(f.Name, f.Value)
294	case "grpc-encoding":
295		d.encoding = f.Value
296	case "grpc-status":
297		code, err := strconv.Atoi(f.Value)
298		if err != nil {
299			return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)
300		}
301		d.rawStatusCode = &code
302	case "grpc-message":
303		d.rawStatusMsg = decodeGrpcMessage(f.Value)
304	case "grpc-status-details-bin":
305		v, err := decodeBinHeader(f.Value)
306		if err != nil {
307			return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
308		}
309		s := &spb.Status{}
310		if err := proto.Unmarshal(v, s); err != nil {
311			return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
312		}
313		d.statusGen = status.FromProto(s)
314	case "grpc-timeout":
315		d.timeoutSet = true
316		var err error
317		if d.timeout, err = decodeTimeout(f.Value); err != nil {
318			return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)
319		}
320	case ":path":
321		d.method = f.Value
322	case ":status":
323		code, err := strconv.Atoi(f.Value)
324		if err != nil {
325			return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err)
326		}
327		d.httpStatus = &code
328	case "grpc-tags-bin":
329		v, err := decodeBinHeader(f.Value)
330		if err != nil {
331			return streamErrorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
332		}
333		d.statsTags = v
334		d.addMetadata(f.Name, string(v))
335	case "grpc-trace-bin":
336		v, err := decodeBinHeader(f.Value)
337		if err != nil {
338			return streamErrorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
339		}
340		d.statsTrace = v
341		d.addMetadata(f.Name, string(v))
342	default:
343		if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
344			break
345		}
346		v, err := decodeMetadataHeader(f.Name, f.Value)
347		if err != nil {
348			errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
349			return nil
350		}
351		d.addMetadata(f.Name, v)
352	}
353	return nil
354}
355
356type timeoutUnit uint8
357
358const (
359	hour        timeoutUnit = 'H'
360	minute      timeoutUnit = 'M'
361	second      timeoutUnit = 'S'
362	millisecond timeoutUnit = 'm'
363	microsecond timeoutUnit = 'u'
364	nanosecond  timeoutUnit = 'n'
365)
366
367func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
368	switch u {
369	case hour:
370		return time.Hour, true
371	case minute:
372		return time.Minute, true
373	case second:
374		return time.Second, true
375	case millisecond:
376		return time.Millisecond, true
377	case microsecond:
378		return time.Microsecond, true
379	case nanosecond:
380		return time.Nanosecond, true
381	default:
382	}
383	return
384}
385
386const maxTimeoutValue int64 = 100000000 - 1
387
388// div does integer division and round-up the result. Note that this is
389// equivalent to (d+r-1)/r but has less chance to overflow.
390func div(d, r time.Duration) int64 {
391	if m := d % r; m > 0 {
392		return int64(d/r + 1)
393	}
394	return int64(d / r)
395}
396
397// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
398func encodeTimeout(t time.Duration) string {
399	if t <= 0 {
400		return "0n"
401	}
402	if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
403		return strconv.FormatInt(d, 10) + "n"
404	}
405	if d := div(t, time.Microsecond); d <= maxTimeoutValue {
406		return strconv.FormatInt(d, 10) + "u"
407	}
408	if d := div(t, time.Millisecond); d <= maxTimeoutValue {
409		return strconv.FormatInt(d, 10) + "m"
410	}
411	if d := div(t, time.Second); d <= maxTimeoutValue {
412		return strconv.FormatInt(d, 10) + "S"
413	}
414	if d := div(t, time.Minute); d <= maxTimeoutValue {
415		return strconv.FormatInt(d, 10) + "M"
416	}
417	// Note that maxTimeoutValue * time.Hour > MaxInt64.
418	return strconv.FormatInt(div(t, time.Hour), 10) + "H"
419}
420
421func decodeTimeout(s string) (time.Duration, error) {
422	size := len(s)
423	if size < 2 {
424		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
425	}
426	unit := timeoutUnit(s[size-1])
427	d, ok := timeoutUnitToDuration(unit)
428	if !ok {
429		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
430	}
431	t, err := strconv.ParseInt(s[:size-1], 10, 64)
432	if err != nil {
433		return 0, err
434	}
435	return d * time.Duration(t), nil
436}
437
438const (
439	spaceByte   = ' '
440	tildaByte   = '~'
441	percentByte = '%'
442)
443
444// encodeGrpcMessage is used to encode status code in header field
445// "grpc-message".
446// It checks to see if each individual byte in msg is an
447// allowable byte, and then either percent encoding or passing it through.
448// When percent encoding, the byte is converted into hexadecimal notation
449// with a '%' prepended.
450func encodeGrpcMessage(msg string) string {
451	if msg == "" {
452		return ""
453	}
454	lenMsg := len(msg)
455	for i := 0; i < lenMsg; i++ {
456		c := msg[i]
457		if !(c >= spaceByte && c < tildaByte && c != percentByte) {
458			return encodeGrpcMessageUnchecked(msg)
459		}
460	}
461	return msg
462}
463
464func encodeGrpcMessageUnchecked(msg string) string {
465	var buf bytes.Buffer
466	lenMsg := len(msg)
467	for i := 0; i < lenMsg; i++ {
468		c := msg[i]
469		if c >= spaceByte && c < tildaByte && c != percentByte {
470			buf.WriteByte(c)
471		} else {
472			buf.WriteString(fmt.Sprintf("%%%02X", c))
473		}
474	}
475	return buf.String()
476}
477
478// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
479func decodeGrpcMessage(msg string) string {
480	if msg == "" {
481		return ""
482	}
483	lenMsg := len(msg)
484	for i := 0; i < lenMsg; i++ {
485		if msg[i] == percentByte && i+2 < lenMsg {
486			return decodeGrpcMessageUnchecked(msg)
487		}
488	}
489	return msg
490}
491
492func decodeGrpcMessageUnchecked(msg string) string {
493	var buf bytes.Buffer
494	lenMsg := len(msg)
495	for i := 0; i < lenMsg; i++ {
496		c := msg[i]
497		if c == percentByte && i+2 < lenMsg {
498			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
499			if err != nil {
500				buf.WriteByte(c)
501			} else {
502				buf.WriteByte(byte(parsed))
503				i += 2
504			}
505		} else {
506			buf.WriteByte(c)
507		}
508	}
509	return buf.String()
510}
511
512type bufWriter struct {
513	buf       []byte
514	offset    int
515	batchSize int
516	conn      net.Conn
517	err       error
518
519	onFlush func()
520}
521
522func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
523	return &bufWriter{
524		buf:       make([]byte, batchSize*2),
525		batchSize: batchSize,
526		conn:      conn,
527	}
528}
529
530func (w *bufWriter) Write(b []byte) (n int, err error) {
531	if w.err != nil {
532		return 0, w.err
533	}
534	n = copy(w.buf[w.offset:], b)
535	w.offset += n
536	if w.offset >= w.batchSize {
537		err = w.Flush()
538	}
539	return n, err
540}
541
542func (w *bufWriter) Flush() error {
543	if w.err != nil {
544		return w.err
545	}
546	if w.offset == 0 {
547		return nil
548	}
549	if w.onFlush != nil {
550		w.onFlush()
551	}
552	_, w.err = w.conn.Write(w.buf[:w.offset])
553	w.offset = 0
554	return w.err
555}
556
557type framer struct {
558	writer *bufWriter
559	fr     *http2.Framer
560}
561
562func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer {
563	r := bufio.NewReaderSize(conn, readBufferSize)
564	w := newBufWriter(conn, writeBufferSize)
565	f := &framer{
566		writer: w,
567		fr:     http2.NewFramer(w, r),
568	}
569	// Opt-in to Frame reuse API on framer to reduce garbage.
570	// Frames aren't safe to read from after a subsequent call to ReadFrame.
571	f.fr.SetReuseFrames()
572	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
573	return f
574}
575