1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpc
20
21import (
22	"context"
23	"fmt"
24	"io"
25	"math"
26	"net"
27	"strconv"
28	"strings"
29	"sync"
30	"testing"
31	"time"
32
33	"google.golang.org/grpc/codes"
34	"google.golang.org/grpc/internal/transport"
35	"google.golang.org/grpc/status"
36)
37
38var (
39	expectedRequest  = "ping"
40	expectedResponse = "pong"
41	weirdError       = "format verbs: %v%s"
42	sizeLargeErr     = 1024 * 1024
43	canceled         = 0
44)
45
46const defaultTestTimeout = 10 * time.Second
47
48type testCodec struct {
49}
50
51func (testCodec) Marshal(v interface{}) ([]byte, error) {
52	return []byte(*(v.(*string))), nil
53}
54
55func (testCodec) Unmarshal(data []byte, v interface{}) error {
56	*(v.(*string)) = string(data)
57	return nil
58}
59
60func (testCodec) String() string {
61	return "test"
62}
63
64type testStreamHandler struct {
65	port string
66	t    transport.ServerTransport
67}
68
69func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
70	p := &parser{r: s}
71	for {
72		pf, req, err := p.recvMsg(math.MaxInt32)
73		if err == io.EOF {
74			break
75		}
76		if err != nil {
77			return
78		}
79		if pf != compressionNone {
80			t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
81			return
82		}
83		var v string
84		codec := testCodec{}
85		if err := codec.Unmarshal(req, &v); err != nil {
86			t.Errorf("Failed to unmarshal the received message: %v", err)
87			return
88		}
89		if v == "weird error" {
90			h.t.WriteStatus(s, status.New(codes.Internal, weirdError))
91			return
92		}
93		if v == "canceled" {
94			canceled++
95			h.t.WriteStatus(s, status.New(codes.Internal, ""))
96			return
97		}
98		if v == "port" {
99			h.t.WriteStatus(s, status.New(codes.Internal, h.port))
100			return
101		}
102
103		if v != expectedRequest {
104			h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr)))
105			return
106		}
107	}
108	// send a response back to end the stream.
109	data, err := encode(testCodec{}, &expectedResponse)
110	if err != nil {
111		t.Errorf("Failed to encode the response: %v", err)
112		return
113	}
114	hdr, payload := msgHeader(data, nil)
115	h.t.Write(s, hdr, payload, &transport.Options{})
116	h.t.WriteStatus(s, status.New(codes.OK, ""))
117}
118
119type server struct {
120	lis        net.Listener
121	port       string
122	addr       string
123	startedErr chan error // sent nil or an error after server starts
124	mu         sync.Mutex
125	conns      map[transport.ServerTransport]bool
126}
127
128type ctxKey string
129
130func newTestServer() *server {
131	return &server{startedErr: make(chan error, 1)}
132}
133
134// start starts server. Other goroutines should block on s.startedErr for further operations.
135func (s *server) start(t *testing.T, port int, maxStreams uint32) {
136	var err error
137	if port == 0 {
138		s.lis, err = net.Listen("tcp", "localhost:0")
139	} else {
140		s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
141	}
142	if err != nil {
143		s.startedErr <- fmt.Errorf("failed to listen: %v", err)
144		return
145	}
146	s.addr = s.lis.Addr().String()
147	_, p, err := net.SplitHostPort(s.addr)
148	if err != nil {
149		s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
150		return
151	}
152	s.port = p
153	s.conns = make(map[transport.ServerTransport]bool)
154	s.startedErr <- nil
155	for {
156		conn, err := s.lis.Accept()
157		if err != nil {
158			return
159		}
160		config := &transport.ServerConfig{
161			MaxStreams: maxStreams,
162		}
163		st, err := transport.NewServerTransport(conn, config)
164		if err != nil {
165			continue
166		}
167		s.mu.Lock()
168		if s.conns == nil {
169			s.mu.Unlock()
170			st.Close()
171			return
172		}
173		s.conns[st] = true
174		s.mu.Unlock()
175		h := &testStreamHandler{
176			port: s.port,
177			t:    st,
178		}
179		go st.HandleStreams(func(s *transport.Stream) {
180			go h.handleStream(t, s)
181		}, func(ctx context.Context, method string) context.Context {
182			return ctx
183		})
184	}
185}
186
187func (s *server) wait(t *testing.T, timeout time.Duration) {
188	select {
189	case err := <-s.startedErr:
190		if err != nil {
191			t.Fatal(err)
192		}
193	case <-time.After(timeout):
194		t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
195	}
196}
197
198func (s *server) stop() {
199	s.lis.Close()
200	s.mu.Lock()
201	for c := range s.conns {
202		c.Close()
203	}
204	s.conns = nil
205	s.mu.Unlock()
206}
207
208func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
209	return setUpWithOptions(t, port, maxStreams)
210}
211
212func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
213	server := newTestServer()
214	go server.start(t, port, maxStreams)
215	server.wait(t, 2*time.Second)
216	addr := "localhost:" + server.port
217	dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
218	cc, err := Dial(addr, dopts...)
219	if err != nil {
220		t.Fatalf("Failed to create ClientConn: %v", err)
221	}
222	return server, cc
223}
224
225func (s) TestUnaryClientInterceptor(t *testing.T) {
226	parentKey := ctxKey("parentKey")
227
228	interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
229		if ctx.Value(parentKey) == nil {
230			t.Fatalf("interceptor should have %v in context", parentKey)
231		}
232		return invoker(ctx, method, req, reply, cc, opts...)
233	}
234
235	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
236	defer func() {
237		cc.Close()
238		server.stop()
239	}()
240
241	var reply string
242	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
243	defer cancel()
244	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
245	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
246		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
247	}
248}
249
250func (s) TestChainUnaryClientInterceptor(t *testing.T) {
251	var (
252		parentKey    = ctxKey("parentKey")
253		firstIntKey  = ctxKey("firstIntKey")
254		secondIntKey = ctxKey("secondIntKey")
255	)
256
257	firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
258		if ctx.Value(parentKey) == nil {
259			t.Fatalf("first interceptor should have %v in context", parentKey)
260		}
261		if ctx.Value(firstIntKey) != nil {
262			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
263		}
264		if ctx.Value(secondIntKey) != nil {
265			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
266		}
267		firstCtx := context.WithValue(ctx, firstIntKey, 1)
268		err := invoker(firstCtx, method, req, reply, cc, opts...)
269		*(reply.(*string)) += "1"
270		return err
271	}
272
273	secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
274		if ctx.Value(parentKey) == nil {
275			t.Fatalf("second interceptor should have %v in context", parentKey)
276		}
277		if ctx.Value(firstIntKey) == nil {
278			t.Fatalf("second interceptor should have %v in context", firstIntKey)
279		}
280		if ctx.Value(secondIntKey) != nil {
281			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
282		}
283		secondCtx := context.WithValue(ctx, secondIntKey, 2)
284		err := invoker(secondCtx, method, req, reply, cc, opts...)
285		*(reply.(*string)) += "2"
286		return err
287	}
288
289	lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
290		if ctx.Value(parentKey) == nil {
291			t.Fatalf("last interceptor should have %v in context", parentKey)
292		}
293		if ctx.Value(firstIntKey) == nil {
294			t.Fatalf("last interceptor should have %v in context", firstIntKey)
295		}
296		if ctx.Value(secondIntKey) == nil {
297			t.Fatalf("last interceptor should have %v in context", secondIntKey)
298		}
299		err := invoker(ctx, method, req, reply, cc, opts...)
300		*(reply.(*string)) += "3"
301		return err
302	}
303
304	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
305	defer func() {
306		cc.Close()
307		server.stop()
308	}()
309
310	var reply string
311	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
312	defer cancel()
313	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
314	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
315		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
316	}
317}
318
319func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
320	var (
321		parentKey  = ctxKey("parentKey")
322		baseIntKey = ctxKey("baseIntKey")
323	)
324
325	baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
326		if ctx.Value(parentKey) == nil {
327			t.Fatalf("base interceptor should have %v in context", parentKey)
328		}
329		if ctx.Value(baseIntKey) != nil {
330			t.Fatalf("base interceptor should not have %v in context", baseIntKey)
331		}
332		baseCtx := context.WithValue(ctx, baseIntKey, 1)
333		return invoker(baseCtx, method, req, reply, cc, opts...)
334	}
335
336	chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
337		if ctx.Value(parentKey) == nil {
338			t.Fatalf("chain interceptor should have %v in context", parentKey)
339		}
340		if ctx.Value(baseIntKey) == nil {
341			t.Fatalf("chain interceptor should have %v in context", baseIntKey)
342		}
343		return invoker(ctx, method, req, reply, cc, opts...)
344	}
345
346	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
347	defer func() {
348		cc.Close()
349		server.stop()
350	}()
351
352	var reply string
353	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
354	defer cancel()
355	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
356	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
357		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
358	}
359}
360
361func (s) TestChainStreamClientInterceptor(t *testing.T) {
362	var (
363		parentKey    = ctxKey("parentKey")
364		firstIntKey  = ctxKey("firstIntKey")
365		secondIntKey = ctxKey("secondIntKey")
366	)
367
368	firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
369		if ctx.Value(parentKey) == nil {
370			t.Fatalf("first interceptor should have %v in context", parentKey)
371		}
372		if ctx.Value(firstIntKey) != nil {
373			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
374		}
375		if ctx.Value(secondIntKey) != nil {
376			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
377		}
378		firstCtx := context.WithValue(ctx, firstIntKey, 1)
379		return streamer(firstCtx, desc, cc, method, opts...)
380	}
381
382	secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
383		if ctx.Value(parentKey) == nil {
384			t.Fatalf("second interceptor should have %v in context", parentKey)
385		}
386		if ctx.Value(firstIntKey) == nil {
387			t.Fatalf("second interceptor should have %v in context", firstIntKey)
388		}
389		if ctx.Value(secondIntKey) != nil {
390			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
391		}
392		secondCtx := context.WithValue(ctx, secondIntKey, 2)
393		return streamer(secondCtx, desc, cc, method, opts...)
394	}
395
396	lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
397		if ctx.Value(parentKey) == nil {
398			t.Fatalf("last interceptor should have %v in context", parentKey)
399		}
400		if ctx.Value(firstIntKey) == nil {
401			t.Fatalf("last interceptor should have %v in context", firstIntKey)
402		}
403		if ctx.Value(secondIntKey) == nil {
404			t.Fatalf("last interceptor should have %v in context", secondIntKey)
405		}
406		return streamer(ctx, desc, cc, method, opts...)
407	}
408
409	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
410	defer func() {
411		cc.Close()
412		server.stop()
413	}()
414
415	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
416	defer cancel()
417	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
418	_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
419	if err != nil {
420		t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
421	}
422}
423
424func (s) TestInvoke(t *testing.T) {
425	server, cc := setUp(t, 0, math.MaxUint32)
426	var reply string
427	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
428	defer cancel()
429	if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
430		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
431	}
432	cc.Close()
433	server.stop()
434}
435
436func (s) TestInvokeLargeErr(t *testing.T) {
437	server, cc := setUp(t, 0, math.MaxUint32)
438	var reply string
439	req := "hello"
440	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
441	defer cancel()
442	err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
443	if _, ok := status.FromError(err); !ok {
444		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
445	}
446	if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr {
447		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr)
448	}
449	cc.Close()
450	server.stop()
451}
452
453// TestInvokeErrorSpecialChars checks that error messages don't get mangled.
454func (s) TestInvokeErrorSpecialChars(t *testing.T) {
455	server, cc := setUp(t, 0, math.MaxUint32)
456	var reply string
457	req := "weird error"
458	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
459	defer cancel()
460	err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
461	if _, ok := status.FromError(err); !ok {
462		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
463	}
464	if got, want := errorDesc(err), weirdError; got != want {
465		t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want)
466	}
467	cc.Close()
468	server.stop()
469}
470
471// TestInvokeCancel checks that an Invoke with a canceled context is not sent.
472func (s) TestInvokeCancel(t *testing.T) {
473	server, cc := setUp(t, 0, math.MaxUint32)
474	var reply string
475	req := "canceled"
476	for i := 0; i < 100; i++ {
477		ctx, cancel := context.WithCancel(context.Background())
478		cancel()
479		cc.Invoke(ctx, "/foo/bar", &req, &reply)
480	}
481	if canceled != 0 {
482		t.Fatalf("received %d of 100 canceled requests", canceled)
483	}
484	cc.Close()
485	server.stop()
486}
487
488// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
489// on a closed client will terminate.
490func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) {
491	server, cc := setUp(t, 0, math.MaxUint32)
492	var reply string
493	cc.Close()
494	req := "hello"
495	ctx, cancel := context.WithCancel(context.Background())
496	cancel()
497	if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, WaitForReady(true)); err == nil {
498		t.Fatalf("canceled invoke on closed connection should fail")
499	}
500	server.stop()
501}
502