1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package transport 20 21import ( 22 "context" 23 "errors" 24 "fmt" 25 "io" 26 "net/http" 27 "net/http/httptest" 28 "net/url" 29 "reflect" 30 "sync" 31 "testing" 32 "time" 33 34 "github.com/golang/protobuf/proto" 35 dpb "github.com/golang/protobuf/ptypes/duration" 36 epb "google.golang.org/genproto/googleapis/rpc/errdetails" 37 "google.golang.org/grpc/codes" 38 "google.golang.org/grpc/metadata" 39 "google.golang.org/grpc/status" 40) 41 42func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { 43 type testCase struct { 44 name string 45 req *http.Request 46 wantErr string 47 modrw func(http.ResponseWriter) http.ResponseWriter 48 check func(*serverHandlerTransport, *testCase) error 49 } 50 tests := []testCase{ 51 { 52 name: "http/1.1", 53 req: &http.Request{ 54 ProtoMajor: 1, 55 ProtoMinor: 1, 56 }, 57 wantErr: "gRPC requires HTTP/2", 58 }, 59 { 60 name: "bad method", 61 req: &http.Request{ 62 ProtoMajor: 2, 63 Method: "GET", 64 Header: http.Header{}, 65 RequestURI: "/", 66 }, 67 wantErr: "invalid gRPC request method", 68 }, 69 { 70 name: "bad content type", 71 req: &http.Request{ 72 ProtoMajor: 2, 73 Method: "POST", 74 Header: http.Header{ 75 "Content-Type": {"application/foo"}, 76 }, 77 RequestURI: "/service/foo.bar", 78 }, 79 wantErr: "invalid gRPC request content-type", 80 }, 81 { 82 name: "not flusher", 83 req: &http.Request{ 84 ProtoMajor: 2, 85 Method: "POST", 86 Header: http.Header{ 87 "Content-Type": {"application/grpc"}, 88 }, 89 RequestURI: "/service/foo.bar", 90 }, 91 modrw: func(w http.ResponseWriter) http.ResponseWriter { 92 // Return w without its Flush method 93 type onlyCloseNotifier interface { 94 http.ResponseWriter 95 http.CloseNotifier 96 } 97 return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)} 98 }, 99 wantErr: "gRPC requires a ResponseWriter supporting http.Flusher", 100 }, 101 { 102 name: "valid", 103 req: &http.Request{ 104 ProtoMajor: 2, 105 Method: "POST", 106 Header: http.Header{ 107 "Content-Type": {"application/grpc"}, 108 }, 109 URL: &url.URL{ 110 Path: "/service/foo.bar", 111 }, 112 RequestURI: "/service/foo.bar", 113 }, 114 check: func(t *serverHandlerTransport, tt *testCase) error { 115 if t.req != tt.req { 116 return fmt.Errorf("t.req = %p; want %p", t.req, tt.req) 117 } 118 if t.rw == nil { 119 return errors.New("t.rw = nil; want non-nil") 120 } 121 return nil 122 }, 123 }, 124 { 125 name: "with timeout", 126 req: &http.Request{ 127 ProtoMajor: 2, 128 Method: "POST", 129 Header: http.Header{ 130 "Content-Type": []string{"application/grpc"}, 131 "Grpc-Timeout": {"200m"}, 132 }, 133 URL: &url.URL{ 134 Path: "/service/foo.bar", 135 }, 136 RequestURI: "/service/foo.bar", 137 }, 138 check: func(t *serverHandlerTransport, tt *testCase) error { 139 if !t.timeoutSet { 140 return errors.New("timeout not set") 141 } 142 if want := 200 * time.Millisecond; t.timeout != want { 143 return fmt.Errorf("timeout = %v; want %v", t.timeout, want) 144 } 145 return nil 146 }, 147 }, 148 { 149 name: "with bad timeout", 150 req: &http.Request{ 151 ProtoMajor: 2, 152 Method: "POST", 153 Header: http.Header{ 154 "Content-Type": []string{"application/grpc"}, 155 "Grpc-Timeout": {"tomorrow"}, 156 }, 157 URL: &url.URL{ 158 Path: "/service/foo.bar", 159 }, 160 RequestURI: "/service/foo.bar", 161 }, 162 wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`, 163 }, 164 { 165 name: "with metadata", 166 req: &http.Request{ 167 ProtoMajor: 2, 168 Method: "POST", 169 Header: http.Header{ 170 "Content-Type": []string{"application/grpc"}, 171 "meta-foo": {"foo-val"}, 172 "meta-bar": {"bar-val1", "bar-val2"}, 173 "user-agent": {"x/y a/b"}, 174 }, 175 URL: &url.URL{ 176 Path: "/service/foo.bar", 177 }, 178 RequestURI: "/service/foo.bar", 179 }, 180 check: func(ht *serverHandlerTransport, tt *testCase) error { 181 want := metadata.MD{ 182 "meta-bar": {"bar-val1", "bar-val2"}, 183 "user-agent": {"x/y a/b"}, 184 "meta-foo": {"foo-val"}, 185 "content-type": {"application/grpc"}, 186 } 187 188 if !reflect.DeepEqual(ht.headerMD, want) { 189 return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want) 190 } 191 return nil 192 }, 193 }, 194 } 195 196 for _, tt := range tests { 197 rw := newTestHandlerResponseWriter() 198 if tt.modrw != nil { 199 rw = tt.modrw(rw) 200 } 201 got, gotErr := NewServerHandlerTransport(rw, tt.req, nil) 202 if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) { 203 t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr) 204 continue 205 } 206 if gotErr != nil { 207 continue 208 } 209 if tt.check != nil { 210 if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil { 211 t.Errorf("%s: %v", tt.name, err) 212 } 213 } 214 } 215} 216 217type testHandlerResponseWriter struct { 218 *httptest.ResponseRecorder 219 closeNotify chan bool 220} 221 222func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify } 223func (w testHandlerResponseWriter) Flush() {} 224 225func newTestHandlerResponseWriter() http.ResponseWriter { 226 return testHandlerResponseWriter{ 227 ResponseRecorder: httptest.NewRecorder(), 228 closeNotify: make(chan bool, 1), 229 } 230} 231 232type handleStreamTest struct { 233 t *testing.T 234 bodyw *io.PipeWriter 235 rw testHandlerResponseWriter 236 ht *serverHandlerTransport 237} 238 239func newHandleStreamTest(t *testing.T) *handleStreamTest { 240 bodyr, bodyw := io.Pipe() 241 req := &http.Request{ 242 ProtoMajor: 2, 243 Method: "POST", 244 Header: http.Header{ 245 "Content-Type": {"application/grpc"}, 246 }, 247 URL: &url.URL{ 248 Path: "/service/foo.bar", 249 }, 250 RequestURI: "/service/foo.bar", 251 Body: bodyr, 252 } 253 rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) 254 ht, err := NewServerHandlerTransport(rw, req, nil) 255 if err != nil { 256 t.Fatal(err) 257 } 258 return &handleStreamTest{ 259 t: t, 260 bodyw: bodyw, 261 ht: ht.(*serverHandlerTransport), 262 rw: rw, 263 } 264} 265 266func TestHandlerTransport_HandleStreams(t *testing.T) { 267 st := newHandleStreamTest(t) 268 handleStream := func(s *Stream) { 269 if want := "/service/foo.bar"; s.method != want { 270 t.Errorf("stream method = %q; want %q", s.method, want) 271 } 272 st.bodyw.Close() // no body 273 st.ht.WriteStatus(s, status.New(codes.OK, "")) 274 } 275 st.ht.HandleStreams( 276 func(s *Stream) { go handleStream(s) }, 277 func(ctx context.Context, method string) context.Context { return ctx }, 278 ) 279 wantHeader := http.Header{ 280 "Date": nil, 281 "Content-Type": {"application/grpc"}, 282 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 283 "Grpc-Status": {"0"}, 284 } 285 if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { 286 t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader) 287 } 288} 289 290// Tests that codes.Unimplemented will close the body, per comment in handler_server.go. 291func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) { 292 handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented") 293} 294 295// Tests that codes.InvalidArgument will close the body, per comment in handler_server.go. 296func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { 297 handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg") 298} 299 300func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { 301 st := newHandleStreamTest(t) 302 303 handleStream := func(s *Stream) { 304 st.ht.WriteStatus(s, status.New(statusCode, msg)) 305 } 306 st.ht.HandleStreams( 307 func(s *Stream) { go handleStream(s) }, 308 func(ctx context.Context, method string) context.Context { return ctx }, 309 ) 310 wantHeader := http.Header{ 311 "Date": nil, 312 "Content-Type": {"application/grpc"}, 313 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 314 "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, 315 "Grpc-Message": {encodeGrpcMessage(msg)}, 316 } 317 318 if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { 319 t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) 320 } 321} 322 323func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { 324 bodyr, bodyw := io.Pipe() 325 req := &http.Request{ 326 ProtoMajor: 2, 327 Method: "POST", 328 Header: http.Header{ 329 "Content-Type": {"application/grpc"}, 330 "Grpc-Timeout": {"200m"}, 331 }, 332 URL: &url.URL{ 333 Path: "/service/foo.bar", 334 }, 335 RequestURI: "/service/foo.bar", 336 Body: bodyr, 337 } 338 rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) 339 ht, err := NewServerHandlerTransport(rw, req, nil) 340 if err != nil { 341 t.Fatal(err) 342 } 343 runStream := func(s *Stream) { 344 defer bodyw.Close() 345 select { 346 case <-s.ctx.Done(): 347 case <-time.After(5 * time.Second): 348 t.Errorf("timeout waiting for ctx.Done") 349 return 350 } 351 err := s.ctx.Err() 352 if err != context.DeadlineExceeded { 353 t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded) 354 return 355 } 356 ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow")) 357 } 358 ht.HandleStreams( 359 func(s *Stream) { go runStream(s) }, 360 func(ctx context.Context, method string) context.Context { return ctx }, 361 ) 362 wantHeader := http.Header{ 363 "Date": nil, 364 "Content-Type": {"application/grpc"}, 365 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 366 "Grpc-Status": {"4"}, 367 "Grpc-Message": {encodeGrpcMessage("too slow")}, 368 } 369 if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { 370 t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) 371 } 372} 373 374// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that 375// concurrent "WriteStatus"s do not panic writing to closed "writes" channel. 376func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) { 377 testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) { 378 if want := "/service/foo.bar"; s.method != want { 379 t.Errorf("stream method = %q; want %q", s.method, want) 380 } 381 st.bodyw.Close() // no body 382 383 var wg sync.WaitGroup 384 wg.Add(5) 385 for i := 0; i < 5; i++ { 386 go func() { 387 defer wg.Done() 388 st.ht.WriteStatus(s, status.New(codes.OK, "")) 389 }() 390 } 391 wg.Wait() 392 }) 393} 394 395// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write" 396// following "WriteStatus" does not panic writing to closed "writes" channel. 397func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { 398 testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) { 399 if want := "/service/foo.bar"; s.method != want { 400 t.Errorf("stream method = %q; want %q", s.method, want) 401 } 402 st.bodyw.Close() // no body 403 404 st.ht.WriteStatus(s, status.New(codes.OK, "")) 405 st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{}) 406 }) 407} 408 409func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) { 410 st := newHandleStreamTest(t) 411 st.ht.HandleStreams( 412 func(s *Stream) { go handleStream(st, s) }, 413 func(ctx context.Context, method string) context.Context { return ctx }, 414 ) 415} 416 417func TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { 418 errDetails := []proto.Message{ 419 &epb.RetryInfo{ 420 RetryDelay: &dpb.Duration{Seconds: 60}, 421 }, 422 &epb.ResourceInfo{ 423 ResourceType: "foo bar", 424 ResourceName: "service.foo.bar", 425 Owner: "User", 426 }, 427 } 428 429 statusCode := codes.ResourceExhausted 430 msg := "you are being throttled" 431 st, err := status.New(statusCode, msg).WithDetails(errDetails...) 432 if err != nil { 433 t.Fatal(err) 434 } 435 436 stBytes, err := proto.Marshal(st.Proto()) 437 if err != nil { 438 t.Fatal(err) 439 } 440 441 hst := newHandleStreamTest(t) 442 handleStream := func(s *Stream) { 443 hst.ht.WriteStatus(s, st) 444 } 445 hst.ht.HandleStreams( 446 func(s *Stream) { go handleStream(s) }, 447 func(ctx context.Context, method string) context.Context { return ctx }, 448 ) 449 wantHeader := http.Header{ 450 "Date": nil, 451 "Content-Type": {"application/grpc"}, 452 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, 453 "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, 454 "Grpc-Message": {encodeGrpcMessage(msg)}, 455 "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, 456 } 457 458 if !reflect.DeepEqual(hst.rw.HeaderMap, wantHeader) { 459 t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", hst.rw.HeaderMap, wantHeader) 460 } 461} 462