1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpc
20
21import (
22	"context"
23	"fmt"
24	"io"
25	"math"
26	"net"
27	"strconv"
28	"strings"
29	"sync"
30	"testing"
31	"time"
32
33	"google.golang.org/grpc/codes"
34	"google.golang.org/grpc/internal/transport"
35	"google.golang.org/grpc/status"
36)
37
38var (
39	expectedRequest  = "ping"
40	expectedResponse = "pong"
41	weirdError       = "format verbs: %v%s"
42	sizeLargeErr     = 1024 * 1024
43	canceled         = 0
44)
45
46type testCodec struct {
47}
48
49func (testCodec) Marshal(v interface{}) ([]byte, error) {
50	return []byte(*(v.(*string))), nil
51}
52
53func (testCodec) Unmarshal(data []byte, v interface{}) error {
54	*(v.(*string)) = string(data)
55	return nil
56}
57
58func (testCodec) String() string {
59	return "test"
60}
61
62type testStreamHandler struct {
63	port string
64	t    transport.ServerTransport
65}
66
67func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
68	p := &parser{r: s}
69	for {
70		pf, req, err := p.recvMsg(math.MaxInt32)
71		if err == io.EOF {
72			break
73		}
74		if err != nil {
75			return
76		}
77		if pf != compressionNone {
78			t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
79			return
80		}
81		var v string
82		codec := testCodec{}
83		if err := codec.Unmarshal(req, &v); err != nil {
84			t.Errorf("Failed to unmarshal the received message: %v", err)
85			return
86		}
87		if v == "weird error" {
88			h.t.WriteStatus(s, status.New(codes.Internal, weirdError))
89			return
90		}
91		if v == "canceled" {
92			canceled++
93			h.t.WriteStatus(s, status.New(codes.Internal, ""))
94			return
95		}
96		if v == "port" {
97			h.t.WriteStatus(s, status.New(codes.Internal, h.port))
98			return
99		}
100
101		if v != expectedRequest {
102			h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr)))
103			return
104		}
105	}
106	// send a response back to end the stream.
107	data, err := encode(testCodec{}, &expectedResponse)
108	if err != nil {
109		t.Errorf("Failed to encode the response: %v", err)
110		return
111	}
112	hdr, payload := msgHeader(data, nil)
113	h.t.Write(s, hdr, payload, &transport.Options{})
114	h.t.WriteStatus(s, status.New(codes.OK, ""))
115}
116
117type server struct {
118	lis        net.Listener
119	port       string
120	addr       string
121	startedErr chan error // sent nil or an error after server starts
122	mu         sync.Mutex
123	conns      map[transport.ServerTransport]bool
124}
125
126type ctxKey string
127
128func newTestServer() *server {
129	return &server{startedErr: make(chan error, 1)}
130}
131
132// start starts server. Other goroutines should block on s.startedErr for further operations.
133func (s *server) start(t *testing.T, port int, maxStreams uint32) {
134	var err error
135	if port == 0 {
136		s.lis, err = net.Listen("tcp", "localhost:0")
137	} else {
138		s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
139	}
140	if err != nil {
141		s.startedErr <- fmt.Errorf("failed to listen: %v", err)
142		return
143	}
144	s.addr = s.lis.Addr().String()
145	_, p, err := net.SplitHostPort(s.addr)
146	if err != nil {
147		s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
148		return
149	}
150	s.port = p
151	s.conns = make(map[transport.ServerTransport]bool)
152	s.startedErr <- nil
153	for {
154		conn, err := s.lis.Accept()
155		if err != nil {
156			return
157		}
158		config := &transport.ServerConfig{
159			MaxStreams: maxStreams,
160		}
161		st, err := transport.NewServerTransport("http2", conn, config)
162		if err != nil {
163			continue
164		}
165		s.mu.Lock()
166		if s.conns == nil {
167			s.mu.Unlock()
168			st.Close()
169			return
170		}
171		s.conns[st] = true
172		s.mu.Unlock()
173		h := &testStreamHandler{
174			port: s.port,
175			t:    st,
176		}
177		go st.HandleStreams(func(s *transport.Stream) {
178			go h.handleStream(t, s)
179		}, func(ctx context.Context, method string) context.Context {
180			return ctx
181		})
182	}
183}
184
185func (s *server) wait(t *testing.T, timeout time.Duration) {
186	select {
187	case err := <-s.startedErr:
188		if err != nil {
189			t.Fatal(err)
190		}
191	case <-time.After(timeout):
192		t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
193	}
194}
195
196func (s *server) stop() {
197	s.lis.Close()
198	s.mu.Lock()
199	for c := range s.conns {
200		c.Close()
201	}
202	s.conns = nil
203	s.mu.Unlock()
204}
205
206func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) {
207	return setUpWithOptions(t, port, maxStreams)
208}
209
210func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) {
211	server := newTestServer()
212	go server.start(t, port, maxStreams)
213	server.wait(t, 2*time.Second)
214	addr := "localhost:" + server.port
215	dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
216	cc, err := Dial(addr, dopts...)
217	if err != nil {
218		t.Fatalf("Failed to create ClientConn: %v", err)
219	}
220	return server, cc
221}
222
223func (s) TestUnaryClientInterceptor(t *testing.T) {
224	parentKey := ctxKey("parentKey")
225
226	interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
227		if ctx.Value(parentKey) == nil {
228			t.Fatalf("interceptor should have %v in context", parentKey)
229		}
230		return invoker(ctx, method, req, reply, cc, opts...)
231	}
232
233	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor))
234	defer func() {
235		cc.Close()
236		server.stop()
237	}()
238
239	var reply string
240	ctx := context.Background()
241	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
242	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
243		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
244	}
245}
246
247func (s) TestChainUnaryClientInterceptor(t *testing.T) {
248	var (
249		parentKey    = ctxKey("parentKey")
250		firstIntKey  = ctxKey("firstIntKey")
251		secondIntKey = ctxKey("secondIntKey")
252	)
253
254	firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
255		if ctx.Value(parentKey) == nil {
256			t.Fatalf("first interceptor should have %v in context", parentKey)
257		}
258		if ctx.Value(firstIntKey) != nil {
259			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
260		}
261		if ctx.Value(secondIntKey) != nil {
262			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
263		}
264		firstCtx := context.WithValue(ctx, firstIntKey, 1)
265		err := invoker(firstCtx, method, req, reply, cc, opts...)
266		*(reply.(*string)) += "1"
267		return err
268	}
269
270	secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
271		if ctx.Value(parentKey) == nil {
272			t.Fatalf("second interceptor should have %v in context", parentKey)
273		}
274		if ctx.Value(firstIntKey) == nil {
275			t.Fatalf("second interceptor should have %v in context", firstIntKey)
276		}
277		if ctx.Value(secondIntKey) != nil {
278			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
279		}
280		secondCtx := context.WithValue(ctx, secondIntKey, 2)
281		err := invoker(secondCtx, method, req, reply, cc, opts...)
282		*(reply.(*string)) += "2"
283		return err
284	}
285
286	lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
287		if ctx.Value(parentKey) == nil {
288			t.Fatalf("last interceptor should have %v in context", parentKey)
289		}
290		if ctx.Value(firstIntKey) == nil {
291			t.Fatalf("last interceptor should have %v in context", firstIntKey)
292		}
293		if ctx.Value(secondIntKey) == nil {
294			t.Fatalf("last interceptor should have %v in context", secondIntKey)
295		}
296		err := invoker(ctx, method, req, reply, cc, opts...)
297		*(reply.(*string)) += "3"
298		return err
299	}
300
301	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt))
302	defer func() {
303		cc.Close()
304		server.stop()
305	}()
306
307	var reply string
308	ctx := context.Background()
309	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
310	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
311		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
312	}
313}
314
315func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
316	var (
317		parentKey  = ctxKey("parentKey")
318		baseIntKey = ctxKey("baseIntKey")
319	)
320
321	baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
322		if ctx.Value(parentKey) == nil {
323			t.Fatalf("base interceptor should have %v in context", parentKey)
324		}
325		if ctx.Value(baseIntKey) != nil {
326			t.Fatalf("base interceptor should not have %v in context", baseIntKey)
327		}
328		baseCtx := context.WithValue(ctx, baseIntKey, 1)
329		return invoker(baseCtx, method, req, reply, cc, opts...)
330	}
331
332	chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
333		if ctx.Value(parentKey) == nil {
334			t.Fatalf("chain interceptor should have %v in context", parentKey)
335		}
336		if ctx.Value(baseIntKey) == nil {
337			t.Fatalf("chain interceptor should have %v in context", baseIntKey)
338		}
339		return invoker(ctx, method, req, reply, cc, opts...)
340	}
341
342	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt))
343	defer func() {
344		cc.Close()
345		server.stop()
346	}()
347
348	var reply string
349	ctx := context.Background()
350	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
351	if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
352		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
353	}
354}
355
356func (s) TestChainStreamClientInterceptor(t *testing.T) {
357	var (
358		parentKey    = ctxKey("parentKey")
359		firstIntKey  = ctxKey("firstIntKey")
360		secondIntKey = ctxKey("secondIntKey")
361	)
362
363	firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
364		if ctx.Value(parentKey) == nil {
365			t.Fatalf("first interceptor should have %v in context", parentKey)
366		}
367		if ctx.Value(firstIntKey) != nil {
368			t.Fatalf("first interceptor should not have %v in context", firstIntKey)
369		}
370		if ctx.Value(secondIntKey) != nil {
371			t.Fatalf("first interceptor should not have %v in context", secondIntKey)
372		}
373		firstCtx := context.WithValue(ctx, firstIntKey, 1)
374		return streamer(firstCtx, desc, cc, method, opts...)
375	}
376
377	secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
378		if ctx.Value(parentKey) == nil {
379			t.Fatalf("second interceptor should have %v in context", parentKey)
380		}
381		if ctx.Value(firstIntKey) == nil {
382			t.Fatalf("second interceptor should have %v in context", firstIntKey)
383		}
384		if ctx.Value(secondIntKey) != nil {
385			t.Fatalf("second interceptor should not have %v in context", secondIntKey)
386		}
387		secondCtx := context.WithValue(ctx, secondIntKey, 2)
388		return streamer(secondCtx, desc, cc, method, opts...)
389	}
390
391	lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
392		if ctx.Value(parentKey) == nil {
393			t.Fatalf("last interceptor should have %v in context", parentKey)
394		}
395		if ctx.Value(firstIntKey) == nil {
396			t.Fatalf("last interceptor should have %v in context", firstIntKey)
397		}
398		if ctx.Value(secondIntKey) == nil {
399			t.Fatalf("last interceptor should have %v in context", secondIntKey)
400		}
401		return streamer(ctx, desc, cc, method, opts...)
402	}
403
404	server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt))
405	defer func() {
406		cc.Close()
407		server.stop()
408	}()
409
410	ctx := context.Background()
411	parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
412	_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
413	if err != nil {
414		t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err)
415	}
416}
417
418func (s) TestInvoke(t *testing.T) {
419	server, cc := setUp(t, 0, math.MaxUint32)
420	var reply string
421	if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
422		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
423	}
424	cc.Close()
425	server.stop()
426}
427
428func (s) TestInvokeLargeErr(t *testing.T) {
429	server, cc := setUp(t, 0, math.MaxUint32)
430	var reply string
431	req := "hello"
432	err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
433	if _, ok := status.FromError(err); !ok {
434		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
435	}
436	if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr {
437		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr)
438	}
439	cc.Close()
440	server.stop()
441}
442
443// TestInvokeErrorSpecialChars checks that error messages don't get mangled.
444func (s) TestInvokeErrorSpecialChars(t *testing.T) {
445	server, cc := setUp(t, 0, math.MaxUint32)
446	var reply string
447	req := "weird error"
448	err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
449	if _, ok := status.FromError(err); !ok {
450		t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
451	}
452	if got, want := errorDesc(err), weirdError; got != want {
453		t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want)
454	}
455	cc.Close()
456	server.stop()
457}
458
459// TestInvokeCancel checks that an Invoke with a canceled context is not sent.
460func (s) TestInvokeCancel(t *testing.T) {
461	server, cc := setUp(t, 0, math.MaxUint32)
462	var reply string
463	req := "canceled"
464	for i := 0; i < 100; i++ {
465		ctx, cancel := context.WithCancel(context.Background())
466		cancel()
467		cc.Invoke(ctx, "/foo/bar", &req, &reply)
468	}
469	if canceled != 0 {
470		t.Fatalf("received %d of 100 canceled requests", canceled)
471	}
472	cc.Close()
473	server.stop()
474}
475
476// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
477// on a closed client will terminate.
478func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) {
479	server, cc := setUp(t, 0, math.MaxUint32)
480	var reply string
481	cc.Close()
482	req := "hello"
483	ctx, cancel := context.WithCancel(context.Background())
484	cancel()
485	if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, WaitForReady(true)); err == nil {
486		t.Fatalf("canceled invoke on closed connection should fail")
487	}
488	server.stop()
489}
490