1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package rpcreplay
16
17import (
18	"bytes"
19	"context"
20	"errors"
21	"io"
22	"strings"
23	"testing"
24
25	"cloud.google.com/go/internal/testutil"
26	ipb "cloud.google.com/go/rpcreplay/proto/intstore"
27	rpb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
28	"github.com/golang/protobuf/proto"
29	"github.com/google/go-cmp/cmp"
30	"github.com/google/go-cmp/cmp/cmpopts"
31	"google.golang.org/grpc"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/status"
34)
35
36func TestRecordIO(t *testing.T) {
37	buf := &bytes.Buffer{}
38	want := []byte{1, 2, 3}
39	if err := writeRecord(buf, want); err != nil {
40		t.Fatal(err)
41	}
42	got, err := readRecord(buf)
43	if err != nil {
44		t.Fatal(err)
45	}
46	if !bytes.Equal(got, want) {
47		t.Errorf("got %v, want %v", got, want)
48	}
49}
50
51func TestHeaderIO(t *testing.T) {
52	buf := &bytes.Buffer{}
53	want := []byte{1, 2, 3}
54	if err := writeHeader(buf, want); err != nil {
55		t.Fatal(err)
56	}
57	got, err := readHeader(buf)
58	if err != nil {
59		t.Fatal(err)
60	}
61	if !testutil.Equal(got, want) {
62		t.Errorf("got %v, want %v", got, want)
63	}
64
65	// readHeader errors
66	for _, contents := range []string{"", "badmagic", "gRPCReplay"} {
67		if _, err := readHeader(bytes.NewBufferString(contents)); err == nil {
68			t.Errorf("%q: got nil, want error", contents)
69		}
70	}
71}
72
73func TestEntryIO(t *testing.T) {
74	for i, want := range []*entry{
75		{
76			kind:     rpb.Entry_REQUEST,
77			method:   "method",
78			msg:      message{msg: &rpb.Entry{}},
79			refIndex: 7,
80		},
81		{
82			kind:     rpb.Entry_RESPONSE,
83			method:   "method",
84			msg:      message{err: status.Error(codes.NotFound, "not found")},
85			refIndex: 8,
86		},
87		{
88			kind:     rpb.Entry_RECV,
89			method:   "method",
90			msg:      message{err: io.EOF},
91			refIndex: 3,
92		},
93	} {
94		buf := &bytes.Buffer{}
95		if err := writeEntry(buf, want); err != nil {
96			t.Fatal(err)
97		}
98		got, err := readEntry(buf)
99		if err != nil {
100			t.Fatal(err)
101		}
102		if !got.equal(want) {
103			t.Errorf("#%d: got %v, want %v", i, got, want)
104		}
105	}
106}
107
108var initialState = []byte{1, 2, 3}
109
110func TestRecord(t *testing.T) {
111	buf := record(t, testService)
112
113	gotIstate, err := readHeader(buf)
114	if err != nil {
115		t.Fatal(err)
116	}
117	if !testutil.Equal(gotIstate, initialState) {
118		t.Fatalf("got %v, want %v", gotIstate, initialState)
119	}
120	item := &ipb.Item{Name: "a", Value: 1}
121	wantEntries := []*entry{
122		// Set
123		{
124			kind:   rpb.Entry_REQUEST,
125			method: "/intstore.IntStore/Set",
126			msg:    message{msg: item},
127		},
128		{
129			kind:     rpb.Entry_RESPONSE,
130			msg:      message{msg: &ipb.SetResponse{PrevValue: 0}},
131			refIndex: 1,
132		},
133		// Get
134		{
135			kind:   rpb.Entry_REQUEST,
136			method: "/intstore.IntStore/Get",
137			msg:    message{msg: &ipb.GetRequest{Name: "a"}},
138		},
139		{
140			kind:     rpb.Entry_RESPONSE,
141			msg:      message{msg: item},
142			refIndex: 3,
143		},
144		{
145			kind:   rpb.Entry_REQUEST,
146			method: "/intstore.IntStore/Get",
147			msg:    message{msg: &ipb.GetRequest{Name: "x"}},
148		},
149		{
150			kind:     rpb.Entry_RESPONSE,
151			msg:      message{err: status.Error(codes.NotFound, `"x"`)},
152			refIndex: 5,
153		},
154		// ListItems
155		{ // entry #7
156			kind:   rpb.Entry_CREATE_STREAM,
157			method: "/intstore.IntStore/ListItems",
158		},
159		{
160			kind:     rpb.Entry_SEND,
161			msg:      message{msg: &ipb.ListItemsRequest{}},
162			refIndex: 7,
163		},
164		{
165			kind:     rpb.Entry_RECV,
166			msg:      message{msg: item},
167			refIndex: 7,
168		},
169		{
170			kind:     rpb.Entry_RECV,
171			msg:      message{err: io.EOF},
172			refIndex: 7,
173		},
174		// SetStream
175		{ // entry #11
176			kind:   rpb.Entry_CREATE_STREAM,
177			method: "/intstore.IntStore/SetStream",
178		},
179		{
180			kind:     rpb.Entry_SEND,
181			msg:      message{msg: &ipb.Item{Name: "b", Value: 2}},
182			refIndex: 11,
183		},
184		{
185			kind:     rpb.Entry_SEND,
186			msg:      message{msg: &ipb.Item{Name: "c", Value: 3}},
187			refIndex: 11,
188		},
189		{
190			kind:     rpb.Entry_RECV,
191			msg:      message{msg: &ipb.Summary{Count: 2}},
192			refIndex: 11,
193		},
194
195		// StreamChat
196		{ // entry #15
197			kind:   rpb.Entry_CREATE_STREAM,
198			method: "/intstore.IntStore/StreamChat",
199		},
200		{
201			kind:     rpb.Entry_SEND,
202			msg:      message{msg: &ipb.Item{Name: "d", Value: 4}},
203			refIndex: 15,
204		},
205		{
206			kind:     rpb.Entry_RECV,
207			msg:      message{msg: &ipb.Item{Name: "d", Value: 4}},
208			refIndex: 15,
209		},
210		{
211			kind:     rpb.Entry_SEND,
212			msg:      message{msg: &ipb.Item{Name: "e", Value: 5}},
213			refIndex: 15,
214		},
215		{
216			kind:     rpb.Entry_RECV,
217			msg:      message{msg: &ipb.Item{Name: "e", Value: 5}},
218			refIndex: 15,
219		},
220		{
221			kind:     rpb.Entry_RECV,
222			msg:      message{err: io.EOF},
223			refIndex: 15,
224		},
225	}
226	for i, w := range wantEntries {
227		g, err := readEntry(buf)
228		if err != nil {
229			t.Fatalf("#%d: %v", i+1, err)
230		}
231		if !g.equal(w) {
232			t.Errorf("#%d:\ngot  %+v\nwant %+v", i+1, g, w)
233		}
234	}
235	g, err := readEntry(buf)
236	if err != nil {
237		t.Fatal(err)
238	}
239	if g != nil {
240		t.Errorf("\ngot  %+v\nwant nil", g)
241	}
242}
243
244func TestReplay(t *testing.T) {
245	buf := record(t, testService)
246	replay(t, buf, testService)
247}
248
249func record(t *testing.T, run func(*testing.T, *grpc.ClientConn)) *bytes.Buffer {
250	srv := newIntStoreServer()
251	defer srv.stop()
252
253	buf := &bytes.Buffer{}
254	rec, err := NewRecorderWriter(buf, initialState)
255	if err != nil {
256		t.Fatal(err)
257	}
258	conn, err := grpc.Dial(srv.Addr,
259		append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
260	if err != nil {
261		t.Fatal(err)
262	}
263	defer conn.Close()
264	run(t, conn)
265	if err := rec.Close(); err != nil {
266		t.Fatal(err)
267	}
268	return buf
269}
270
271func replay(t *testing.T, buf *bytes.Buffer, run func(*testing.T, *grpc.ClientConn)) {
272	rep, err := NewReplayerReader(buf)
273	if err != nil {
274		t.Fatal(err)
275	}
276	defer rep.Close()
277	if got, want := rep.Initial(), initialState; !testutil.Equal(got, want) {
278		t.Fatalf("got %v, want %v", got, want)
279	}
280	// Replay the test.
281	conn, err := rep.Connection()
282	if err != nil {
283		t.Fatal(err)
284	}
285	defer conn.Close()
286	run(t, conn)
287}
288
289func testService(t *testing.T, conn *grpc.ClientConn) {
290	client := ipb.NewIntStoreClient(conn)
291	ctx := context.Background()
292	item := &ipb.Item{Name: "a", Value: 1}
293	res, err := client.Set(ctx, item)
294	if err != nil {
295		t.Fatal(err)
296	}
297	if res.PrevValue != 0 {
298		t.Errorf("got %d, want 0", res.PrevValue)
299	}
300	got, err := client.Get(ctx, &ipb.GetRequest{Name: "a"})
301	if err != nil {
302		t.Fatal(err)
303	}
304	if !proto.Equal(got, item) {
305		t.Errorf("got %v, want %v", got, item)
306	}
307	_, err = client.Get(ctx, &ipb.GetRequest{Name: "x"})
308	if err == nil {
309		t.Fatal("got nil, want error")
310	}
311	if _, ok := status.FromError(err); !ok {
312		t.Errorf("got error type %T, want a grpc/status.Status", err)
313	}
314
315	gotItems := listItems(t, client, 0)
316	compareLists(t, gotItems, []*ipb.Item{item})
317
318	ssc, err := client.SetStream(ctx)
319	if err != nil {
320		t.Fatal(err)
321	}
322
323	must := func(err error) {
324		if err != nil {
325			t.Fatal(err)
326		}
327	}
328
329	for i, name := range []string{"b", "c"} {
330		must(ssc.Send(&ipb.Item{Name: name, Value: int32(i + 2)}))
331	}
332	summary, err := ssc.CloseAndRecv()
333	if err != nil {
334		t.Fatal(err)
335	}
336	if got, want := summary.Count, int32(2); got != want {
337		t.Fatalf("got %d, want %d", got, want)
338	}
339
340	chatc, err := client.StreamChat(ctx)
341	if err != nil {
342		t.Fatal(err)
343	}
344	for i, name := range []string{"d", "e"} {
345		item := &ipb.Item{Name: name, Value: int32(i + 4)}
346		must(chatc.Send(item))
347		got, err := chatc.Recv()
348		if err != nil {
349			t.Fatal(err)
350		}
351		if !proto.Equal(got, item) {
352			t.Errorf("got %v, want %v", got, item)
353		}
354	}
355	must(chatc.CloseSend())
356	if _, err := chatc.Recv(); err != io.EOF {
357		t.Fatalf("got %v, want EOF", err)
358	}
359}
360
361func listItems(t *testing.T, client ipb.IntStoreClient, greaterThan int) []*ipb.Item {
362	t.Helper()
363	lic, err := client.ListItems(context.Background(), &ipb.ListItemsRequest{GreaterThan: int32(greaterThan)})
364	if err != nil {
365		t.Fatal(err)
366	}
367	var items []*ipb.Item
368	for i := 0; ; i++ {
369		item, err := lic.Recv()
370		if err == io.EOF {
371			break
372		}
373		if err != nil {
374			t.Fatal(err)
375		}
376		items = append(items, item)
377	}
378	return items
379}
380
381func compareLists(t *testing.T, got, want []*ipb.Item) {
382	t.Helper()
383	diff := cmp.Diff(got, want, cmp.Comparer(proto.Equal), cmpopts.SortSlices(func(i1, i2 *ipb.Item) bool {
384		return i1.Value < i2.Value
385	}))
386	if diff != "" {
387		t.Error(diff)
388	}
389}
390
391func TestRecorderBeforeFunc(t *testing.T) {
392	var tests = []struct {
393		name                           string
394		msg, wantRespMsg, wantEntryMsg *ipb.Item
395		f                              func(string, proto.Message) error
396		wantErr                        bool
397	}{
398		{
399			name:         "BeforeFunc should modify messages saved, but not alter what is sent/received to/from services",
400			msg:          &ipb.Item{Name: "foo", Value: 1},
401			wantEntryMsg: &ipb.Item{Name: "bar", Value: 2},
402			wantRespMsg:  &ipb.Item{Name: "foo", Value: 1},
403			f: func(method string, m proto.Message) error {
404				// This callback only runs when Set is called.
405				if !strings.HasSuffix(method, "Set") {
406					return nil
407				}
408				if _, ok := m.(*ipb.Item); !ok {
409					return nil
410				}
411
412				item := m.(*ipb.Item)
413				item.Name = "bar"
414				item.Value = 2
415				return nil
416			},
417		},
418		{
419			name:        "BeforeFunc should not be able to alter returned responses",
420			msg:         &ipb.Item{Name: "foo", Value: 1},
421			wantRespMsg: &ipb.Item{Name: "foo", Value: 1},
422			f: func(method string, m proto.Message) error {
423				// This callback only runs when Get is called.
424				if !strings.HasSuffix(method, "Get") {
425					return nil
426				}
427				if _, ok := m.(*ipb.Item); !ok {
428					return nil
429				}
430
431				item := m.(*ipb.Item)
432				item.Value = 2
433				return nil
434			},
435		},
436		{
437			name: "Errors should cause the RPC send to fail",
438			msg:  &ipb.Item{},
439			f: func(_ string, _ proto.Message) error {
440				return errors.New("err")
441			},
442			wantErr: true,
443		},
444	}
445
446	for _, tc := range tests {
447		// Wrap test cases in a func so defers execute correctly.
448		func() {
449			srv := newIntStoreServer()
450			defer srv.stop()
451
452			var b bytes.Buffer
453			r, err := NewRecorderWriter(&b, nil)
454			if err != nil {
455				t.Error(err)
456				return
457			}
458			r.BeforeFunc = tc.f
459			ctx := context.Background()
460			conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, r.DialOptions()...)...)
461			if err != nil {
462				t.Error(err)
463				return
464			}
465			defer conn.Close()
466
467			client := ipb.NewIntStoreClient(conn)
468			_, err = client.Set(ctx, tc.msg)
469			switch {
470			case err != nil && !tc.wantErr:
471				t.Error(err)
472				return
473			case err == nil && tc.wantErr:
474				t.Errorf("got nil; want error")
475				return
476			case err != nil:
477				// Error found as expected, don't check Get().
478				return
479			}
480
481			if tc.wantRespMsg != nil {
482				got, err := client.Get(ctx, &ipb.GetRequest{Name: tc.msg.GetName()})
483				if err != nil {
484					t.Error(err)
485					return
486				}
487				if !cmp.Equal(got, tc.wantRespMsg) {
488					t.Errorf("got %+v; want %+v", got, tc.wantRespMsg)
489				}
490			}
491
492			r.Close()
493
494			if tc.wantEntryMsg != nil {
495				_, _ = readHeader(&b)
496				e, err := readEntry(&b)
497				if err != nil {
498					t.Error(err)
499					return
500				}
501				got := e.msg.msg.(*ipb.Item)
502				if !cmp.Equal(got, tc.wantEntryMsg) {
503					t.Errorf("got %v; want %v", got, tc.wantEntryMsg)
504				}
505			}
506		}()
507	}
508}
509
510func TestReplayerBeforeFunc(t *testing.T) {
511	var tests = []struct {
512		name        string
513		msg, reqMsg *ipb.Item
514		f           func(string, proto.Message) error
515		wantErr     bool
516	}{
517		{
518			name:   "BeforeFunc should modify messages sent before they are passed to the replayer",
519			msg:    &ipb.Item{Name: "foo", Value: 1},
520			reqMsg: &ipb.Item{Name: "bar", Value: 1},
521			f: func(method string, m proto.Message) error {
522				item := m.(*ipb.Item)
523				item.Name = "foo"
524				return nil
525			},
526		},
527		{
528			name: "Errors should cause the RPC send to fail",
529			msg:  &ipb.Item{},
530			f: func(_ string, _ proto.Message) error {
531				return errors.New("err")
532			},
533			wantErr: true,
534		},
535	}
536
537	for _, tc := range tests {
538		// Wrap test cases in a func so defers execute correctly.
539		func() {
540			srv := newIntStoreServer()
541			defer srv.stop()
542
543			var b bytes.Buffer
544			rec, err := NewRecorderWriter(&b, nil)
545			if err != nil {
546				t.Error(err)
547				return
548			}
549			ctx := context.Background()
550			conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
551			if err != nil {
552				t.Error(err)
553				return
554			}
555			defer conn.Close()
556
557			client := ipb.NewIntStoreClient(conn)
558			_, err = client.Set(ctx, tc.msg)
559			if err != nil {
560				t.Error(err)
561				return
562			}
563			rec.Close()
564
565			rep, err := NewReplayerReader(&b)
566			if err != nil {
567				t.Error(err)
568				return
569			}
570			rep.BeforeFunc = tc.f
571			conn, err = grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
572			if err != nil {
573				t.Error(err)
574				return
575			}
576			defer conn.Close()
577
578			client = ipb.NewIntStoreClient(conn)
579			_, err = client.Set(ctx, tc.reqMsg)
580			switch {
581			case err != nil && !tc.wantErr:
582				t.Error(err)
583			case err == nil && tc.wantErr:
584				t.Errorf("got nil; want error")
585			}
586		}()
587	}
588}
589
590func TestOutOfOrderStreamReplay(t *testing.T) {
591	// Check that streams are matched by method and first request sent, if any.
592
593	items := []*ipb.Item{
594		{Name: "a", Value: 1},
595		{Name: "b", Value: 2},
596		{Name: "c", Value: 3},
597	}
598	run := func(t *testing.T, conn *grpc.ClientConn, arg1, arg2 int) {
599		client := ipb.NewIntStoreClient(conn)
600		ctx := context.Background()
601		// Set some items.
602		for _, item := range items {
603			_, err := client.Set(ctx, item)
604			if err != nil {
605				t.Fatal(err)
606			}
607		}
608		// List them twice, with different requests.
609		compareLists(t, listItems(t, client, arg1), items[arg1:])
610		compareLists(t, listItems(t, client, arg2), items[arg2:])
611	}
612
613	srv := newIntStoreServer()
614	defer srv.stop()
615
616	// Replay in the same order.
617	buf := record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
618	replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
619
620	// Replay in a different order.
621	buf = record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
622	replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 2, 1) })
623}
624