1package grpcurl_test 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net" 9 "os" 10 "reflect" 11 "strings" 12 "testing" 13 "time" 14 15 "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 we have to import this because it appears in exported API 16 "github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API 17 "github.com/jhump/protoreflect/desc" 18 "github.com/jhump/protoreflect/grpcreflect" 19 "google.golang.org/grpc" 20 "google.golang.org/grpc/codes" 21 "google.golang.org/grpc/metadata" 22 "google.golang.org/grpc/reflection" 23 reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 24 "google.golang.org/grpc/status" 25 26 . "github.com/fullstorydev/grpcurl" 27 grpcurl_testing "github.com/fullstorydev/grpcurl/internal/testing" 28 jsonpbtest "github.com/fullstorydev/grpcurl/internal/testing/jsonpb_test_proto" 29) 30 31var ( 32 sourceProtoset DescriptorSource 33 sourceProtoFiles DescriptorSource 34 ccNoReflect *grpc.ClientConn 35 36 sourceReflect DescriptorSource 37 ccReflect *grpc.ClientConn 38 39 descSources []descSourceCase 40) 41 42type descSourceCase struct { 43 name string 44 source DescriptorSource 45 includeRefl bool 46} 47 48// NB: These tests intentionally use the deprecated InvokeRpc since that 49// calls the other (non-deprecated InvokeRPC). That allows the tests to 50// easily exercise both functions. 51 52func TestMain(m *testing.M) { 53 var err error 54 sourceProtoset, err = DescriptorSourceFromProtoSets("internal/testing/test.protoset") 55 if err != nil { 56 panic(err) 57 } 58 sourceProtoFiles, err = DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto") 59 if err != nil { 60 panic(err) 61 } 62 63 // Create a server that includes the reflection service 64 svrReflect := grpc.NewServer() 65 grpcurl_testing.RegisterTestServiceServer(svrReflect, grpcurl_testing.TestServer{}) 66 reflection.Register(svrReflect) 67 var portReflect int 68 if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil { 69 panic(err) 70 } else { 71 portReflect = l.Addr().(*net.TCPAddr).Port 72 go svrReflect.Serve(l) 73 } 74 defer svrReflect.Stop() 75 76 // And a corresponding client 77 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 78 defer cancel() 79 if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect), 80 grpc.WithInsecure(), grpc.WithBlock()); err != nil { 81 panic(err) 82 } 83 defer ccReflect.Close() 84 refClient := grpcreflect.NewClient(context.Background(), reflectpb.NewServerReflectionClient(ccReflect)) 85 defer refClient.Reset() 86 87 sourceReflect = DescriptorSourceFromServer(context.Background(), refClient) 88 89 // Also create a server that does *not* include the reflection service 90 svrProtoset := grpc.NewServer() 91 grpcurl_testing.RegisterTestServiceServer(svrProtoset, grpcurl_testing.TestServer{}) 92 var portProtoset int 93 if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil { 94 panic(err) 95 } else { 96 portProtoset = l.Addr().(*net.TCPAddr).Port 97 go svrProtoset.Serve(l) 98 } 99 defer svrProtoset.Stop() 100 101 // And a corresponding client 102 ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) 103 defer cancel() 104 if ccNoReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset), 105 grpc.WithInsecure(), grpc.WithBlock()); err != nil { 106 panic(err) 107 } 108 defer ccNoReflect.Close() 109 110 descSources = []descSourceCase{ 111 {"protoset", sourceProtoset, false}, 112 {"proto", sourceProtoFiles, false}, 113 {"reflect", sourceReflect, true}, 114 } 115 116 os.Exit(m.Run()) 117} 118 119func TestServerDoesNotSupportReflection(t *testing.T) { 120 refClient := grpcreflect.NewClient(context.Background(), reflectpb.NewServerReflectionClient(ccNoReflect)) 121 defer refClient.Reset() 122 123 refSource := DescriptorSourceFromServer(context.Background(), refClient) 124 125 _, err := ListServices(refSource) 126 if err != ErrReflectionNotSupported { 127 t.Errorf("ListServices should have returned ErrReflectionNotSupported; instead got %v", err) 128 } 129 130 _, err = ListMethods(refSource, "SomeService") 131 if err != ErrReflectionNotSupported { 132 t.Errorf("ListMethods should have returned ErrReflectionNotSupported; instead got %v", err) 133 } 134 135 err = InvokeRpc(context.Background(), refSource, ccNoReflect, "FooService/Method", nil, nil, nil) 136 // InvokeRpc wraps the error, so we just verify the returned error includes the right message 137 if err == nil || !strings.Contains(err.Error(), ErrReflectionNotSupported.Error()) { 138 t.Errorf("InvokeRpc should have returned ErrReflectionNotSupported; instead got %v", err) 139 } 140} 141 142func TestProtosetWithImports(t *testing.T) { 143 sourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/example.protoset") 144 if err != nil { 145 t.Fatalf("failed to load protoset: %v", err) 146 } 147 // really shallow check of the loaded descriptors 148 if sd, err := sourceProtoset.FindSymbol("TestService"); err != nil { 149 t.Errorf("failed to find TestService in protoset: %v", err) 150 } else if sd == nil { 151 t.Errorf("FindSymbol returned nil for TestService") 152 } else if _, ok := sd.(*desc.ServiceDescriptor); !ok { 153 t.Errorf("FindSymbol returned wrong kind of descriptor for TestService: %T", sd) 154 } 155 if md, err := sourceProtoset.FindSymbol("TestRequest"); err != nil { 156 t.Errorf("failed to find TestRequest in protoset: %v", err) 157 } else if md == nil { 158 t.Errorf("FindSymbol returned nil for TestRequest") 159 } else if _, ok := md.(*desc.MessageDescriptor); !ok { 160 t.Errorf("FindSymbol returned wrong kind of descriptor for TestRequest: %T", md) 161 } 162} 163 164func TestListServices(t *testing.T) { 165 for _, ds := range descSources { 166 t.Run(ds.name, func(t *testing.T) { 167 doTestListServices(t, ds.source, ds.includeRefl) 168 }) 169 } 170} 171 172func doTestListServices(t *testing.T, source DescriptorSource, includeReflection bool) { 173 names, err := ListServices(source) 174 if err != nil { 175 t.Fatalf("failed to list services: %v", err) 176 } 177 var expected []string 178 if includeReflection { 179 // when using server reflection, we see the TestService as well as the ServerReflection service 180 expected = []string{"grpc.reflection.v1alpha.ServerReflection", "testing.TestService"} 181 } else { 182 // without reflection, we see all services defined in the same test.proto file, which is the 183 // TestService as well as UnimplementedService 184 expected = []string{"testing.TestService", "testing.UnimplementedService"} 185 } 186 if !reflect.DeepEqual(expected, names) { 187 t.Errorf("ListServices returned wrong results: wanted %v, got %v", expected, names) 188 } 189} 190 191func TestListMethods(t *testing.T) { 192 for _, ds := range descSources { 193 t.Run(ds.name, func(t *testing.T) { 194 doTestListMethods(t, ds.source, ds.includeRefl) 195 }) 196 } 197} 198 199func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection bool) { 200 names, err := ListMethods(source, "testing.TestService") 201 if err != nil { 202 t.Fatalf("failed to list methods for TestService: %v", err) 203 } 204 expected := []string{ 205 "testing.TestService.EmptyCall", 206 "testing.TestService.FullDuplexCall", 207 "testing.TestService.HalfDuplexCall", 208 "testing.TestService.StreamingInputCall", 209 "testing.TestService.StreamingOutputCall", 210 "testing.TestService.UnaryCall", 211 } 212 if !reflect.DeepEqual(expected, names) { 213 t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names) 214 } 215 216 if includeReflection { 217 // when using server reflection, we see the TestService as well as the ServerReflection service 218 names, err = ListMethods(source, "grpc.reflection.v1alpha.ServerReflection") 219 if err != nil { 220 t.Fatalf("failed to list methods for ServerReflection: %v", err) 221 } 222 expected = []string{"grpc.reflection.v1alpha.ServerReflection.ServerReflectionInfo"} 223 } else { 224 // without reflection, we see all services defined in the same test.proto file, which is the 225 // TestService as well as UnimplementedService 226 names, err = ListMethods(source, "testing.UnimplementedService") 227 if err != nil { 228 t.Fatalf("failed to list methods for ServerReflection: %v", err) 229 } 230 expected = []string{"testing.UnimplementedService.UnimplementedCall"} 231 } 232 if !reflect.DeepEqual(expected, names) { 233 t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names) 234 } 235 236 // force an error 237 _, err = ListMethods(source, "FooService") 238 if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") { 239 t.Errorf("ListMethods should have returned 'not found' error but instead returned %v", err) 240 } 241} 242 243func TestGetAllFiles(t *testing.T) { 244 expectedFiles := []string{"test.proto"} 245 // server reflection picks up filename from linked in Go package, 246 // which indicates "grpc_testing/test.proto", not our local copy. 247 expectedFilesWithReflection := [][]string{ 248 {"grpc_reflection_v1alpha/reflection.proto", "test.proto"}, 249 // depending on the version of grpc, the filenames could be prefixed with "interop/" and "reflection/" 250 {"reflection/grpc_reflection_v1alpha/reflection.proto", "test.proto"}, 251 } 252 253 for _, ds := range descSources { 254 t.Run(ds.name, func(t *testing.T) { 255 files, err := GetAllFiles(ds.source) 256 if err != nil { 257 t.Fatalf("failed to get all files: %v", err) 258 } 259 names := fileNames(files) 260 match := false 261 var expected []string 262 if ds.includeRefl { 263 for _, expectedNames := range expectedFilesWithReflection { 264 expected = expectedNames 265 if reflect.DeepEqual(expected, names) { 266 match = true 267 break 268 } 269 } 270 } else { 271 expected = expectedFiles 272 match = reflect.DeepEqual(expected, names) 273 } 274 if !match { 275 t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expected, names) 276 } 277 }) 278 } 279 280 // try cases with more complicated set of files 281 otherSourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/test.protoset", "internal/testing/example.protoset") 282 if err != nil { 283 t.Fatal(err.Error()) 284 } 285 otherSourceProtoFiles, err := DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto", "example.proto") 286 if err != nil { 287 t.Fatal(err.Error()) 288 } 289 otherDescSources := []descSourceCase{ 290 {"protoset[b]", otherSourceProtoset, false}, 291 {"proto[b]", otherSourceProtoFiles, false}, 292 } 293 expectedFiles = []string{ 294 "example.proto", 295 "example2.proto", 296 "google/protobuf/any.proto", 297 "google/protobuf/descriptor.proto", 298 "google/protobuf/empty.proto", 299 "google/protobuf/timestamp.proto", 300 "test.proto", 301 } 302 for _, ds := range otherDescSources { 303 t.Run(ds.name, func(t *testing.T) { 304 files, err := GetAllFiles(ds.source) 305 if err != nil { 306 t.Fatalf("failed to get all files: %v", err) 307 } 308 names := fileNames(files) 309 if !reflect.DeepEqual(expectedFiles, names) { 310 t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expectedFiles, names) 311 } 312 }) 313 } 314} 315 316func TestExpandHeaders(t *testing.T) { 317 inHeaders := []string{"key1: ${value}", "key2: bar", "key3: ${woo", "key4: woo}", "key5: ${TEST}", 318 "key6: ${TEST_VAR}", "${TEST}: ${TEST_VAR}", "key8: ${EMPTY}"} 319 os.Setenv("value", "value") 320 os.Setenv("TEST", "value5") 321 os.Setenv("TEST_VAR", "value6") 322 os.Setenv("EMPTY", "") 323 expectedHeaders := map[string]bool{"key1: value": true, "key2: bar": true, "key3: ${woo": true, "key4: woo}": true, 324 "key5: value5": true, "key6: value6": true, "value5: value6": true, "key8: ": true} 325 326 outHeaders, err := ExpandHeaders(inHeaders) 327 if err != nil { 328 t.Errorf("The ExpandHeaders function generated an unexpected error %s", err) 329 } 330 for _, expandedHeader := range outHeaders { 331 if _, ok := expectedHeaders[expandedHeader]; !ok { 332 t.Errorf("The ExpandHeaders function has returned an unexpected header. Received unexpected header %s", expandedHeader) 333 } 334 } 335 336 badHeaders := []string{"key: ${DNE}"} 337 _, err = ExpandHeaders(badHeaders) 338 if err == nil { 339 t.Errorf("The ExpandHeaders function should return an error for missing environment variables %q", badHeaders) 340 } 341} 342 343func fileNames(files []*desc.FileDescriptor) []string { 344 names := make([]string, len(files)) 345 for i, f := range files { 346 names[i] = f.GetName() 347 } 348 return names 349} 350 351const expectKnownType = `{ 352 "dur": "0s", 353 "ts": "1970-01-01T00:00:00Z", 354 "dbl": 0, 355 "flt": 0, 356 "i64": "0", 357 "u64": "0", 358 "i32": 0, 359 "u32": 0, 360 "bool": false, 361 "str": "", 362 "bytes": null, 363 "st": {"google.protobuf.Struct": "supports arbitrary JSON objects"}, 364 "an": {"@type": "type.googleapis.com/google.protobuf.Empty", "value": {}}, 365 "lv": [{"google.protobuf.ListValue": "is an array of arbitrary JSON values"}], 366 "val": {"google.protobuf.Value": "supports arbitrary JSON"} 367}` 368 369func TestMakeTemplateKnownTypes(t *testing.T) { 370 descriptor, err := desc.LoadMessageDescriptorForMessage((*jsonpbtest.KnownTypes)(nil)) 371 if err != nil { 372 t.Fatalf("failed to load descriptor: %v", err) 373 } 374 message := MakeTemplate(descriptor) 375 376 jsm := jsonpb.Marshaler{EmitDefaults: true} 377 out, err := jsm.MarshalToString(message) 378 if err != nil { 379 t.Fatalf("failed to marshal to JSON: %v", err) 380 } 381 382 // make sure template JSON matches expected 383 var actual, expected interface{} 384 if err := json.Unmarshal([]byte(out), &actual); err != nil { 385 t.Fatalf("failed to parse actual JSON: %v", err) 386 } 387 if err := json.Unmarshal([]byte(expectKnownType), &expected); err != nil { 388 t.Fatalf("failed to parse expected JSON: %v", err) 389 } 390 391 if !reflect.DeepEqual(actual, expected) { 392 t.Errorf("template message is not as expected; want:\n%s\ngot:\n%s", expectKnownType, out) 393 } 394} 395 396func TestDescribe(t *testing.T) { 397 for _, ds := range descSources { 398 t.Run(ds.name, func(t *testing.T) { 399 doTestDescribe(t, ds.source) 400 }) 401 } 402} 403 404func doTestDescribe(t *testing.T, source DescriptorSource) { 405 sym := "testing.TestService.EmptyCall" 406 dsc, err := source.FindSymbol(sym) 407 if err != nil { 408 t.Fatalf("failed to get descriptor for %q: %v", sym, err) 409 } 410 if _, ok := dsc.(*desc.MethodDescriptor); !ok { 411 t.Fatalf("descriptor for %q was a %T (expecting a MethodDescriptor)", sym, dsc) 412 } 413 txt := proto.MarshalTextString(dsc.AsProto()) 414 expected := 415 `name: "EmptyCall" 416input_type: ".testing.Empty" 417output_type: ".testing.Empty" 418` 419 if expected != txt { 420 t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt) 421 } 422 423 sym = "testing.StreamingOutputCallResponse" 424 dsc, err = source.FindSymbol(sym) 425 if err != nil { 426 t.Fatalf("failed to get descriptor for %q: %v", sym, err) 427 } 428 if _, ok := dsc.(*desc.MessageDescriptor); !ok { 429 t.Fatalf("descriptor for %q was a %T (expecting a MessageDescriptor)", sym, dsc) 430 } 431 txt = proto.MarshalTextString(dsc.AsProto()) 432 expected = 433 `name: "StreamingOutputCallResponse" 434field: < 435 name: "payload" 436 number: 1 437 label: LABEL_OPTIONAL 438 type: TYPE_MESSAGE 439 type_name: ".testing.Payload" 440 json_name: "payload" 441> 442` 443 if expected != txt { 444 t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt) 445 } 446 447 _, err = source.FindSymbol("FooService") 448 if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") { 449 t.Errorf("FindSymbol should have returned 'not found' error but instead returned %v", err) 450 } 451} 452 453const ( 454 // type == COMPRESSABLE, but that is default (since it has 455 // numeric value == 0) and thus doesn't actually get included 456 // on the wire 457 payload1 = `{ 458 "payload": { 459 "body": "SXQncyBCdXNpbmVzcyBUaW1l" 460 } 461}` 462 payload2 = `{ 463 "payload": { 464 "type": "RANDOM", 465 "body": "Rm91eCBkdSBGYUZh" 466 } 467}` 468 payload3 = `{ 469 "payload": { 470 "type": "UNCOMPRESSABLE", 471 "body": "SGlwaG9wb3BvdGFtdXMgdnMuIFJoeW1lbm9jZXJvcw==" 472 } 473}` 474) 475 476func getCC(includeRefl bool) *grpc.ClientConn { 477 if includeRefl { 478 return ccReflect 479 } else { 480 return ccNoReflect 481 } 482} 483 484func TestUnary(t *testing.T) { 485 for _, ds := range descSources { 486 t.Run(ds.name, func(t *testing.T) { 487 doTestUnary(t, getCC(ds.includeRefl), ds.source) 488 }) 489 } 490} 491 492func doTestUnary(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { 493 // Success 494 h := &handler{reqMessages: []string{payload1}} 495 err := InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.OK), h, h.getRequestData) 496 if err != nil { 497 t.Fatalf("unexpected error during RPC: %v", err) 498 } 499 500 if h.check(t, "testing.TestService.UnaryCall", codes.OK, 1, 1) { 501 if h.respMessages[0] != payload1 { 502 t.Errorf("unexpected response from RPC: expecting %s; got %s", payload1, h.respMessages[0]) 503 } 504 } 505 506 // Failure 507 h = &handler{reqMessages: []string{payload1}} 508 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.NotFound), h, h.getRequestData) 509 if err != nil { 510 t.Fatalf("unexpected error during RPC: %v", err) 511 } 512 513 h.check(t, "testing.TestService.UnaryCall", codes.NotFound, 1, 0) 514} 515 516func TestClientStream(t *testing.T) { 517 for _, ds := range descSources { 518 t.Run(ds.name, func(t *testing.T) { 519 doTestClientStream(t, getCC(ds.includeRefl), ds.source) 520 }) 521 } 522} 523 524func doTestClientStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { 525 // Success 526 h := &handler{reqMessages: []string{payload1, payload2, payload3}} 527 err := InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.OK), h, h.getRequestData) 528 if err != nil { 529 t.Fatalf("unexpected error during RPC: %v", err) 530 } 531 532 if h.check(t, "testing.TestService.StreamingInputCall", codes.OK, 3, 1) { 533 expected := 534 `{ 535 "aggregatedPayloadSize": 61 536}` 537 if h.respMessages[0] != expected { 538 t.Errorf("unexpected response from RPC: expecting %s; got %s", expected, h.respMessages[0]) 539 } 540 } 541 542 // Fail fast (server rejects as soon as possible) 543 h = &handler{reqMessages: []string{payload1, payload2, payload3}} 544 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.InvalidArgument), h, h.getRequestData) 545 if err != nil { 546 t.Fatalf("unexpected error during RPC: %v", err) 547 } 548 549 h.check(t, "testing.TestService.StreamingInputCall", codes.InvalidArgument, -3, 0) 550 551 // Fail late (server waits until stream is complete to reject) 552 h = &handler{reqMessages: []string{payload1, payload2, payload3}} 553 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.Internal, true), h, h.getRequestData) 554 if err != nil { 555 t.Fatalf("unexpected error during RPC: %v", err) 556 } 557 558 h.check(t, "testing.TestService.StreamingInputCall", codes.Internal, 3, 0) 559} 560 561func TestServerStream(t *testing.T) { 562 for _, ds := range descSources { 563 t.Run(ds.name, func(t *testing.T) { 564 doTestServerStream(t, getCC(ds.includeRefl), ds.source) 565 }) 566 } 567} 568 569func doTestServerStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { 570 req := &grpcurl_testing.StreamingOutputCallRequest{ 571 ResponseType: grpcurl_testing.PayloadType_COMPRESSABLE, 572 ResponseParameters: []*grpcurl_testing.ResponseParameters{ 573 {Size: 10}, {Size: 20}, {Size: 30}, {Size: 40}, {Size: 50}, 574 }, 575 } 576 payload, err := (&jsonpb.Marshaler{}).MarshalToString(req) 577 if err != nil { 578 t.Fatalf("failed to construct request: %v", err) 579 } 580 581 // Success 582 h := &handler{reqMessages: []string{payload}} 583 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.OK), h, h.getRequestData) 584 if err != nil { 585 t.Fatalf("unexpected error during RPC: %v", err) 586 } 587 588 if h.check(t, "testing.TestService.StreamingOutputCall", codes.OK, 1, 5) { 589 resp := &grpcurl_testing.StreamingOutputCallResponse{} 590 for i, msg := range h.respMessages { 591 if err := jsonpb.UnmarshalString(msg, resp); err != nil { 592 t.Errorf("failed to parse response %d: %v", i+1, err) 593 } 594 if resp.Payload.GetType() != grpcurl_testing.PayloadType_COMPRESSABLE { 595 t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_COMPRESSABLE, resp.Payload.Type) 596 } 597 if len(resp.Payload.Body) != (i+1)*10 { 598 t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (i+1)*10, len(resp.Payload.Body)) 599 } 600 resp.Reset() 601 } 602 } 603 604 // Fail fast (server rejects as soon as possible) 605 h = &handler{reqMessages: []string{payload}} 606 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.Aborted), h, h.getRequestData) 607 if err != nil { 608 t.Fatalf("unexpected error during RPC: %v", err) 609 } 610 611 h.check(t, "testing.TestService.StreamingOutputCall", codes.Aborted, 1, 0) 612 613 // Fail late (server waits until stream is complete to reject) 614 h = &handler{reqMessages: []string{payload}} 615 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.AlreadyExists, true), h, h.getRequestData) 616 if err != nil { 617 t.Fatalf("unexpected error during RPC: %v", err) 618 } 619 620 h.check(t, "testing.TestService.StreamingOutputCall", codes.AlreadyExists, 1, 5) 621} 622 623func TestHalfDuplexStream(t *testing.T) { 624 for _, ds := range descSources { 625 t.Run(ds.name, func(t *testing.T) { 626 doTestHalfDuplexStream(t, getCC(ds.includeRefl), ds.source) 627 }) 628 } 629} 630 631func doTestHalfDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { 632 reqs := []string{payload1, payload2, payload3} 633 634 // Success 635 h := &handler{reqMessages: reqs} 636 err := InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.OK), h, h.getRequestData) 637 if err != nil { 638 t.Fatalf("unexpected error during RPC: %v", err) 639 } 640 641 if h.check(t, "testing.TestService.HalfDuplexCall", codes.OK, 3, 3) { 642 for i, resp := range h.respMessages { 643 if resp != reqs[i] { 644 t.Errorf("unexpected response %d from RPC:\nexpecting %q\ngot %q", i, reqs[i], resp) 645 } 646 } 647 } 648 649 // Fail fast (server rejects as soon as possible) 650 h = &handler{reqMessages: reqs} 651 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.Canceled), h, h.getRequestData) 652 if err != nil { 653 t.Fatalf("unexpected error during RPC: %v", err) 654 } 655 656 h.check(t, "testing.TestService.HalfDuplexCall", codes.Canceled, -3, 0) 657 658 // Fail late (server waits until stream is complete to reject) 659 h = &handler{reqMessages: reqs} 660 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.DataLoss, true), h, h.getRequestData) 661 if err != nil { 662 t.Fatalf("unexpected error during RPC: %v", err) 663 } 664 665 h.check(t, "testing.TestService.HalfDuplexCall", codes.DataLoss, 3, 3) 666} 667 668func TestFullDuplexStream(t *testing.T) { 669 for _, ds := range descSources { 670 t.Run(ds.name, func(t *testing.T) { 671 doTestFullDuplexStream(t, getCC(ds.includeRefl), ds.source) 672 }) 673 } 674} 675 676func doTestFullDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { 677 reqs := make([]string, 3) 678 req := &grpcurl_testing.StreamingOutputCallRequest{ 679 ResponseType: grpcurl_testing.PayloadType_RANDOM, 680 } 681 for i := range reqs { 682 req.ResponseParameters = append(req.ResponseParameters, &grpcurl_testing.ResponseParameters{Size: int32((i + 1) * 10)}) 683 payload, err := (&jsonpb.Marshaler{}).MarshalToString(req) 684 if err != nil { 685 t.Fatalf("failed to construct request %d: %v", i, err) 686 } 687 reqs[i] = payload 688 } 689 690 // Success 691 h := &handler{reqMessages: reqs} 692 err := InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.OK), h, h.getRequestData) 693 if err != nil { 694 t.Fatalf("unexpected error during RPC: %v", err) 695 } 696 697 if h.check(t, "testing.TestService.FullDuplexCall", codes.OK, 3, 6) { 698 resp := &grpcurl_testing.StreamingOutputCallResponse{} 699 i := 0 700 for j := 1; j < 3; j++ { 701 // three requests 702 for k := 0; k < j; k++ { 703 // 1 response for first request, 2 for second, etc 704 msg := h.respMessages[i] 705 if err := jsonpb.UnmarshalString(msg, resp); err != nil { 706 t.Errorf("failed to parse response %d: %v", i+1, err) 707 } 708 if resp.Payload.GetType() != grpcurl_testing.PayloadType_RANDOM { 709 t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_RANDOM, resp.Payload.Type) 710 } 711 if len(resp.Payload.Body) != (k+1)*10 { 712 t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (k+1)*10, len(resp.Payload.Body)) 713 } 714 resp.Reset() 715 716 i++ 717 } 718 } 719 } 720 721 // Fail fast (server rejects as soon as possible) 722 h = &handler{reqMessages: reqs} 723 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.PermissionDenied), h, h.getRequestData) 724 if err != nil { 725 t.Fatalf("unexpected error during RPC: %v", err) 726 } 727 728 h.check(t, "testing.TestService.FullDuplexCall", codes.PermissionDenied, -3, 0) 729 730 // Fail late (server waits until stream is complete to reject) 731 h = &handler{reqMessages: reqs} 732 err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.ResourceExhausted, true), h, h.getRequestData) 733 if err != nil { 734 t.Fatalf("unexpected error during RPC: %v", err) 735 } 736 737 h.check(t, "testing.TestService.FullDuplexCall", codes.ResourceExhausted, 3, 6) 738} 739 740type handler struct { 741 method *desc.MethodDescriptor 742 methodCount int 743 reqHeaders metadata.MD 744 reqHeadersCount int 745 reqMessages []string 746 reqMessagesCount int 747 respHeaders metadata.MD 748 respHeadersCount int 749 respMessages []string 750 respTrailers metadata.MD 751 respStatus *status.Status 752 respTrailersCount int 753} 754 755func (h *handler) getRequestData() ([]byte, error) { 756 // we don't use a mutex, though this method will be called from different goroutine 757 // than other methods for bidi calls, because this method does not share any state 758 // with the other methods. 759 h.reqMessagesCount++ 760 if h.reqMessagesCount > len(h.reqMessages) { 761 return nil, io.EOF 762 } 763 if h.reqMessagesCount > 1 { 764 // insert delay between messages in request stream 765 time.Sleep(time.Millisecond * 50) 766 } 767 return []byte(h.reqMessages[h.reqMessagesCount-1]), nil 768} 769 770func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) { 771 h.methodCount++ 772 h.method = md 773} 774 775func (h *handler) OnSendHeaders(md metadata.MD) { 776 h.reqHeadersCount++ 777 h.reqHeaders = md 778} 779 780func (h *handler) OnReceiveHeaders(md metadata.MD) { 781 h.respHeadersCount++ 782 h.respHeaders = md 783} 784 785func (h *handler) OnReceiveResponse(msg proto.Message) { 786 jsm := jsonpb.Marshaler{Indent: " "} 787 respStr, err := jsm.MarshalToString(msg) 788 if err != nil { 789 panic(fmt.Errorf("failed to generate JSON form of response message: %v", err)) 790 } 791 h.respMessages = append(h.respMessages, respStr) 792} 793 794func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) { 795 h.respTrailersCount++ 796 h.respTrailers = md 797 h.respStatus = stat 798} 799 800func (h *handler) check(t *testing.T, expectedMethod string, expectedCode codes.Code, expectedRequestQueries, expectedResponses int) bool { 801 // verify a few things were only ever called once 802 if h.methodCount != 1 { 803 t.Errorf("expected grpcurl to invoke OnResolveMethod once; was %d", h.methodCount) 804 } 805 if h.reqHeadersCount != 1 { 806 t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount) 807 } 808 if h.reqHeadersCount != 1 { 809 t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount) 810 } 811 if h.respHeadersCount != 1 { 812 t.Errorf("expected grpcurl to invoke OnReceiveHeaders once; was %d", h.respHeadersCount) 813 } 814 if h.respTrailersCount != 1 { 815 t.Errorf("expected grpcurl to invoke OnReceiveTrailers once; was %d", h.respTrailersCount) 816 } 817 818 // check other stuff against given expectations 819 if h.method.GetFullyQualifiedName() != expectedMethod { 820 t.Errorf("wrong method: expecting %v, got %v", expectedMethod, h.method.GetFullyQualifiedName()) 821 } 822 if h.respStatus.Code() != expectedCode { 823 t.Errorf("wrong code: expecting %v, got %v", expectedCode, h.respStatus.Code()) 824 } 825 if expectedRequestQueries < 0 { 826 // negative expectation means "negate and expect up to that number; could be fewer" 827 if h.reqMessagesCount > -expectedRequestQueries+1 { 828 // the + 1 is because there will be an extra query that returns EOF 829 t.Errorf("wrong number of messages queried: expecting no more than %v, got %v", -expectedRequestQueries, h.reqMessagesCount-1) 830 } 831 } else { 832 if h.reqMessagesCount != expectedRequestQueries+1 { 833 // the + 1 is because there will be an extra query that returns EOF 834 t.Errorf("wrong number of messages queried: expecting %v, got %v", expectedRequestQueries, h.reqMessagesCount-1) 835 } 836 } 837 if len(h.respMessages) != expectedResponses { 838 t.Errorf("wrong number of messages received: expecting %v, got %v", expectedResponses, len(h.respMessages)) 839 } 840 841 // also check headers and trailers came through as expected 842 v := h.respHeaders["some-fake-header-1"] 843 if len(v) != 1 || v[0] != "val1" { 844 t.Errorf("wrong request header for %q: %v", "some-fake-header-1", v) 845 } 846 v = h.respHeaders["some-fake-header-2"] 847 if len(v) != 1 || v[0] != "val2" { 848 t.Errorf("wrong request header for %q: %v", "some-fake-header-2", v) 849 } 850 v = h.respTrailers["some-fake-trailer-1"] 851 if len(v) != 1 || v[0] != "valA" { 852 t.Errorf("wrong request header for %q: %v", "some-fake-trailer-1", v) 853 } 854 v = h.respTrailers["some-fake-trailer-2"] 855 if len(v) != 1 || v[0] != "valB" { 856 t.Errorf("wrong request header for %q: %v", "some-fake-trailer-2", v) 857 } 858 859 return len(h.respMessages) == expectedResponses 860} 861 862func makeHeaders(code codes.Code, failLate ...bool) []string { 863 if len(failLate) > 1 { 864 panic("incorrect use of makeContext; should be at most one failLate flag") 865 } 866 867 hdrs := append(make([]string, 0, 5), 868 fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-1: val1"), 869 fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-2: val2"), 870 fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-1: valA"), 871 fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-2: valB")) 872 if code != codes.OK { 873 if len(failLate) > 0 && failLate[0] { 874 hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailLate, code)) 875 } else { 876 hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailEarly, code)) 877 } 878 } 879 880 return hdrs 881} 882