1/* 2Copyright 2017 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spanner 18 19import ( 20 "context" 21 "errors" 22 "fmt" 23 "io" 24 "sync/atomic" 25 "testing" 26 "time" 27 28 . "cloud.google.com/go/spanner/internal/testutil" 29 "github.com/golang/protobuf/proto" 30 proto3 "github.com/golang/protobuf/ptypes/struct" 31 "github.com/googleapis/gax-go/v2" 32 "google.golang.org/api/iterator" 33 sppb "google.golang.org/genproto/googleapis/spanner/v1" 34 "google.golang.org/grpc" 35 "google.golang.org/grpc/codes" 36 "google.golang.org/grpc/status" 37) 38 39var ( 40 // Mocked transaction timestamp. 41 trxTs = time.Unix(1, 2) 42 // Metadata for mocked KV table, its rows are returned by SingleUse 43 // transactions. 44 kvMeta = func() *sppb.ResultSetMetadata { 45 meta := KvMeta 46 meta.Transaction = &sppb.Transaction{ 47 ReadTimestamp: timestampProto(trxTs), 48 } 49 return &meta 50 }() 51 // Metadata for mocked ListKV table, which uses List for its key and value. 52 // Its rows are returned by snapshot readonly transactions, as indicated in 53 // the transaction metadata. 54 kvListMeta = &sppb.ResultSetMetadata{ 55 RowType: &sppb.StructType{ 56 Fields: []*sppb.StructType_Field{ 57 { 58 Name: "Key", 59 Type: &sppb.Type{ 60 Code: sppb.TypeCode_ARRAY, 61 ArrayElementType: &sppb.Type{ 62 Code: sppb.TypeCode_STRING, 63 }, 64 }, 65 }, 66 { 67 Name: "Value", 68 Type: &sppb.Type{ 69 Code: sppb.TypeCode_ARRAY, 70 ArrayElementType: &sppb.Type{ 71 Code: sppb.TypeCode_STRING, 72 }, 73 }, 74 }, 75 }, 76 }, 77 Transaction: &sppb.Transaction{ 78 Id: transactionID{5, 6, 7, 8, 9}, 79 ReadTimestamp: timestampProto(trxTs), 80 }, 81 } 82 // Metadata for mocked schema of a query result set, which has two struct 83 // columns named "Col1" and "Col2", the struct's schema is like the 84 // following: 85 // 86 // STRUCT { 87 // INT 88 // LIST<STRING> 89 // } 90 // 91 // Its rows are returned in readwrite transaction, as indicated in the 92 // transaction metadata. 93 kvObjectMeta = &sppb.ResultSetMetadata{ 94 RowType: &sppb.StructType{ 95 Fields: []*sppb.StructType_Field{ 96 { 97 Name: "Col1", 98 Type: &sppb.Type{ 99 Code: sppb.TypeCode_STRUCT, 100 StructType: &sppb.StructType{ 101 Fields: []*sppb.StructType_Field{ 102 { 103 Name: "foo-f1", 104 Type: &sppb.Type{ 105 Code: sppb.TypeCode_INT64, 106 }, 107 }, 108 { 109 Name: "foo-f2", 110 Type: &sppb.Type{ 111 Code: sppb.TypeCode_ARRAY, 112 ArrayElementType: &sppb.Type{ 113 Code: sppb.TypeCode_STRING, 114 }, 115 }, 116 }, 117 }, 118 }, 119 }, 120 }, 121 { 122 Name: "Col2", 123 Type: &sppb.Type{ 124 Code: sppb.TypeCode_STRUCT, 125 StructType: &sppb.StructType{ 126 Fields: []*sppb.StructType_Field{ 127 { 128 Name: "bar-f1", 129 Type: &sppb.Type{ 130 Code: sppb.TypeCode_INT64, 131 }, 132 }, 133 { 134 Name: "bar-f2", 135 Type: &sppb.Type{ 136 Code: sppb.TypeCode_ARRAY, 137 ArrayElementType: &sppb.Type{ 138 Code: sppb.TypeCode_STRING, 139 }, 140 }, 141 }, 142 }, 143 }, 144 }, 145 }, 146 }, 147 }, 148 Transaction: &sppb.Transaction{ 149 Id: transactionID{1, 2, 3, 4, 5}, 150 }, 151 } 152) 153 154// String implements fmt.stringer. 155func (r *Row) String() string { 156 return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals) 157} 158 159func describeRows(l []*Row) string { 160 // generate a nice test failure description 161 var s = "[" 162 for i, r := range l { 163 if i != 0 { 164 s += ",\n " 165 } 166 s += fmt.Sprint(r) 167 } 168 s += "]" 169 return s 170} 171 172// Helper for generating proto3 Value_ListValue instances, making test code 173// shorter and readable. 174func genProtoListValue(v ...string) *proto3.Value_ListValue { 175 r := &proto3.Value_ListValue{ 176 ListValue: &proto3.ListValue{ 177 Values: []*proto3.Value{}, 178 }, 179 } 180 for _, e := range v { 181 r.ListValue.Values = append( 182 r.ListValue.Values, 183 &proto3.Value{ 184 Kind: &proto3.Value_StringValue{StringValue: e}, 185 }, 186 ) 187 } 188 return r 189} 190 191// Test Row generation logics of partialResultSetDecoder. 192func TestPartialResultSetDecoder(t *testing.T) { 193 restore := setMaxBytesBetweenResumeTokens() 194 defer restore() 195 var tests = []struct { 196 input []*sppb.PartialResultSet 197 wantF []*Row 198 wantTxID transactionID 199 wantTs time.Time 200 wantD bool 201 }{ 202 { 203 // Empty input. 204 wantD: true, 205 }, 206 // String merging examples. 207 { 208 // Single KV result. 209 input: []*sppb.PartialResultSet{ 210 { 211 Metadata: kvMeta, 212 Values: []*proto3.Value{ 213 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 214 {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, 215 }, 216 }, 217 }, 218 wantF: []*Row{ 219 { 220 fields: kvMeta.RowType.Fields, 221 vals: []*proto3.Value{ 222 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 223 {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, 224 }, 225 }, 226 }, 227 wantTs: trxTs, 228 wantD: true, 229 }, 230 { 231 // Incomplete partial result. 232 input: []*sppb.PartialResultSet{ 233 { 234 Metadata: kvMeta, 235 Values: []*proto3.Value{ 236 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 237 }, 238 }, 239 }, 240 wantTs: trxTs, 241 wantD: false, 242 }, 243 { 244 // Complete splitted result. 245 input: []*sppb.PartialResultSet{ 246 { 247 Metadata: kvMeta, 248 Values: []*proto3.Value{ 249 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 250 }, 251 }, 252 { 253 Values: []*proto3.Value{ 254 {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, 255 }, 256 }, 257 }, 258 wantF: []*Row{ 259 { 260 fields: kvMeta.RowType.Fields, 261 vals: []*proto3.Value{ 262 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 263 {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, 264 }, 265 }, 266 }, 267 wantTs: trxTs, 268 wantD: true, 269 }, 270 { 271 // Multi-row example with splitted row in the middle. 272 input: []*sppb.PartialResultSet{ 273 { 274 Metadata: kvMeta, 275 Values: []*proto3.Value{ 276 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 277 {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, 278 {Kind: &proto3.Value_StringValue{StringValue: "A"}}, 279 }, 280 }, 281 { 282 Values: []*proto3.Value{ 283 {Kind: &proto3.Value_StringValue{StringValue: "1"}}, 284 {Kind: &proto3.Value_StringValue{StringValue: "B"}}, 285 {Kind: &proto3.Value_StringValue{StringValue: "2"}}, 286 }, 287 }, 288 }, 289 wantF: []*Row{ 290 { 291 fields: kvMeta.RowType.Fields, 292 vals: []*proto3.Value{ 293 {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, 294 {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, 295 }, 296 }, 297 { 298 fields: kvMeta.RowType.Fields, 299 vals: []*proto3.Value{ 300 {Kind: &proto3.Value_StringValue{StringValue: "A"}}, 301 {Kind: &proto3.Value_StringValue{StringValue: "1"}}, 302 }, 303 }, 304 { 305 fields: kvMeta.RowType.Fields, 306 vals: []*proto3.Value{ 307 {Kind: &proto3.Value_StringValue{StringValue: "B"}}, 308 {Kind: &proto3.Value_StringValue{StringValue: "2"}}, 309 }, 310 }, 311 }, 312 wantTs: trxTs, 313 wantD: true, 314 }, 315 { 316 // Merging example in result_set.proto. 317 input: []*sppb.PartialResultSet{ 318 { 319 Metadata: kvMeta, 320 Values: []*proto3.Value{ 321 {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, 322 {Kind: &proto3.Value_StringValue{StringValue: "W"}}, 323 }, 324 ChunkedValue: true, 325 }, 326 { 327 Values: []*proto3.Value{ 328 {Kind: &proto3.Value_StringValue{StringValue: "orl"}}, 329 }, 330 ChunkedValue: true, 331 }, 332 { 333 Values: []*proto3.Value{ 334 {Kind: &proto3.Value_StringValue{StringValue: "d"}}, 335 }, 336 }, 337 }, 338 wantF: []*Row{ 339 { 340 fields: kvMeta.RowType.Fields, 341 vals: []*proto3.Value{ 342 {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, 343 {Kind: &proto3.Value_StringValue{StringValue: "World"}}, 344 }, 345 }, 346 }, 347 wantTs: trxTs, 348 wantD: true, 349 }, 350 { 351 // More complex example showing completing a merge and 352 // starting a new merge in the same partialResultSet. 353 input: []*sppb.PartialResultSet{ 354 { 355 Metadata: kvMeta, 356 Values: []*proto3.Value{ 357 {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, 358 {Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value 359 }, 360 ChunkedValue: true, 361 }, 362 { 363 Values: []*proto3.Value{ 364 {Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value 365 {Kind: &proto3.Value_StringValue{StringValue: "i"}}, // start split in key 366 }, 367 ChunkedValue: true, 368 }, 369 { 370 Values: []*proto3.Value{ 371 {Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key 372 {Kind: &proto3.Value_StringValue{StringValue: "not"}}, 373 {Kind: &proto3.Value_StringValue{StringValue: "a"}}, 374 {Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value 375 }, 376 ChunkedValue: true, 377 }, 378 { 379 Values: []*proto3.Value{ 380 {Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value 381 }, 382 }, 383 }, 384 wantF: []*Row{ 385 { 386 fields: kvMeta.RowType.Fields, 387 vals: []*proto3.Value{ 388 {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, 389 {Kind: &proto3.Value_StringValue{StringValue: "World"}}, 390 }, 391 }, 392 { 393 fields: kvMeta.RowType.Fields, 394 vals: []*proto3.Value{ 395 {Kind: &proto3.Value_StringValue{StringValue: "is"}}, 396 {Kind: &proto3.Value_StringValue{StringValue: "not"}}, 397 }, 398 }, 399 { 400 fields: kvMeta.RowType.Fields, 401 vals: []*proto3.Value{ 402 {Kind: &proto3.Value_StringValue{StringValue: "a"}}, 403 {Kind: &proto3.Value_StringValue{StringValue: "question"}}, 404 }, 405 }, 406 }, 407 wantTs: trxTs, 408 wantD: true, 409 }, 410 // List merging examples. 411 { 412 // Non-splitting Lists. 413 input: []*sppb.PartialResultSet{ 414 { 415 Metadata: kvListMeta, 416 Values: []*proto3.Value{ 417 { 418 Kind: genProtoListValue("foo-1", "foo-2"), 419 }, 420 }, 421 }, 422 { 423 Values: []*proto3.Value{ 424 { 425 Kind: genProtoListValue("bar-1", "bar-2"), 426 }, 427 }, 428 }, 429 }, 430 wantF: []*Row{ 431 { 432 fields: kvListMeta.RowType.Fields, 433 vals: []*proto3.Value{ 434 { 435 Kind: genProtoListValue("foo-1", "foo-2"), 436 }, 437 { 438 Kind: genProtoListValue("bar-1", "bar-2"), 439 }, 440 }, 441 }, 442 }, 443 wantTxID: transactionID{5, 6, 7, 8, 9}, 444 wantTs: trxTs, 445 wantD: true, 446 }, 447 { 448 // Simple List merge case: splitted string element. 449 input: []*sppb.PartialResultSet{ 450 { 451 Metadata: kvListMeta, 452 Values: []*proto3.Value{ 453 { 454 Kind: genProtoListValue("foo-1", "foo-"), 455 }, 456 }, 457 ChunkedValue: true, 458 }, 459 { 460 Values: []*proto3.Value{ 461 { 462 Kind: genProtoListValue("2"), 463 }, 464 }, 465 }, 466 { 467 Values: []*proto3.Value{ 468 { 469 Kind: genProtoListValue("bar-1", "bar-2"), 470 }, 471 }, 472 }, 473 }, 474 wantF: []*Row{ 475 { 476 fields: kvListMeta.RowType.Fields, 477 vals: []*proto3.Value{ 478 { 479 Kind: genProtoListValue("foo-1", "foo-2"), 480 }, 481 { 482 Kind: genProtoListValue("bar-1", "bar-2"), 483 }, 484 }, 485 }, 486 }, 487 wantTxID: transactionID{5, 6, 7, 8, 9}, 488 wantTs: trxTs, 489 wantD: true, 490 }, 491 { 492 // Struct merging is also implemented by List merging. Note that 493 // Cloud Spanner uses proto.ListValue to encode Structs as well. 494 input: []*sppb.PartialResultSet{ 495 { 496 Metadata: kvObjectMeta, 497 Values: []*proto3.Value{ 498 { 499 Kind: &proto3.Value_ListValue{ 500 ListValue: &proto3.ListValue{ 501 Values: []*proto3.Value{ 502 {Kind: &proto3.Value_NumberValue{NumberValue: 23}}, 503 {Kind: genProtoListValue("foo-1", "fo")}, 504 }, 505 }, 506 }, 507 }, 508 }, 509 ChunkedValue: true, 510 }, 511 { 512 Values: []*proto3.Value{ 513 { 514 Kind: &proto3.Value_ListValue{ 515 ListValue: &proto3.ListValue{ 516 Values: []*proto3.Value{ 517 {Kind: genProtoListValue("o-2", "f")}, 518 }, 519 }, 520 }, 521 }, 522 }, 523 ChunkedValue: true, 524 }, 525 { 526 Values: []*proto3.Value{ 527 { 528 Kind: &proto3.Value_ListValue{ 529 ListValue: &proto3.ListValue{ 530 Values: []*proto3.Value{ 531 {Kind: genProtoListValue("oo-3")}, 532 }, 533 }, 534 }, 535 }, 536 { 537 Kind: &proto3.Value_ListValue{ 538 ListValue: &proto3.ListValue{ 539 Values: []*proto3.Value{ 540 {Kind: &proto3.Value_NumberValue{NumberValue: 45}}, 541 {Kind: genProtoListValue("bar-1")}, 542 }, 543 }, 544 }, 545 }, 546 }, 547 }, 548 }, 549 wantF: []*Row{ 550 { 551 fields: kvObjectMeta.RowType.Fields, 552 vals: []*proto3.Value{ 553 { 554 Kind: &proto3.Value_ListValue{ 555 ListValue: &proto3.ListValue{ 556 Values: []*proto3.Value{ 557 {Kind: &proto3.Value_NumberValue{NumberValue: 23}}, 558 {Kind: genProtoListValue("foo-1", "foo-2", "foo-3")}, 559 }, 560 }, 561 }, 562 }, 563 { 564 Kind: &proto3.Value_ListValue{ 565 ListValue: &proto3.ListValue{ 566 Values: []*proto3.Value{ 567 {Kind: &proto3.Value_NumberValue{NumberValue: 45}}, 568 {Kind: genProtoListValue("bar-1")}, 569 }, 570 }, 571 }, 572 }, 573 }, 574 }, 575 }, 576 wantTxID: transactionID{1, 2, 3, 4, 5}, 577 wantD: true, 578 }, 579 } 580 581nextTest: 582 for i, test := range tests { 583 var rows []*Row 584 p := &partialResultSetDecoder{} 585 for j, v := range test.input { 586 rs, err := p.add(v) 587 if err != nil { 588 t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err) 589 continue nextTest 590 } 591 rows = append(rows, rs...) 592 } 593 if !testEqual(p.ts, test.wantTs) { 594 t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs) 595 } 596 if !testEqual(rows, test.wantF) { 597 t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row) 598 } 599 if got := p.done(); got != test.wantD { 600 t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got) 601 } 602 } 603} 604 605const ( 606 // max number of PartialResultSets that will be buffered in tests. 607 maxBuffers = 16 608) 609 610// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to 611// a smaller value more suitable for tests. It returns a function which should 612// be called to restore the maxBytesBetweenResumeTokens to its old value. 613func setMaxBytesBetweenResumeTokens() func() { 614 o := atomic.LoadInt32(&maxBytesBetweenResumeTokens) 615 atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{ 616 Metadata: kvMeta, 617 Values: []*proto3.Value{ 618 {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, 619 {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, 620 }, 621 }))) 622 return func() { 623 atomic.StoreInt32(&maxBytesBetweenResumeTokens, o) 624 } 625} 626 627// keyStr generates key string for kvMeta schema. 628func keyStr(i int) string { 629 return fmt.Sprintf("foo-%02d", i) 630} 631 632// valStr generates value string for kvMeta schema. 633func valStr(i int) string { 634 return fmt.Sprintf("bar-%02d", i) 635} 636 637// Test state transitions of resumableStreamDecoder where state machine ends up 638// to a non-blocking state(resumableStreamDecoder.Next returns on non-blocking 639// state). 640func TestRsdNonblockingStates(t *testing.T) { 641 restore := setMaxBytesBetweenResumeTokens() 642 defer restore() 643 tests := []struct { 644 name string 645 msgs []MockCtlMsg 646 rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) 647 sql string 648 // Expected values 649 want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller 650 queue []*sppb.PartialResultSet // PartialResultSets that should be buffered 651 resumeToken []byte // Resume token that is maintained by resumableStreamDecoder 652 stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder 653 wantErr error 654 }{ 655 { 656 // unConnected->queueingRetryable->finished 657 name: "unConnected->queueingRetryable->finished", 658 msgs: []MockCtlMsg{ 659 {}, 660 {}, 661 {Err: io.EOF, ResumeToken: false}, 662 }, 663 sql: "SELECT t.key key, t.value value FROM t_mock t", 664 want: []*sppb.PartialResultSet{ 665 { 666 Metadata: kvMeta, 667 Values: []*proto3.Value{ 668 {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, 669 {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, 670 }, 671 }, 672 }, 673 queue: []*sppb.PartialResultSet{ 674 { 675 Metadata: kvMeta, 676 Values: []*proto3.Value{ 677 {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, 678 {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, 679 }, 680 }, 681 }, 682 stateHistory: []resumableStreamDecoderState{ 683 queueingRetryable, // do RPC 684 queueingRetryable, // got foo-00 685 queueingRetryable, // got foo-01 686 finished, // got EOF 687 }, 688 }, 689 { 690 // unConnected->queueingRetryable->aborted 691 name: "unConnected->queueingRetryable->aborted", 692 msgs: []MockCtlMsg{ 693 {}, 694 {Err: nil, ResumeToken: true}, 695 {}, 696 {Err: errors.New("I quit"), ResumeToken: false}, 697 }, 698 sql: "SELECT t.key key, t.value value FROM t_mock t", 699 want: []*sppb.PartialResultSet{ 700 { 701 Metadata: kvMeta, 702 Values: []*proto3.Value{ 703 {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, 704 {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, 705 }, 706 }, 707 { 708 Metadata: kvMeta, 709 Values: []*proto3.Value{ 710 {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, 711 {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, 712 }, 713 ResumeToken: EncodeResumeToken(1), 714 }, 715 }, 716 stateHistory: []resumableStreamDecoderState{ 717 queueingRetryable, // do RPC 718 queueingRetryable, // got foo-00 719 queueingRetryable, // got foo-01 720 queueingRetryable, // foo-01, resume token 721 queueingRetryable, // got foo-02 722 aborted, // got error 723 }, 724 wantErr: status.Errorf(codes.Unknown, "I quit"), 725 }, 726 { 727 // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable 728 name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable", 729 msgs: func() (m []MockCtlMsg) { 730 for i := 0; i < maxBuffers+1; i++ { 731 m = append(m, MockCtlMsg{}) 732 } 733 return m 734 }(), 735 sql: "SELECT t.key key, t.value value FROM t_mock t", 736 want: func() (s []*sppb.PartialResultSet) { 737 for i := 0; i < maxBuffers+1; i++ { 738 s = append(s, &sppb.PartialResultSet{ 739 Metadata: kvMeta, 740 Values: []*proto3.Value{ 741 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 742 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 743 }, 744 }) 745 } 746 return s 747 }(), 748 stateHistory: func() (s []resumableStreamDecoderState) { 749 s = append(s, queueingRetryable) // RPC 750 for i := 0; i < maxBuffers; i++ { 751 s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up 752 } 753 // the first item fills up the queue and triggers state transition; 754 // the second item is received under queueingUnretryable state. 755 s = append(s, queueingUnretryable) 756 s = append(s, queueingUnretryable) 757 return s 758 }(), 759 }, 760 { 761 // unConnected->queueingRetryable->queueingUnretryable->aborted 762 name: "unConnected->queueingRetryable->queueingUnretryable->aborted", 763 msgs: func() (m []MockCtlMsg) { 764 for i := 0; i < maxBuffers; i++ { 765 m = append(m, MockCtlMsg{}) 766 } 767 m = append(m, MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false}) 768 return m 769 }(), 770 sql: "SELECT t.key key, t.value value FROM t_mock t", 771 want: func() (s []*sppb.PartialResultSet) { 772 for i := 0; i < maxBuffers; i++ { 773 s = append(s, &sppb.PartialResultSet{ 774 Metadata: kvMeta, 775 Values: []*proto3.Value{ 776 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 777 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 778 }, 779 }) 780 } 781 return s 782 }(), 783 stateHistory: func() (s []resumableStreamDecoderState) { 784 s = append(s, queueingRetryable) // RPC 785 for i := 0; i < maxBuffers; i++ { 786 s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up 787 } 788 s = append(s, queueingUnretryable) // the last row triggers state change 789 s = append(s, aborted) // Error happens 790 return s 791 }(), 792 wantErr: status.Errorf(codes.Unknown, "Just Abort It"), 793 }, 794 } 795 for _, test := range tests { 796 t.Run(test.name, func(t *testing.T) { 797 ms := NewMockCloudSpanner(t, trxTs) 798 ms.Serve() 799 mc := sppb.NewSpannerClient(dialMock(t, ms)) 800 if test.rpc == nil { 801 test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 802 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 803 Sql: test.sql, 804 ResumeToken: resumeToken, 805 }) 806 } 807 } 808 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 809 defer cancel() 810 r := newResumableStreamDecoder( 811 ctx, 812 nil, 813 test.rpc, 814 nil, 815 ) 816 st := []resumableStreamDecoderState{} 817 var lastErr error 818 // Once the expected number of state transitions are observed, 819 // send a signal by setting stateDone = true. 820 stateDone := false 821 // Set stateWitness to listen to state changes. 822 hl := len(test.stateHistory) // To avoid data race on test. 823 r.stateWitness = func(rs resumableStreamDecoderState) { 824 if !stateDone { 825 // Record state transitions. 826 st = append(st, rs) 827 if len(st) == hl { 828 lastErr = r.lastErr() 829 stateDone = true 830 } 831 } 832 } 833 // Let mock server stream given messages to resumableStreamDecoder. 834 for _, m := range test.msgs { 835 ms.AddMsg(m.Err, m.ResumeToken) 836 } 837 var rs []*sppb.PartialResultSet 838 for { 839 select { 840 case <-ctx.Done(): 841 t.Fatal("context cancelled or timeout during test") 842 default: 843 } 844 if stateDone { 845 // Check if resumableStreamDecoder carried out expected 846 // state transitions. 847 if !testEqual(st, test.stateHistory) { 848 t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory) 849 } 850 // Check if resumableStreamDecoder returns expected array of 851 // PartialResultSets. 852 if !testEqual(rs, test.want) { 853 t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want) 854 } 855 // Verify that resumableStreamDecoder's internal buffering is 856 // also correct. 857 var q []*sppb.PartialResultSet 858 for { 859 item := r.q.pop() 860 if item == nil { 861 break 862 } 863 q = append(q, item) 864 } 865 if !testEqual(q, test.queue) { 866 t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue) 867 } 868 // Verify resume token. 869 if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) { 870 t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken) 871 } 872 // Verify error message. 873 if !testEqual(lastErr, test.wantErr) { 874 t.Fatalf("got error %v, want %v", lastErr, test.wantErr) 875 } 876 return 877 } 878 // Receive next decoded item. 879 if r.next() { 880 rs = append(rs, r.get()) 881 } 882 } 883 }) 884 } 885} 886 887// Test state transitions of resumableStreamDecoder where state machine 888// ends up to a blocking state(resumableStreamDecoder.Next blocks 889// on blocking state). 890func TestRsdBlockingStates(t *testing.T) { 891 restore := setMaxBytesBetweenResumeTokens() 892 defer restore() 893 for _, test := range []struct { 894 name string 895 msgs []MockCtlMsg 896 rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) 897 sql string 898 // Expected values 899 want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller 900 queue []*sppb.PartialResultSet // PartialResultSets that should be buffered 901 resumeToken []byte // Resume token that is maintained by resumableStreamDecoder 902 stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder 903 wantErr error 904 }{ 905 { 906 // unConnected -> unConnected 907 name: "unConnected -> unConnected", 908 rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 909 return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable") 910 }, 911 sql: "SELECT * from t_whatever", 912 stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected}, 913 wantErr: status.Errorf(codes.Unavailable, "trust me: server is unavailable"), 914 }, 915 { 916 // unConnected -> queueingRetryable 917 name: "unConnected -> queueingRetryable", 918 sql: "SELECT t.key key, t.value value FROM t_mock t", 919 stateHistory: []resumableStreamDecoderState{queueingRetryable}, 920 }, 921 { 922 // unConnected->queueingRetryable->queueingRetryable 923 name: "unConnected->queueingRetryable->queueingRetryable", 924 msgs: []MockCtlMsg{ 925 {}, 926 {Err: nil, ResumeToken: true}, 927 {Err: nil, ResumeToken: true}, 928 {}, 929 }, 930 sql: "SELECT t.key key, t.value value FROM t_mock t", 931 want: []*sppb.PartialResultSet{ 932 { 933 Metadata: kvMeta, 934 Values: []*proto3.Value{ 935 {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, 936 {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, 937 }, 938 }, 939 { 940 Metadata: kvMeta, 941 Values: []*proto3.Value{ 942 {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, 943 {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, 944 }, 945 ResumeToken: EncodeResumeToken(1), 946 }, 947 { 948 Metadata: kvMeta, 949 Values: []*proto3.Value{ 950 {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}}, 951 {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}}, 952 }, 953 ResumeToken: EncodeResumeToken(2), 954 }, 955 }, 956 queue: []*sppb.PartialResultSet{ 957 { 958 Metadata: kvMeta, 959 Values: []*proto3.Value{ 960 {Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}}, 961 {Kind: &proto3.Value_StringValue{StringValue: valStr(3)}}, 962 }, 963 }, 964 }, 965 resumeToken: EncodeResumeToken(2), 966 stateHistory: []resumableStreamDecoderState{ 967 queueingRetryable, // do RPC 968 queueingRetryable, // got foo-00 969 queueingRetryable, // got foo-01 970 queueingRetryable, // foo-01, resume token 971 queueingRetryable, // got foo-02 972 queueingRetryable, // foo-02, resume token 973 queueingRetryable, // got foo-03 974 }, 975 }, 976 { 977 // unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable 978 name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable", 979 msgs: func() (m []MockCtlMsg) { 980 for i := 0; i < maxBuffers+1; i++ { 981 m = append(m, MockCtlMsg{}) 982 } 983 m = append(m, MockCtlMsg{Err: nil, ResumeToken: true}) 984 m = append(m, MockCtlMsg{}) 985 return m 986 }(), 987 sql: "SELECT t.key key, t.value value FROM t_mock t", 988 want: func() (s []*sppb.PartialResultSet) { 989 for i := 0; i < maxBuffers+2; i++ { 990 s = append(s, &sppb.PartialResultSet{ 991 Metadata: kvMeta, 992 Values: []*proto3.Value{ 993 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 994 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 995 }, 996 }) 997 } 998 s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1) 999 return s 1000 }(), 1001 resumeToken: EncodeResumeToken(maxBuffers + 1), 1002 queue: []*sppb.PartialResultSet{ 1003 { 1004 Metadata: kvMeta, 1005 Values: []*proto3.Value{ 1006 {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}}, 1007 {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}}, 1008 }, 1009 }, 1010 }, 1011 stateHistory: func() (s []resumableStreamDecoderState) { 1012 s = append(s, queueingRetryable) // RPC 1013 for i := 0; i < maxBuffers; i++ { 1014 s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up 1015 } 1016 for i := maxBuffers - 1; i < maxBuffers+1; i++ { 1017 // the first item fills up the queue and triggers state 1018 // change; the second item is received under 1019 // queueingUnretryable state. 1020 s = append(s, queueingUnretryable) 1021 } 1022 s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state 1023 s = append(s, queueingRetryable) // (maxBuffers+1)th row has resume token 1024 s = append(s, queueingRetryable) // (maxBuffers+2)th row has no resume token 1025 return s 1026 }(), 1027 }, 1028 { 1029 // unConnected->queueingRetryable->queueingUnretryable->finished 1030 name: "unConnected->queueingRetryable->queueingUnretryable->finished", 1031 msgs: func() (m []MockCtlMsg) { 1032 for i := 0; i < maxBuffers; i++ { 1033 m = append(m, MockCtlMsg{}) 1034 } 1035 m = append(m, MockCtlMsg{Err: io.EOF, ResumeToken: false}) 1036 return m 1037 }(), 1038 sql: "SELECT t.key key, t.value value FROM t_mock t", 1039 want: func() (s []*sppb.PartialResultSet) { 1040 for i := 0; i < maxBuffers; i++ { 1041 s = append(s, &sppb.PartialResultSet{ 1042 Metadata: kvMeta, 1043 Values: []*proto3.Value{ 1044 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 1045 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 1046 }, 1047 }) 1048 } 1049 return s 1050 }(), 1051 stateHistory: func() (s []resumableStreamDecoderState) { 1052 s = append(s, queueingRetryable) // RPC 1053 for i := 0; i < maxBuffers; i++ { 1054 s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up 1055 } 1056 s = append(s, queueingUnretryable) // last row triggers state change 1057 s = append(s, finished) // query finishes 1058 return s 1059 }(), 1060 }, 1061 } { 1062 t.Run(test.name, func(t *testing.T) { 1063 ms := NewMockCloudSpanner(t, trxTs) 1064 ms.Serve() 1065 cc := dialMock(t, ms) 1066 mc := sppb.NewSpannerClient(cc) 1067 if test.rpc == nil { 1068 // Avoid using test.sql directly in closure because for loop changes 1069 // test. 1070 sql := test.sql 1071 test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1072 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1073 Sql: sql, 1074 ResumeToken: resumeToken, 1075 }) 1076 } 1077 } 1078 ctx, cancel := context.WithCancel(context.Background()) 1079 defer cancel() 1080 r := newResumableStreamDecoder( 1081 ctx, 1082 nil, 1083 test.rpc, 1084 nil, 1085 ) 1086 // Override backoff to make the test run faster. 1087 r.backoff = gax.Backoff{ 1088 Initial: 1 * time.Nanosecond, 1089 Max: 1 * time.Nanosecond, 1090 Multiplier: 1.3, 1091 } 1092 // st is the set of observed state transitions. 1093 st := []resumableStreamDecoderState{} 1094 // q is the content of the decoder's partial result queue when expected 1095 // number of state transitions are done. 1096 q := []*sppb.PartialResultSet{} 1097 var lastErr error 1098 // Once the expected number of state transitions are observed, send a 1099 // signal to channel stateDone. 1100 stateDone := make(chan int) 1101 // Set stateWitness to listen to state changes. 1102 hl := len(test.stateHistory) // To avoid data race on test. 1103 r.stateWitness = func(rs resumableStreamDecoderState) { 1104 select { 1105 case <-stateDone: 1106 // Noop after expected number of state transitions 1107 default: 1108 // Record state transitions. 1109 st = append(st, rs) 1110 if len(st) == hl { 1111 lastErr = r.lastErr() 1112 q = r.q.dump() 1113 close(stateDone) 1114 } 1115 } 1116 } 1117 // Let mock server stream given messages to resumableStreamDecoder. 1118 for _, m := range test.msgs { 1119 ms.AddMsg(m.Err, m.ResumeToken) 1120 } 1121 var rs []*sppb.PartialResultSet 1122 go func() { 1123 for { 1124 if !r.next() { 1125 // Note that r.Next also exits on context cancel/timeout. 1126 return 1127 } 1128 rs = append(rs, r.get()) 1129 } 1130 }() 1131 // Verify that resumableStreamDecoder reaches expected state. 1132 select { 1133 case <-stateDone: // Note that at this point, receiver is still blockingon r.next(). 1134 // Check if resumableStreamDecoder carried out expected state 1135 // transitions. 1136 if !testEqual(st, test.stateHistory) { 1137 t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory) 1138 } 1139 // Check if resumableStreamDecoder returns expected array of 1140 // PartialResultSets. 1141 if !testEqual(rs, test.want) { 1142 t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want) 1143 } 1144 // Verify that resumableStreamDecoder's internal buffering is also 1145 // correct. 1146 if !testEqual(q, test.queue) { 1147 t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue) 1148 } 1149 // Verify resume token. 1150 if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) { 1151 t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken) 1152 } 1153 // Verify error message. 1154 if !testEqual(lastErr, test.wantErr) { 1155 t.Fatalf("got error %v, want %v", lastErr, test.wantErr) 1156 } 1157 case <-time.After(1 * time.Second): 1158 t.Fatal("Timeout in waiting for state change") 1159 } 1160 ms.Stop() 1161 if err := cc.Close(); err != nil { 1162 t.Fatal(err) 1163 } 1164 }) 1165 } 1166} 1167 1168// sReceiver signals every receiving attempt through a channel, used by 1169// TestResumeToken to determine if the receiving of a certain PartialResultSet 1170// will be attempted next. 1171type sReceiver struct { 1172 c chan int 1173 rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient 1174} 1175 1176// Recv() implements streamingReceiver.Recv for sReceiver. 1177func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) { 1178 sr.c <- 1 1179 return sr.rpcReceiver.Recv() 1180} 1181 1182// waitn waits for nth receiving attempt from now on, until the signal for nth 1183// Recv() attempts is received or timeout. Note that because the way stream() 1184// works, the signal for the nth Recv() means that the previous n - 1 1185// PartialResultSets has already been returned to caller or queued, if no error 1186// happened. 1187func (sr *sReceiver) waitn(n int) error { 1188 for i := 0; i < n; i++ { 1189 select { 1190 case <-sr.c: 1191 case <-time.After(10 * time.Second): 1192 return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1) 1193 } 1194 } 1195 return nil 1196} 1197 1198// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens. 1199func TestQueueBytes(t *testing.T) { 1200 restore := setMaxBytesBetweenResumeTokens() 1201 defer restore() 1202 ms := NewMockCloudSpanner(t, trxTs) 1203 ms.Serve() 1204 defer ms.Stop() 1205 cc := dialMock(t, ms) 1206 defer cc.Close() 1207 mc := sppb.NewSpannerClient(cc) 1208 sr := &sReceiver{ 1209 c: make(chan int, 1000), // will never block in this test 1210 } 1211 wantQueueBytes := 0 1212 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 1213 defer cancel() 1214 r := newResumableStreamDecoder( 1215 ctx, 1216 nil, 1217 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1218 r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1219 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1220 ResumeToken: resumeToken, 1221 }) 1222 sr.rpcReceiver = r 1223 return sr, err 1224 }, 1225 nil, 1226 ) 1227 go func() { 1228 for r.next() { 1229 } 1230 }() 1231 // Let server send maxBuffers / 2 rows. 1232 for i := 0; i < maxBuffers/2; i++ { 1233 wantQueueBytes += proto.Size(&sppb.PartialResultSet{ 1234 Metadata: kvMeta, 1235 Values: []*proto3.Value{ 1236 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 1237 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 1238 }, 1239 }) 1240 ms.AddMsg(nil, false) 1241 } 1242 if err := sr.waitn(maxBuffers/2 + 1); err != nil { 1243 t.Fatalf("failed to wait for the first %v recv() calls: %v", maxBuffers, err) 1244 } 1245 if int32(wantQueueBytes) != r.bytesBetweenResumeTokens { 1246 t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", r.bytesBetweenResumeTokens, wantQueueBytes) 1247 } 1248 // Now send a resume token to drain the queue. 1249 ms.AddMsg(nil, true) 1250 // Wait for all rows to be processes. 1251 if err := sr.waitn(1); err != nil { 1252 t.Fatalf("failed to wait for rows to be processed: %v", err) 1253 } 1254 if r.bytesBetweenResumeTokens != 0 { 1255 t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens) 1256 } 1257 // Let server send maxBuffers - 1 rows. 1258 wantQueueBytes = 0 1259 for i := 0; i < maxBuffers-1; i++ { 1260 wantQueueBytes += proto.Size(&sppb.PartialResultSet{ 1261 Metadata: kvMeta, 1262 Values: []*proto3.Value{ 1263 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 1264 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 1265 }, 1266 }) 1267 ms.AddMsg(nil, false) 1268 } 1269 if err := sr.waitn(maxBuffers - 1); err != nil { 1270 t.Fatalf("failed to wait for %v rows to be processed: %v", maxBuffers-1, err) 1271 } 1272 if int32(wantQueueBytes) != r.bytesBetweenResumeTokens { 1273 t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens) 1274 } 1275 // Trigger a state transition: queueingRetryable -> queueingUnretryable. 1276 ms.AddMsg(nil, false) 1277 if err := sr.waitn(1); err != nil { 1278 t.Fatalf("failed to wait for state transition: %v", err) 1279 } 1280 if r.bytesBetweenResumeTokens != 0 { 1281 t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens) 1282 } 1283} 1284 1285// Verify that client can deal with resume token correctly 1286func TestResumeToken(t *testing.T) { 1287 restore := setMaxBytesBetweenResumeTokens() 1288 defer restore() 1289 ms := NewMockCloudSpanner(t, trxTs) 1290 ms.Serve() 1291 defer ms.Stop() 1292 cc := dialMock(t, ms) 1293 defer cc.Close() 1294 mc := sppb.NewSpannerClient(cc) 1295 sr := &sReceiver{ 1296 c: make(chan int, 1000), // will never block in this test 1297 } 1298 rows := []*Row{} 1299 done := make(chan error) 1300 streaming := func() { 1301 // Establish a stream to mock cloud spanner server. 1302 iter := stream(context.Background(), nil, 1303 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1304 r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1305 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1306 ResumeToken: resumeToken, 1307 }) 1308 sr.rpcReceiver = r 1309 return sr, err 1310 }, 1311 nil, 1312 func(error) {}) 1313 defer iter.Stop() 1314 var err error 1315 for { 1316 var row *Row 1317 row, err = iter.Next() 1318 if err == iterator.Done { 1319 err = nil 1320 break 1321 } 1322 if err != nil { 1323 break 1324 } 1325 rows = append(rows, row) 1326 } 1327 done <- err 1328 } 1329 go streaming() 1330 // Server streaming row 0 - 2, only row 1 has resume token. 1331 // Client will receive row 0 - 2, so it will try receiving for 1332 // 4 times (the last recv will block), and only row 0 - 1 will 1333 // be yielded. 1334 for i := 0; i < 3; i++ { 1335 if i == 1 { 1336 ms.AddMsg(nil, true) 1337 } else { 1338 ms.AddMsg(nil, false) 1339 } 1340 } 1341 // Wait for 4 receive attempts, as explained above. 1342 if err := sr.waitn(4); err != nil { 1343 t.Fatalf("failed to wait for row 0 - 2: %v", err) 1344 } 1345 want := []*Row{ 1346 { 1347 fields: kvMeta.RowType.Fields, 1348 vals: []*proto3.Value{ 1349 {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, 1350 {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, 1351 }, 1352 }, 1353 { 1354 fields: kvMeta.RowType.Fields, 1355 vals: []*proto3.Value{ 1356 {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, 1357 {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, 1358 }, 1359 }, 1360 } 1361 if !testEqual(rows, want) { 1362 t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want) 1363 } 1364 // Inject resumable failure. 1365 ms.AddMsg( 1366 status.Errorf(codes.Unavailable, "mock server unavailable"), 1367 false, 1368 ) 1369 // Test if client detects the resumable failure and retries. 1370 if err := sr.waitn(1); err != nil { 1371 t.Fatalf("failed to wait for client to retry: %v", err) 1372 } 1373 // Client has resumed the query, now server resend row 2. 1374 ms.AddMsg(nil, true) 1375 if err := sr.waitn(1); err != nil { 1376 t.Fatalf("failed to wait for resending row 2: %v", err) 1377 } 1378 // Now client should have received row 0 - 2. 1379 want = append(want, &Row{ 1380 fields: kvMeta.RowType.Fields, 1381 vals: []*proto3.Value{ 1382 {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}}, 1383 {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}}, 1384 }, 1385 }) 1386 if !testEqual(rows, want) { 1387 t.Errorf("received rows: \n%v\n, want\n%v\n", rows, want) 1388 } 1389 // Sending 3rd - (maxBuffers+1)th rows without resume tokens, client should buffer them. 1390 for i := 3; i < maxBuffers+2; i++ { 1391 ms.AddMsg(nil, false) 1392 } 1393 if err := sr.waitn(maxBuffers - 1); err != nil { 1394 t.Fatalf("failed to wait for row 3-%v: %v", maxBuffers+1, err) 1395 } 1396 // Received rows should be unchanged. 1397 if !testEqual(rows, want) { 1398 t.Errorf("receive rows: \n%v\n, want\n%v\n", rows, want) 1399 } 1400 // Send (maxBuffers+2)th row to trigger state change of resumableStreamDecoder: 1401 // queueingRetryable -> queueingUnretryable 1402 ms.AddMsg(nil, false) 1403 if err := sr.waitn(1); err != nil { 1404 t.Fatalf("failed to wait for row %v: %v", maxBuffers+2, err) 1405 } 1406 // Client should yield row 3rd - (maxBuffers+2)th to application. Therefore, 1407 // application should see row 0 - (maxBuffers+2)th so far. 1408 for i := 3; i < maxBuffers+3; i++ { 1409 want = append(want, &Row{ 1410 fields: kvMeta.RowType.Fields, 1411 vals: []*proto3.Value{ 1412 {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, 1413 {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, 1414 }, 1415 }) 1416 } 1417 if !testEqual(rows, want) { 1418 t.Errorf("received rows: \n%v\n; want\n%v\n", rows, want) 1419 } 1420 // Inject resumable error, but since resumableStreamDecoder is already at 1421 // queueingUnretryable state, query will just fail. 1422 ms.AddMsg( 1423 status.Errorf(codes.Unavailable, "mock server wants some sleep"), 1424 false, 1425 ) 1426 var gotErr error 1427 select { 1428 case gotErr = <-done: 1429 case <-time.After(10 * time.Second): 1430 t.Fatalf("timeout in waiting for failed query to return.") 1431 } 1432 if wantErr := spannerErrorf(codes.Unavailable, "mock server wants some sleep"); !testEqual(gotErr, wantErr) { 1433 t.Fatalf("stream() returns error: %v, but want error: %v", gotErr, wantErr) 1434 } 1435 1436 // Reconnect to mock Cloud Spanner. 1437 rows = []*Row{} 1438 go streaming() 1439 // Let server send two rows without resume token. 1440 for i := maxBuffers + 3; i < maxBuffers+5; i++ { 1441 ms.AddMsg(nil, false) 1442 } 1443 if err := sr.waitn(3); err != nil { 1444 t.Fatalf("failed to wait for row %v - %v: %v", maxBuffers+3, maxBuffers+5, err) 1445 } 1446 if len(rows) > 0 { 1447 t.Errorf("client received some rows unexpectedly: %v, want nothing", rows) 1448 } 1449 // Let server end the query. 1450 ms.AddMsg(io.EOF, false) 1451 select { 1452 case gotErr = <-done: 1453 case <-time.After(10 * time.Second): 1454 t.Fatalf("timeout in waiting for failed query to return") 1455 } 1456 if gotErr != nil { 1457 t.Fatalf("stream() returns unexpected error: %v, but want no error", gotErr) 1458 } 1459 // Verify if a normal server side EOF flushes all queued rows. 1460 want = []*Row{ 1461 { 1462 fields: kvMeta.RowType.Fields, 1463 vals: []*proto3.Value{ 1464 {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 3)}}, 1465 {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 3)}}, 1466 }, 1467 }, 1468 { 1469 fields: kvMeta.RowType.Fields, 1470 vals: []*proto3.Value{ 1471 {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 4)}}, 1472 {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 4)}}, 1473 }, 1474 }, 1475 } 1476 if !testEqual(rows, want) { 1477 t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want) 1478 } 1479} 1480 1481// Verify that streaming query get retried upon real gRPC server transport 1482// failures. 1483func TestGrpcReconnect(t *testing.T) { 1484 restore := setMaxBytesBetweenResumeTokens() 1485 defer restore() 1486 ms := NewMockCloudSpanner(t, trxTs) 1487 ms.Serve() 1488 defer ms.Stop() 1489 cc := dialMock(t, ms) 1490 defer cc.Close() 1491 mc := sppb.NewSpannerClient(cc) 1492 retry := make(chan int) 1493 row := make(chan int) 1494 var err error 1495 go func() { 1496 r := 0 1497 // Establish a stream to mock cloud spanner server. 1498 iter := stream(context.Background(), nil, 1499 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1500 if r > 0 { 1501 // This RPC attempt is a retry, signal it. 1502 retry <- r 1503 } 1504 r++ 1505 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1506 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1507 ResumeToken: resumeToken, 1508 }) 1509 1510 }, 1511 nil, 1512 func(error) {}) 1513 defer iter.Stop() 1514 for { 1515 _, err = iter.Next() 1516 if err == iterator.Done { 1517 err = nil 1518 break 1519 } 1520 if err != nil { 1521 break 1522 } 1523 row <- 0 1524 } 1525 }() 1526 // Add a message and wait for the receipt. 1527 ms.AddMsg(nil, true) 1528 select { 1529 case <-row: 1530 case <-time.After(10 * time.Second): 1531 t.Fatalf("expect stream to be established within 10 seconds, but it didn't") 1532 } 1533 // Error injection: force server to close all connections. 1534 ms.Stop() 1535 // Test to see if client respond to the real RPC failure correctly by 1536 // retrying RPC. 1537 select { 1538 case r, ok := <-retry: 1539 if ok && r == 1 { 1540 break 1541 } 1542 t.Errorf("retry count = %v, want 1", r) 1543 case <-time.After(10 * time.Second): 1544 t.Errorf("client library failed to respond after 10 seconds, aborting") 1545 return 1546 } 1547} 1548 1549// Test cancel/timeout for client operations. 1550func TestCancelTimeout(t *testing.T) { 1551 restore := setMaxBytesBetweenResumeTokens() 1552 defer restore() 1553 ms := NewMockCloudSpanner(t, trxTs) 1554 ms.Serve() 1555 defer ms.Stop() 1556 cc := dialMock(t, ms) 1557 defer cc.Close() 1558 mc := sppb.NewSpannerClient(cc) 1559 done := make(chan int) 1560 go func() { 1561 for { 1562 ms.AddMsg(nil, true) 1563 } 1564 }() 1565 // Test cancelling query. 1566 ctx, cancel := context.WithCancel(context.Background()) 1567 var err error 1568 go func() { 1569 // Establish a stream to mock cloud spanner server. 1570 iter := stream(ctx, nil, 1571 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1572 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1573 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1574 ResumeToken: resumeToken, 1575 }) 1576 }, 1577 nil, 1578 func(error) {}) 1579 defer iter.Stop() 1580 for { 1581 _, err = iter.Next() 1582 if err == iterator.Done { 1583 break 1584 } 1585 if err != nil { 1586 done <- 0 1587 break 1588 } 1589 } 1590 }() 1591 cancel() 1592 select { 1593 case <-done: 1594 if ErrCode(err) != codes.Canceled { 1595 t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled) 1596 } 1597 case <-time.After(1 * time.Second): 1598 t.Errorf("query doesn't exit timely after being cancelled") 1599 } 1600 // Test query timeout. 1601 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 1602 defer cancel() 1603 go func() { 1604 // Establish a stream to mock cloud spanner server. 1605 iter := stream(ctx, nil, 1606 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1607 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1608 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1609 ResumeToken: resumeToken, 1610 }) 1611 }, 1612 nil, 1613 func(error) {}) 1614 defer iter.Stop() 1615 for { 1616 _, err = iter.Next() 1617 if err == iterator.Done { 1618 err = nil 1619 break 1620 } 1621 if err != nil { 1622 break 1623 } 1624 } 1625 done <- 0 1626 }() 1627 select { 1628 case <-done: 1629 if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr { 1630 t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr) 1631 } 1632 case <-time.After(2 * time.Second): 1633 t.Errorf("query doesn't timeout as expected") 1634 } 1635} 1636 1637func TestRowIteratorDo(t *testing.T) { 1638 restore := setMaxBytesBetweenResumeTokens() 1639 defer restore() 1640 ms := NewMockCloudSpanner(t, trxTs) 1641 ms.Serve() 1642 defer ms.Stop() 1643 cc := dialMock(t, ms) 1644 defer cc.Close() 1645 mc := sppb.NewSpannerClient(cc) 1646 1647 for i := 0; i < 3; i++ { 1648 ms.AddMsg(nil, false) 1649 } 1650 ms.AddMsg(io.EOF, true) 1651 nRows := 0 1652 iter := stream(context.Background(), nil, 1653 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1654 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1655 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1656 ResumeToken: resumeToken, 1657 }) 1658 }, 1659 nil, 1660 func(error) {}) 1661 err := iter.Do(func(r *Row) error { nRows++; return nil }) 1662 if err != nil { 1663 t.Errorf("Using Do: %v", err) 1664 } 1665 if nRows != 3 { 1666 t.Errorf("got %d rows, want 3", nRows) 1667 } 1668} 1669 1670func TestRowIteratorDoWithError(t *testing.T) { 1671 restore := setMaxBytesBetweenResumeTokens() 1672 defer restore() 1673 ms := NewMockCloudSpanner(t, trxTs) 1674 ms.Serve() 1675 defer ms.Stop() 1676 cc := dialMock(t, ms) 1677 defer cc.Close() 1678 mc := sppb.NewSpannerClient(cc) 1679 1680 for i := 0; i < 3; i++ { 1681 ms.AddMsg(nil, false) 1682 } 1683 ms.AddMsg(io.EOF, true) 1684 iter := stream(context.Background(), nil, 1685 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1686 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1687 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1688 ResumeToken: resumeToken, 1689 }) 1690 }, 1691 nil, 1692 func(error) {}) 1693 injected := errors.New("Failed iterator") 1694 err := iter.Do(func(r *Row) error { return injected }) 1695 if err != injected { 1696 t.Errorf("got <%v>, want <%v>", err, injected) 1697 } 1698} 1699 1700func TestIteratorStopEarly(t *testing.T) { 1701 ctx := context.Background() 1702 restore := setMaxBytesBetweenResumeTokens() 1703 defer restore() 1704 ms := NewMockCloudSpanner(t, trxTs) 1705 ms.Serve() 1706 defer ms.Stop() 1707 cc := dialMock(t, ms) 1708 defer cc.Close() 1709 mc := sppb.NewSpannerClient(cc) 1710 1711 ms.AddMsg(nil, false) 1712 ms.AddMsg(nil, false) 1713 ms.AddMsg(io.EOF, true) 1714 1715 iter := stream(ctx, nil, 1716 func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { 1717 return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ 1718 Sql: "SELECT t.key key, t.value value FROM t_mock t", 1719 ResumeToken: resumeToken, 1720 }) 1721 }, 1722 nil, 1723 func(error) {}) 1724 _, err := iter.Next() 1725 if err != nil { 1726 t.Fatalf("before Stop: %v", err) 1727 } 1728 iter.Stop() 1729 // Stop sets r.err to the FailedPrecondition error "Next called after Stop". 1730 _, err = iter.Next() 1731 if g, w := ErrCode(err), codes.FailedPrecondition; g != w { 1732 t.Errorf("after Stop: got: %v, want: %v", g, w) 1733 } 1734} 1735 1736func TestIteratorWithError(t *testing.T) { 1737 injected := errors.New("Failed iterator") 1738 iter := RowIterator{err: injected} 1739 defer iter.Stop() 1740 if _, err := iter.Next(); err != injected { 1741 t.Fatalf("Expected error: %v, got %v", injected, err) 1742 } 1743} 1744 1745func dialMock(t *testing.T, ms *MockCloudSpanner) *grpc.ClientConn { 1746 cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock()) 1747 if err != nil { 1748 t.Fatalf("Dial(%q) = %v", ms.Addr(), err) 1749 } 1750 return cc 1751} 1752