1/* 2 * 3 * Copyright 2014, Google Inc. 4 * All rights reserved. 5 * 6 * Redistribution and use in source and binary forms, with or without 7 * modification, are permitted provided that the following conditions are 8 * met: 9 * 10 * * Redistributions of source code must retain the above copyright 11 * notice, this list of conditions and the following disclaimer. 12 * * Redistributions in binary form must reproduce the above 13 * copyright notice, this list of conditions and the following disclaimer 14 * in the documentation and/or other materials provided with the 15 * distribution. 16 * * Neither the name of Google Inc. nor the names of its 17 * contributors may be used to endorse or promote products derived from 18 * this software without specific prior written permission. 19 * 20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 * 32 */ 33 34package transport 35 36import ( 37 "bufio" 38 "bytes" 39 "fmt" 40 "io" 41 "net" 42 "strconv" 43 "strings" 44 "sync/atomic" 45 "time" 46 47 "golang.org/x/net/http2" 48 "golang.org/x/net/http2/hpack" 49 "google.golang.org/grpc/codes" 50 "google.golang.org/grpc/grpclog" 51 "google.golang.org/grpc/metadata" 52) 53 54const ( 55 // The primary user agent 56 primaryUA = "grpc-go/1.0" 57 // http2MaxFrameLen specifies the max length of a HTTP2 frame. 58 http2MaxFrameLen = 16384 // 16KB frame 59 // http://http2.github.io/http2-spec/#SettingValues 60 http2InitHeaderTableSize = 4096 61 // http2IOBufSize specifies the buffer size for sending frames. 62 http2IOBufSize = 32 * 1024 63) 64 65var ( 66 clientPreface = []byte(http2.ClientPreface) 67 http2ErrConvTab = map[http2.ErrCode]codes.Code{ 68 http2.ErrCodeNo: codes.Internal, 69 http2.ErrCodeProtocol: codes.Internal, 70 http2.ErrCodeInternal: codes.Internal, 71 http2.ErrCodeFlowControl: codes.ResourceExhausted, 72 http2.ErrCodeSettingsTimeout: codes.Internal, 73 http2.ErrCodeStreamClosed: codes.Internal, 74 http2.ErrCodeFrameSize: codes.Internal, 75 http2.ErrCodeRefusedStream: codes.Unavailable, 76 http2.ErrCodeCancel: codes.Canceled, 77 http2.ErrCodeCompression: codes.Internal, 78 http2.ErrCodeConnect: codes.Internal, 79 http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, 80 http2.ErrCodeInadequateSecurity: codes.PermissionDenied, 81 http2.ErrCodeHTTP11Required: codes.FailedPrecondition, 82 } 83 statusCodeConvTab = map[codes.Code]http2.ErrCode{ 84 codes.Internal: http2.ErrCodeInternal, 85 codes.Canceled: http2.ErrCodeCancel, 86 codes.Unavailable: http2.ErrCodeRefusedStream, 87 codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, 88 codes.PermissionDenied: http2.ErrCodeInadequateSecurity, 89 } 90) 91 92// Records the states during HPACK decoding. Must be reset once the 93// decoding of the entire headers are finished. 94type decodeState struct { 95 err error // first error encountered decoding 96 97 encoding string 98 // statusCode caches the stream status received from the trailer 99 // the server sent. Client side only. 100 statusCode codes.Code 101 statusDesc string 102 // Server side only fields. 103 timeoutSet bool 104 timeout time.Duration 105 method string 106 // key-value metadata map from the peer. 107 mdata map[string][]string 108} 109 110// isReservedHeader checks whether hdr belongs to HTTP2 headers 111// reserved by gRPC protocol. Any other headers are classified as the 112// user-specified metadata. 113func isReservedHeader(hdr string) bool { 114 if hdr != "" && hdr[0] == ':' { 115 return true 116 } 117 switch hdr { 118 case "content-type", 119 "grpc-message-type", 120 "grpc-encoding", 121 "grpc-message", 122 "grpc-status", 123 "grpc-timeout", 124 "te": 125 return true 126 default: 127 return false 128 } 129} 130 131// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders 132// that should be propagated into metadata visible to users. 133func isWhitelistedPseudoHeader(hdr string) bool { 134 switch hdr { 135 case ":authority": 136 return true 137 default: 138 return false 139 } 140} 141 142func (d *decodeState) setErr(err error) { 143 if d.err == nil { 144 d.err = err 145 } 146} 147 148func validContentType(t string) bool { 149 e := "application/grpc" 150 if !strings.HasPrefix(t, e) { 151 return false 152 } 153 // Support variations on the content-type 154 // (e.g. "application/grpc+blah", "application/grpc;blah"). 155 if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' { 156 return false 157 } 158 return true 159} 160 161func (d *decodeState) processHeaderField(f hpack.HeaderField) { 162 switch f.Name { 163 case "content-type": 164 if !validContentType(f.Value) { 165 d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) 166 return 167 } 168 case "grpc-encoding": 169 d.encoding = f.Value 170 case "grpc-status": 171 code, err := strconv.Atoi(f.Value) 172 if err != nil { 173 d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)) 174 return 175 } 176 d.statusCode = codes.Code(code) 177 case "grpc-message": 178 d.statusDesc = decodeGrpcMessage(f.Value) 179 case "grpc-timeout": 180 d.timeoutSet = true 181 var err error 182 d.timeout, err = decodeTimeout(f.Value) 183 if err != nil { 184 d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) 185 return 186 } 187 case ":path": 188 d.method = f.Value 189 default: 190 if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) { 191 if f.Name == "user-agent" { 192 i := strings.LastIndex(f.Value, " ") 193 if i == -1 { 194 // There is no application user agent string being set. 195 return 196 } 197 // Extract the application user agent string. 198 f.Value = f.Value[:i] 199 } 200 if d.mdata == nil { 201 d.mdata = make(map[string][]string) 202 } 203 k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) 204 if err != nil { 205 grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err) 206 return 207 } 208 d.mdata[k] = append(d.mdata[k], v) 209 } 210 } 211} 212 213type timeoutUnit uint8 214 215const ( 216 hour timeoutUnit = 'H' 217 minute timeoutUnit = 'M' 218 second timeoutUnit = 'S' 219 millisecond timeoutUnit = 'm' 220 microsecond timeoutUnit = 'u' 221 nanosecond timeoutUnit = 'n' 222) 223 224func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) { 225 switch u { 226 case hour: 227 return time.Hour, true 228 case minute: 229 return time.Minute, true 230 case second: 231 return time.Second, true 232 case millisecond: 233 return time.Millisecond, true 234 case microsecond: 235 return time.Microsecond, true 236 case nanosecond: 237 return time.Nanosecond, true 238 default: 239 } 240 return 241} 242 243const maxTimeoutValue int64 = 100000000 - 1 244 245// div does integer division and round-up the result. Note that this is 246// equivalent to (d+r-1)/r but has less chance to overflow. 247func div(d, r time.Duration) int64 { 248 if m := d % r; m > 0 { 249 return int64(d/r + 1) 250 } 251 return int64(d / r) 252} 253 254// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. 255func encodeTimeout(t time.Duration) string { 256 if t <= 0 { 257 return "0n" 258 } 259 if d := div(t, time.Nanosecond); d <= maxTimeoutValue { 260 return strconv.FormatInt(d, 10) + "n" 261 } 262 if d := div(t, time.Microsecond); d <= maxTimeoutValue { 263 return strconv.FormatInt(d, 10) + "u" 264 } 265 if d := div(t, time.Millisecond); d <= maxTimeoutValue { 266 return strconv.FormatInt(d, 10) + "m" 267 } 268 if d := div(t, time.Second); d <= maxTimeoutValue { 269 return strconv.FormatInt(d, 10) + "S" 270 } 271 if d := div(t, time.Minute); d <= maxTimeoutValue { 272 return strconv.FormatInt(d, 10) + "M" 273 } 274 // Note that maxTimeoutValue * time.Hour > MaxInt64. 275 return strconv.FormatInt(div(t, time.Hour), 10) + "H" 276} 277 278func decodeTimeout(s string) (time.Duration, error) { 279 size := len(s) 280 if size < 2 { 281 return 0, fmt.Errorf("transport: timeout string is too short: %q", s) 282 } 283 unit := timeoutUnit(s[size-1]) 284 d, ok := timeoutUnitToDuration(unit) 285 if !ok { 286 return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) 287 } 288 t, err := strconv.ParseInt(s[:size-1], 10, 64) 289 if err != nil { 290 return 0, err 291 } 292 return d * time.Duration(t), nil 293} 294 295const ( 296 spaceByte = ' ' 297 tildaByte = '~' 298 percentByte = '%' 299) 300 301// encodeGrpcMessage is used to encode status code in header field 302// "grpc-message". 303// It checks to see if each individual byte in msg is an 304// allowable byte, and then either percent encoding or passing it through. 305// When percent encoding, the byte is converted into hexadecimal notation 306// with a '%' prepended. 307func encodeGrpcMessage(msg string) string { 308 if msg == "" { 309 return "" 310 } 311 lenMsg := len(msg) 312 for i := 0; i < lenMsg; i++ { 313 c := msg[i] 314 if !(c >= spaceByte && c < tildaByte && c != percentByte) { 315 return encodeGrpcMessageUnchecked(msg) 316 } 317 } 318 return msg 319} 320 321func encodeGrpcMessageUnchecked(msg string) string { 322 var buf bytes.Buffer 323 lenMsg := len(msg) 324 for i := 0; i < lenMsg; i++ { 325 c := msg[i] 326 if c >= spaceByte && c < tildaByte && c != percentByte { 327 buf.WriteByte(c) 328 } else { 329 buf.WriteString(fmt.Sprintf("%%%02X", c)) 330 } 331 } 332 return buf.String() 333} 334 335// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage. 336func decodeGrpcMessage(msg string) string { 337 if msg == "" { 338 return "" 339 } 340 lenMsg := len(msg) 341 for i := 0; i < lenMsg; i++ { 342 if msg[i] == percentByte && i+2 < lenMsg { 343 return decodeGrpcMessageUnchecked(msg) 344 } 345 } 346 return msg 347} 348 349func decodeGrpcMessageUnchecked(msg string) string { 350 var buf bytes.Buffer 351 lenMsg := len(msg) 352 for i := 0; i < lenMsg; i++ { 353 c := msg[i] 354 if c == percentByte && i+2 < lenMsg { 355 parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8) 356 if err != nil { 357 buf.WriteByte(c) 358 } else { 359 buf.WriteByte(byte(parsed)) 360 i += 2 361 } 362 } else { 363 buf.WriteByte(c) 364 } 365 } 366 return buf.String() 367} 368 369type framer struct { 370 numWriters int32 371 reader io.Reader 372 writer *bufio.Writer 373 fr *http2.Framer 374} 375 376func newFramer(conn net.Conn) *framer { 377 f := &framer{ 378 reader: bufio.NewReaderSize(conn, http2IOBufSize), 379 writer: bufio.NewWriterSize(conn, http2IOBufSize), 380 } 381 f.fr = http2.NewFramer(f.writer, f.reader) 382 f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) 383 return f 384} 385 386func (f *framer) adjustNumWriters(i int32) int32 { 387 return atomic.AddInt32(&f.numWriters, i) 388} 389 390// The following writeXXX functions can only be called when the caller gets 391// unblocked from writableChan channel (i.e., owns the privilege to write). 392 393func (f *framer) writeContinuation(forceFlush bool, streamID uint32, endHeaders bool, headerBlockFragment []byte) error { 394 if err := f.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { 395 return err 396 } 397 if forceFlush { 398 return f.writer.Flush() 399 } 400 return nil 401} 402 403func (f *framer) writeData(forceFlush bool, streamID uint32, endStream bool, data []byte) error { 404 if err := f.fr.WriteData(streamID, endStream, data); err != nil { 405 return err 406 } 407 if forceFlush { 408 return f.writer.Flush() 409 } 410 return nil 411} 412 413func (f *framer) writeGoAway(forceFlush bool, maxStreamID uint32, code http2.ErrCode, debugData []byte) error { 414 if err := f.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { 415 return err 416 } 417 if forceFlush { 418 return f.writer.Flush() 419 } 420 return nil 421} 422 423func (f *framer) writeHeaders(forceFlush bool, p http2.HeadersFrameParam) error { 424 if err := f.fr.WriteHeaders(p); err != nil { 425 return err 426 } 427 if forceFlush { 428 return f.writer.Flush() 429 } 430 return nil 431} 432 433func (f *framer) writePing(forceFlush, ack bool, data [8]byte) error { 434 if err := f.fr.WritePing(ack, data); err != nil { 435 return err 436 } 437 if forceFlush { 438 return f.writer.Flush() 439 } 440 return nil 441} 442 443func (f *framer) writePriority(forceFlush bool, streamID uint32, p http2.PriorityParam) error { 444 if err := f.fr.WritePriority(streamID, p); err != nil { 445 return err 446 } 447 if forceFlush { 448 return f.writer.Flush() 449 } 450 return nil 451} 452 453func (f *framer) writePushPromise(forceFlush bool, p http2.PushPromiseParam) error { 454 if err := f.fr.WritePushPromise(p); err != nil { 455 return err 456 } 457 if forceFlush { 458 return f.writer.Flush() 459 } 460 return nil 461} 462 463func (f *framer) writeRSTStream(forceFlush bool, streamID uint32, code http2.ErrCode) error { 464 if err := f.fr.WriteRSTStream(streamID, code); err != nil { 465 return err 466 } 467 if forceFlush { 468 return f.writer.Flush() 469 } 470 return nil 471} 472 473func (f *framer) writeSettings(forceFlush bool, settings ...http2.Setting) error { 474 if err := f.fr.WriteSettings(settings...); err != nil { 475 return err 476 } 477 if forceFlush { 478 return f.writer.Flush() 479 } 480 return nil 481} 482 483func (f *framer) writeSettingsAck(forceFlush bool) error { 484 if err := f.fr.WriteSettingsAck(); err != nil { 485 return err 486 } 487 if forceFlush { 488 return f.writer.Flush() 489 } 490 return nil 491} 492 493func (f *framer) writeWindowUpdate(forceFlush bool, streamID, incr uint32) error { 494 if err := f.fr.WriteWindowUpdate(streamID, incr); err != nil { 495 return err 496 } 497 if forceFlush { 498 return f.writer.Flush() 499 } 500 return nil 501} 502 503func (f *framer) flushWrite() error { 504 return f.writer.Flush() 505} 506 507func (f *framer) readFrame() (http2.Frame, error) { 508 return f.fr.ReadFrame() 509} 510 511func (f *framer) errorDetail() error { 512 return f.fr.ErrorDetail() 513} 514