1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"sync"
24	"sync/atomic"
25	"testing"
26	"time"
27
28	vkit "cloud.google.com/go/spanner/apiv1"
29	. "cloud.google.com/go/spanner/internal/testutil"
30	"github.com/golang/protobuf/proto"
31	proto3 "github.com/golang/protobuf/ptypes/struct"
32	structpb "github.com/golang/protobuf/ptypes/struct"
33	"github.com/googleapis/gax-go/v2"
34	"google.golang.org/api/iterator"
35	sppb "google.golang.org/genproto/googleapis/spanner/v1"
36	"google.golang.org/grpc/codes"
37	"google.golang.org/grpc/status"
38)
39
40var (
41	// Mocked transaction timestamp.
42	trxTs = time.Unix(1, 2)
43	// Metadata for mocked KV table, its rows are returned by SingleUse
44	// transactions.
45	kvMeta = func() *sppb.ResultSetMetadata {
46		meta := KvMeta
47		meta.Transaction = &sppb.Transaction{
48			ReadTimestamp: timestampProto(trxTs),
49		}
50		return &meta
51	}()
52	// Metadata for mocked ListKV table, which uses List for its key and value.
53	// Its rows are returned by snapshot readonly transactions, as indicated in
54	// the transaction metadata.
55	kvListMeta = &sppb.ResultSetMetadata{
56		RowType: &sppb.StructType{
57			Fields: []*sppb.StructType_Field{
58				{
59					Name: "Key",
60					Type: &sppb.Type{
61						Code: sppb.TypeCode_ARRAY,
62						ArrayElementType: &sppb.Type{
63							Code: sppb.TypeCode_STRING,
64						},
65					},
66				},
67				{
68					Name: "Value",
69					Type: &sppb.Type{
70						Code: sppb.TypeCode_ARRAY,
71						ArrayElementType: &sppb.Type{
72							Code: sppb.TypeCode_STRING,
73						},
74					},
75				},
76			},
77		},
78		Transaction: &sppb.Transaction{
79			Id:            transactionID{5, 6, 7, 8, 9},
80			ReadTimestamp: timestampProto(trxTs),
81		},
82	}
83	// Metadata for mocked schema of a query result set, which has two struct
84	// columns named "Col1" and "Col2", the struct's schema is like the
85	// following:
86	//
87	//	STRUCT {
88	//		INT
89	//		LIST<STRING>
90	//	}
91	//
92	// Its rows are returned in readwrite transaction, as indicated in the
93	// transaction metadata.
94	kvObjectMeta = &sppb.ResultSetMetadata{
95		RowType: &sppb.StructType{
96			Fields: []*sppb.StructType_Field{
97				{
98					Name: "Col1",
99					Type: &sppb.Type{
100						Code: sppb.TypeCode_STRUCT,
101						StructType: &sppb.StructType{
102							Fields: []*sppb.StructType_Field{
103								{
104									Name: "foo-f1",
105									Type: &sppb.Type{
106										Code: sppb.TypeCode_INT64,
107									},
108								},
109								{
110									Name: "foo-f2",
111									Type: &sppb.Type{
112										Code: sppb.TypeCode_ARRAY,
113										ArrayElementType: &sppb.Type{
114											Code: sppb.TypeCode_STRING,
115										},
116									},
117								},
118							},
119						},
120					},
121				},
122				{
123					Name: "Col2",
124					Type: &sppb.Type{
125						Code: sppb.TypeCode_STRUCT,
126						StructType: &sppb.StructType{
127							Fields: []*sppb.StructType_Field{
128								{
129									Name: "bar-f1",
130									Type: &sppb.Type{
131										Code: sppb.TypeCode_INT64,
132									},
133								},
134								{
135									Name: "bar-f2",
136									Type: &sppb.Type{
137										Code: sppb.TypeCode_ARRAY,
138										ArrayElementType: &sppb.Type{
139											Code: sppb.TypeCode_STRING,
140										},
141									},
142								},
143							},
144						},
145					},
146				},
147			},
148		},
149		Transaction: &sppb.Transaction{
150			Id: transactionID{1, 2, 3, 4, 5},
151		},
152	}
153)
154
155// String implements fmt.stringer.
156func (r *Row) String() string {
157	return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals)
158}
159
160func describeRows(l []*Row) string {
161	// generate a nice test failure description
162	var s = "["
163	for i, r := range l {
164		if i != 0 {
165			s += ",\n "
166		}
167		s += fmt.Sprint(r)
168	}
169	s += "]"
170	return s
171}
172
173// Helper for generating proto3 Value_ListValue instances, making test code
174// shorter and readable.
175func genProtoListValue(v ...string) *proto3.Value_ListValue {
176	r := &proto3.Value_ListValue{
177		ListValue: &proto3.ListValue{
178			Values: []*proto3.Value{},
179		},
180	}
181	for _, e := range v {
182		r.ListValue.Values = append(
183			r.ListValue.Values,
184			&proto3.Value{
185				Kind: &proto3.Value_StringValue{StringValue: e},
186			},
187		)
188	}
189	return r
190}
191
192// Test Row generation logics of partialResultSetDecoder.
193func TestPartialResultSetDecoder(t *testing.T) {
194	restore := setMaxBytesBetweenResumeTokens()
195	defer restore()
196	var tests = []struct {
197		input    []*sppb.PartialResultSet
198		wantF    []*Row
199		wantTxID transactionID
200		wantTs   time.Time
201		wantD    bool
202	}{
203		{
204			// Empty input.
205			wantD: true,
206		},
207		// String merging examples.
208		{
209			// Single KV result.
210			input: []*sppb.PartialResultSet{
211				{
212					Metadata: kvMeta,
213					Values: []*proto3.Value{
214						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
215						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
216					},
217				},
218			},
219			wantF: []*Row{
220				{
221					fields: kvMeta.RowType.Fields,
222					vals: []*proto3.Value{
223						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
224						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
225					},
226				},
227			},
228			wantTs: trxTs,
229			wantD:  true,
230		},
231		{
232			// Incomplete partial result.
233			input: []*sppb.PartialResultSet{
234				{
235					Metadata: kvMeta,
236					Values: []*proto3.Value{
237						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
238					},
239				},
240			},
241			wantTs: trxTs,
242			wantD:  false,
243		},
244		{
245			// Complete splitted result.
246			input: []*sppb.PartialResultSet{
247				{
248					Metadata: kvMeta,
249					Values: []*proto3.Value{
250						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
251					},
252				},
253				{
254					Values: []*proto3.Value{
255						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
256					},
257				},
258			},
259			wantF: []*Row{
260				{
261					fields: kvMeta.RowType.Fields,
262					vals: []*proto3.Value{
263						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
264						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
265					},
266				},
267			},
268			wantTs: trxTs,
269			wantD:  true,
270		},
271		{
272			// Multi-row example with splitted row in the middle.
273			input: []*sppb.PartialResultSet{
274				{
275					Metadata: kvMeta,
276					Values: []*proto3.Value{
277						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
278						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
279						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
280					},
281				},
282				{
283					Values: []*proto3.Value{
284						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
285						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
286						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
287					},
288				},
289			},
290			wantF: []*Row{
291				{
292					fields: kvMeta.RowType.Fields,
293					vals: []*proto3.Value{
294						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
295						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
296					},
297				},
298				{
299					fields: kvMeta.RowType.Fields,
300					vals: []*proto3.Value{
301						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
302						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
303					},
304				},
305				{
306					fields: kvMeta.RowType.Fields,
307					vals: []*proto3.Value{
308						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
309						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
310					},
311				},
312			},
313			wantTs: trxTs,
314			wantD:  true,
315		},
316		{
317			// Merging example in result_set.proto.
318			input: []*sppb.PartialResultSet{
319				{
320					Metadata: kvMeta,
321					Values: []*proto3.Value{
322						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
323						{Kind: &proto3.Value_StringValue{StringValue: "W"}},
324					},
325					ChunkedValue: true,
326				},
327				{
328					Values: []*proto3.Value{
329						{Kind: &proto3.Value_StringValue{StringValue: "orl"}},
330					},
331					ChunkedValue: true,
332				},
333				{
334					Values: []*proto3.Value{
335						{Kind: &proto3.Value_StringValue{StringValue: "d"}},
336					},
337				},
338			},
339			wantF: []*Row{
340				{
341					fields: kvMeta.RowType.Fields,
342					vals: []*proto3.Value{
343						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
344						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
345					},
346				},
347			},
348			wantTs: trxTs,
349			wantD:  true,
350		},
351		{
352			// More complex example showing completing a merge and
353			// starting a new merge in the same partialResultSet.
354			input: []*sppb.PartialResultSet{
355				{
356					Metadata: kvMeta,
357					Values: []*proto3.Value{
358						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
359						{Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
360					},
361					ChunkedValue: true,
362				},
363				{
364					Values: []*proto3.Value{
365						{Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
366						{Kind: &proto3.Value_StringValue{StringValue: "i"}},    // start split in key
367					},
368					ChunkedValue: true,
369				},
370				{
371					Values: []*proto3.Value{
372						{Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
373						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
374						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
375						{Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
376					},
377					ChunkedValue: true,
378				},
379				{
380					Values: []*proto3.Value{
381						{Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
382					},
383				},
384			},
385			wantF: []*Row{
386				{
387					fields: kvMeta.RowType.Fields,
388					vals: []*proto3.Value{
389						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
390						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
391					},
392				},
393				{
394					fields: kvMeta.RowType.Fields,
395					vals: []*proto3.Value{
396						{Kind: &proto3.Value_StringValue{StringValue: "is"}},
397						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
398					},
399				},
400				{
401					fields: kvMeta.RowType.Fields,
402					vals: []*proto3.Value{
403						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
404						{Kind: &proto3.Value_StringValue{StringValue: "question"}},
405					},
406				},
407			},
408			wantTs: trxTs,
409			wantD:  true,
410		},
411		// List merging examples.
412		{
413			// Non-splitting Lists.
414			input: []*sppb.PartialResultSet{
415				{
416					Metadata: kvListMeta,
417					Values: []*proto3.Value{
418						{
419							Kind: genProtoListValue("foo-1", "foo-2"),
420						},
421					},
422				},
423				{
424					Values: []*proto3.Value{
425						{
426							Kind: genProtoListValue("bar-1", "bar-2"),
427						},
428					},
429				},
430			},
431			wantF: []*Row{
432				{
433					fields: kvListMeta.RowType.Fields,
434					vals: []*proto3.Value{
435						{
436							Kind: genProtoListValue("foo-1", "foo-2"),
437						},
438						{
439							Kind: genProtoListValue("bar-1", "bar-2"),
440						},
441					},
442				},
443			},
444			wantTxID: transactionID{5, 6, 7, 8, 9},
445			wantTs:   trxTs,
446			wantD:    true,
447		},
448		{
449			// Simple List merge case: splitted string element.
450			input: []*sppb.PartialResultSet{
451				{
452					Metadata: kvListMeta,
453					Values: []*proto3.Value{
454						{
455							Kind: genProtoListValue("foo-1", "foo-"),
456						},
457					},
458					ChunkedValue: true,
459				},
460				{
461					Values: []*proto3.Value{
462						{
463							Kind: genProtoListValue("2"),
464						},
465					},
466				},
467				{
468					Values: []*proto3.Value{
469						{
470							Kind: genProtoListValue("bar-1", "bar-2"),
471						},
472					},
473				},
474			},
475			wantF: []*Row{
476				{
477					fields: kvListMeta.RowType.Fields,
478					vals: []*proto3.Value{
479						{
480							Kind: genProtoListValue("foo-1", "foo-2"),
481						},
482						{
483							Kind: genProtoListValue("bar-1", "bar-2"),
484						},
485					},
486				},
487			},
488			wantTxID: transactionID{5, 6, 7, 8, 9},
489			wantTs:   trxTs,
490			wantD:    true,
491		},
492		{
493			// Struct merging is also implemented by List merging. Note that
494			// Cloud Spanner uses proto.ListValue to encode Structs as well.
495			input: []*sppb.PartialResultSet{
496				{
497					Metadata: kvObjectMeta,
498					Values: []*proto3.Value{
499						{
500							Kind: &proto3.Value_ListValue{
501								ListValue: &proto3.ListValue{
502									Values: []*proto3.Value{
503										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
504										{Kind: genProtoListValue("foo-1", "fo")},
505									},
506								},
507							},
508						},
509					},
510					ChunkedValue: true,
511				},
512				{
513					Values: []*proto3.Value{
514						{
515							Kind: &proto3.Value_ListValue{
516								ListValue: &proto3.ListValue{
517									Values: []*proto3.Value{
518										{Kind: genProtoListValue("o-2", "f")},
519									},
520								},
521							},
522						},
523					},
524					ChunkedValue: true,
525				},
526				{
527					Values: []*proto3.Value{
528						{
529							Kind: &proto3.Value_ListValue{
530								ListValue: &proto3.ListValue{
531									Values: []*proto3.Value{
532										{Kind: genProtoListValue("oo-3")},
533									},
534								},
535							},
536						},
537						{
538							Kind: &proto3.Value_ListValue{
539								ListValue: &proto3.ListValue{
540									Values: []*proto3.Value{
541										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
542										{Kind: genProtoListValue("bar-1")},
543									},
544								},
545							},
546						},
547					},
548				},
549			},
550			wantF: []*Row{
551				{
552					fields: kvObjectMeta.RowType.Fields,
553					vals: []*proto3.Value{
554						{
555							Kind: &proto3.Value_ListValue{
556								ListValue: &proto3.ListValue{
557									Values: []*proto3.Value{
558										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
559										{Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
560									},
561								},
562							},
563						},
564						{
565							Kind: &proto3.Value_ListValue{
566								ListValue: &proto3.ListValue{
567									Values: []*proto3.Value{
568										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
569										{Kind: genProtoListValue("bar-1")},
570									},
571								},
572							},
573						},
574					},
575				},
576			},
577			wantTxID: transactionID{1, 2, 3, 4, 5},
578			wantD:    true,
579		},
580	}
581
582nextTest:
583	for i, test := range tests {
584		var rows []*Row
585		p := &partialResultSetDecoder{}
586		for j, v := range test.input {
587			rs, _, err := p.add(v)
588			if err != nil {
589				t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
590				continue nextTest
591			}
592			rows = append(rows, rs...)
593		}
594		if !testEqual(p.ts, test.wantTs) {
595			t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
596		}
597		if !testEqual(rows, test.wantF) {
598			t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
599		}
600		if got := p.done(); got != test.wantD {
601			t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
602		}
603	}
604}
605
606const (
607	// max number of PartialResultSets that will be buffered in tests.
608	maxBuffers = 16
609)
610
611// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to
612// a smaller value more suitable for tests. It returns a function which should
613// be called to restore the maxBytesBetweenResumeTokens to its old value.
614func setMaxBytesBetweenResumeTokens() func() {
615	o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
616	atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
617		Metadata: kvMeta,
618		Values: []*proto3.Value{
619			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
620			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
621		},
622	})))
623	return func() {
624		atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
625	}
626}
627
628// keyStr generates key string for kvMeta schema.
629func keyStr(i int) string {
630	return fmt.Sprintf("foo-%02d", i)
631}
632
633// valStr generates value string for kvMeta schema.
634func valStr(i int) string {
635	return fmt.Sprintf("bar-%02d", i)
636}
637
638// Test state transitions of resumableStreamDecoder where state machine ends up
639// to a non-blocking state(resumableStreamDecoder.Next returns on non-blocking
640// state).
641func TestRsdNonblockingStates(t *testing.T) {
642	restore := setMaxBytesBetweenResumeTokens()
643	defer restore()
644	tests := []struct {
645		name         string
646		resumeTokens [][]byte
647		prsErrors    []PartialResultSetExecutionTime
648		rpc          func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
649		sql          string
650		// Expected values
651		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
652		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
653		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
654		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
655		wantErr      error
656	}{
657		{
658			// unConnected->queueingRetryable->finished
659			name:         "unConnected->queueingRetryable->finished",
660			resumeTokens: make([][]byte, 2),
661			sql:          "SELECT t.key key, t.value value FROM t_mock t",
662			want: []*sppb.PartialResultSet{
663				{
664					Metadata: kvMeta,
665					Values: []*proto3.Value{
666						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
667						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
668					},
669				},
670			},
671			queue: []*sppb.PartialResultSet{
672				{
673					Metadata: kvMeta,
674					Values: []*proto3.Value{
675						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
676						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
677					},
678				},
679			},
680			stateHistory: []resumableStreamDecoderState{
681				queueingRetryable, // do RPC
682				queueingRetryable, // got foo-00
683				queueingRetryable, // got foo-01
684				finished,          // got EOF
685			},
686		},
687		{
688			// unConnected->queueingRetryable->aborted
689			name:         "unConnected->queueingRetryable->aborted",
690			resumeTokens: [][]byte{{}, EncodeResumeToken(1), {}, EncodeResumeToken(2)},
691			prsErrors: []PartialResultSetExecutionTime{{
692				ResumeToken: EncodeResumeToken(2),
693				Err:         status.Error(codes.Unknown, "I quit"),
694			}},
695			sql: "SELECT t.key key, t.value value FROM t_mock t",
696			want: []*sppb.PartialResultSet{
697				{
698					Metadata: kvMeta,
699					Values: []*proto3.Value{
700						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
701						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
702					},
703				},
704				{
705					Metadata: kvMeta,
706					Values: []*proto3.Value{
707						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
708						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
709					},
710					ResumeToken: EncodeResumeToken(1),
711				},
712			},
713			stateHistory: []resumableStreamDecoderState{
714				queueingRetryable, // do RPC
715				queueingRetryable, // got foo-00
716				queueingRetryable, // got foo-01
717				queueingRetryable, // foo-01, resume token
718				queueingRetryable, // got foo-02
719				aborted,           // got error
720			},
721			wantErr: status.Errorf(codes.Unknown, "I quit"),
722		},
723		{
724			// unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
725			name:         "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
726			resumeTokens: make([][]byte, maxBuffers+1),
727			sql:          "SELECT t.key key, t.value value FROM t_mock t",
728			want: func() (s []*sppb.PartialResultSet) {
729				for i := 0; i < maxBuffers+1; i++ {
730					s = append(s, &sppb.PartialResultSet{
731						Metadata: kvMeta,
732						Values: []*proto3.Value{
733							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
734							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
735						},
736					})
737				}
738				return s
739			}(),
740			stateHistory: func() (s []resumableStreamDecoderState) {
741				s = append(s, queueingRetryable) // RPC
742				for i := 0; i < maxBuffers; i++ {
743					s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
744				}
745				// the first item fills up the queue and triggers state transition;
746				// the second item is received under queueingUnretryable state.
747				s = append(s, queueingUnretryable)
748				s = append(s, queueingUnretryable)
749				return s
750			}(),
751		},
752		{
753			// unConnected->queueingRetryable->queueingUnretryable->aborted
754			name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
755			resumeTokens: func() (rts [][]byte) {
756				rts = make([][]byte, maxBuffers+1)
757				rts[maxBuffers] = EncodeResumeToken(1)
758				return rts
759			}(),
760			prsErrors: []PartialResultSetExecutionTime{{
761				ResumeToken: EncodeResumeToken(1),
762				Err:         status.Error(codes.Unknown, "Just Abort It"),
763			}},
764			sql: "SELECT t.key key, t.value value FROM t_mock t",
765			want: func() (s []*sppb.PartialResultSet) {
766				for i := 0; i < maxBuffers; i++ {
767					s = append(s, &sppb.PartialResultSet{
768						Metadata: kvMeta,
769						Values: []*proto3.Value{
770							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
771							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
772						},
773					})
774				}
775				return s
776			}(),
777			stateHistory: func() (s []resumableStreamDecoderState) {
778				s = append(s, queueingRetryable) // RPC
779				for i := 0; i < maxBuffers; i++ {
780					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
781				}
782				s = append(s, queueingUnretryable) // the last row triggers state change
783				s = append(s, aborted)             // Error happens
784				return s
785			}(),
786			wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
787		},
788	}
789	for _, test := range tests {
790		t.Run(test.name, func(t *testing.T) {
791			server, c, teardown := setupMockedTestServer(t)
792			defer teardown()
793			mc, err := c.sc.nextClient()
794			if err != nil {
795				t.Fatalf("failed to create a grpc client")
796			}
797
798			session, err := createSession(mc)
799			if err != nil {
800				t.Fatalf("failed to create a session")
801			}
802
803			if test.rpc == nil {
804				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
805					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
806						Session:     session.Name,
807						Sql:         test.sql,
808						ResumeToken: resumeToken,
809					})
810				}
811			}
812			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
813			defer cancel()
814			r := newResumableStreamDecoder(
815				ctx,
816				nil,
817				test.rpc,
818				nil,
819			)
820			st := []resumableStreamDecoderState{}
821			var lastErr error
822			// Once the expected number of state transitions are observed,
823			// send a signal by setting stateDone = true.
824			stateDone := false
825			// Set stateWitness to listen to state changes.
826			hl := len(test.stateHistory) // To avoid data race on test.
827			r.stateWitness = func(rs resumableStreamDecoderState) {
828				if !stateDone {
829					// Record state transitions.
830					st = append(st, rs)
831					if len(st) == hl {
832						lastErr = r.lastErr()
833						stateDone = true
834					}
835				}
836			}
837			// Let mock server stream given messages to resumableStreamDecoder.
838			err = setupStatementResult(t, server, test.sql, len(test.resumeTokens), test.resumeTokens)
839			if err != nil {
840				t.Fatalf("failed to set up a result for a statement: %v", err)
841			}
842
843			for _, et := range test.prsErrors {
844				server.TestSpanner.AddPartialResultSetError(
845					test.sql,
846					et,
847				)
848			}
849			var rs []*sppb.PartialResultSet
850			for {
851				select {
852				case <-ctx.Done():
853					t.Fatal("context cancelled or timeout during test")
854				default:
855				}
856				if stateDone {
857					// Check if resumableStreamDecoder carried out expected
858					// state transitions.
859					if !testEqual(st, test.stateHistory) {
860						t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
861					}
862					// Check if resumableStreamDecoder returns expected array of
863					// PartialResultSets.
864					if !testEqual(rs, test.want) {
865						t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
866					}
867					// Verify that resumableStreamDecoder's internal buffering is
868					// also correct.
869					var q []*sppb.PartialResultSet
870					for {
871						item := r.q.pop()
872						if item == nil {
873							break
874						}
875						q = append(q, item)
876					}
877					if !testEqual(q, test.queue) {
878						t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
879					}
880					// Verify resume token.
881					if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
882						t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
883					}
884					// Verify error message.
885					if !testEqual(lastErr, test.wantErr) {
886						t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
887					}
888					return
889				}
890				// Receive next decoded item.
891				if r.next() {
892					rs = append(rs, r.get())
893				}
894			}
895		})
896	}
897}
898
899// Test state transitions of resumableStreamDecoder where state machine
900// ends up to a blocking state(resumableStreamDecoder.Next blocks
901// on blocking state).
902func TestRsdBlockingStates(t *testing.T) {
903	restore := setMaxBytesBetweenResumeTokens()
904	defer restore()
905	for _, test := range []struct {
906		name         string
907		resumeTokens [][]byte
908		rpc          func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
909		sql          string
910		// Expected values
911		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
912		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
913		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
914		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
915		wantErr      error
916	}{
917		{
918			// unConnected -> unConnected
919			name: "unConnected -> unConnected",
920			rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
921				return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
922			},
923			sql:          "SELECT * from t_whatever",
924			stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
925			wantErr:      status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
926		},
927		{
928			// unConnected -> queueingRetryable
929			name:         "unConnected -> queueingRetryable",
930			sql:          "SELECT t.key key, t.value value FROM t_mock t",
931			stateHistory: []resumableStreamDecoderState{queueingRetryable},
932		},
933		{
934			// unConnected->queueingRetryable->queueingRetryable
935			name:         "unConnected->queueingRetryable->queueingRetryable",
936			resumeTokens: [][]byte{{}, EncodeResumeToken(1), EncodeResumeToken(2), {}},
937			sql:          "SELECT t.key key, t.value value FROM t_mock t",
938			want: []*sppb.PartialResultSet{
939				{
940					Metadata: kvMeta,
941					Values: []*proto3.Value{
942						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
943						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
944					},
945				},
946				{
947					Metadata: kvMeta,
948					Values: []*proto3.Value{
949						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
950						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
951					},
952					ResumeToken: EncodeResumeToken(1),
953				},
954				{
955					Metadata: kvMeta,
956					Values: []*proto3.Value{
957						{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
958						{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
959					},
960					ResumeToken: EncodeResumeToken(2),
961				},
962				// The server sends an io.EOF at last and the decoder will
963				// flush out all messages in the internal queue.
964				{
965					Metadata: kvMeta,
966					Values: []*proto3.Value{
967						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
968						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
969					},
970				},
971			},
972			queue: []*sppb.PartialResultSet{
973				{
974					Metadata: kvMeta,
975					Values: []*proto3.Value{
976						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
977						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
978					},
979				},
980			},
981			resumeToken: EncodeResumeToken(2),
982			stateHistory: []resumableStreamDecoderState{
983				queueingRetryable, // do RPC
984				queueingRetryable, // got foo-00
985				queueingRetryable, // got foo-01
986				queueingRetryable, // foo-01, resume token
987				queueingRetryable, // got foo-02
988				queueingRetryable, // foo-02, resume token
989				queueingRetryable, // got foo-03
990			},
991		},
992		{
993			// unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
994			name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
995			resumeTokens: func() (rts [][]byte) {
996				rts = make([][]byte, maxBuffers+3)
997				rts[maxBuffers+1] = EncodeResumeToken(maxBuffers + 1)
998				return rts
999			}(),
1000			sql: "SELECT t.key key, t.value value FROM t_mock t",
1001			want: func() (s []*sppb.PartialResultSet) {
1002				// The server sends an io.EOF at last and the decoder will
1003				// flush out all messages in the internal queue. Although the
1004				// last message is supposed to be queued and the decoder waits
1005				// for the next resume token, an io.EOF leads to a `finished`
1006				// state that the last message will be removed from the queue
1007				// and be read by the client side.
1008				for i := 0; i < maxBuffers+3; i++ {
1009					s = append(s, &sppb.PartialResultSet{
1010						Metadata: kvMeta,
1011						Values: []*proto3.Value{
1012							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1013							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1014						},
1015					})
1016				}
1017				s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1)
1018				return s
1019			}(),
1020			resumeToken: EncodeResumeToken(maxBuffers + 1),
1021			queue: []*sppb.PartialResultSet{
1022				{
1023					Metadata: kvMeta,
1024					Values: []*proto3.Value{
1025						{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
1026						{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
1027					},
1028				},
1029			},
1030			stateHistory: func() (s []resumableStreamDecoderState) {
1031				s = append(s, queueingRetryable) // RPC
1032				for i := 0; i < maxBuffers; i++ {
1033					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
1034				}
1035				for i := maxBuffers - 1; i < maxBuffers+1; i++ {
1036					// the first item fills up the queue and triggers state
1037					// change; the second item is received under
1038					// queueingUnretryable state.
1039					s = append(s, queueingUnretryable)
1040				}
1041				s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
1042				s = append(s, queueingRetryable)   // (maxBuffers+1)th row has resume token
1043				s = append(s, queueingRetryable)   // (maxBuffers+2)th row has no resume token
1044				return s
1045			}(),
1046		},
1047		{
1048			// unConnected->queueingRetryable->queueingUnretryable->finished
1049			name:         "unConnected->queueingRetryable->queueingUnretryable->finished",
1050			resumeTokens: make([][]byte, maxBuffers),
1051			sql:          "SELECT t.key key, t.value value FROM t_mock t",
1052			want: func() (s []*sppb.PartialResultSet) {
1053				for i := 0; i < maxBuffers; i++ {
1054					s = append(s, &sppb.PartialResultSet{
1055						Metadata: kvMeta,
1056						Values: []*proto3.Value{
1057							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1058							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1059						},
1060					})
1061				}
1062				return s
1063			}(),
1064			stateHistory: func() (s []resumableStreamDecoderState) {
1065				s = append(s, queueingRetryable) // RPC
1066				for i := 0; i < maxBuffers; i++ {
1067					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
1068				}
1069				s = append(s, queueingUnretryable) // last row triggers state change
1070				s = append(s, finished)            // query finishes
1071				return s
1072			}(),
1073		},
1074	} {
1075		t.Run(test.name, func(t *testing.T) {
1076			server, c, teardown := setupMockedTestServer(t)
1077			defer teardown()
1078			mc, err := c.sc.nextClient()
1079			if err != nil {
1080				t.Fatalf("failed to create a grpc client")
1081			}
1082
1083			session, err := createSession(mc)
1084			if err != nil {
1085				t.Fatalf("failed to create a session")
1086			}
1087
1088			if test.rpc == nil {
1089				// Avoid using test.sql directly in closure because for loop changes
1090				// test.
1091				sql := test.sql
1092				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1093					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1094						Session:     session.Name,
1095						Sql:         sql,
1096						ResumeToken: resumeToken,
1097					})
1098				}
1099			}
1100			ctx, cancel := context.WithCancel(context.Background())
1101			defer cancel()
1102			r := newResumableStreamDecoder(
1103				ctx,
1104				nil,
1105				test.rpc,
1106				nil,
1107			)
1108			// Override backoff to make the test run faster.
1109			r.backoff = gax.Backoff{
1110				Initial:    1 * time.Nanosecond,
1111				Max:        1 * time.Nanosecond,
1112				Multiplier: 1.3,
1113			}
1114			// st is the set of observed state transitions.
1115			st := []resumableStreamDecoderState{}
1116			// q is the content of the decoder's partial result queue when expected
1117			// number of state transitions are done.
1118			q := []*sppb.PartialResultSet{}
1119			var lastErr error
1120			// Once the expected number of state transitions are observed, send a
1121			// signal to channel stateDone.
1122			stateDone := make(chan int)
1123			// Set stateWitness to listen to state changes.
1124			hl := len(test.stateHistory) // To avoid data race on test.
1125			r.stateWitness = func(rs resumableStreamDecoderState) {
1126				select {
1127				case <-stateDone:
1128					// Noop after expected number of state transitions
1129				default:
1130					// Record state transitions.
1131					st = append(st, rs)
1132					if len(st) == hl {
1133						lastErr = r.lastErr()
1134						q = r.q.dump()
1135						close(stateDone)
1136					}
1137				}
1138			}
1139			// Let mock server stream given messages to resumableStreamDecoder.
1140			err = setupStatementResult(t, server, test.sql, len(test.resumeTokens), test.resumeTokens)
1141			if err != nil {
1142				t.Fatalf("failed to set up a result for a statement: %v", err)
1143			}
1144			var mutex = &sync.Mutex{}
1145			var rs []*sppb.PartialResultSet
1146			rowsFetched := make(chan int)
1147			go func() {
1148				for {
1149					if !r.next() {
1150						// Note that r.Next also exits on context cancel/timeout.
1151						close(rowsFetched)
1152						return
1153					}
1154					mutex.Lock()
1155					rs = append(rs, r.get())
1156					mutex.Unlock()
1157				}
1158			}()
1159			// Wait until all rows have been fetched.
1160			if len(test.want) > 0 {
1161				select {
1162				case <-rowsFetched:
1163				case <-time.After(1 * time.Second):
1164					t.Fatal("Timeout in waiting for rows to be fetched")
1165				}
1166			}
1167			// Verify that resumableStreamDecoder reaches expected state.
1168			select {
1169			case <-stateDone: // Note that at this point, receiver is still blocking on r.next().
1170				// Check if resumableStreamDecoder carried out expected state
1171				// transitions.
1172				if !testEqual(st, test.stateHistory) {
1173					t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
1174				}
1175				// Check if resumableStreamDecoder returns expected array of
1176				// PartialResultSets.
1177				mutex.Lock()
1178				defer mutex.Unlock()
1179				if !testEqual(rs, test.want) {
1180					t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
1181				}
1182				// Verify that resumableStreamDecoder's internal buffering is also
1183				// correct.
1184				if !testEqual(q, test.queue) {
1185					t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
1186				}
1187				// Verify resume token.
1188				if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
1189					t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
1190				}
1191				// Verify error message.
1192				if !testEqual(lastErr, test.wantErr) {
1193					t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
1194				}
1195			case <-time.After(1 * time.Second):
1196				t.Fatal("Timeout in waiting for state change")
1197			}
1198		})
1199	}
1200}
1201
1202// sReceiver signals every receiving attempt through a channel, used by
1203// TestResumeToken to determine if the receiving of a certain PartialResultSet
1204// will be attempted next.
1205type sReceiver struct {
1206	c           chan int
1207	rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
1208}
1209
1210// Recv() implements streamingReceiver.Recv for sReceiver.
1211func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
1212	sr.c <- 1
1213	return sr.rpcReceiver.Recv()
1214}
1215
1216// waitn waits for nth receiving attempt from now on, until the signal for nth
1217// Recv() attempts is received or timeout. Note that because the way stream()
1218// works, the signal for the nth Recv() means that the previous n - 1
1219// PartialResultSets has already been returned to caller or queued, if no error
1220// happened.
1221func (sr *sReceiver) waitn(n int) error {
1222	for i := 0; i < n; i++ {
1223		select {
1224		case <-sr.c:
1225		case <-time.After(10 * time.Second):
1226			return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
1227		}
1228	}
1229	return nil
1230}
1231
1232// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
1233func TestQueueBytes(t *testing.T) {
1234	restore := setMaxBytesBetweenResumeTokens()
1235	defer restore()
1236
1237	server, c, teardown := setupMockedTestServer(t)
1238	defer teardown()
1239	mc, err := c.sc.nextClient()
1240	if err != nil {
1241		t.Fatalf("failed to create a grpc client")
1242	}
1243
1244	rt1 := EncodeResumeToken(1)
1245	rt2 := EncodeResumeToken(2)
1246	rt3 := EncodeResumeToken(3)
1247	resumeTokens := [][]byte{rt1, rt1, rt1, rt2, rt2, rt3}
1248	err = setupStatementResult(t, server, "SELECT t.key key, t.value value FROM t_mock t", len(resumeTokens), resumeTokens)
1249	if err != nil {
1250		t.Fatalf("failed to set up a result for a statement: %v", err)
1251	}
1252
1253	session, err := createSession(mc)
1254	if err != nil {
1255		t.Fatalf("failed to create a session")
1256	}
1257
1258	sr := &sReceiver{
1259		c: make(chan int, 1000), // will never block in this test
1260	}
1261	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1262	defer cancel()
1263	decoder := newResumableStreamDecoder(
1264		ctx,
1265		nil,
1266		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1267			r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1268				Session:     session.Name,
1269				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1270				ResumeToken: resumeToken,
1271			})
1272			sr.rpcReceiver = r
1273			return sr, err
1274		},
1275		nil,
1276	)
1277
1278	sizeOfPRS := proto.Size(&sppb.PartialResultSet{
1279		Metadata: kvMeta,
1280		Values: []*proto3.Value{
1281			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1282			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1283		},
1284		ResumeToken: rt1,
1285	})
1286
1287	decoder.next()
1288	decoder.next()
1289	decoder.next()
1290	if got, want := decoder.bytesBetweenResumeTokens, int32(2*sizeOfPRS); got != want {
1291		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want)
1292	}
1293
1294	decoder.next()
1295	if decoder.bytesBetweenResumeTokens != 0 {
1296		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens)
1297	}
1298
1299	decoder.next()
1300	if got, want := decoder.bytesBetweenResumeTokens, int32(sizeOfPRS); got != want {
1301		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want)
1302	}
1303
1304	decoder.next()
1305	if decoder.bytesBetweenResumeTokens != 0 {
1306		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens)
1307	}
1308}
1309
1310// Verify that client can deal with resume token correctly
1311func TestResumeToken(t *testing.T) {
1312	restore := setMaxBytesBetweenResumeTokens()
1313	defer restore()
1314	query := "SELECT t.key key, t.value value FROM t_mock t"
1315	server, c, teardown := setupMockedTestServer(t)
1316	defer teardown()
1317	mc, err := c.sc.nextClient()
1318	if err != nil {
1319		t.Fatalf("failed to create a grpc client")
1320	}
1321
1322	rt1 := EncodeResumeToken(1)
1323	rt2 := EncodeResumeToken(2)
1324	resumeTokens := make([][]byte, 3+maxBuffers)
1325	resumeTokens[1] = rt1
1326	resumeTokens[3+maxBuffers-1] = rt2
1327	err = setupStatementResult(t, server, query, len(resumeTokens), resumeTokens)
1328	if err != nil {
1329		t.Fatalf("failed to set up a result for a statement: %v", err)
1330	}
1331
1332	// The first error will be retried.
1333	server.TestSpanner.AddPartialResultSetError(
1334		query,
1335		PartialResultSetExecutionTime{
1336			ResumeToken: rt1,
1337			Err:         status.Error(codes.Unavailable, "mock server unavailable"),
1338		},
1339	)
1340	// The second error will not be retried because maxBytesBetweenResumeTokens
1341	// is reached and the state of resumableStreamDecoder:
1342	// queueingRetryable -> queueingUnretryable. The query will just fail.
1343	server.TestSpanner.AddPartialResultSetError(
1344		query,
1345		PartialResultSetExecutionTime{
1346			ResumeToken: rt2,
1347			Err:         status.Error(codes.Unavailable, "mock server wants some sleep"),
1348		},
1349	)
1350
1351	session, err := createSession(mc)
1352	if err != nil {
1353		t.Fatalf("failed to create a session")
1354	}
1355
1356	sr := &sReceiver{
1357		c: make(chan int, 1000), // will never block in this test
1358	}
1359	rows := []*Row{}
1360
1361	streaming := func() *RowIterator {
1362		return stream(context.Background(), nil,
1363			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1364				r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1365					Session:     session.Name,
1366					Sql:         query,
1367					ResumeToken: resumeToken,
1368				})
1369				sr.rpcReceiver = r
1370				return sr, err
1371			},
1372			nil,
1373			func(error) {})
1374	}
1375
1376	// Establish a stream to mock cloud spanner server.
1377	iter := streaming()
1378	defer iter.Stop()
1379	var row *Row
1380
1381	// Read first two rows.
1382	for i := 0; i < 3; i++ {
1383		row, err = iter.Next()
1384		if err != nil {
1385			t.Fatalf("failed to get next value: %v", err)
1386		}
1387		rows = append(rows, row)
1388	}
1389
1390	want := []*Row{
1391		{
1392			fields: kvMeta.RowType.Fields,
1393			vals: []*proto3.Value{
1394				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1395				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1396			},
1397		},
1398		{
1399			fields: kvMeta.RowType.Fields,
1400			vals: []*proto3.Value{
1401				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1402				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1403			},
1404		},
1405		{
1406			fields: kvMeta.RowType.Fields,
1407			vals: []*proto3.Value{
1408				{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
1409				{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
1410			},
1411		},
1412	}
1413	if !testEqual(rows, want) {
1414		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1415	}
1416
1417	// Trigger state change of resumableStreamDecoder:
1418	// queueingRetryable -> queueingUnretryable
1419	for i := 0; i < maxBuffers-1; i++ {
1420		row, err = iter.Next()
1421		if err != nil {
1422			t.Fatalf("failed to get next value: %v", err)
1423		}
1424		rows = append(rows, row)
1425	}
1426
1427	// Since resumableStreamDecoder is already at queueingUnretryable state,
1428	// query will just fail.
1429	_, err = iter.Next()
1430	if wantErr := spannerErrorf(codes.Unavailable, "mock server wants some sleep"); !testEqual(err, wantErr) {
1431		t.Fatalf("stream() returns error: %v, but want error: %v", err, wantErr)
1432	}
1433
1434	// Let server send two rows without resume token.
1435	resumeTokens = make([][]byte, 2)
1436	err = setupStatementResult(t, server, query, len(resumeTokens), resumeTokens)
1437	if err != nil {
1438		t.Fatalf("failed to set up a result for a statement: %v", err)
1439	}
1440
1441	// Reconnect to mock Cloud Spanner.
1442	rows = []*Row{}
1443	iter = streaming()
1444	defer iter.Stop()
1445
1446	for i := 0; i < 2; i++ {
1447		row, err = iter.Next()
1448		if err != nil {
1449			t.Fatalf("failed to get next value: %v", err)
1450		}
1451		rows = append(rows, row)
1452	}
1453
1454	// Verify if a normal server side EOF flushes all queued rows.
1455	want = []*Row{
1456		{
1457			fields: kvMeta.RowType.Fields,
1458			vals: []*proto3.Value{
1459				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1460				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1461			},
1462		},
1463		{
1464			fields: kvMeta.RowType.Fields,
1465			vals: []*proto3.Value{
1466				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1467				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1468			},
1469		},
1470	}
1471	if !testEqual(rows, want) {
1472		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1473	}
1474}
1475
1476// Verify that streaming query get retried upon real gRPC server transport
1477// failures.
1478func TestGrpcReconnect(t *testing.T) {
1479	restore := setMaxBytesBetweenResumeTokens()
1480	defer restore()
1481
1482	server, c, teardown := setupMockedTestServer(t)
1483	defer teardown()
1484	mc, err := c.sc.nextClient()
1485	if err != nil {
1486		t.Fatalf("failed to create a grpc client")
1487	}
1488
1489	session, err := createSession(mc)
1490	if err != nil {
1491		t.Fatalf("failed to create a session")
1492	}
1493
1494	// Simulate an unavailable error to interrupt the stream of PartialResultSet
1495	// in order to test the grpc retrying mechanism.
1496	server.TestSpanner.AddPartialResultSetError(
1497		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1498		PartialResultSetExecutionTime{
1499			ResumeToken: EncodeResumeToken(2),
1500			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
1501		},
1502	)
1503
1504	// The retry is counted from the second call.
1505	r := -1
1506	// Establish a stream to mock cloud spanner server.
1507	iter := stream(context.Background(), nil,
1508		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1509			r++
1510			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1511				Session:     session.Name,
1512				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1513				ResumeToken: resumeToken,
1514			})
1515
1516		},
1517		nil,
1518		func(error) {})
1519	defer iter.Stop()
1520	for {
1521		_, err := iter.Next()
1522		if err == iterator.Done {
1523			err = nil
1524			break
1525		}
1526		if err != nil {
1527			break
1528		}
1529	}
1530	if r != 1 {
1531		t.Errorf("retry count = %v, want 1", r)
1532	}
1533}
1534
1535// Test cancel/timeout for client operations.
1536func TestCancelTimeout(t *testing.T) {
1537	restore := setMaxBytesBetweenResumeTokens()
1538	defer restore()
1539	server, c, teardown := setupMockedTestServer(t)
1540	defer teardown()
1541	server.TestSpanner.PutExecutionTime(
1542		MethodExecuteStreamingSql,
1543		SimulatedExecutionTime{MinimumExecutionTime: 1 * time.Second},
1544	)
1545	mc, err := c.sc.nextClient()
1546	if err != nil {
1547		t.Fatalf("failed to create a grpc client")
1548	}
1549
1550	session, err := createSession(mc)
1551	if err != nil {
1552		t.Fatalf("failed to create a session")
1553	}
1554	done := make(chan int)
1555
1556	// Test cancelling query.
1557	ctx, cancel := context.WithCancel(context.Background())
1558	go func() {
1559		// Establish a stream to mock cloud spanner server.
1560		iter := stream(ctx, nil,
1561			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1562				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1563					Session:     session.Name,
1564					Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1565					ResumeToken: resumeToken,
1566				})
1567			},
1568			nil,
1569			func(error) {})
1570		defer iter.Stop()
1571		for {
1572			_, err = iter.Next()
1573			if err == iterator.Done {
1574				break
1575			}
1576			if err != nil {
1577				done <- 0
1578				break
1579			}
1580		}
1581	}()
1582	cancel()
1583	select {
1584	case <-done:
1585		if ErrCode(err) != codes.Canceled {
1586			t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
1587		}
1588	case <-time.After(1 * time.Second):
1589		t.Errorf("query doesn't exit timely after being cancelled")
1590	}
1591
1592	// Test query timeout.
1593	ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
1594	defer cancel()
1595	go func() {
1596		// Establish a stream to mock cloud spanner server.
1597		iter := stream(ctx, nil,
1598			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1599				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1600					Session:     session.Name,
1601					Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1602					ResumeToken: resumeToken,
1603				})
1604			},
1605			nil,
1606			func(error) {})
1607		defer iter.Stop()
1608		for {
1609			_, err = iter.Next()
1610			if err == iterator.Done {
1611				err = nil
1612				break
1613			}
1614			if err != nil {
1615				break
1616			}
1617		}
1618		done <- 0
1619	}()
1620	select {
1621	case <-done:
1622		if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
1623			t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
1624		}
1625	case <-time.After(2 * time.Second):
1626		t.Errorf("query doesn't timeout as expected")
1627	}
1628}
1629
1630func setupStatementResult(t *testing.T, server *MockedSpannerInMemTestServer, stmt string, rowCount int, resumeTokens [][]byte) error {
1631	selectValues := make([][]string, rowCount)
1632	for i := 0; i < rowCount; i++ {
1633		selectValues[i] = []string{keyStr(i), valStr(i)}
1634	}
1635
1636	rows := make([]*structpb.ListValue, len(selectValues))
1637	for i, values := range selectValues {
1638		rowValues := make([]*structpb.Value, len(kvMeta.RowType.Fields))
1639		for j, value := range values {
1640			rowValues[j] = &structpb.Value{
1641				Kind: &structpb.Value_StringValue{StringValue: value},
1642			}
1643		}
1644		rows[i] = &structpb.ListValue{
1645			Values: rowValues,
1646		}
1647	}
1648	resultSet := &sppb.ResultSet{
1649		Metadata: kvMeta,
1650		Rows:     rows,
1651	}
1652	result := &StatementResult{
1653		Type:         StatementResultResultSet,
1654		ResultSet:    resultSet,
1655		ResumeTokens: resumeTokens,
1656	}
1657	return server.TestSpanner.PutStatementResult(stmt, result)
1658}
1659
1660func TestRowIteratorDo(t *testing.T) {
1661	restore := setMaxBytesBetweenResumeTokens()
1662	defer restore()
1663
1664	_, c, teardown := setupMockedTestServer(t)
1665	defer teardown()
1666	mc, err := c.sc.nextClient()
1667	if err != nil {
1668		t.Fatalf("failed to create a grpc client")
1669	}
1670
1671	session, err := createSession(mc)
1672	if err != nil {
1673		t.Fatalf("failed to create a session")
1674	}
1675
1676	nRows := 0
1677	iter := stream(context.Background(), nil,
1678		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1679			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1680				Session:     session.Name,
1681				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1682				ResumeToken: resumeToken,
1683			})
1684		},
1685		nil,
1686		func(error) {})
1687	err = iter.Do(func(r *Row) error { nRows++; return nil })
1688	if err != nil {
1689		t.Errorf("Using Do: %v", err)
1690	}
1691	if nRows != 3 {
1692		t.Errorf("got %d rows, want 3", nRows)
1693	}
1694}
1695
1696func TestRowIteratorDoWithError(t *testing.T) {
1697	restore := setMaxBytesBetweenResumeTokens()
1698	defer restore()
1699
1700	_, c, teardown := setupMockedTestServer(t)
1701	defer teardown()
1702	mc, err := c.sc.nextClient()
1703	if err != nil {
1704		t.Fatalf("failed to create a grpc client")
1705	}
1706
1707	session, err := createSession(mc)
1708	if err != nil {
1709		t.Fatalf("failed to create a session")
1710	}
1711
1712	iter := stream(context.Background(), nil,
1713		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1714			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1715				Session:     session.Name,
1716				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1717				ResumeToken: resumeToken,
1718			})
1719		},
1720		nil,
1721		func(error) {})
1722	injected := errors.New("Failed iterator")
1723	err = iter.Do(func(r *Row) error { return injected })
1724	if err != injected {
1725		t.Errorf("got <%v>, want <%v>", err, injected)
1726	}
1727}
1728
1729func TestIteratorStopEarly(t *testing.T) {
1730	ctx := context.Background()
1731	restore := setMaxBytesBetweenResumeTokens()
1732	defer restore()
1733
1734	_, c, teardown := setupMockedTestServer(t)
1735	defer teardown()
1736	mc, err := c.sc.nextClient()
1737	if err != nil {
1738		t.Fatalf("failed to create a grpc client")
1739	}
1740
1741	session, err := createSession(mc)
1742	if err != nil {
1743		t.Fatalf("failed to create a session")
1744	}
1745
1746	iter := stream(ctx, nil,
1747		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1748			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1749				Session:     session.Name,
1750				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1751				ResumeToken: resumeToken,
1752			})
1753		},
1754		nil,
1755		func(error) {})
1756	_, err = iter.Next()
1757	if err != nil {
1758		t.Fatalf("before Stop: %v", err)
1759	}
1760	iter.Stop()
1761	// Stop sets r.err to the FailedPrecondition error "Next called after Stop".
1762	_, err = iter.Next()
1763	if g, w := ErrCode(err), codes.FailedPrecondition; g != w {
1764		t.Errorf("after Stop: got: %v, want: %v", g, w)
1765	}
1766}
1767
1768func TestIteratorWithError(t *testing.T) {
1769	injected := errors.New("Failed iterator")
1770	iter := RowIterator{err: injected}
1771	defer iter.Stop()
1772	if _, err := iter.Next(); err != injected {
1773		t.Fatalf("Expected error: %v, got %v", injected, err)
1774	}
1775}
1776
1777func createSession(client *vkit.Client) (*sppb.Session, error) {
1778	var formattedDatabase string = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
1779	var request = &sppb.CreateSessionRequest{
1780		Database: formattedDatabase,
1781	}
1782	return client.CreateSession(context.Background(), request)
1783}
1784