1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package rpcreplay 16 17import ( 18 "bytes" 19 "context" 20 "errors" 21 "io" 22 "strings" 23 "testing" 24 25 "cloud.google.com/go/internal/testutil" 26 ipb "cloud.google.com/go/rpcreplay/proto/intstore" 27 pb "cloud.google.com/go/rpcreplay/proto/intstore" 28 rpb "cloud.google.com/go/rpcreplay/proto/rpcreplay" 29 "github.com/golang/protobuf/proto" 30 "github.com/google/go-cmp/cmp" 31 "github.com/google/go-cmp/cmp/cmpopts" 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/codes" 34 "google.golang.org/grpc/status" 35) 36 37func TestRecordIO(t *testing.T) { 38 buf := &bytes.Buffer{} 39 want := []byte{1, 2, 3} 40 if err := writeRecord(buf, want); err != nil { 41 t.Fatal(err) 42 } 43 got, err := readRecord(buf) 44 if err != nil { 45 t.Fatal(err) 46 } 47 if !bytes.Equal(got, want) { 48 t.Errorf("got %v, want %v", got, want) 49 } 50} 51 52func TestHeaderIO(t *testing.T) { 53 buf := &bytes.Buffer{} 54 want := []byte{1, 2, 3} 55 if err := writeHeader(buf, want); err != nil { 56 t.Fatal(err) 57 } 58 got, err := readHeader(buf) 59 if err != nil { 60 t.Fatal(err) 61 } 62 if !testutil.Equal(got, want) { 63 t.Errorf("got %v, want %v", got, want) 64 } 65 66 // readHeader errors 67 for _, contents := range []string{"", "badmagic", "gRPCReplay"} { 68 if _, err := readHeader(bytes.NewBufferString(contents)); err == nil { 69 t.Errorf("%q: got nil, want error", contents) 70 } 71 } 72} 73 74func TestEntryIO(t *testing.T) { 75 for i, want := range []*entry{ 76 { 77 kind: rpb.Entry_REQUEST, 78 method: "method", 79 msg: message{msg: &rpb.Entry{}}, 80 refIndex: 7, 81 }, 82 { 83 kind: rpb.Entry_RESPONSE, 84 method: "method", 85 msg: message{err: status.Error(codes.NotFound, "not found")}, 86 refIndex: 8, 87 }, 88 { 89 kind: rpb.Entry_RECV, 90 method: "method", 91 msg: message{err: io.EOF}, 92 refIndex: 3, 93 }, 94 } { 95 buf := &bytes.Buffer{} 96 if err := writeEntry(buf, want); err != nil { 97 t.Fatal(err) 98 } 99 got, err := readEntry(buf) 100 if err != nil { 101 t.Fatal(err) 102 } 103 if !got.equal(want) { 104 t.Errorf("#%d: got %v, want %v", i, got, want) 105 } 106 } 107} 108 109var initialState = []byte{1, 2, 3} 110 111func TestRecord(t *testing.T) { 112 buf := record(t, testService) 113 114 gotIstate, err := readHeader(buf) 115 if err != nil { 116 t.Fatal(err) 117 } 118 if !testutil.Equal(gotIstate, initialState) { 119 t.Fatalf("got %v, want %v", gotIstate, initialState) 120 } 121 item := &ipb.Item{Name: "a", Value: 1} 122 wantEntries := []*entry{ 123 // Set 124 { 125 kind: rpb.Entry_REQUEST, 126 method: "/intstore.IntStore/Set", 127 msg: message{msg: item}, 128 }, 129 { 130 kind: rpb.Entry_RESPONSE, 131 msg: message{msg: &ipb.SetResponse{PrevValue: 0}}, 132 refIndex: 1, 133 }, 134 // Get 135 { 136 kind: rpb.Entry_REQUEST, 137 method: "/intstore.IntStore/Get", 138 msg: message{msg: &ipb.GetRequest{Name: "a"}}, 139 }, 140 { 141 kind: rpb.Entry_RESPONSE, 142 msg: message{msg: item}, 143 refIndex: 3, 144 }, 145 { 146 kind: rpb.Entry_REQUEST, 147 method: "/intstore.IntStore/Get", 148 msg: message{msg: &ipb.GetRequest{Name: "x"}}, 149 }, 150 { 151 kind: rpb.Entry_RESPONSE, 152 msg: message{err: status.Error(codes.NotFound, `"x"`)}, 153 refIndex: 5, 154 }, 155 // ListItems 156 { // entry #7 157 kind: rpb.Entry_CREATE_STREAM, 158 method: "/intstore.IntStore/ListItems", 159 }, 160 { 161 kind: rpb.Entry_SEND, 162 msg: message{msg: &ipb.ListItemsRequest{}}, 163 refIndex: 7, 164 }, 165 { 166 kind: rpb.Entry_RECV, 167 msg: message{msg: item}, 168 refIndex: 7, 169 }, 170 { 171 kind: rpb.Entry_RECV, 172 msg: message{err: io.EOF}, 173 refIndex: 7, 174 }, 175 // SetStream 176 { // entry #11 177 kind: rpb.Entry_CREATE_STREAM, 178 method: "/intstore.IntStore/SetStream", 179 }, 180 { 181 kind: rpb.Entry_SEND, 182 msg: message{msg: &ipb.Item{Name: "b", Value: 2}}, 183 refIndex: 11, 184 }, 185 { 186 kind: rpb.Entry_SEND, 187 msg: message{msg: &ipb.Item{Name: "c", Value: 3}}, 188 refIndex: 11, 189 }, 190 { 191 kind: rpb.Entry_RECV, 192 msg: message{msg: &ipb.Summary{Count: 2}}, 193 refIndex: 11, 194 }, 195 196 // StreamChat 197 { // entry #15 198 kind: rpb.Entry_CREATE_STREAM, 199 method: "/intstore.IntStore/StreamChat", 200 }, 201 { 202 kind: rpb.Entry_SEND, 203 msg: message{msg: &ipb.Item{Name: "d", Value: 4}}, 204 refIndex: 15, 205 }, 206 { 207 kind: rpb.Entry_RECV, 208 msg: message{msg: &ipb.Item{Name: "d", Value: 4}}, 209 refIndex: 15, 210 }, 211 { 212 kind: rpb.Entry_SEND, 213 msg: message{msg: &ipb.Item{Name: "e", Value: 5}}, 214 refIndex: 15, 215 }, 216 { 217 kind: rpb.Entry_RECV, 218 msg: message{msg: &ipb.Item{Name: "e", Value: 5}}, 219 refIndex: 15, 220 }, 221 { 222 kind: rpb.Entry_RECV, 223 msg: message{err: io.EOF}, 224 refIndex: 15, 225 }, 226 } 227 for i, w := range wantEntries { 228 g, err := readEntry(buf) 229 if err != nil { 230 t.Fatalf("#%d: %v", i+1, err) 231 } 232 if !g.equal(w) { 233 t.Errorf("#%d:\ngot %+v\nwant %+v", i+1, g, w) 234 } 235 } 236 g, err := readEntry(buf) 237 if err != nil { 238 t.Fatal(err) 239 } 240 if g != nil { 241 t.Errorf("\ngot %+v\nwant nil", g) 242 } 243} 244 245func TestReplay(t *testing.T) { 246 buf := record(t, testService) 247 replay(t, buf, testService) 248} 249 250func record(t *testing.T, run func(*testing.T, *grpc.ClientConn)) *bytes.Buffer { 251 srv := newIntStoreServer() 252 defer srv.stop() 253 254 buf := &bytes.Buffer{} 255 rec, err := NewRecorderWriter(buf, initialState) 256 if err != nil { 257 t.Fatal(err) 258 } 259 conn, err := grpc.Dial(srv.Addr, 260 append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...) 261 if err != nil { 262 t.Fatal(err) 263 } 264 defer conn.Close() 265 run(t, conn) 266 if err := rec.Close(); err != nil { 267 t.Fatal(err) 268 } 269 return buf 270} 271 272func replay(t *testing.T, buf *bytes.Buffer, run func(*testing.T, *grpc.ClientConn)) { 273 rep, err := NewReplayerReader(buf) 274 if err != nil { 275 t.Fatal(err) 276 } 277 defer rep.Close() 278 if got, want := rep.Initial(), initialState; !testutil.Equal(got, want) { 279 t.Fatalf("got %v, want %v", got, want) 280 } 281 // Replay the test. 282 conn, err := rep.Connection() 283 if err != nil { 284 t.Fatal(err) 285 } 286 defer conn.Close() 287 run(t, conn) 288} 289 290func testService(t *testing.T, conn *grpc.ClientConn) { 291 client := ipb.NewIntStoreClient(conn) 292 ctx := context.Background() 293 item := &ipb.Item{Name: "a", Value: 1} 294 res, err := client.Set(ctx, item) 295 if err != nil { 296 t.Fatal(err) 297 } 298 if res.PrevValue != 0 { 299 t.Errorf("got %d, want 0", res.PrevValue) 300 } 301 got, err := client.Get(ctx, &ipb.GetRequest{Name: "a"}) 302 if err != nil { 303 t.Fatal(err) 304 } 305 if !proto.Equal(got, item) { 306 t.Errorf("got %v, want %v", got, item) 307 } 308 _, err = client.Get(ctx, &ipb.GetRequest{Name: "x"}) 309 if err == nil { 310 t.Fatal("got nil, want error") 311 } 312 if _, ok := status.FromError(err); !ok { 313 t.Errorf("got error type %T, want a grpc/status.Status", err) 314 } 315 316 gotItems := listItems(t, client, 0) 317 compareLists(t, gotItems, []*ipb.Item{item}) 318 319 ssc, err := client.SetStream(ctx) 320 if err != nil { 321 t.Fatal(err) 322 } 323 324 must := func(err error) { 325 if err != nil { 326 t.Fatal(err) 327 } 328 } 329 330 for i, name := range []string{"b", "c"} { 331 must(ssc.Send(&ipb.Item{Name: name, Value: int32(i + 2)})) 332 } 333 summary, err := ssc.CloseAndRecv() 334 if err != nil { 335 t.Fatal(err) 336 } 337 if got, want := summary.Count, int32(2); got != want { 338 t.Fatalf("got %d, want %d", got, want) 339 } 340 341 chatc, err := client.StreamChat(ctx) 342 if err != nil { 343 t.Fatal(err) 344 } 345 for i, name := range []string{"d", "e"} { 346 item := &ipb.Item{Name: name, Value: int32(i + 4)} 347 must(chatc.Send(item)) 348 got, err := chatc.Recv() 349 if err != nil { 350 t.Fatal(err) 351 } 352 if !proto.Equal(got, item) { 353 t.Errorf("got %v, want %v", got, item) 354 } 355 } 356 must(chatc.CloseSend()) 357 if _, err := chatc.Recv(); err != io.EOF { 358 t.Fatalf("got %v, want EOF", err) 359 } 360} 361 362func listItems(t *testing.T, client ipb.IntStoreClient, greaterThan int) []*ipb.Item { 363 t.Helper() 364 lic, err := client.ListItems(context.Background(), &ipb.ListItemsRequest{GreaterThan: int32(greaterThan)}) 365 if err != nil { 366 t.Fatal(err) 367 } 368 var items []*ipb.Item 369 for i := 0; ; i++ { 370 item, err := lic.Recv() 371 if err == io.EOF { 372 break 373 } 374 if err != nil { 375 t.Fatal(err) 376 } 377 items = append(items, item) 378 } 379 return items 380} 381 382func compareLists(t *testing.T, got, want []*pb.Item) { 383 t.Helper() 384 diff := cmp.Diff(got, want, cmp.Comparer(proto.Equal), cmpopts.SortSlices(func(i1, i2 *pb.Item) bool { 385 return i1.Value < i2.Value 386 })) 387 if diff != "" { 388 t.Error(diff) 389 } 390} 391 392func TestRecorderBeforeFunc(t *testing.T) { 393 var tests = []struct { 394 name string 395 msg, wantRespMsg, wantEntryMsg *ipb.Item 396 f func(string, proto.Message) error 397 wantErr bool 398 }{ 399 { 400 name: "BeforeFunc should modify messages saved, but not alter what is sent/received to/from services", 401 msg: &ipb.Item{Name: "foo", Value: 1}, 402 wantEntryMsg: &ipb.Item{Name: "bar", Value: 2}, 403 wantRespMsg: &ipb.Item{Name: "foo", Value: 1}, 404 f: func(method string, m proto.Message) error { 405 // This callback only runs when Set is called. 406 if !strings.HasSuffix(method, "Set") { 407 return nil 408 } 409 if _, ok := m.(*ipb.Item); !ok { 410 return nil 411 } 412 413 item := m.(*ipb.Item) 414 item.Name = "bar" 415 item.Value = 2 416 return nil 417 }, 418 }, 419 { 420 name: "BeforeFunc should not be able to alter returned responses", 421 msg: &ipb.Item{Name: "foo", Value: 1}, 422 wantRespMsg: &ipb.Item{Name: "foo", Value: 1}, 423 f: func(method string, m proto.Message) error { 424 // This callback only runs when Get is called. 425 if !strings.HasSuffix(method, "Get") { 426 return nil 427 } 428 if _, ok := m.(*ipb.Item); !ok { 429 return nil 430 } 431 432 item := m.(*ipb.Item) 433 item.Value = 2 434 return nil 435 }, 436 }, 437 { 438 name: "Errors should cause the RPC send to fail", 439 msg: &ipb.Item{}, 440 f: func(_ string, _ proto.Message) error { 441 return errors.New("err") 442 }, 443 wantErr: true, 444 }, 445 } 446 447 for _, tc := range tests { 448 // Wrap test cases in a func so defers execute correctly. 449 func() { 450 srv := newIntStoreServer() 451 defer srv.stop() 452 453 var b bytes.Buffer 454 r, err := NewRecorderWriter(&b, nil) 455 if err != nil { 456 t.Error(err) 457 return 458 } 459 r.BeforeFunc = tc.f 460 ctx := context.Background() 461 conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, r.DialOptions()...)...) 462 if err != nil { 463 t.Error(err) 464 return 465 } 466 defer conn.Close() 467 468 client := ipb.NewIntStoreClient(conn) 469 _, err = client.Set(ctx, tc.msg) 470 switch { 471 case err != nil && !tc.wantErr: 472 t.Error(err) 473 return 474 case err == nil && tc.wantErr: 475 t.Errorf("got nil; want error") 476 return 477 case err != nil: 478 // Error found as expected, don't check Get(). 479 return 480 } 481 482 if tc.wantRespMsg != nil { 483 got, err := client.Get(ctx, &ipb.GetRequest{Name: tc.msg.GetName()}) 484 if err != nil { 485 t.Error(err) 486 return 487 } 488 if !cmp.Equal(got, tc.wantRespMsg) { 489 t.Errorf("got %+v; want %+v", got, tc.wantRespMsg) 490 } 491 } 492 493 r.Close() 494 495 if tc.wantEntryMsg != nil { 496 _, _ = readHeader(&b) 497 e, err := readEntry(&b) 498 if err != nil { 499 t.Error(err) 500 return 501 } 502 got := e.msg.msg.(*ipb.Item) 503 if !cmp.Equal(got, tc.wantEntryMsg) { 504 t.Errorf("got %v; want %v", got, tc.wantEntryMsg) 505 } 506 } 507 }() 508 } 509} 510 511func TestReplayerBeforeFunc(t *testing.T) { 512 var tests = []struct { 513 name string 514 msg, reqMsg *ipb.Item 515 f func(string, proto.Message) error 516 wantErr bool 517 }{ 518 { 519 name: "BeforeFunc should modify messages sent before they are passed to the replayer", 520 msg: &ipb.Item{Name: "foo", Value: 1}, 521 reqMsg: &ipb.Item{Name: "bar", Value: 1}, 522 f: func(method string, m proto.Message) error { 523 item := m.(*ipb.Item) 524 item.Name = "foo" 525 return nil 526 }, 527 }, 528 { 529 name: "Errors should cause the RPC send to fail", 530 msg: &ipb.Item{}, 531 f: func(_ string, _ proto.Message) error { 532 return errors.New("err") 533 }, 534 wantErr: true, 535 }, 536 } 537 538 for _, tc := range tests { 539 // Wrap test cases in a func so defers execute correctly. 540 func() { 541 srv := newIntStoreServer() 542 defer srv.stop() 543 544 var b bytes.Buffer 545 rec, err := NewRecorderWriter(&b, nil) 546 if err != nil { 547 t.Error(err) 548 return 549 } 550 ctx := context.Background() 551 conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...) 552 if err != nil { 553 t.Error(err) 554 return 555 } 556 defer conn.Close() 557 558 client := ipb.NewIntStoreClient(conn) 559 _, err = client.Set(ctx, tc.msg) 560 if err != nil { 561 t.Error(err) 562 return 563 } 564 rec.Close() 565 566 rep, err := NewReplayerReader(&b) 567 if err != nil { 568 t.Error(err) 569 return 570 } 571 rep.BeforeFunc = tc.f 572 conn, err = grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...) 573 if err != nil { 574 t.Error(err) 575 return 576 } 577 defer conn.Close() 578 579 client = ipb.NewIntStoreClient(conn) 580 _, err = client.Set(ctx, tc.reqMsg) 581 switch { 582 case err != nil && !tc.wantErr: 583 t.Error(err) 584 case err == nil && tc.wantErr: 585 t.Errorf("got nil; want error") 586 } 587 }() 588 } 589} 590 591func TestOutOfOrderStreamReplay(t *testing.T) { 592 // Check that streams are matched by method and first request sent, if any. 593 594 items := []*ipb.Item{ 595 {Name: "a", Value: 1}, 596 {Name: "b", Value: 2}, 597 {Name: "c", Value: 3}, 598 } 599 run := func(t *testing.T, conn *grpc.ClientConn, arg1, arg2 int) { 600 client := ipb.NewIntStoreClient(conn) 601 ctx := context.Background() 602 // Set some items. 603 for _, item := range items { 604 _, err := client.Set(ctx, item) 605 if err != nil { 606 t.Fatal(err) 607 } 608 } 609 // List them twice, with different requests. 610 compareLists(t, listItems(t, client, arg1), items[arg1:]) 611 compareLists(t, listItems(t, client, arg2), items[arg2:]) 612 } 613 614 srv := newIntStoreServer() 615 defer srv.stop() 616 617 // Replay in the same order. 618 buf := record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) }) 619 replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) }) 620 621 // Replay in a different order. 622 buf = record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) }) 623 replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 2, 1) }) 624} 625