1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package rpcreplay
16
17import (
18	"bytes"
19	"context"
20	"errors"
21	"io"
22	"strings"
23	"testing"
24
25	"cloud.google.com/go/internal/testutil"
26	ipb "cloud.google.com/go/rpcreplay/proto/intstore"
27	rpb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
28	"github.com/golang/protobuf/proto"
29	"github.com/google/go-cmp/cmp"
30	"github.com/google/go-cmp/cmp/cmpopts"
31	"google.golang.org/grpc"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/status"
34	"google.golang.org/protobuf/testing/protocmp"
35)
36
37func TestRecordIO(t *testing.T) {
38	buf := &bytes.Buffer{}
39	want := []byte{1, 2, 3}
40	if err := writeRecord(buf, want); err != nil {
41		t.Fatal(err)
42	}
43	got, err := readRecord(buf)
44	if err != nil {
45		t.Fatal(err)
46	}
47	if !bytes.Equal(got, want) {
48		t.Errorf("got %v, want %v", got, want)
49	}
50}
51
52func TestHeaderIO(t *testing.T) {
53	buf := &bytes.Buffer{}
54	want := []byte{1, 2, 3}
55	if err := writeHeader(buf, want); err != nil {
56		t.Fatal(err)
57	}
58	got, err := readHeader(buf)
59	if err != nil {
60		t.Fatal(err)
61	}
62	if !testutil.Equal(got, want) {
63		t.Errorf("got %v, want %v", got, want)
64	}
65
66	// readHeader errors
67	for _, contents := range []string{"", "badmagic", "gRPCReplay"} {
68		if _, err := readHeader(bytes.NewBufferString(contents)); err == nil {
69			t.Errorf("%q: got nil, want error", contents)
70		}
71	}
72}
73
74func TestEntryIO(t *testing.T) {
75	for i, want := range []*entry{
76		{
77			kind:     rpb.Entry_REQUEST,
78			method:   "method",
79			msg:      message{msg: &rpb.Entry{}},
80			refIndex: 7,
81		},
82		{
83			kind:     rpb.Entry_RESPONSE,
84			method:   "method",
85			msg:      message{err: status.Error(codes.NotFound, "not found")},
86			refIndex: 8,
87		},
88		{
89			kind:     rpb.Entry_RECV,
90			method:   "method",
91			msg:      message{err: io.EOF},
92			refIndex: 3,
93		},
94	} {
95		buf := &bytes.Buffer{}
96		if err := writeEntry(buf, want); err != nil {
97			t.Fatal(err)
98		}
99		got, err := readEntry(buf)
100		if err != nil {
101			t.Fatal(err)
102		}
103		if !got.equal(want) {
104			t.Errorf("#%d: got %v, want %v", i, got, want)
105		}
106	}
107}
108
109var initialState = []byte{1, 2, 3}
110
111func TestRecord(t *testing.T) {
112	buf := record(t, testService)
113
114	gotIstate, err := readHeader(buf)
115	if err != nil {
116		t.Fatal(err)
117	}
118	if !testutil.Equal(gotIstate, initialState) {
119		t.Fatalf("got %v, want %v", gotIstate, initialState)
120	}
121	item := &ipb.Item{Name: "a", Value: 1}
122	wantEntries := []*entry{
123		// Set
124		{
125			kind:   rpb.Entry_REQUEST,
126			method: "/intstore.IntStore/Set",
127			msg:    message{msg: item},
128		},
129		{
130			kind:     rpb.Entry_RESPONSE,
131			msg:      message{msg: &ipb.SetResponse{PrevValue: 0}},
132			refIndex: 1,
133		},
134		// Get
135		{
136			kind:   rpb.Entry_REQUEST,
137			method: "/intstore.IntStore/Get",
138			msg:    message{msg: &ipb.GetRequest{Name: "a"}},
139		},
140		{
141			kind:     rpb.Entry_RESPONSE,
142			msg:      message{msg: item},
143			refIndex: 3,
144		},
145		{
146			kind:   rpb.Entry_REQUEST,
147			method: "/intstore.IntStore/Get",
148			msg:    message{msg: &ipb.GetRequest{Name: "x"}},
149		},
150		{
151			kind:     rpb.Entry_RESPONSE,
152			msg:      message{err: status.Error(codes.NotFound, `"x"`)},
153			refIndex: 5,
154		},
155		// ListItems
156		{ // entry #7
157			kind:   rpb.Entry_CREATE_STREAM,
158			method: "/intstore.IntStore/ListItems",
159		},
160		{
161			kind:     rpb.Entry_SEND,
162			msg:      message{msg: &ipb.ListItemsRequest{}},
163			refIndex: 7,
164		},
165		{
166			kind:     rpb.Entry_RECV,
167			msg:      message{msg: item},
168			refIndex: 7,
169		},
170		{
171			kind:     rpb.Entry_RECV,
172			msg:      message{err: io.EOF},
173			refIndex: 7,
174		},
175		// SetStream
176		{ // entry #11
177			kind:   rpb.Entry_CREATE_STREAM,
178			method: "/intstore.IntStore/SetStream",
179		},
180		{
181			kind:     rpb.Entry_SEND,
182			msg:      message{msg: &ipb.Item{Name: "b", Value: 2}},
183			refIndex: 11,
184		},
185		{
186			kind:     rpb.Entry_SEND,
187			msg:      message{msg: &ipb.Item{Name: "c", Value: 3}},
188			refIndex: 11,
189		},
190		{
191			kind:     rpb.Entry_RECV,
192			msg:      message{msg: &ipb.Summary{Count: 2}},
193			refIndex: 11,
194		},
195
196		// StreamChat
197		{ // entry #15
198			kind:   rpb.Entry_CREATE_STREAM,
199			method: "/intstore.IntStore/StreamChat",
200		},
201		{
202			kind:     rpb.Entry_SEND,
203			msg:      message{msg: &ipb.Item{Name: "d", Value: 4}},
204			refIndex: 15,
205		},
206		{
207			kind:     rpb.Entry_RECV,
208			msg:      message{msg: &ipb.Item{Name: "d", Value: 4}},
209			refIndex: 15,
210		},
211		{
212			kind:     rpb.Entry_SEND,
213			msg:      message{msg: &ipb.Item{Name: "e", Value: 5}},
214			refIndex: 15,
215		},
216		{
217			kind:     rpb.Entry_RECV,
218			msg:      message{msg: &ipb.Item{Name: "e", Value: 5}},
219			refIndex: 15,
220		},
221		{
222			kind:     rpb.Entry_RECV,
223			msg:      message{err: io.EOF},
224			refIndex: 15,
225		},
226	}
227	for i, w := range wantEntries {
228		g, err := readEntry(buf)
229		if err != nil {
230			t.Fatalf("#%d: %v", i+1, err)
231		}
232		if !g.equal(w) {
233			t.Errorf("#%d:\ngot  %+v\nwant %+v", i+1, g, w)
234		}
235	}
236	g, err := readEntry(buf)
237	if err != nil {
238		t.Fatal(err)
239	}
240	if g != nil {
241		t.Errorf("\ngot  %+v\nwant nil", g)
242	}
243}
244
245func TestReplay(t *testing.T) {
246	buf := record(t, testService)
247	replay(t, buf, testService)
248}
249
250func record(t *testing.T, run func(*testing.T, *grpc.ClientConn)) *bytes.Buffer {
251	srv := newIntStoreServer()
252	defer srv.stop()
253
254	buf := &bytes.Buffer{}
255	rec, err := NewRecorderWriter(buf, initialState)
256	if err != nil {
257		t.Fatal(err)
258	}
259	conn, err := grpc.Dial(srv.Addr,
260		append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
261	if err != nil {
262		t.Fatal(err)
263	}
264	defer conn.Close()
265	run(t, conn)
266	if err := rec.Close(); err != nil {
267		t.Fatal(err)
268	}
269	return buf
270}
271
272func replay(t *testing.T, buf *bytes.Buffer, run func(*testing.T, *grpc.ClientConn)) {
273	rep, err := NewReplayerReader(buf)
274	if err != nil {
275		t.Fatal(err)
276	}
277	defer rep.Close()
278	if got, want := rep.Initial(), initialState; !testutil.Equal(got, want) {
279		t.Fatalf("got %v, want %v", got, want)
280	}
281	// Replay the test.
282	conn, err := rep.Connection()
283	if err != nil {
284		t.Fatal(err)
285	}
286	defer conn.Close()
287	run(t, conn)
288}
289
290func testService(t *testing.T, conn *grpc.ClientConn) {
291	client := ipb.NewIntStoreClient(conn)
292	ctx := context.Background()
293	item := &ipb.Item{Name: "a", Value: 1}
294	res, err := client.Set(ctx, item)
295	if err != nil {
296		t.Fatal(err)
297	}
298	if res.PrevValue != 0 {
299		t.Errorf("got %d, want 0", res.PrevValue)
300	}
301	got, err := client.Get(ctx, &ipb.GetRequest{Name: "a"})
302	if err != nil {
303		t.Fatal(err)
304	}
305	if !proto.Equal(got, item) {
306		t.Errorf("got %v, want %v", got, item)
307	}
308	_, err = client.Get(ctx, &ipb.GetRequest{Name: "x"})
309	if err == nil {
310		t.Fatal("got nil, want error")
311	}
312	if _, ok := status.FromError(err); !ok {
313		t.Errorf("got error type %T, want a grpc/status.Status", err)
314	}
315
316	gotItems := listItems(t, client, 0)
317	compareLists(t, gotItems, []*ipb.Item{item})
318
319	ssc, err := client.SetStream(ctx)
320	if err != nil {
321		t.Fatal(err)
322	}
323
324	must := func(err error) {
325		if err != nil {
326			t.Fatal(err)
327		}
328	}
329
330	for i, name := range []string{"b", "c"} {
331		must(ssc.Send(&ipb.Item{Name: name, Value: int32(i + 2)}))
332	}
333	summary, err := ssc.CloseAndRecv()
334	if err != nil {
335		t.Fatal(err)
336	}
337	if got, want := summary.Count, int32(2); got != want {
338		t.Fatalf("got %d, want %d", got, want)
339	}
340
341	chatc, err := client.StreamChat(ctx)
342	if err != nil {
343		t.Fatal(err)
344	}
345	for i, name := range []string{"d", "e"} {
346		item := &ipb.Item{Name: name, Value: int32(i + 4)}
347		must(chatc.Send(item))
348		got, err := chatc.Recv()
349		if err != nil {
350			t.Fatal(err)
351		}
352		if !proto.Equal(got, item) {
353			t.Errorf("got %v, want %v", got, item)
354		}
355	}
356	must(chatc.CloseSend())
357	if _, err := chatc.Recv(); err != io.EOF {
358		t.Fatalf("got %v, want EOF", err)
359	}
360}
361
362func listItems(t *testing.T, client ipb.IntStoreClient, greaterThan int) []*ipb.Item {
363	t.Helper()
364	lic, err := client.ListItems(context.Background(), &ipb.ListItemsRequest{GreaterThan: int32(greaterThan)})
365	if err != nil {
366		t.Fatal(err)
367	}
368	var items []*ipb.Item
369	for i := 0; ; i++ {
370		item, err := lic.Recv()
371		if err == io.EOF {
372			break
373		}
374		if err != nil {
375			t.Fatal(err)
376		}
377		items = append(items, item)
378	}
379	return items
380}
381
382func compareLists(t *testing.T, got, want []*ipb.Item) {
383	t.Helper()
384	diff := cmp.Diff(got, want, cmp.Comparer(proto.Equal), cmpopts.SortSlices(func(i1, i2 *ipb.Item) bool {
385		return i1.Value < i2.Value
386	}))
387	if diff != "" {
388		t.Error(diff)
389	}
390}
391
392func TestRecorderBeforeFunc(t *testing.T) {
393	var tests = []struct {
394		name                           string
395		msg, wantRespMsg, wantEntryMsg *ipb.Item
396		f                              func(string, proto.Message) error
397		wantErr                        bool
398	}{
399		{
400			name:         "BeforeFunc should modify messages saved, but not alter what is sent/received to/from services",
401			msg:          &ipb.Item{Name: "foo", Value: 1},
402			wantEntryMsg: &ipb.Item{Name: "bar", Value: 2},
403			wantRespMsg:  &ipb.Item{Name: "foo", Value: 1},
404			f: func(method string, m proto.Message) error {
405				// This callback only runs when Set is called.
406				if !strings.HasSuffix(method, "Set") {
407					return nil
408				}
409				if _, ok := m.(*ipb.Item); !ok {
410					return nil
411				}
412
413				item := m.(*ipb.Item)
414				item.Name = "bar"
415				item.Value = 2
416				return nil
417			},
418		},
419		{
420			name:        "BeforeFunc should not be able to alter returned responses",
421			msg:         &ipb.Item{Name: "foo", Value: 1},
422			wantRespMsg: &ipb.Item{Name: "foo", Value: 1},
423			f: func(method string, m proto.Message) error {
424				// This callback only runs when Get is called.
425				if !strings.HasSuffix(method, "Get") {
426					return nil
427				}
428				if _, ok := m.(*ipb.Item); !ok {
429					return nil
430				}
431
432				item := m.(*ipb.Item)
433				item.Value = 2
434				return nil
435			},
436		},
437		{
438			name: "Errors should cause the RPC send to fail",
439			msg:  &ipb.Item{},
440			f: func(_ string, _ proto.Message) error {
441				return errors.New("err")
442			},
443			wantErr: true,
444		},
445	}
446
447	for _, tc := range tests {
448		// Wrap test cases in a func so defers execute correctly.
449		func() {
450			srv := newIntStoreServer()
451			defer srv.stop()
452
453			var b bytes.Buffer
454			r, err := NewRecorderWriter(&b, nil)
455			if err != nil {
456				t.Error(err)
457				return
458			}
459			r.BeforeFunc = tc.f
460			ctx := context.Background()
461			conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, r.DialOptions()...)...)
462			if err != nil {
463				t.Error(err)
464				return
465			}
466			defer conn.Close()
467
468			client := ipb.NewIntStoreClient(conn)
469			_, err = client.Set(ctx, tc.msg)
470			switch {
471			case err != nil && !tc.wantErr:
472				t.Error(err)
473				return
474			case err == nil && tc.wantErr:
475				t.Errorf("got nil; want error")
476				return
477			case err != nil:
478				// Error found as expected, don't check Get().
479				return
480			}
481
482			if tc.wantRespMsg != nil {
483				got, err := client.Get(ctx, &ipb.GetRequest{Name: tc.msg.GetName()})
484				if err != nil {
485					t.Error(err)
486					return
487				}
488				if !cmp.Equal(got, tc.wantRespMsg, protocmp.Transform()) {
489					t.Errorf("got %+v; want %+v", got, tc.wantRespMsg)
490				}
491			}
492
493			r.Close()
494
495			if tc.wantEntryMsg != nil {
496				_, _ = readHeader(&b)
497				e, err := readEntry(&b)
498				if err != nil {
499					t.Error(err)
500					return
501				}
502				got := e.msg.msg.(*ipb.Item)
503				if !cmp.Equal(got, tc.wantEntryMsg, protocmp.Transform()) {
504					t.Errorf("got %v; want %v", got, tc.wantEntryMsg)
505				}
506			}
507		}()
508	}
509}
510
511func TestReplayerBeforeFunc(t *testing.T) {
512	var tests = []struct {
513		name        string
514		msg, reqMsg *ipb.Item
515		f           func(string, proto.Message) error
516		wantErr     bool
517	}{
518		{
519			name:   "BeforeFunc should modify messages sent before they are passed to the replayer",
520			msg:    &ipb.Item{Name: "foo", Value: 1},
521			reqMsg: &ipb.Item{Name: "bar", Value: 1},
522			f: func(method string, m proto.Message) error {
523				item := m.(*ipb.Item)
524				item.Name = "foo"
525				return nil
526			},
527		},
528		{
529			name: "Errors should cause the RPC send to fail",
530			msg:  &ipb.Item{},
531			f: func(_ string, _ proto.Message) error {
532				return errors.New("err")
533			},
534			wantErr: true,
535		},
536	}
537
538	for _, tc := range tests {
539		// Wrap test cases in a func so defers execute correctly.
540		func() {
541			srv := newIntStoreServer()
542			defer srv.stop()
543
544			var b bytes.Buffer
545			rec, err := NewRecorderWriter(&b, nil)
546			if err != nil {
547				t.Error(err)
548				return
549			}
550			ctx := context.Background()
551			conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
552			if err != nil {
553				t.Error(err)
554				return
555			}
556			defer conn.Close()
557
558			client := ipb.NewIntStoreClient(conn)
559			_, err = client.Set(ctx, tc.msg)
560			if err != nil {
561				t.Error(err)
562				return
563			}
564			rec.Close()
565
566			rep, err := NewReplayerReader(&b)
567			if err != nil {
568				t.Error(err)
569				return
570			}
571			rep.BeforeFunc = tc.f
572			conn, err = grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
573			if err != nil {
574				t.Error(err)
575				return
576			}
577			defer conn.Close()
578
579			client = ipb.NewIntStoreClient(conn)
580			_, err = client.Set(ctx, tc.reqMsg)
581			switch {
582			case err != nil && !tc.wantErr:
583				t.Error(err)
584			case err == nil && tc.wantErr:
585				t.Errorf("got nil; want error")
586			}
587		}()
588	}
589}
590
591func TestOutOfOrderStreamReplay(t *testing.T) {
592	// Check that streams are matched by method and first request sent, if any.
593
594	items := []*ipb.Item{
595		{Name: "a", Value: 1},
596		{Name: "b", Value: 2},
597		{Name: "c", Value: 3},
598	}
599	run := func(t *testing.T, conn *grpc.ClientConn, arg1, arg2 int) {
600		client := ipb.NewIntStoreClient(conn)
601		ctx := context.Background()
602		// Set some items.
603		for _, item := range items {
604			_, err := client.Set(ctx, item)
605			if err != nil {
606				t.Fatal(err)
607			}
608		}
609		// List them twice, with different requests.
610		compareLists(t, listItems(t, client, arg1), items[arg1:])
611		compareLists(t, listItems(t, client, arg2), items[arg2:])
612	}
613
614	srv := newIntStoreServer()
615	defer srv.stop()
616
617	// Replay in the same order.
618	buf := record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
619	replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
620
621	// Replay in a different order.
622	buf = record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
623	replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 2, 1) })
624}
625