1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package rpcreplay 16 17import ( 18 "bytes" 19 "context" 20 "errors" 21 "io" 22 "strings" 23 "testing" 24 25 "cloud.google.com/go/internal/testutil" 26 ipb "cloud.google.com/go/rpcreplay/proto/intstore" 27 rpb "cloud.google.com/go/rpcreplay/proto/rpcreplay" 28 "github.com/golang/protobuf/proto" 29 "github.com/google/go-cmp/cmp" 30 "github.com/google/go-cmp/cmp/cmpopts" 31 "google.golang.org/grpc" 32 "google.golang.org/grpc/codes" 33 "google.golang.org/grpc/status" 34) 35 36func TestRecordIO(t *testing.T) { 37 buf := &bytes.Buffer{} 38 want := []byte{1, 2, 3} 39 if err := writeRecord(buf, want); err != nil { 40 t.Fatal(err) 41 } 42 got, err := readRecord(buf) 43 if err != nil { 44 t.Fatal(err) 45 } 46 if !bytes.Equal(got, want) { 47 t.Errorf("got %v, want %v", got, want) 48 } 49} 50 51func TestHeaderIO(t *testing.T) { 52 buf := &bytes.Buffer{} 53 want := []byte{1, 2, 3} 54 if err := writeHeader(buf, want); err != nil { 55 t.Fatal(err) 56 } 57 got, err := readHeader(buf) 58 if err != nil { 59 t.Fatal(err) 60 } 61 if !testutil.Equal(got, want) { 62 t.Errorf("got %v, want %v", got, want) 63 } 64 65 // readHeader errors 66 for _, contents := range []string{"", "badmagic", "gRPCReplay"} { 67 if _, err := readHeader(bytes.NewBufferString(contents)); err == nil { 68 t.Errorf("%q: got nil, want error", contents) 69 } 70 } 71} 72 73func TestEntryIO(t *testing.T) { 74 for i, want := range []*entry{ 75 { 76 kind: rpb.Entry_REQUEST, 77 method: "method", 78 msg: message{msg: &rpb.Entry{}}, 79 refIndex: 7, 80 }, 81 { 82 kind: rpb.Entry_RESPONSE, 83 method: "method", 84 msg: message{err: status.Error(codes.NotFound, "not found")}, 85 refIndex: 8, 86 }, 87 { 88 kind: rpb.Entry_RECV, 89 method: "method", 90 msg: message{err: io.EOF}, 91 refIndex: 3, 92 }, 93 } { 94 buf := &bytes.Buffer{} 95 if err := writeEntry(buf, want); err != nil { 96 t.Fatal(err) 97 } 98 got, err := readEntry(buf) 99 if err != nil { 100 t.Fatal(err) 101 } 102 if !got.equal(want) { 103 t.Errorf("#%d: got %v, want %v", i, got, want) 104 } 105 } 106} 107 108var initialState = []byte{1, 2, 3} 109 110func TestRecord(t *testing.T) { 111 buf := record(t, testService) 112 113 gotIstate, err := readHeader(buf) 114 if err != nil { 115 t.Fatal(err) 116 } 117 if !testutil.Equal(gotIstate, initialState) { 118 t.Fatalf("got %v, want %v", gotIstate, initialState) 119 } 120 item := &ipb.Item{Name: "a", Value: 1} 121 wantEntries := []*entry{ 122 // Set 123 { 124 kind: rpb.Entry_REQUEST, 125 method: "/intstore.IntStore/Set", 126 msg: message{msg: item}, 127 }, 128 { 129 kind: rpb.Entry_RESPONSE, 130 msg: message{msg: &ipb.SetResponse{PrevValue: 0}}, 131 refIndex: 1, 132 }, 133 // Get 134 { 135 kind: rpb.Entry_REQUEST, 136 method: "/intstore.IntStore/Get", 137 msg: message{msg: &ipb.GetRequest{Name: "a"}}, 138 }, 139 { 140 kind: rpb.Entry_RESPONSE, 141 msg: message{msg: item}, 142 refIndex: 3, 143 }, 144 { 145 kind: rpb.Entry_REQUEST, 146 method: "/intstore.IntStore/Get", 147 msg: message{msg: &ipb.GetRequest{Name: "x"}}, 148 }, 149 { 150 kind: rpb.Entry_RESPONSE, 151 msg: message{err: status.Error(codes.NotFound, `"x"`)}, 152 refIndex: 5, 153 }, 154 // ListItems 155 { // entry #7 156 kind: rpb.Entry_CREATE_STREAM, 157 method: "/intstore.IntStore/ListItems", 158 }, 159 { 160 kind: rpb.Entry_SEND, 161 msg: message{msg: &ipb.ListItemsRequest{}}, 162 refIndex: 7, 163 }, 164 { 165 kind: rpb.Entry_RECV, 166 msg: message{msg: item}, 167 refIndex: 7, 168 }, 169 { 170 kind: rpb.Entry_RECV, 171 msg: message{err: io.EOF}, 172 refIndex: 7, 173 }, 174 // SetStream 175 { // entry #11 176 kind: rpb.Entry_CREATE_STREAM, 177 method: "/intstore.IntStore/SetStream", 178 }, 179 { 180 kind: rpb.Entry_SEND, 181 msg: message{msg: &ipb.Item{Name: "b", Value: 2}}, 182 refIndex: 11, 183 }, 184 { 185 kind: rpb.Entry_SEND, 186 msg: message{msg: &ipb.Item{Name: "c", Value: 3}}, 187 refIndex: 11, 188 }, 189 { 190 kind: rpb.Entry_RECV, 191 msg: message{msg: &ipb.Summary{Count: 2}}, 192 refIndex: 11, 193 }, 194 195 // StreamChat 196 { // entry #15 197 kind: rpb.Entry_CREATE_STREAM, 198 method: "/intstore.IntStore/StreamChat", 199 }, 200 { 201 kind: rpb.Entry_SEND, 202 msg: message{msg: &ipb.Item{Name: "d", Value: 4}}, 203 refIndex: 15, 204 }, 205 { 206 kind: rpb.Entry_RECV, 207 msg: message{msg: &ipb.Item{Name: "d", Value: 4}}, 208 refIndex: 15, 209 }, 210 { 211 kind: rpb.Entry_SEND, 212 msg: message{msg: &ipb.Item{Name: "e", Value: 5}}, 213 refIndex: 15, 214 }, 215 { 216 kind: rpb.Entry_RECV, 217 msg: message{msg: &ipb.Item{Name: "e", Value: 5}}, 218 refIndex: 15, 219 }, 220 { 221 kind: rpb.Entry_RECV, 222 msg: message{err: io.EOF}, 223 refIndex: 15, 224 }, 225 } 226 for i, w := range wantEntries { 227 g, err := readEntry(buf) 228 if err != nil { 229 t.Fatalf("#%d: %v", i+1, err) 230 } 231 if !g.equal(w) { 232 t.Errorf("#%d:\ngot %+v\nwant %+v", i+1, g, w) 233 } 234 } 235 g, err := readEntry(buf) 236 if err != nil { 237 t.Fatal(err) 238 } 239 if g != nil { 240 t.Errorf("\ngot %+v\nwant nil", g) 241 } 242} 243 244func TestReplay(t *testing.T) { 245 buf := record(t, testService) 246 replay(t, buf, testService) 247} 248 249func record(t *testing.T, run func(*testing.T, *grpc.ClientConn)) *bytes.Buffer { 250 srv := newIntStoreServer() 251 defer srv.stop() 252 253 buf := &bytes.Buffer{} 254 rec, err := NewRecorderWriter(buf, initialState) 255 if err != nil { 256 t.Fatal(err) 257 } 258 conn, err := grpc.Dial(srv.Addr, 259 append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...) 260 if err != nil { 261 t.Fatal(err) 262 } 263 defer conn.Close() 264 run(t, conn) 265 if err := rec.Close(); err != nil { 266 t.Fatal(err) 267 } 268 return buf 269} 270 271func replay(t *testing.T, buf *bytes.Buffer, run func(*testing.T, *grpc.ClientConn)) { 272 rep, err := NewReplayerReader(buf) 273 if err != nil { 274 t.Fatal(err) 275 } 276 defer rep.Close() 277 if got, want := rep.Initial(), initialState; !testutil.Equal(got, want) { 278 t.Fatalf("got %v, want %v", got, want) 279 } 280 // Replay the test. 281 conn, err := rep.Connection() 282 if err != nil { 283 t.Fatal(err) 284 } 285 defer conn.Close() 286 run(t, conn) 287} 288 289func testService(t *testing.T, conn *grpc.ClientConn) { 290 client := ipb.NewIntStoreClient(conn) 291 ctx := context.Background() 292 item := &ipb.Item{Name: "a", Value: 1} 293 res, err := client.Set(ctx, item) 294 if err != nil { 295 t.Fatal(err) 296 } 297 if res.PrevValue != 0 { 298 t.Errorf("got %d, want 0", res.PrevValue) 299 } 300 got, err := client.Get(ctx, &ipb.GetRequest{Name: "a"}) 301 if err != nil { 302 t.Fatal(err) 303 } 304 if !proto.Equal(got, item) { 305 t.Errorf("got %v, want %v", got, item) 306 } 307 _, err = client.Get(ctx, &ipb.GetRequest{Name: "x"}) 308 if err == nil { 309 t.Fatal("got nil, want error") 310 } 311 if _, ok := status.FromError(err); !ok { 312 t.Errorf("got error type %T, want a grpc/status.Status", err) 313 } 314 315 gotItems := listItems(t, client, 0) 316 compareLists(t, gotItems, []*ipb.Item{item}) 317 318 ssc, err := client.SetStream(ctx) 319 if err != nil { 320 t.Fatal(err) 321 } 322 323 must := func(err error) { 324 if err != nil { 325 t.Fatal(err) 326 } 327 } 328 329 for i, name := range []string{"b", "c"} { 330 must(ssc.Send(&ipb.Item{Name: name, Value: int32(i + 2)})) 331 } 332 summary, err := ssc.CloseAndRecv() 333 if err != nil { 334 t.Fatal(err) 335 } 336 if got, want := summary.Count, int32(2); got != want { 337 t.Fatalf("got %d, want %d", got, want) 338 } 339 340 chatc, err := client.StreamChat(ctx) 341 if err != nil { 342 t.Fatal(err) 343 } 344 for i, name := range []string{"d", "e"} { 345 item := &ipb.Item{Name: name, Value: int32(i + 4)} 346 must(chatc.Send(item)) 347 got, err := chatc.Recv() 348 if err != nil { 349 t.Fatal(err) 350 } 351 if !proto.Equal(got, item) { 352 t.Errorf("got %v, want %v", got, item) 353 } 354 } 355 must(chatc.CloseSend()) 356 if _, err := chatc.Recv(); err != io.EOF { 357 t.Fatalf("got %v, want EOF", err) 358 } 359} 360 361func listItems(t *testing.T, client ipb.IntStoreClient, greaterThan int) []*ipb.Item { 362 t.Helper() 363 lic, err := client.ListItems(context.Background(), &ipb.ListItemsRequest{GreaterThan: int32(greaterThan)}) 364 if err != nil { 365 t.Fatal(err) 366 } 367 var items []*ipb.Item 368 for i := 0; ; i++ { 369 item, err := lic.Recv() 370 if err == io.EOF { 371 break 372 } 373 if err != nil { 374 t.Fatal(err) 375 } 376 items = append(items, item) 377 } 378 return items 379} 380 381func compareLists(t *testing.T, got, want []*ipb.Item) { 382 t.Helper() 383 diff := cmp.Diff(got, want, cmp.Comparer(proto.Equal), cmpopts.SortSlices(func(i1, i2 *ipb.Item) bool { 384 return i1.Value < i2.Value 385 })) 386 if diff != "" { 387 t.Error(diff) 388 } 389} 390 391func TestRecorderBeforeFunc(t *testing.T) { 392 var tests = []struct { 393 name string 394 msg, wantRespMsg, wantEntryMsg *ipb.Item 395 f func(string, proto.Message) error 396 wantErr bool 397 }{ 398 { 399 name: "BeforeFunc should modify messages saved, but not alter what is sent/received to/from services", 400 msg: &ipb.Item{Name: "foo", Value: 1}, 401 wantEntryMsg: &ipb.Item{Name: "bar", Value: 2}, 402 wantRespMsg: &ipb.Item{Name: "foo", Value: 1}, 403 f: func(method string, m proto.Message) error { 404 // This callback only runs when Set is called. 405 if !strings.HasSuffix(method, "Set") { 406 return nil 407 } 408 if _, ok := m.(*ipb.Item); !ok { 409 return nil 410 } 411 412 item := m.(*ipb.Item) 413 item.Name = "bar" 414 item.Value = 2 415 return nil 416 }, 417 }, 418 { 419 name: "BeforeFunc should not be able to alter returned responses", 420 msg: &ipb.Item{Name: "foo", Value: 1}, 421 wantRespMsg: &ipb.Item{Name: "foo", Value: 1}, 422 f: func(method string, m proto.Message) error { 423 // This callback only runs when Get is called. 424 if !strings.HasSuffix(method, "Get") { 425 return nil 426 } 427 if _, ok := m.(*ipb.Item); !ok { 428 return nil 429 } 430 431 item := m.(*ipb.Item) 432 item.Value = 2 433 return nil 434 }, 435 }, 436 { 437 name: "Errors should cause the RPC send to fail", 438 msg: &ipb.Item{}, 439 f: func(_ string, _ proto.Message) error { 440 return errors.New("err") 441 }, 442 wantErr: true, 443 }, 444 } 445 446 for _, tc := range tests { 447 // Wrap test cases in a func so defers execute correctly. 448 func() { 449 srv := newIntStoreServer() 450 defer srv.stop() 451 452 var b bytes.Buffer 453 r, err := NewRecorderWriter(&b, nil) 454 if err != nil { 455 t.Error(err) 456 return 457 } 458 r.BeforeFunc = tc.f 459 ctx := context.Background() 460 conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, r.DialOptions()...)...) 461 if err != nil { 462 t.Error(err) 463 return 464 } 465 defer conn.Close() 466 467 client := ipb.NewIntStoreClient(conn) 468 _, err = client.Set(ctx, tc.msg) 469 switch { 470 case err != nil && !tc.wantErr: 471 t.Error(err) 472 return 473 case err == nil && tc.wantErr: 474 t.Errorf("got nil; want error") 475 return 476 case err != nil: 477 // Error found as expected, don't check Get(). 478 return 479 } 480 481 if tc.wantRespMsg != nil { 482 got, err := client.Get(ctx, &ipb.GetRequest{Name: tc.msg.GetName()}) 483 if err != nil { 484 t.Error(err) 485 return 486 } 487 if !cmp.Equal(got, tc.wantRespMsg) { 488 t.Errorf("got %+v; want %+v", got, tc.wantRespMsg) 489 } 490 } 491 492 r.Close() 493 494 if tc.wantEntryMsg != nil { 495 _, _ = readHeader(&b) 496 e, err := readEntry(&b) 497 if err != nil { 498 t.Error(err) 499 return 500 } 501 got := e.msg.msg.(*ipb.Item) 502 if !cmp.Equal(got, tc.wantEntryMsg) { 503 t.Errorf("got %v; want %v", got, tc.wantEntryMsg) 504 } 505 } 506 }() 507 } 508} 509 510func TestReplayerBeforeFunc(t *testing.T) { 511 var tests = []struct { 512 name string 513 msg, reqMsg *ipb.Item 514 f func(string, proto.Message) error 515 wantErr bool 516 }{ 517 { 518 name: "BeforeFunc should modify messages sent before they are passed to the replayer", 519 msg: &ipb.Item{Name: "foo", Value: 1}, 520 reqMsg: &ipb.Item{Name: "bar", Value: 1}, 521 f: func(method string, m proto.Message) error { 522 item := m.(*ipb.Item) 523 item.Name = "foo" 524 return nil 525 }, 526 }, 527 { 528 name: "Errors should cause the RPC send to fail", 529 msg: &ipb.Item{}, 530 f: func(_ string, _ proto.Message) error { 531 return errors.New("err") 532 }, 533 wantErr: true, 534 }, 535 } 536 537 for _, tc := range tests { 538 // Wrap test cases in a func so defers execute correctly. 539 func() { 540 srv := newIntStoreServer() 541 defer srv.stop() 542 543 var b bytes.Buffer 544 rec, err := NewRecorderWriter(&b, nil) 545 if err != nil { 546 t.Error(err) 547 return 548 } 549 ctx := context.Background() 550 conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...) 551 if err != nil { 552 t.Error(err) 553 return 554 } 555 defer conn.Close() 556 557 client := ipb.NewIntStoreClient(conn) 558 _, err = client.Set(ctx, tc.msg) 559 if err != nil { 560 t.Error(err) 561 return 562 } 563 rec.Close() 564 565 rep, err := NewReplayerReader(&b) 566 if err != nil { 567 t.Error(err) 568 return 569 } 570 rep.BeforeFunc = tc.f 571 conn, err = grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...) 572 if err != nil { 573 t.Error(err) 574 return 575 } 576 defer conn.Close() 577 578 client = ipb.NewIntStoreClient(conn) 579 _, err = client.Set(ctx, tc.reqMsg) 580 switch { 581 case err != nil && !tc.wantErr: 582 t.Error(err) 583 case err == nil && tc.wantErr: 584 t.Errorf("got nil; want error") 585 } 586 }() 587 } 588} 589 590func TestOutOfOrderStreamReplay(t *testing.T) { 591 // Check that streams are matched by method and first request sent, if any. 592 593 items := []*ipb.Item{ 594 {Name: "a", Value: 1}, 595 {Name: "b", Value: 2}, 596 {Name: "c", Value: 3}, 597 } 598 run := func(t *testing.T, conn *grpc.ClientConn, arg1, arg2 int) { 599 client := ipb.NewIntStoreClient(conn) 600 ctx := context.Background() 601 // Set some items. 602 for _, item := range items { 603 _, err := client.Set(ctx, item) 604 if err != nil { 605 t.Fatal(err) 606 } 607 } 608 // List them twice, with different requests. 609 compareLists(t, listItems(t, client, arg1), items[arg1:]) 610 compareLists(t, listItems(t, client, arg2), items[arg2:]) 611 } 612 613 srv := newIntStoreServer() 614 defer srv.stop() 615 616 // Replay in the same order. 617 buf := record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) }) 618 replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) }) 619 620 // Replay in a different order. 621 buf = record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) }) 622 replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 2, 1) }) 623} 624