1package grpcurl_test
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"io"
8	"net"
9	"os"
10	"reflect"
11	"strings"
12	"testing"
13	"time"
14
15	"github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 we have to import this because it appears in exported API
16	"github.com/golang/protobuf/proto"  //lint:ignore SA1019 we have to import this because it appears in exported API
17	"github.com/jhump/protoreflect/desc"
18	"github.com/jhump/protoreflect/grpcreflect"
19	"google.golang.org/grpc"
20	"google.golang.org/grpc/codes"
21	"google.golang.org/grpc/metadata"
22	"google.golang.org/grpc/reflection"
23	reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
24	"google.golang.org/grpc/status"
25
26	. "github.com/fullstorydev/grpcurl"
27	grpcurl_testing "github.com/fullstorydev/grpcurl/internal/testing"
28	jsonpbtest "github.com/fullstorydev/grpcurl/internal/testing/jsonpb_test_proto"
29)
30
31var (
32	sourceProtoset   DescriptorSource
33	sourceProtoFiles DescriptorSource
34	ccNoReflect      *grpc.ClientConn
35
36	sourceReflect DescriptorSource
37	ccReflect     *grpc.ClientConn
38
39	descSources []descSourceCase
40)
41
42type descSourceCase struct {
43	name        string
44	source      DescriptorSource
45	includeRefl bool
46}
47
48// NB: These tests intentionally use the deprecated InvokeRpc since that
49// calls the other (non-deprecated InvokeRPC). That allows the tests to
50// easily exercise both functions.
51
52func TestMain(m *testing.M) {
53	var err error
54	sourceProtoset, err = DescriptorSourceFromProtoSets("internal/testing/test.protoset")
55	if err != nil {
56		panic(err)
57	}
58	sourceProtoFiles, err = DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto")
59	if err != nil {
60		panic(err)
61	}
62
63	// Create a server that includes the reflection service
64	svrReflect := grpc.NewServer()
65	grpcurl_testing.RegisterTestServiceServer(svrReflect, grpcurl_testing.TestServer{})
66	reflection.Register(svrReflect)
67	var portReflect int
68	if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil {
69		panic(err)
70	} else {
71		portReflect = l.Addr().(*net.TCPAddr).Port
72		go svrReflect.Serve(l)
73	}
74	defer svrReflect.Stop()
75
76	// And a corresponding client
77	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
78	defer cancel()
79	if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect),
80		grpc.WithInsecure(), grpc.WithBlock()); err != nil {
81		panic(err)
82	}
83	defer ccReflect.Close()
84	refClient := grpcreflect.NewClient(context.Background(), reflectpb.NewServerReflectionClient(ccReflect))
85	defer refClient.Reset()
86
87	sourceReflect = DescriptorSourceFromServer(context.Background(), refClient)
88
89	// Also create a server that does *not* include the reflection service
90	svrProtoset := grpc.NewServer()
91	grpcurl_testing.RegisterTestServiceServer(svrProtoset, grpcurl_testing.TestServer{})
92	var portProtoset int
93	if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil {
94		panic(err)
95	} else {
96		portProtoset = l.Addr().(*net.TCPAddr).Port
97		go svrProtoset.Serve(l)
98	}
99	defer svrProtoset.Stop()
100
101	// And a corresponding client
102	ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
103	defer cancel()
104	if ccNoReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset),
105		grpc.WithInsecure(), grpc.WithBlock()); err != nil {
106		panic(err)
107	}
108	defer ccNoReflect.Close()
109
110	descSources = []descSourceCase{
111		{"protoset", sourceProtoset, false},
112		{"proto", sourceProtoFiles, false},
113		{"reflect", sourceReflect, true},
114	}
115
116	os.Exit(m.Run())
117}
118
119func TestServerDoesNotSupportReflection(t *testing.T) {
120	refClient := grpcreflect.NewClient(context.Background(), reflectpb.NewServerReflectionClient(ccNoReflect))
121	defer refClient.Reset()
122
123	refSource := DescriptorSourceFromServer(context.Background(), refClient)
124
125	_, err := ListServices(refSource)
126	if err != ErrReflectionNotSupported {
127		t.Errorf("ListServices should have returned ErrReflectionNotSupported; instead got %v", err)
128	}
129
130	_, err = ListMethods(refSource, "SomeService")
131	if err != ErrReflectionNotSupported {
132		t.Errorf("ListMethods should have returned ErrReflectionNotSupported; instead got %v", err)
133	}
134
135	err = InvokeRpc(context.Background(), refSource, ccNoReflect, "FooService/Method", nil, nil, nil)
136	// InvokeRpc wraps the error, so we just verify the returned error includes the right message
137	if err == nil || !strings.Contains(err.Error(), ErrReflectionNotSupported.Error()) {
138		t.Errorf("InvokeRpc should have returned ErrReflectionNotSupported; instead got %v", err)
139	}
140}
141
142func TestProtosetWithImports(t *testing.T) {
143	sourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/example.protoset")
144	if err != nil {
145		t.Fatalf("failed to load protoset: %v", err)
146	}
147	// really shallow check of the loaded descriptors
148	if sd, err := sourceProtoset.FindSymbol("TestService"); err != nil {
149		t.Errorf("failed to find TestService in protoset: %v", err)
150	} else if sd == nil {
151		t.Errorf("FindSymbol returned nil for TestService")
152	} else if _, ok := sd.(*desc.ServiceDescriptor); !ok {
153		t.Errorf("FindSymbol returned wrong kind of descriptor for TestService: %T", sd)
154	}
155	if md, err := sourceProtoset.FindSymbol("TestRequest"); err != nil {
156		t.Errorf("failed to find TestRequest in protoset: %v", err)
157	} else if md == nil {
158		t.Errorf("FindSymbol returned nil for TestRequest")
159	} else if _, ok := md.(*desc.MessageDescriptor); !ok {
160		t.Errorf("FindSymbol returned wrong kind of descriptor for TestRequest: %T", md)
161	}
162}
163
164func TestListServices(t *testing.T) {
165	for _, ds := range descSources {
166		t.Run(ds.name, func(t *testing.T) {
167			doTestListServices(t, ds.source, ds.includeRefl)
168		})
169	}
170}
171
172func doTestListServices(t *testing.T, source DescriptorSource, includeReflection bool) {
173	names, err := ListServices(source)
174	if err != nil {
175		t.Fatalf("failed to list services: %v", err)
176	}
177	var expected []string
178	if includeReflection {
179		// when using server reflection, we see the TestService as well as the ServerReflection service
180		expected = []string{"grpc.reflection.v1alpha.ServerReflection", "testing.TestService"}
181	} else {
182		// without reflection, we see all services defined in the same test.proto file, which is the
183		// TestService as well as UnimplementedService
184		expected = []string{"testing.TestService", "testing.UnimplementedService"}
185	}
186	if !reflect.DeepEqual(expected, names) {
187		t.Errorf("ListServices returned wrong results: wanted %v, got %v", expected, names)
188	}
189}
190
191func TestListMethods(t *testing.T) {
192	for _, ds := range descSources {
193		t.Run(ds.name, func(t *testing.T) {
194			doTestListMethods(t, ds.source, ds.includeRefl)
195		})
196	}
197}
198
199func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection bool) {
200	names, err := ListMethods(source, "testing.TestService")
201	if err != nil {
202		t.Fatalf("failed to list methods for TestService: %v", err)
203	}
204	expected := []string{
205		"testing.TestService.EmptyCall",
206		"testing.TestService.FullDuplexCall",
207		"testing.TestService.HalfDuplexCall",
208		"testing.TestService.StreamingInputCall",
209		"testing.TestService.StreamingOutputCall",
210		"testing.TestService.UnaryCall",
211	}
212	if !reflect.DeepEqual(expected, names) {
213		t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
214	}
215
216	if includeReflection {
217		// when using server reflection, we see the TestService as well as the ServerReflection service
218		names, err = ListMethods(source, "grpc.reflection.v1alpha.ServerReflection")
219		if err != nil {
220			t.Fatalf("failed to list methods for ServerReflection: %v", err)
221		}
222		expected = []string{"grpc.reflection.v1alpha.ServerReflection.ServerReflectionInfo"}
223	} else {
224		// without reflection, we see all services defined in the same test.proto file, which is the
225		// TestService as well as UnimplementedService
226		names, err = ListMethods(source, "testing.UnimplementedService")
227		if err != nil {
228			t.Fatalf("failed to list methods for ServerReflection: %v", err)
229		}
230		expected = []string{"testing.UnimplementedService.UnimplementedCall"}
231	}
232	if !reflect.DeepEqual(expected, names) {
233		t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
234	}
235
236	// force an error
237	_, err = ListMethods(source, "FooService")
238	if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") {
239		t.Errorf("ListMethods should have returned 'not found' error but instead returned %v", err)
240	}
241}
242
243func TestGetAllFiles(t *testing.T) {
244	expectedFiles := []string{"test.proto"}
245	// server reflection picks up filename from linked in Go package,
246	// which indicates "grpc_testing/test.proto", not our local copy.
247	expectedFilesWithReflection := [][]string{
248		{"grpc_reflection_v1alpha/reflection.proto", "test.proto"},
249		// depending on the version of grpc, the filenames could be prefixed with "interop/" and "reflection/"
250		{"reflection/grpc_reflection_v1alpha/reflection.proto", "test.proto"},
251	}
252
253	for _, ds := range descSources {
254		t.Run(ds.name, func(t *testing.T) {
255			files, err := GetAllFiles(ds.source)
256			if err != nil {
257				t.Fatalf("failed to get all files: %v", err)
258			}
259			names := fileNames(files)
260			match := false
261			var expected []string
262			if ds.includeRefl {
263				for _, expectedNames := range expectedFilesWithReflection {
264					expected = expectedNames
265					if reflect.DeepEqual(expected, names) {
266						match = true
267						break
268					}
269				}
270			} else {
271				expected = expectedFiles
272				match = reflect.DeepEqual(expected, names)
273			}
274			if !match {
275				t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expected, names)
276			}
277		})
278	}
279
280	// try cases with more complicated set of files
281	otherSourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/test.protoset", "internal/testing/example.protoset")
282	if err != nil {
283		t.Fatal(err.Error())
284	}
285	otherSourceProtoFiles, err := DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto", "example.proto")
286	if err != nil {
287		t.Fatal(err.Error())
288	}
289	otherDescSources := []descSourceCase{
290		{"protoset[b]", otherSourceProtoset, false},
291		{"proto[b]", otherSourceProtoFiles, false},
292	}
293	expectedFiles = []string{
294		"example.proto",
295		"example2.proto",
296		"google/protobuf/any.proto",
297		"google/protobuf/descriptor.proto",
298		"google/protobuf/empty.proto",
299		"google/protobuf/timestamp.proto",
300		"test.proto",
301	}
302	for _, ds := range otherDescSources {
303		t.Run(ds.name, func(t *testing.T) {
304			files, err := GetAllFiles(ds.source)
305			if err != nil {
306				t.Fatalf("failed to get all files: %v", err)
307			}
308			names := fileNames(files)
309			if !reflect.DeepEqual(expectedFiles, names) {
310				t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expectedFiles, names)
311			}
312		})
313	}
314}
315
316func TestExpandHeaders(t *testing.T) {
317	inHeaders := []string{"key1: ${value}", "key2: bar", "key3: ${woo", "key4: woo}", "key5: ${TEST}",
318		"key6: ${TEST_VAR}", "${TEST}: ${TEST_VAR}", "key8: ${EMPTY}"}
319	os.Setenv("value", "value")
320	os.Setenv("TEST", "value5")
321	os.Setenv("TEST_VAR", "value6")
322	os.Setenv("EMPTY", "")
323	expectedHeaders := map[string]bool{"key1: value": true, "key2: bar": true, "key3: ${woo": true, "key4: woo}": true,
324		"key5: value5": true, "key6: value6": true, "value5: value6": true, "key8: ": true}
325
326	outHeaders, err := ExpandHeaders(inHeaders)
327	if err != nil {
328		t.Errorf("The ExpandHeaders function generated an unexpected error %s", err)
329	}
330	for _, expandedHeader := range outHeaders {
331		if _, ok := expectedHeaders[expandedHeader]; !ok {
332			t.Errorf("The ExpandHeaders function has returned an unexpected header. Received unexpected header %s", expandedHeader)
333		}
334	}
335
336	badHeaders := []string{"key: ${DNE}"}
337	_, err = ExpandHeaders(badHeaders)
338	if err == nil {
339		t.Errorf("The ExpandHeaders function should return an error for missing environment variables %q", badHeaders)
340	}
341}
342
343func fileNames(files []*desc.FileDescriptor) []string {
344	names := make([]string, len(files))
345	for i, f := range files {
346		names[i] = f.GetName()
347	}
348	return names
349}
350
351const expectKnownType = `{
352  "dur": "0s",
353  "ts": "1970-01-01T00:00:00Z",
354  "dbl": 0,
355  "flt": 0,
356  "i64": "0",
357  "u64": "0",
358  "i32": 0,
359  "u32": 0,
360  "bool": false,
361  "str": "",
362  "bytes": null,
363  "st": {"google.protobuf.Struct": "supports arbitrary JSON objects"},
364  "an": {"@type": "type.googleapis.com/google.protobuf.Empty", "value": {}},
365  "lv": [{"google.protobuf.ListValue": "is an array of arbitrary JSON values"}],
366  "val": {"google.protobuf.Value": "supports arbitrary JSON"}
367}`
368
369func TestMakeTemplateKnownTypes(t *testing.T) {
370	descriptor, err := desc.LoadMessageDescriptorForMessage((*jsonpbtest.KnownTypes)(nil))
371	if err != nil {
372		t.Fatalf("failed to load descriptor: %v", err)
373	}
374	message := MakeTemplate(descriptor)
375
376	jsm := jsonpb.Marshaler{EmitDefaults: true}
377	out, err := jsm.MarshalToString(message)
378	if err != nil {
379		t.Fatalf("failed to marshal to JSON: %v", err)
380	}
381
382	// make sure template JSON matches expected
383	var actual, expected interface{}
384	if err := json.Unmarshal([]byte(out), &actual); err != nil {
385		t.Fatalf("failed to parse actual JSON: %v", err)
386	}
387	if err := json.Unmarshal([]byte(expectKnownType), &expected); err != nil {
388		t.Fatalf("failed to parse expected JSON: %v", err)
389	}
390
391	if !reflect.DeepEqual(actual, expected) {
392		t.Errorf("template message is not as expected; want:\n%s\ngot:\n%s", expectKnownType, out)
393	}
394}
395
396func TestDescribe(t *testing.T) {
397	for _, ds := range descSources {
398		t.Run(ds.name, func(t *testing.T) {
399			doTestDescribe(t, ds.source)
400		})
401	}
402}
403
404func doTestDescribe(t *testing.T, source DescriptorSource) {
405	sym := "testing.TestService.EmptyCall"
406	dsc, err := source.FindSymbol(sym)
407	if err != nil {
408		t.Fatalf("failed to get descriptor for %q: %v", sym, err)
409	}
410	if _, ok := dsc.(*desc.MethodDescriptor); !ok {
411		t.Fatalf("descriptor for %q was a %T (expecting a MethodDescriptor)", sym, dsc)
412	}
413	txt := proto.MarshalTextString(dsc.AsProto())
414	expected :=
415		`name: "EmptyCall"
416input_type: ".testing.Empty"
417output_type: ".testing.Empty"
418`
419	if expected != txt {
420		t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt)
421	}
422
423	sym = "testing.StreamingOutputCallResponse"
424	dsc, err = source.FindSymbol(sym)
425	if err != nil {
426		t.Fatalf("failed to get descriptor for %q: %v", sym, err)
427	}
428	if _, ok := dsc.(*desc.MessageDescriptor); !ok {
429		t.Fatalf("descriptor for %q was a %T (expecting a MessageDescriptor)", sym, dsc)
430	}
431	txt = proto.MarshalTextString(dsc.AsProto())
432	expected =
433		`name: "StreamingOutputCallResponse"
434field: <
435  name: "payload"
436  number: 1
437  label: LABEL_OPTIONAL
438  type: TYPE_MESSAGE
439  type_name: ".testing.Payload"
440  json_name: "payload"
441>
442`
443	if expected != txt {
444		t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt)
445	}
446
447	_, err = source.FindSymbol("FooService")
448	if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") {
449		t.Errorf("FindSymbol should have returned 'not found' error but instead returned %v", err)
450	}
451}
452
453const (
454	// type == COMPRESSABLE, but that is default (since it has
455	// numeric value == 0) and thus doesn't actually get included
456	// on the wire
457	payload1 = `{
458  "payload": {
459    "body": "SXQncyBCdXNpbmVzcyBUaW1l"
460  }
461}`
462	payload2 = `{
463  "payload": {
464    "type": "RANDOM",
465    "body": "Rm91eCBkdSBGYUZh"
466  }
467}`
468	payload3 = `{
469  "payload": {
470    "type": "UNCOMPRESSABLE",
471    "body": "SGlwaG9wb3BvdGFtdXMgdnMuIFJoeW1lbm9jZXJvcw=="
472  }
473}`
474)
475
476func getCC(includeRefl bool) *grpc.ClientConn {
477	if includeRefl {
478		return ccReflect
479	} else {
480		return ccNoReflect
481	}
482}
483
484func TestUnary(t *testing.T) {
485	for _, ds := range descSources {
486		t.Run(ds.name, func(t *testing.T) {
487			doTestUnary(t, getCC(ds.includeRefl), ds.source)
488		})
489	}
490}
491
492func doTestUnary(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
493	// Success
494	h := &handler{reqMessages: []string{payload1}}
495	err := InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.OK), h, h.getRequestData)
496	if err != nil {
497		t.Fatalf("unexpected error during RPC: %v", err)
498	}
499
500	if h.check(t, "testing.TestService.UnaryCall", codes.OK, 1, 1) {
501		if h.respMessages[0] != payload1 {
502			t.Errorf("unexpected response from RPC: expecting %s; got %s", payload1, h.respMessages[0])
503		}
504	}
505
506	// Failure
507	h = &handler{reqMessages: []string{payload1}}
508	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.NotFound), h, h.getRequestData)
509	if err != nil {
510		t.Fatalf("unexpected error during RPC: %v", err)
511	}
512
513	h.check(t, "testing.TestService.UnaryCall", codes.NotFound, 1, 0)
514}
515
516func TestClientStream(t *testing.T) {
517	for _, ds := range descSources {
518		t.Run(ds.name, func(t *testing.T) {
519			doTestClientStream(t, getCC(ds.includeRefl), ds.source)
520		})
521	}
522}
523
524func doTestClientStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
525	// Success
526	h := &handler{reqMessages: []string{payload1, payload2, payload3}}
527	err := InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.OK), h, h.getRequestData)
528	if err != nil {
529		t.Fatalf("unexpected error during RPC: %v", err)
530	}
531
532	if h.check(t, "testing.TestService.StreamingInputCall", codes.OK, 3, 1) {
533		expected :=
534			`{
535  "aggregatedPayloadSize": 61
536}`
537		if h.respMessages[0] != expected {
538			t.Errorf("unexpected response from RPC: expecting %s; got %s", expected, h.respMessages[0])
539		}
540	}
541
542	// Fail fast (server rejects as soon as possible)
543	h = &handler{reqMessages: []string{payload1, payload2, payload3}}
544	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.InvalidArgument), h, h.getRequestData)
545	if err != nil {
546		t.Fatalf("unexpected error during RPC: %v", err)
547	}
548
549	h.check(t, "testing.TestService.StreamingInputCall", codes.InvalidArgument, -3, 0)
550
551	// Fail late (server waits until stream is complete to reject)
552	h = &handler{reqMessages: []string{payload1, payload2, payload3}}
553	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.Internal, true), h, h.getRequestData)
554	if err != nil {
555		t.Fatalf("unexpected error during RPC: %v", err)
556	}
557
558	h.check(t, "testing.TestService.StreamingInputCall", codes.Internal, 3, 0)
559}
560
561func TestServerStream(t *testing.T) {
562	for _, ds := range descSources {
563		t.Run(ds.name, func(t *testing.T) {
564			doTestServerStream(t, getCC(ds.includeRefl), ds.source)
565		})
566	}
567}
568
569func doTestServerStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
570	req := &grpcurl_testing.StreamingOutputCallRequest{
571		ResponseType: grpcurl_testing.PayloadType_COMPRESSABLE,
572		ResponseParameters: []*grpcurl_testing.ResponseParameters{
573			{Size: 10}, {Size: 20}, {Size: 30}, {Size: 40}, {Size: 50},
574		},
575	}
576	payload, err := (&jsonpb.Marshaler{}).MarshalToString(req)
577	if err != nil {
578		t.Fatalf("failed to construct request: %v", err)
579	}
580
581	// Success
582	h := &handler{reqMessages: []string{payload}}
583	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.OK), h, h.getRequestData)
584	if err != nil {
585		t.Fatalf("unexpected error during RPC: %v", err)
586	}
587
588	if h.check(t, "testing.TestService.StreamingOutputCall", codes.OK, 1, 5) {
589		resp := &grpcurl_testing.StreamingOutputCallResponse{}
590		for i, msg := range h.respMessages {
591			if err := jsonpb.UnmarshalString(msg, resp); err != nil {
592				t.Errorf("failed to parse response %d: %v", i+1, err)
593			}
594			if resp.Payload.GetType() != grpcurl_testing.PayloadType_COMPRESSABLE {
595				t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_COMPRESSABLE, resp.Payload.Type)
596			}
597			if len(resp.Payload.Body) != (i+1)*10 {
598				t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (i+1)*10, len(resp.Payload.Body))
599			}
600			resp.Reset()
601		}
602	}
603
604	// Fail fast (server rejects as soon as possible)
605	h = &handler{reqMessages: []string{payload}}
606	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.Aborted), h, h.getRequestData)
607	if err != nil {
608		t.Fatalf("unexpected error during RPC: %v", err)
609	}
610
611	h.check(t, "testing.TestService.StreamingOutputCall", codes.Aborted, 1, 0)
612
613	// Fail late (server waits until stream is complete to reject)
614	h = &handler{reqMessages: []string{payload}}
615	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.AlreadyExists, true), h, h.getRequestData)
616	if err != nil {
617		t.Fatalf("unexpected error during RPC: %v", err)
618	}
619
620	h.check(t, "testing.TestService.StreamingOutputCall", codes.AlreadyExists, 1, 5)
621}
622
623func TestHalfDuplexStream(t *testing.T) {
624	for _, ds := range descSources {
625		t.Run(ds.name, func(t *testing.T) {
626			doTestHalfDuplexStream(t, getCC(ds.includeRefl), ds.source)
627		})
628	}
629}
630
631func doTestHalfDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
632	reqs := []string{payload1, payload2, payload3}
633
634	// Success
635	h := &handler{reqMessages: reqs}
636	err := InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.OK), h, h.getRequestData)
637	if err != nil {
638		t.Fatalf("unexpected error during RPC: %v", err)
639	}
640
641	if h.check(t, "testing.TestService.HalfDuplexCall", codes.OK, 3, 3) {
642		for i, resp := range h.respMessages {
643			if resp != reqs[i] {
644				t.Errorf("unexpected response %d from RPC:\nexpecting %q\ngot %q", i, reqs[i], resp)
645			}
646		}
647	}
648
649	// Fail fast (server rejects as soon as possible)
650	h = &handler{reqMessages: reqs}
651	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.Canceled), h, h.getRequestData)
652	if err != nil {
653		t.Fatalf("unexpected error during RPC: %v", err)
654	}
655
656	h.check(t, "testing.TestService.HalfDuplexCall", codes.Canceled, -3, 0)
657
658	// Fail late (server waits until stream is complete to reject)
659	h = &handler{reqMessages: reqs}
660	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.DataLoss, true), h, h.getRequestData)
661	if err != nil {
662		t.Fatalf("unexpected error during RPC: %v", err)
663	}
664
665	h.check(t, "testing.TestService.HalfDuplexCall", codes.DataLoss, 3, 3)
666}
667
668func TestFullDuplexStream(t *testing.T) {
669	for _, ds := range descSources {
670		t.Run(ds.name, func(t *testing.T) {
671			doTestFullDuplexStream(t, getCC(ds.includeRefl), ds.source)
672		})
673	}
674}
675
676func doTestFullDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
677	reqs := make([]string, 3)
678	req := &grpcurl_testing.StreamingOutputCallRequest{
679		ResponseType: grpcurl_testing.PayloadType_RANDOM,
680	}
681	for i := range reqs {
682		req.ResponseParameters = append(req.ResponseParameters, &grpcurl_testing.ResponseParameters{Size: int32((i + 1) * 10)})
683		payload, err := (&jsonpb.Marshaler{}).MarshalToString(req)
684		if err != nil {
685			t.Fatalf("failed to construct request %d: %v", i, err)
686		}
687		reqs[i] = payload
688	}
689
690	// Success
691	h := &handler{reqMessages: reqs}
692	err := InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.OK), h, h.getRequestData)
693	if err != nil {
694		t.Fatalf("unexpected error during RPC: %v", err)
695	}
696
697	if h.check(t, "testing.TestService.FullDuplexCall", codes.OK, 3, 6) {
698		resp := &grpcurl_testing.StreamingOutputCallResponse{}
699		i := 0
700		for j := 1; j < 3; j++ {
701			// three requests
702			for k := 0; k < j; k++ {
703				// 1 response for first request, 2 for second, etc
704				msg := h.respMessages[i]
705				if err := jsonpb.UnmarshalString(msg, resp); err != nil {
706					t.Errorf("failed to parse response %d: %v", i+1, err)
707				}
708				if resp.Payload.GetType() != grpcurl_testing.PayloadType_RANDOM {
709					t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_RANDOM, resp.Payload.Type)
710				}
711				if len(resp.Payload.Body) != (k+1)*10 {
712					t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (k+1)*10, len(resp.Payload.Body))
713				}
714				resp.Reset()
715
716				i++
717			}
718		}
719	}
720
721	// Fail fast (server rejects as soon as possible)
722	h = &handler{reqMessages: reqs}
723	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.PermissionDenied), h, h.getRequestData)
724	if err != nil {
725		t.Fatalf("unexpected error during RPC: %v", err)
726	}
727
728	h.check(t, "testing.TestService.FullDuplexCall", codes.PermissionDenied, -3, 0)
729
730	// Fail late (server waits until stream is complete to reject)
731	h = &handler{reqMessages: reqs}
732	err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.ResourceExhausted, true), h, h.getRequestData)
733	if err != nil {
734		t.Fatalf("unexpected error during RPC: %v", err)
735	}
736
737	h.check(t, "testing.TestService.FullDuplexCall", codes.ResourceExhausted, 3, 6)
738}
739
740type handler struct {
741	method            *desc.MethodDescriptor
742	methodCount       int
743	reqHeaders        metadata.MD
744	reqHeadersCount   int
745	reqMessages       []string
746	reqMessagesCount  int
747	respHeaders       metadata.MD
748	respHeadersCount  int
749	respMessages      []string
750	respTrailers      metadata.MD
751	respStatus        *status.Status
752	respTrailersCount int
753}
754
755func (h *handler) getRequestData() ([]byte, error) {
756	// we don't use a mutex, though this method will be called from different goroutine
757	// than other methods for bidi calls, because this method does not share any state
758	// with the other methods.
759	h.reqMessagesCount++
760	if h.reqMessagesCount > len(h.reqMessages) {
761		return nil, io.EOF
762	}
763	if h.reqMessagesCount > 1 {
764		// insert delay between messages in request stream
765		time.Sleep(time.Millisecond * 50)
766	}
767	return []byte(h.reqMessages[h.reqMessagesCount-1]), nil
768}
769
770func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) {
771	h.methodCount++
772	h.method = md
773}
774
775func (h *handler) OnSendHeaders(md metadata.MD) {
776	h.reqHeadersCount++
777	h.reqHeaders = md
778}
779
780func (h *handler) OnReceiveHeaders(md metadata.MD) {
781	h.respHeadersCount++
782	h.respHeaders = md
783}
784
785func (h *handler) OnReceiveResponse(msg proto.Message) {
786	jsm := jsonpb.Marshaler{Indent: "  "}
787	respStr, err := jsm.MarshalToString(msg)
788	if err != nil {
789		panic(fmt.Errorf("failed to generate JSON form of response message: %v", err))
790	}
791	h.respMessages = append(h.respMessages, respStr)
792}
793
794func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) {
795	h.respTrailersCount++
796	h.respTrailers = md
797	h.respStatus = stat
798}
799
800func (h *handler) check(t *testing.T, expectedMethod string, expectedCode codes.Code, expectedRequestQueries, expectedResponses int) bool {
801	// verify a few things were only ever called once
802	if h.methodCount != 1 {
803		t.Errorf("expected grpcurl to invoke OnResolveMethod once; was %d", h.methodCount)
804	}
805	if h.reqHeadersCount != 1 {
806		t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount)
807	}
808	if h.reqHeadersCount != 1 {
809		t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount)
810	}
811	if h.respHeadersCount != 1 {
812		t.Errorf("expected grpcurl to invoke OnReceiveHeaders once; was %d", h.respHeadersCount)
813	}
814	if h.respTrailersCount != 1 {
815		t.Errorf("expected grpcurl to invoke OnReceiveTrailers once; was %d", h.respTrailersCount)
816	}
817
818	// check other stuff against given expectations
819	if h.method.GetFullyQualifiedName() != expectedMethod {
820		t.Errorf("wrong method: expecting %v, got %v", expectedMethod, h.method.GetFullyQualifiedName())
821	}
822	if h.respStatus.Code() != expectedCode {
823		t.Errorf("wrong code: expecting %v, got %v", expectedCode, h.respStatus.Code())
824	}
825	if expectedRequestQueries < 0 {
826		// negative expectation means "negate and expect up to that number; could be fewer"
827		if h.reqMessagesCount > -expectedRequestQueries+1 {
828			// the + 1 is because there will be an extra query that returns EOF
829			t.Errorf("wrong number of messages queried: expecting no more than %v, got %v", -expectedRequestQueries, h.reqMessagesCount-1)
830		}
831	} else {
832		if h.reqMessagesCount != expectedRequestQueries+1 {
833			// the + 1 is because there will be an extra query that returns EOF
834			t.Errorf("wrong number of messages queried: expecting %v, got %v", expectedRequestQueries, h.reqMessagesCount-1)
835		}
836	}
837	if len(h.respMessages) != expectedResponses {
838		t.Errorf("wrong number of messages received: expecting %v, got %v", expectedResponses, len(h.respMessages))
839	}
840
841	// also check headers and trailers came through as expected
842	v := h.respHeaders["some-fake-header-1"]
843	if len(v) != 1 || v[0] != "val1" {
844		t.Errorf("wrong request header for %q: %v", "some-fake-header-1", v)
845	}
846	v = h.respHeaders["some-fake-header-2"]
847	if len(v) != 1 || v[0] != "val2" {
848		t.Errorf("wrong request header for %q: %v", "some-fake-header-2", v)
849	}
850	v = h.respTrailers["some-fake-trailer-1"]
851	if len(v) != 1 || v[0] != "valA" {
852		t.Errorf("wrong request header for %q: %v", "some-fake-trailer-1", v)
853	}
854	v = h.respTrailers["some-fake-trailer-2"]
855	if len(v) != 1 || v[0] != "valB" {
856		t.Errorf("wrong request header for %q: %v", "some-fake-trailer-2", v)
857	}
858
859	return len(h.respMessages) == expectedResponses
860}
861
862func makeHeaders(code codes.Code, failLate ...bool) []string {
863	if len(failLate) > 1 {
864		panic("incorrect use of makeContext; should be at most one failLate flag")
865	}
866
867	hdrs := append(make([]string, 0, 5),
868		fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-1: val1"),
869		fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-2: val2"),
870		fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-1: valA"),
871		fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-2: valB"))
872	if code != codes.OK {
873		if len(failLate) > 0 && failLate[0] {
874			hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailLate, code))
875		} else {
876			hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailEarly, code))
877		}
878	}
879
880	return hdrs
881}
882