1/*
2 *
3 * Copyright 2014, Google Inc.
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are
8 * met:
9 *
10 *     * Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 *     * Redistributions in binary form must reproduce the above
13 * copyright notice, this list of conditions and the following disclaimer
14 * in the documentation and/or other materials provided with the
15 * distribution.
16 *     * Neither the name of Google Inc. nor the names of its
17 * contributors may be used to endorse or promote products derived from
18 * this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 *
32 */
33
34package transport
35
36import (
37	"bufio"
38	"bytes"
39	"fmt"
40	"io"
41	"net"
42	"strconv"
43	"strings"
44	"sync/atomic"
45	"time"
46
47	"golang.org/x/net/http2"
48	"golang.org/x/net/http2/hpack"
49	"google.golang.org/grpc/codes"
50	"google.golang.org/grpc/grpclog"
51	"google.golang.org/grpc/metadata"
52)
53
54const (
55	// The primary user agent
56	primaryUA = "grpc-go/1.0"
57	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
58	http2MaxFrameLen = 16384 // 16KB frame
59	// http://http2.github.io/http2-spec/#SettingValues
60	http2InitHeaderTableSize = 4096
61	// http2IOBufSize specifies the buffer size for sending frames.
62	http2IOBufSize = 32 * 1024
63)
64
65var (
66	clientPreface   = []byte(http2.ClientPreface)
67	http2ErrConvTab = map[http2.ErrCode]codes.Code{
68		http2.ErrCodeNo:                 codes.Internal,
69		http2.ErrCodeProtocol:           codes.Internal,
70		http2.ErrCodeInternal:           codes.Internal,
71		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
72		http2.ErrCodeSettingsTimeout:    codes.Internal,
73		http2.ErrCodeStreamClosed:       codes.Internal,
74		http2.ErrCodeFrameSize:          codes.Internal,
75		http2.ErrCodeRefusedStream:      codes.Unavailable,
76		http2.ErrCodeCancel:             codes.Canceled,
77		http2.ErrCodeCompression:        codes.Internal,
78		http2.ErrCodeConnect:            codes.Internal,
79		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
80		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
81		http2.ErrCodeHTTP11Required:     codes.FailedPrecondition,
82	}
83	statusCodeConvTab = map[codes.Code]http2.ErrCode{
84		codes.Internal:          http2.ErrCodeInternal,
85		codes.Canceled:          http2.ErrCodeCancel,
86		codes.Unavailable:       http2.ErrCodeRefusedStream,
87		codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
88		codes.PermissionDenied:  http2.ErrCodeInadequateSecurity,
89	}
90)
91
92// Records the states during HPACK decoding. Must be reset once the
93// decoding of the entire headers are finished.
94type decodeState struct {
95	err error // first error encountered decoding
96
97	encoding string
98	// statusCode caches the stream status received from the trailer
99	// the server sent. Client side only.
100	statusCode codes.Code
101	statusDesc string
102	// Server side only fields.
103	timeoutSet bool
104	timeout    time.Duration
105	method     string
106	// key-value metadata map from the peer.
107	mdata map[string][]string
108}
109
110// isReservedHeader checks whether hdr belongs to HTTP2 headers
111// reserved by gRPC protocol. Any other headers are classified as the
112// user-specified metadata.
113func isReservedHeader(hdr string) bool {
114	if hdr != "" && hdr[0] == ':' {
115		return true
116	}
117	switch hdr {
118	case "content-type",
119		"grpc-message-type",
120		"grpc-encoding",
121		"grpc-message",
122		"grpc-status",
123		"grpc-timeout",
124		"te":
125		return true
126	default:
127		return false
128	}
129}
130
131// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders
132// that should be propagated into metadata visible to users.
133func isWhitelistedPseudoHeader(hdr string) bool {
134	switch hdr {
135	case ":authority":
136		return true
137	default:
138		return false
139	}
140}
141
142func (d *decodeState) setErr(err error) {
143	if d.err == nil {
144		d.err = err
145	}
146}
147
148func validContentType(t string) bool {
149	e := "application/grpc"
150	if !strings.HasPrefix(t, e) {
151		return false
152	}
153	// Support variations on the content-type
154	// (e.g. "application/grpc+blah", "application/grpc;blah").
155	if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' {
156		return false
157	}
158	return true
159}
160
161func (d *decodeState) processHeaderField(f hpack.HeaderField) {
162	switch f.Name {
163	case "content-type":
164		if !validContentType(f.Value) {
165			d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
166			return
167		}
168	case "grpc-encoding":
169		d.encoding = f.Value
170	case "grpc-status":
171		code, err := strconv.Atoi(f.Value)
172		if err != nil {
173			d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err))
174			return
175		}
176		d.statusCode = codes.Code(code)
177	case "grpc-message":
178		d.statusDesc = decodeGrpcMessage(f.Value)
179	case "grpc-timeout":
180		d.timeoutSet = true
181		var err error
182		d.timeout, err = decodeTimeout(f.Value)
183		if err != nil {
184			d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
185			return
186		}
187	case ":path":
188		d.method = f.Value
189	default:
190		if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) {
191			if f.Name == "user-agent" {
192				i := strings.LastIndex(f.Value, " ")
193				if i == -1 {
194					// There is no application user agent string being set.
195					return
196				}
197				// Extract the application user agent string.
198				f.Value = f.Value[:i]
199			}
200			if d.mdata == nil {
201				d.mdata = make(map[string][]string)
202			}
203			k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
204			if err != nil {
205				grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err)
206				return
207			}
208			d.mdata[k] = append(d.mdata[k], v)
209		}
210	}
211}
212
213type timeoutUnit uint8
214
215const (
216	hour        timeoutUnit = 'H'
217	minute      timeoutUnit = 'M'
218	second      timeoutUnit = 'S'
219	millisecond timeoutUnit = 'm'
220	microsecond timeoutUnit = 'u'
221	nanosecond  timeoutUnit = 'n'
222)
223
224func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
225	switch u {
226	case hour:
227		return time.Hour, true
228	case minute:
229		return time.Minute, true
230	case second:
231		return time.Second, true
232	case millisecond:
233		return time.Millisecond, true
234	case microsecond:
235		return time.Microsecond, true
236	case nanosecond:
237		return time.Nanosecond, true
238	default:
239	}
240	return
241}
242
243const maxTimeoutValue int64 = 100000000 - 1
244
245// div does integer division and round-up the result. Note that this is
246// equivalent to (d+r-1)/r but has less chance to overflow.
247func div(d, r time.Duration) int64 {
248	if m := d % r; m > 0 {
249		return int64(d/r + 1)
250	}
251	return int64(d / r)
252}
253
254// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
255func encodeTimeout(t time.Duration) string {
256	if t <= 0 {
257		return "0n"
258	}
259	if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
260		return strconv.FormatInt(d, 10) + "n"
261	}
262	if d := div(t, time.Microsecond); d <= maxTimeoutValue {
263		return strconv.FormatInt(d, 10) + "u"
264	}
265	if d := div(t, time.Millisecond); d <= maxTimeoutValue {
266		return strconv.FormatInt(d, 10) + "m"
267	}
268	if d := div(t, time.Second); d <= maxTimeoutValue {
269		return strconv.FormatInt(d, 10) + "S"
270	}
271	if d := div(t, time.Minute); d <= maxTimeoutValue {
272		return strconv.FormatInt(d, 10) + "M"
273	}
274	// Note that maxTimeoutValue * time.Hour > MaxInt64.
275	return strconv.FormatInt(div(t, time.Hour), 10) + "H"
276}
277
278func decodeTimeout(s string) (time.Duration, error) {
279	size := len(s)
280	if size < 2 {
281		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
282	}
283	unit := timeoutUnit(s[size-1])
284	d, ok := timeoutUnitToDuration(unit)
285	if !ok {
286		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
287	}
288	t, err := strconv.ParseInt(s[:size-1], 10, 64)
289	if err != nil {
290		return 0, err
291	}
292	return d * time.Duration(t), nil
293}
294
295const (
296	spaceByte   = ' '
297	tildaByte   = '~'
298	percentByte = '%'
299)
300
301// encodeGrpcMessage is used to encode status code in header field
302// "grpc-message".
303// It checks to see if each individual byte in msg is an
304// allowable byte, and then either percent encoding or passing it through.
305// When percent encoding, the byte is converted into hexadecimal notation
306// with a '%' prepended.
307func encodeGrpcMessage(msg string) string {
308	if msg == "" {
309		return ""
310	}
311	lenMsg := len(msg)
312	for i := 0; i < lenMsg; i++ {
313		c := msg[i]
314		if !(c >= spaceByte && c < tildaByte && c != percentByte) {
315			return encodeGrpcMessageUnchecked(msg)
316		}
317	}
318	return msg
319}
320
321func encodeGrpcMessageUnchecked(msg string) string {
322	var buf bytes.Buffer
323	lenMsg := len(msg)
324	for i := 0; i < lenMsg; i++ {
325		c := msg[i]
326		if c >= spaceByte && c < tildaByte && c != percentByte {
327			buf.WriteByte(c)
328		} else {
329			buf.WriteString(fmt.Sprintf("%%%02X", c))
330		}
331	}
332	return buf.String()
333}
334
335// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
336func decodeGrpcMessage(msg string) string {
337	if msg == "" {
338		return ""
339	}
340	lenMsg := len(msg)
341	for i := 0; i < lenMsg; i++ {
342		if msg[i] == percentByte && i+2 < lenMsg {
343			return decodeGrpcMessageUnchecked(msg)
344		}
345	}
346	return msg
347}
348
349func decodeGrpcMessageUnchecked(msg string) string {
350	var buf bytes.Buffer
351	lenMsg := len(msg)
352	for i := 0; i < lenMsg; i++ {
353		c := msg[i]
354		if c == percentByte && i+2 < lenMsg {
355			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
356			if err != nil {
357				buf.WriteByte(c)
358			} else {
359				buf.WriteByte(byte(parsed))
360				i += 2
361			}
362		} else {
363			buf.WriteByte(c)
364		}
365	}
366	return buf.String()
367}
368
369type framer struct {
370	numWriters int32
371	reader     io.Reader
372	writer     *bufio.Writer
373	fr         *http2.Framer
374}
375
376func newFramer(conn net.Conn) *framer {
377	f := &framer{
378		reader: bufio.NewReaderSize(conn, http2IOBufSize),
379		writer: bufio.NewWriterSize(conn, http2IOBufSize),
380	}
381	f.fr = http2.NewFramer(f.writer, f.reader)
382	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
383	return f
384}
385
386func (f *framer) adjustNumWriters(i int32) int32 {
387	return atomic.AddInt32(&f.numWriters, i)
388}
389
390// The following writeXXX functions can only be called when the caller gets
391// unblocked from writableChan channel (i.e., owns the privilege to write).
392
393func (f *framer) writeContinuation(forceFlush bool, streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
394	if err := f.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
395		return err
396	}
397	if forceFlush {
398		return f.writer.Flush()
399	}
400	return nil
401}
402
403func (f *framer) writeData(forceFlush bool, streamID uint32, endStream bool, data []byte) error {
404	if err := f.fr.WriteData(streamID, endStream, data); err != nil {
405		return err
406	}
407	if forceFlush {
408		return f.writer.Flush()
409	}
410	return nil
411}
412
413func (f *framer) writeGoAway(forceFlush bool, maxStreamID uint32, code http2.ErrCode, debugData []byte) error {
414	if err := f.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
415		return err
416	}
417	if forceFlush {
418		return f.writer.Flush()
419	}
420	return nil
421}
422
423func (f *framer) writeHeaders(forceFlush bool, p http2.HeadersFrameParam) error {
424	if err := f.fr.WriteHeaders(p); err != nil {
425		return err
426	}
427	if forceFlush {
428		return f.writer.Flush()
429	}
430	return nil
431}
432
433func (f *framer) writePing(forceFlush, ack bool, data [8]byte) error {
434	if err := f.fr.WritePing(ack, data); err != nil {
435		return err
436	}
437	if forceFlush {
438		return f.writer.Flush()
439	}
440	return nil
441}
442
443func (f *framer) writePriority(forceFlush bool, streamID uint32, p http2.PriorityParam) error {
444	if err := f.fr.WritePriority(streamID, p); err != nil {
445		return err
446	}
447	if forceFlush {
448		return f.writer.Flush()
449	}
450	return nil
451}
452
453func (f *framer) writePushPromise(forceFlush bool, p http2.PushPromiseParam) error {
454	if err := f.fr.WritePushPromise(p); err != nil {
455		return err
456	}
457	if forceFlush {
458		return f.writer.Flush()
459	}
460	return nil
461}
462
463func (f *framer) writeRSTStream(forceFlush bool, streamID uint32, code http2.ErrCode) error {
464	if err := f.fr.WriteRSTStream(streamID, code); err != nil {
465		return err
466	}
467	if forceFlush {
468		return f.writer.Flush()
469	}
470	return nil
471}
472
473func (f *framer) writeSettings(forceFlush bool, settings ...http2.Setting) error {
474	if err := f.fr.WriteSettings(settings...); err != nil {
475		return err
476	}
477	if forceFlush {
478		return f.writer.Flush()
479	}
480	return nil
481}
482
483func (f *framer) writeSettingsAck(forceFlush bool) error {
484	if err := f.fr.WriteSettingsAck(); err != nil {
485		return err
486	}
487	if forceFlush {
488		return f.writer.Flush()
489	}
490	return nil
491}
492
493func (f *framer) writeWindowUpdate(forceFlush bool, streamID, incr uint32) error {
494	if err := f.fr.WriteWindowUpdate(streamID, incr); err != nil {
495		return err
496	}
497	if forceFlush {
498		return f.writer.Flush()
499	}
500	return nil
501}
502
503func (f *framer) flushWrite() error {
504	return f.writer.Flush()
505}
506
507func (f *framer) readFrame() (http2.Frame, error) {
508	return f.fr.ReadFrame()
509}
510
511func (f *framer) errorDetail() error {
512	return f.fr.ErrorDetail()
513}
514