1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package interop contains functions used by interop client/server.
20package interop
21
22import (
23	"context"
24	"fmt"
25	"io"
26	"io/ioutil"
27	"strings"
28	"time"
29
30	"github.com/golang/protobuf/proto"
31	"golang.org/x/oauth2"
32	"golang.org/x/oauth2/google"
33	"google.golang.org/grpc"
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/grpclog"
36	testpb "google.golang.org/grpc/interop/grpc_testing"
37	"google.golang.org/grpc/metadata"
38	"google.golang.org/grpc/status"
39)
40
41var (
42	reqSizes            = []int{27182, 8, 1828, 45904}
43	respSizes           = []int{31415, 9, 2653, 58979}
44	largeReqSize        = 271828
45	largeRespSize       = 314159
46	initialMetadataKey  = "x-grpc-test-echo-initial"
47	trailingMetadataKey = "x-grpc-test-echo-trailing-bin"
48
49	logger = grpclog.Component("interop")
50)
51
52// ClientNewPayload returns a payload of the given type and size.
53func ClientNewPayload(t testpb.PayloadType, size int) *testpb.Payload {
54	if size < 0 {
55		logger.Fatalf("Requested a response with invalid length %d", size)
56	}
57	body := make([]byte, size)
58	switch t {
59	case testpb.PayloadType_COMPRESSABLE:
60	case testpb.PayloadType_UNCOMPRESSABLE:
61		logger.Fatalf("PayloadType UNCOMPRESSABLE is not supported")
62	default:
63		logger.Fatalf("Unsupported payload type: %d", t)
64	}
65	return &testpb.Payload{
66		Type: t,
67		Body: body,
68	}
69}
70
71// DoEmptyUnaryCall performs a unary RPC with empty request and response messages.
72func DoEmptyUnaryCall(tc testpb.TestServiceClient, args ...grpc.CallOption) {
73	reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, args...)
74	if err != nil {
75		logger.Fatal("/TestService/EmptyCall RPC failed: ", err)
76	}
77	if !proto.Equal(&testpb.Empty{}, reply) {
78		logger.Fatalf("/TestService/EmptyCall receives %v, want %v", reply, testpb.Empty{})
79	}
80}
81
82// DoLargeUnaryCall performs a unary RPC with large payload in the request and response.
83func DoLargeUnaryCall(tc testpb.TestServiceClient, args ...grpc.CallOption) {
84	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
85	req := &testpb.SimpleRequest{
86		ResponseType: testpb.PayloadType_COMPRESSABLE,
87		ResponseSize: int32(largeRespSize),
88		Payload:      pl,
89	}
90	reply, err := tc.UnaryCall(context.Background(), req, args...)
91	if err != nil {
92		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
93	}
94	t := reply.GetPayload().GetType()
95	s := len(reply.GetPayload().GetBody())
96	if t != testpb.PayloadType_COMPRESSABLE || s != largeRespSize {
97		logger.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, largeRespSize)
98	}
99}
100
101// DoClientStreaming performs a client streaming RPC.
102func DoClientStreaming(tc testpb.TestServiceClient, args ...grpc.CallOption) {
103	stream, err := tc.StreamingInputCall(context.Background(), args...)
104	if err != nil {
105		logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
106	}
107	var sum int
108	for _, s := range reqSizes {
109		pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, s)
110		req := &testpb.StreamingInputCallRequest{
111			Payload: pl,
112		}
113		if err := stream.Send(req); err != nil {
114			logger.Fatalf("%v has error %v while sending %v", stream, err, req)
115		}
116		sum += s
117	}
118	reply, err := stream.CloseAndRecv()
119	if err != nil {
120		logger.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil)
121	}
122	if reply.GetAggregatedPayloadSize() != int32(sum) {
123		logger.Fatalf("%v.CloseAndRecv().GetAggregatePayloadSize() = %v; want %v", stream, reply.GetAggregatedPayloadSize(), sum)
124	}
125}
126
127// DoServerStreaming performs a server streaming RPC.
128func DoServerStreaming(tc testpb.TestServiceClient, args ...grpc.CallOption) {
129	respParam := make([]*testpb.ResponseParameters, len(respSizes))
130	for i, s := range respSizes {
131		respParam[i] = &testpb.ResponseParameters{
132			Size: int32(s),
133		}
134	}
135	req := &testpb.StreamingOutputCallRequest{
136		ResponseType:       testpb.PayloadType_COMPRESSABLE,
137		ResponseParameters: respParam,
138	}
139	stream, err := tc.StreamingOutputCall(context.Background(), req, args...)
140	if err != nil {
141		logger.Fatalf("%v.StreamingOutputCall(_) = _, %v", tc, err)
142	}
143	var rpcStatus error
144	var respCnt int
145	var index int
146	for {
147		reply, err := stream.Recv()
148		if err != nil {
149			rpcStatus = err
150			break
151		}
152		t := reply.GetPayload().GetType()
153		if t != testpb.PayloadType_COMPRESSABLE {
154			logger.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE)
155		}
156		size := len(reply.GetPayload().GetBody())
157		if size != respSizes[index] {
158			logger.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
159		}
160		index++
161		respCnt++
162	}
163	if rpcStatus != io.EOF {
164		logger.Fatalf("Failed to finish the server streaming rpc: %v", rpcStatus)
165	}
166	if respCnt != len(respSizes) {
167		logger.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
168	}
169}
170
171// DoPingPong performs ping-pong style bi-directional streaming RPC.
172func DoPingPong(tc testpb.TestServiceClient, args ...grpc.CallOption) {
173	stream, err := tc.FullDuplexCall(context.Background(), args...)
174	if err != nil {
175		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
176	}
177	var index int
178	for index < len(reqSizes) {
179		respParam := []*testpb.ResponseParameters{
180			{
181				Size: int32(respSizes[index]),
182			},
183		}
184		pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, reqSizes[index])
185		req := &testpb.StreamingOutputCallRequest{
186			ResponseType:       testpb.PayloadType_COMPRESSABLE,
187			ResponseParameters: respParam,
188			Payload:            pl,
189		}
190		if err := stream.Send(req); err != nil {
191			logger.Fatalf("%v has error %v while sending %v", stream, err, req)
192		}
193		reply, err := stream.Recv()
194		if err != nil {
195			logger.Fatalf("%v.Recv() = %v", stream, err)
196		}
197		t := reply.GetPayload().GetType()
198		if t != testpb.PayloadType_COMPRESSABLE {
199			logger.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE)
200		}
201		size := len(reply.GetPayload().GetBody())
202		if size != respSizes[index] {
203			logger.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
204		}
205		index++
206	}
207	if err := stream.CloseSend(); err != nil {
208		logger.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
209	}
210	if _, err := stream.Recv(); err != io.EOF {
211		logger.Fatalf("%v failed to complele the ping pong test: %v", stream, err)
212	}
213}
214
215// DoEmptyStream sets up a bi-directional streaming with zero message.
216func DoEmptyStream(tc testpb.TestServiceClient, args ...grpc.CallOption) {
217	stream, err := tc.FullDuplexCall(context.Background(), args...)
218	if err != nil {
219		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
220	}
221	if err := stream.CloseSend(); err != nil {
222		logger.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
223	}
224	if _, err := stream.Recv(); err != io.EOF {
225		logger.Fatalf("%v failed to complete the empty stream test: %v", stream, err)
226	}
227}
228
229// DoTimeoutOnSleepingServer performs an RPC on a sleep server which causes RPC timeout.
230func DoTimeoutOnSleepingServer(tc testpb.TestServiceClient, args ...grpc.CallOption) {
231	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
232	defer cancel()
233	stream, err := tc.FullDuplexCall(ctx, args...)
234	if err != nil {
235		if status.Code(err) == codes.DeadlineExceeded {
236			return
237		}
238		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
239	}
240	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 27182)
241	req := &testpb.StreamingOutputCallRequest{
242		ResponseType: testpb.PayloadType_COMPRESSABLE,
243		Payload:      pl,
244	}
245	if err := stream.Send(req); err != nil && err != io.EOF {
246		logger.Fatalf("%v.Send(_) = %v", stream, err)
247	}
248	if _, err := stream.Recv(); status.Code(err) != codes.DeadlineExceeded {
249		logger.Fatalf("%v.Recv() = _, %v, want error code %d", stream, err, codes.DeadlineExceeded)
250	}
251}
252
253// DoComputeEngineCreds performs a unary RPC with compute engine auth.
254func DoComputeEngineCreds(tc testpb.TestServiceClient, serviceAccount, oauthScope string) {
255	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
256	req := &testpb.SimpleRequest{
257		ResponseType:   testpb.PayloadType_COMPRESSABLE,
258		ResponseSize:   int32(largeRespSize),
259		Payload:        pl,
260		FillUsername:   true,
261		FillOauthScope: true,
262	}
263	reply, err := tc.UnaryCall(context.Background(), req)
264	if err != nil {
265		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
266	}
267	user := reply.GetUsername()
268	scope := reply.GetOauthScope()
269	if user != serviceAccount {
270		logger.Fatalf("Got user name %q, want %q.", user, serviceAccount)
271	}
272	if !strings.Contains(oauthScope, scope) {
273		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
274	}
275}
276
277func getServiceAccountJSONKey(keyFile string) []byte {
278	jsonKey, err := ioutil.ReadFile(keyFile)
279	if err != nil {
280		logger.Fatalf("Failed to read the service account key file: %v", err)
281	}
282	return jsonKey
283}
284
285// DoServiceAccountCreds performs a unary RPC with service account auth.
286func DoServiceAccountCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScope string) {
287	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
288	req := &testpb.SimpleRequest{
289		ResponseType:   testpb.PayloadType_COMPRESSABLE,
290		ResponseSize:   int32(largeRespSize),
291		Payload:        pl,
292		FillUsername:   true,
293		FillOauthScope: true,
294	}
295	reply, err := tc.UnaryCall(context.Background(), req)
296	if err != nil {
297		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
298	}
299	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
300	user := reply.GetUsername()
301	scope := reply.GetOauthScope()
302	if !strings.Contains(string(jsonKey), user) {
303		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
304	}
305	if !strings.Contains(oauthScope, scope) {
306		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
307	}
308}
309
310// DoJWTTokenCreds performs a unary RPC with JWT token auth.
311func DoJWTTokenCreds(tc testpb.TestServiceClient, serviceAccountKeyFile string) {
312	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
313	req := &testpb.SimpleRequest{
314		ResponseType: testpb.PayloadType_COMPRESSABLE,
315		ResponseSize: int32(largeRespSize),
316		Payload:      pl,
317		FillUsername: true,
318	}
319	reply, err := tc.UnaryCall(context.Background(), req)
320	if err != nil {
321		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
322	}
323	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
324	user := reply.GetUsername()
325	if !strings.Contains(string(jsonKey), user) {
326		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
327	}
328}
329
330// GetToken obtains an OAUTH token from the input.
331func GetToken(serviceAccountKeyFile string, oauthScope string) *oauth2.Token {
332	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
333	config, err := google.JWTConfigFromJSON(jsonKey, oauthScope)
334	if err != nil {
335		logger.Fatalf("Failed to get the config: %v", err)
336	}
337	token, err := config.TokenSource(context.Background()).Token()
338	if err != nil {
339		logger.Fatalf("Failed to get the token: %v", err)
340	}
341	return token
342}
343
344// DoOauth2TokenCreds performs a unary RPC with OAUTH2 token auth.
345func DoOauth2TokenCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScope string) {
346	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
347	req := &testpb.SimpleRequest{
348		ResponseType:   testpb.PayloadType_COMPRESSABLE,
349		ResponseSize:   int32(largeRespSize),
350		Payload:        pl,
351		FillUsername:   true,
352		FillOauthScope: true,
353	}
354	reply, err := tc.UnaryCall(context.Background(), req)
355	if err != nil {
356		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
357	}
358	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
359	user := reply.GetUsername()
360	scope := reply.GetOauthScope()
361	if !strings.Contains(string(jsonKey), user) {
362		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
363	}
364	if !strings.Contains(oauthScope, scope) {
365		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
366	}
367}
368
369// DoPerRPCCreds performs a unary RPC with per RPC OAUTH2 token.
370func DoPerRPCCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScope string) {
371	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
372	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
373	req := &testpb.SimpleRequest{
374		ResponseType:   testpb.PayloadType_COMPRESSABLE,
375		ResponseSize:   int32(largeRespSize),
376		Payload:        pl,
377		FillUsername:   true,
378		FillOauthScope: true,
379	}
380	token := GetToken(serviceAccountKeyFile, oauthScope)
381	kv := map[string]string{"authorization": token.Type() + " " + token.AccessToken}
382	ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
383	reply, err := tc.UnaryCall(ctx, req)
384	if err != nil {
385		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
386	}
387	user := reply.GetUsername()
388	scope := reply.GetOauthScope()
389	if !strings.Contains(string(jsonKey), user) {
390		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
391	}
392	if !strings.Contains(oauthScope, scope) {
393		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
394	}
395}
396
397// DoGoogleDefaultCredentials performs an unary RPC with google default credentials
398func DoGoogleDefaultCredentials(tc testpb.TestServiceClient, defaultServiceAccount string) {
399	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
400	req := &testpb.SimpleRequest{
401		ResponseType:   testpb.PayloadType_COMPRESSABLE,
402		ResponseSize:   int32(largeRespSize),
403		Payload:        pl,
404		FillUsername:   true,
405		FillOauthScope: true,
406	}
407	reply, err := tc.UnaryCall(context.Background(), req)
408	if err != nil {
409		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
410	}
411	if reply.GetUsername() != defaultServiceAccount {
412		logger.Fatalf("Got user name %q; wanted %q. ", reply.GetUsername(), defaultServiceAccount)
413	}
414}
415
416// DoComputeEngineChannelCredentials performs an unary RPC with compute engine channel credentials
417func DoComputeEngineChannelCredentials(tc testpb.TestServiceClient, defaultServiceAccount string) {
418	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
419	req := &testpb.SimpleRequest{
420		ResponseType:   testpb.PayloadType_COMPRESSABLE,
421		ResponseSize:   int32(largeRespSize),
422		Payload:        pl,
423		FillUsername:   true,
424		FillOauthScope: true,
425	}
426	reply, err := tc.UnaryCall(context.Background(), req)
427	if err != nil {
428		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
429	}
430	if reply.GetUsername() != defaultServiceAccount {
431		logger.Fatalf("Got user name %q; wanted %q. ", reply.GetUsername(), defaultServiceAccount)
432	}
433}
434
435var testMetadata = metadata.MD{
436	"key1": []string{"value1"},
437	"key2": []string{"value2"},
438}
439
440// DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent.
441func DoCancelAfterBegin(tc testpb.TestServiceClient, args ...grpc.CallOption) {
442	ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(context.Background(), testMetadata))
443	stream, err := tc.StreamingInputCall(ctx, args...)
444	if err != nil {
445		logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
446	}
447	cancel()
448	_, err = stream.CloseAndRecv()
449	if status.Code(err) != codes.Canceled {
450		logger.Fatalf("%v.CloseAndRecv() got error code %d, want %d", stream, status.Code(err), codes.Canceled)
451	}
452}
453
454// DoCancelAfterFirstResponse cancels the RPC after receiving the first message from the server.
455func DoCancelAfterFirstResponse(tc testpb.TestServiceClient, args ...grpc.CallOption) {
456	ctx, cancel := context.WithCancel(context.Background())
457	stream, err := tc.FullDuplexCall(ctx, args...)
458	if err != nil {
459		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
460	}
461	respParam := []*testpb.ResponseParameters{
462		{
463			Size: 31415,
464		},
465	}
466	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 27182)
467	req := &testpb.StreamingOutputCallRequest{
468		ResponseType:       testpb.PayloadType_COMPRESSABLE,
469		ResponseParameters: respParam,
470		Payload:            pl,
471	}
472	if err := stream.Send(req); err != nil {
473		logger.Fatalf("%v has error %v while sending %v", stream, err, req)
474	}
475	if _, err := stream.Recv(); err != nil {
476		logger.Fatalf("%v.Recv() = %v", stream, err)
477	}
478	cancel()
479	if _, err := stream.Recv(); status.Code(err) != codes.Canceled {
480		logger.Fatalf("%v compleled with error code %d, want %d", stream, status.Code(err), codes.Canceled)
481	}
482}
483
484var (
485	initialMetadataValue  = "test_initial_metadata_value"
486	trailingMetadataValue = "\x0a\x0b\x0a\x0b\x0a\x0b"
487	customMetadata        = metadata.Pairs(
488		initialMetadataKey, initialMetadataValue,
489		trailingMetadataKey, trailingMetadataValue,
490	)
491)
492
493func validateMetadata(header, trailer metadata.MD) {
494	if len(header[initialMetadataKey]) != 1 {
495		logger.Fatalf("Expected exactly one header from server. Received %d", len(header[initialMetadataKey]))
496	}
497	if header[initialMetadataKey][0] != initialMetadataValue {
498		logger.Fatalf("Got header %s; want %s", header[initialMetadataKey][0], initialMetadataValue)
499	}
500	if len(trailer[trailingMetadataKey]) != 1 {
501		logger.Fatalf("Expected exactly one trailer from server. Received %d", len(trailer[trailingMetadataKey]))
502	}
503	if trailer[trailingMetadataKey][0] != trailingMetadataValue {
504		logger.Fatalf("Got trailer %s; want %s", trailer[trailingMetadataKey][0], trailingMetadataValue)
505	}
506}
507
508// DoCustomMetadata checks that metadata is echoed back to the client.
509func DoCustomMetadata(tc testpb.TestServiceClient, args ...grpc.CallOption) {
510	// Testing with UnaryCall.
511	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 1)
512	req := &testpb.SimpleRequest{
513		ResponseType: testpb.PayloadType_COMPRESSABLE,
514		ResponseSize: int32(1),
515		Payload:      pl,
516	}
517	ctx := metadata.NewOutgoingContext(context.Background(), customMetadata)
518	var header, trailer metadata.MD
519	args = append(args, grpc.Header(&header), grpc.Trailer(&trailer))
520	reply, err := tc.UnaryCall(
521		ctx,
522		req,
523		args...,
524	)
525	if err != nil {
526		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
527	}
528	t := reply.GetPayload().GetType()
529	s := len(reply.GetPayload().GetBody())
530	if t != testpb.PayloadType_COMPRESSABLE || s != 1 {
531		logger.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, 1)
532	}
533	validateMetadata(header, trailer)
534
535	// Testing with FullDuplex.
536	stream, err := tc.FullDuplexCall(ctx, args...)
537	if err != nil {
538		logger.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
539	}
540	respParam := []*testpb.ResponseParameters{
541		{
542			Size: 1,
543		},
544	}
545	streamReq := &testpb.StreamingOutputCallRequest{
546		ResponseType:       testpb.PayloadType_COMPRESSABLE,
547		ResponseParameters: respParam,
548		Payload:            pl,
549	}
550	if err := stream.Send(streamReq); err != nil {
551		logger.Fatalf("%v has error %v while sending %v", stream, err, streamReq)
552	}
553	streamHeader, err := stream.Header()
554	if err != nil {
555		logger.Fatalf("%v.Header() = %v", stream, err)
556	}
557	if _, err := stream.Recv(); err != nil {
558		logger.Fatalf("%v.Recv() = %v", stream, err)
559	}
560	if err := stream.CloseSend(); err != nil {
561		logger.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
562	}
563	if _, err := stream.Recv(); err != io.EOF {
564		logger.Fatalf("%v failed to complete the custom metadata test: %v", stream, err)
565	}
566	streamTrailer := stream.Trailer()
567	validateMetadata(streamHeader, streamTrailer)
568}
569
570// DoStatusCodeAndMessage checks that the status code is propagated back to the client.
571func DoStatusCodeAndMessage(tc testpb.TestServiceClient, args ...grpc.CallOption) {
572	var code int32 = 2
573	msg := "test status message"
574	expectedErr := status.Error(codes.Code(code), msg)
575	respStatus := &testpb.EchoStatus{
576		Code:    code,
577		Message: msg,
578	}
579	// Test UnaryCall.
580	req := &testpb.SimpleRequest{
581		ResponseStatus: respStatus,
582	}
583	if _, err := tc.UnaryCall(context.Background(), req, args...); err == nil || err.Error() != expectedErr.Error() {
584		logger.Fatalf("%v.UnaryCall(_, %v) = _, %v, want _, %v", tc, req, err, expectedErr)
585	}
586	// Test FullDuplexCall.
587	stream, err := tc.FullDuplexCall(context.Background(), args...)
588	if err != nil {
589		logger.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
590	}
591	streamReq := &testpb.StreamingOutputCallRequest{
592		ResponseStatus: respStatus,
593	}
594	if err := stream.Send(streamReq); err != nil {
595		logger.Fatalf("%v has error %v while sending %v, want <nil>", stream, err, streamReq)
596	}
597	if err := stream.CloseSend(); err != nil {
598		logger.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
599	}
600	if _, err = stream.Recv(); err.Error() != expectedErr.Error() {
601		logger.Fatalf("%v.Recv() returned error %v, want %v", stream, err, expectedErr)
602	}
603}
604
605// DoSpecialStatusMessage verifies Unicode and whitespace is correctly processed
606// in status message.
607func DoSpecialStatusMessage(tc testpb.TestServiceClient, args ...grpc.CallOption) {
608	const (
609		code int32  = 2
610		msg  string = "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP ��\t\n"
611	)
612	expectedErr := status.Error(codes.Code(code), msg)
613	req := &testpb.SimpleRequest{
614		ResponseStatus: &testpb.EchoStatus{
615			Code:    code,
616			Message: msg,
617		},
618	}
619	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
620	defer cancel()
621	if _, err := tc.UnaryCall(ctx, req, args...); err == nil || err.Error() != expectedErr.Error() {
622		logger.Fatalf("%v.UnaryCall(_, %v) = _, %v, want _, %v", tc, req, err, expectedErr)
623	}
624}
625
626// DoUnimplementedService attempts to call a method from an unimplemented service.
627func DoUnimplementedService(tc testpb.UnimplementedServiceClient) {
628	_, err := tc.UnimplementedCall(context.Background(), &testpb.Empty{})
629	if status.Code(err) != codes.Unimplemented {
630		logger.Fatalf("%v.UnimplementedCall() = _, %v, want _, %v", tc, status.Code(err), codes.Unimplemented)
631	}
632}
633
634// DoUnimplementedMethod attempts to call an unimplemented method.
635func DoUnimplementedMethod(cc *grpc.ClientConn) {
636	var req, reply proto.Message
637	if err := cc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", req, reply); err == nil || status.Code(err) != codes.Unimplemented {
638		logger.Fatalf("ClientConn.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
639	}
640}
641
642// DoPickFirstUnary runs multiple RPCs (rpcCount) and checks that all requests
643// are sent to the same backend.
644func DoPickFirstUnary(tc testpb.TestServiceClient) {
645	const rpcCount = 100
646
647	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 1)
648	req := &testpb.SimpleRequest{
649		ResponseType: testpb.PayloadType_COMPRESSABLE,
650		ResponseSize: int32(1),
651		Payload:      pl,
652		FillServerId: true,
653	}
654	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
655	defer cancel()
656	var serverID string
657	for i := 0; i < rpcCount; i++ {
658		resp, err := tc.UnaryCall(ctx, req)
659		if err != nil {
660			logger.Fatalf("iteration %d, failed to do UnaryCall: %v", i, err)
661		}
662		id := resp.ServerId
663		if id == "" {
664			logger.Fatalf("iteration %d, got empty server ID", i)
665		}
666		if i == 0 {
667			serverID = id
668			continue
669		}
670		if serverID != id {
671			logger.Fatalf("iteration %d, got different server ids: %q vs %q", i, serverID, id)
672		}
673	}
674}
675
676type testServer struct{}
677
678// NewTestServer creates a test server for test service.
679func NewTestServer() *testpb.TestServiceService {
680	return testpb.NewTestServiceService(testpb.UnstableTestServiceService(&testServer{}))
681}
682
683func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
684	return new(testpb.Empty), nil
685}
686
687func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) {
688	if size < 0 {
689		return nil, fmt.Errorf("requested a response with invalid length %d", size)
690	}
691	body := make([]byte, size)
692	switch t {
693	case testpb.PayloadType_COMPRESSABLE:
694	case testpb.PayloadType_UNCOMPRESSABLE:
695		return nil, fmt.Errorf("payloadType UNCOMPRESSABLE is not supported")
696	default:
697		return nil, fmt.Errorf("unsupported payload type: %d", t)
698	}
699	return &testpb.Payload{
700		Type: t,
701		Body: body,
702	}, nil
703}
704
705func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
706	st := in.GetResponseStatus()
707	if md, ok := metadata.FromIncomingContext(ctx); ok {
708		if initialMetadata, ok := md[initialMetadataKey]; ok {
709			header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
710			grpc.SendHeader(ctx, header)
711		}
712		if trailingMetadata, ok := md[trailingMetadataKey]; ok {
713			trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
714			grpc.SetTrailer(ctx, trailer)
715		}
716	}
717	if st != nil && st.Code != 0 {
718		return nil, status.Error(codes.Code(st.Code), st.Message)
719	}
720	pl, err := serverNewPayload(in.GetResponseType(), in.GetResponseSize())
721	if err != nil {
722		return nil, err
723	}
724	return &testpb.SimpleResponse{
725		Payload: pl,
726	}, nil
727}
728
729func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
730	cs := args.GetResponseParameters()
731	for _, c := range cs {
732		if us := c.GetIntervalUs(); us > 0 {
733			time.Sleep(time.Duration(us) * time.Microsecond)
734		}
735		pl, err := serverNewPayload(args.GetResponseType(), c.GetSize())
736		if err != nil {
737			return err
738		}
739		if err := stream.Send(&testpb.StreamingOutputCallResponse{
740			Payload: pl,
741		}); err != nil {
742			return err
743		}
744	}
745	return nil
746}
747
748func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error {
749	var sum int
750	for {
751		in, err := stream.Recv()
752		if err == io.EOF {
753			return stream.SendAndClose(&testpb.StreamingInputCallResponse{
754				AggregatedPayloadSize: int32(sum),
755			})
756		}
757		if err != nil {
758			return err
759		}
760		p := in.GetPayload().GetBody()
761		sum += len(p)
762	}
763}
764
765func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
766	if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
767		if initialMetadata, ok := md[initialMetadataKey]; ok {
768			header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
769			stream.SendHeader(header)
770		}
771		if trailingMetadata, ok := md[trailingMetadataKey]; ok {
772			trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
773			stream.SetTrailer(trailer)
774		}
775	}
776	for {
777		in, err := stream.Recv()
778		if err == io.EOF {
779			// read done.
780			return nil
781		}
782		if err != nil {
783			return err
784		}
785		st := in.GetResponseStatus()
786		if st != nil && st.Code != 0 {
787			return status.Error(codes.Code(st.Code), st.Message)
788		}
789		cs := in.GetResponseParameters()
790		for _, c := range cs {
791			if us := c.GetIntervalUs(); us > 0 {
792				time.Sleep(time.Duration(us) * time.Microsecond)
793			}
794			pl, err := serverNewPayload(in.GetResponseType(), c.GetSize())
795			if err != nil {
796				return err
797			}
798			if err := stream.Send(&testpb.StreamingOutputCallResponse{
799				Payload: pl,
800			}); err != nil {
801				return err
802			}
803		}
804	}
805}
806
807func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServer) error {
808	var msgBuf []*testpb.StreamingOutputCallRequest
809	for {
810		in, err := stream.Recv()
811		if err == io.EOF {
812			// read done.
813			break
814		}
815		if err != nil {
816			return err
817		}
818		msgBuf = append(msgBuf, in)
819	}
820	for _, m := range msgBuf {
821		cs := m.GetResponseParameters()
822		for _, c := range cs {
823			if us := c.GetIntervalUs(); us > 0 {
824				time.Sleep(time.Duration(us) * time.Microsecond)
825			}
826			pl, err := serverNewPayload(m.GetResponseType(), c.GetSize())
827			if err != nil {
828				return err
829			}
830			if err := stream.Send(&testpb.StreamingOutputCallResponse{
831				Payload: pl,
832			}); err != nil {
833				return err
834			}
835		}
836	}
837	return nil
838}
839