1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"sync"
24	"sync/atomic"
25	"testing"
26	"time"
27
28	vkit "cloud.google.com/go/spanner/apiv1"
29	. "cloud.google.com/go/spanner/internal/testutil"
30	"github.com/golang/protobuf/proto"
31	proto3 "github.com/golang/protobuf/ptypes/struct"
32	structpb "github.com/golang/protobuf/ptypes/struct"
33	"github.com/googleapis/gax-go/v2"
34	"google.golang.org/api/iterator"
35	sppb "google.golang.org/genproto/googleapis/spanner/v1"
36	"google.golang.org/grpc/codes"
37	"google.golang.org/grpc/status"
38)
39
40var (
41	// Mocked transaction timestamp.
42	trxTs = time.Unix(1, 2)
43	// Metadata for mocked KV table, its rows are returned by SingleUse
44	// transactions.
45	kvMeta = func() *sppb.ResultSetMetadata {
46		meta := KvMeta
47		meta.Transaction = &sppb.Transaction{
48			ReadTimestamp: timestampProto(trxTs),
49		}
50		return &meta
51	}()
52	// Metadata for mocked ListKV table, which uses List for its key and value.
53	// Its rows are returned by snapshot readonly transactions, as indicated in
54	// the transaction metadata.
55	kvListMeta = &sppb.ResultSetMetadata{
56		RowType: &sppb.StructType{
57			Fields: []*sppb.StructType_Field{
58				{
59					Name: "Key",
60					Type: &sppb.Type{
61						Code: sppb.TypeCode_ARRAY,
62						ArrayElementType: &sppb.Type{
63							Code: sppb.TypeCode_STRING,
64						},
65					},
66				},
67				{
68					Name: "Value",
69					Type: &sppb.Type{
70						Code: sppb.TypeCode_ARRAY,
71						ArrayElementType: &sppb.Type{
72							Code: sppb.TypeCode_STRING,
73						},
74					},
75				},
76			},
77		},
78		Transaction: &sppb.Transaction{
79			Id:            transactionID{5, 6, 7, 8, 9},
80			ReadTimestamp: timestampProto(trxTs),
81		},
82	}
83	// Metadata for mocked schema of a query result set, which has two struct
84	// columns named "Col1" and "Col2", the struct's schema is like the
85	// following:
86	//
87	//	STRUCT {
88	//		INT
89	//		LIST<STRING>
90	//	}
91	//
92	// Its rows are returned in readwrite transaction, as indicated in the
93	// transaction metadata.
94	kvObjectMeta = &sppb.ResultSetMetadata{
95		RowType: &sppb.StructType{
96			Fields: []*sppb.StructType_Field{
97				{
98					Name: "Col1",
99					Type: &sppb.Type{
100						Code: sppb.TypeCode_STRUCT,
101						StructType: &sppb.StructType{
102							Fields: []*sppb.StructType_Field{
103								{
104									Name: "foo-f1",
105									Type: &sppb.Type{
106										Code: sppb.TypeCode_INT64,
107									},
108								},
109								{
110									Name: "foo-f2",
111									Type: &sppb.Type{
112										Code: sppb.TypeCode_ARRAY,
113										ArrayElementType: &sppb.Type{
114											Code: sppb.TypeCode_STRING,
115										},
116									},
117								},
118							},
119						},
120					},
121				},
122				{
123					Name: "Col2",
124					Type: &sppb.Type{
125						Code: sppb.TypeCode_STRUCT,
126						StructType: &sppb.StructType{
127							Fields: []*sppb.StructType_Field{
128								{
129									Name: "bar-f1",
130									Type: &sppb.Type{
131										Code: sppb.TypeCode_INT64,
132									},
133								},
134								{
135									Name: "bar-f2",
136									Type: &sppb.Type{
137										Code: sppb.TypeCode_ARRAY,
138										ArrayElementType: &sppb.Type{
139											Code: sppb.TypeCode_STRING,
140										},
141									},
142								},
143							},
144						},
145					},
146				},
147			},
148		},
149		Transaction: &sppb.Transaction{
150			Id: transactionID{1, 2, 3, 4, 5},
151		},
152	}
153)
154
155// String implements fmt.stringer.
156func (r *Row) String() string {
157	return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals)
158}
159
160func describeRows(l []*Row) string {
161	// generate a nice test failure description
162	var s = "["
163	for i, r := range l {
164		if i != 0 {
165			s += ",\n "
166		}
167		s += fmt.Sprint(r)
168	}
169	s += "]"
170	return s
171}
172
173// Helper for generating proto3 Value_ListValue instances, making test code
174// shorter and readable.
175func genProtoListValue(v ...string) *proto3.Value_ListValue {
176	r := &proto3.Value_ListValue{
177		ListValue: &proto3.ListValue{
178			Values: []*proto3.Value{},
179		},
180	}
181	for _, e := range v {
182		r.ListValue.Values = append(
183			r.ListValue.Values,
184			&proto3.Value{
185				Kind: &proto3.Value_StringValue{StringValue: e},
186			},
187		)
188	}
189	return r
190}
191
192// Test Row generation logics of partialResultSetDecoder.
193func TestPartialResultSetDecoder(t *testing.T) {
194	restore := setMaxBytesBetweenResumeTokens()
195	defer restore()
196	var tests = []struct {
197		input    []*sppb.PartialResultSet
198		wantF    []*Row
199		wantTxID transactionID
200		wantTs   time.Time
201		wantD    bool
202	}{
203		{
204			// Empty input.
205			wantD: true,
206		},
207		// String merging examples.
208		{
209			// Single KV result.
210			input: []*sppb.PartialResultSet{
211				{
212					Metadata: kvMeta,
213					Values: []*proto3.Value{
214						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
215						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
216					},
217				},
218			},
219			wantF: []*Row{
220				{
221					fields: kvMeta.RowType.Fields,
222					vals: []*proto3.Value{
223						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
224						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
225					},
226				},
227			},
228			wantTs: trxTs,
229			wantD:  true,
230		},
231		{
232			// Incomplete partial result.
233			input: []*sppb.PartialResultSet{
234				{
235					Metadata: kvMeta,
236					Values: []*proto3.Value{
237						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
238					},
239				},
240			},
241			wantTs: trxTs,
242			wantD:  false,
243		},
244		{
245			// Complete splitted result.
246			input: []*sppb.PartialResultSet{
247				{
248					Metadata: kvMeta,
249					Values: []*proto3.Value{
250						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
251					},
252				},
253				{
254					Values: []*proto3.Value{
255						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
256					},
257				},
258			},
259			wantF: []*Row{
260				{
261					fields: kvMeta.RowType.Fields,
262					vals: []*proto3.Value{
263						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
264						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
265					},
266				},
267			},
268			wantTs: trxTs,
269			wantD:  true,
270		},
271		{
272			// Multi-row example with splitted row in the middle.
273			input: []*sppb.PartialResultSet{
274				{
275					Metadata: kvMeta,
276					Values: []*proto3.Value{
277						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
278						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
279						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
280					},
281				},
282				{
283					Values: []*proto3.Value{
284						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
285						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
286						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
287					},
288				},
289			},
290			wantF: []*Row{
291				{
292					fields: kvMeta.RowType.Fields,
293					vals: []*proto3.Value{
294						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
295						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
296					},
297				},
298				{
299					fields: kvMeta.RowType.Fields,
300					vals: []*proto3.Value{
301						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
302						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
303					},
304				},
305				{
306					fields: kvMeta.RowType.Fields,
307					vals: []*proto3.Value{
308						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
309						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
310					},
311				},
312			},
313			wantTs: trxTs,
314			wantD:  true,
315		},
316		{
317			// Merging example in result_set.proto.
318			input: []*sppb.PartialResultSet{
319				{
320					Metadata: kvMeta,
321					Values: []*proto3.Value{
322						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
323						{Kind: &proto3.Value_StringValue{StringValue: "W"}},
324					},
325					ChunkedValue: true,
326				},
327				{
328					Values: []*proto3.Value{
329						{Kind: &proto3.Value_StringValue{StringValue: "orl"}},
330					},
331					ChunkedValue: true,
332				},
333				{
334					Values: []*proto3.Value{
335						{Kind: &proto3.Value_StringValue{StringValue: "d"}},
336					},
337				},
338			},
339			wantF: []*Row{
340				{
341					fields: kvMeta.RowType.Fields,
342					vals: []*proto3.Value{
343						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
344						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
345					},
346				},
347			},
348			wantTs: trxTs,
349			wantD:  true,
350		},
351		{
352			// More complex example showing completing a merge and
353			// starting a new merge in the same partialResultSet.
354			input: []*sppb.PartialResultSet{
355				{
356					Metadata: kvMeta,
357					Values: []*proto3.Value{
358						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
359						{Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
360					},
361					ChunkedValue: true,
362				},
363				{
364					Values: []*proto3.Value{
365						{Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
366						{Kind: &proto3.Value_StringValue{StringValue: "i"}},    // start split in key
367					},
368					ChunkedValue: true,
369				},
370				{
371					Values: []*proto3.Value{
372						{Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
373						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
374						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
375						{Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
376					},
377					ChunkedValue: true,
378				},
379				{
380					Values: []*proto3.Value{
381						{Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
382					},
383				},
384			},
385			wantF: []*Row{
386				{
387					fields: kvMeta.RowType.Fields,
388					vals: []*proto3.Value{
389						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
390						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
391					},
392				},
393				{
394					fields: kvMeta.RowType.Fields,
395					vals: []*proto3.Value{
396						{Kind: &proto3.Value_StringValue{StringValue: "is"}},
397						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
398					},
399				},
400				{
401					fields: kvMeta.RowType.Fields,
402					vals: []*proto3.Value{
403						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
404						{Kind: &proto3.Value_StringValue{StringValue: "question"}},
405					},
406				},
407			},
408			wantTs: trxTs,
409			wantD:  true,
410		},
411		// List merging examples.
412		{
413			// Non-splitting Lists.
414			input: []*sppb.PartialResultSet{
415				{
416					Metadata: kvListMeta,
417					Values: []*proto3.Value{
418						{
419							Kind: genProtoListValue("foo-1", "foo-2"),
420						},
421					},
422				},
423				{
424					Values: []*proto3.Value{
425						{
426							Kind: genProtoListValue("bar-1", "bar-2"),
427						},
428					},
429				},
430			},
431			wantF: []*Row{
432				{
433					fields: kvListMeta.RowType.Fields,
434					vals: []*proto3.Value{
435						{
436							Kind: genProtoListValue("foo-1", "foo-2"),
437						},
438						{
439							Kind: genProtoListValue("bar-1", "bar-2"),
440						},
441					},
442				},
443			},
444			wantTxID: transactionID{5, 6, 7, 8, 9},
445			wantTs:   trxTs,
446			wantD:    true,
447		},
448		{
449			// Simple List merge case: splitted string element.
450			input: []*sppb.PartialResultSet{
451				{
452					Metadata: kvListMeta,
453					Values: []*proto3.Value{
454						{
455							Kind: genProtoListValue("foo-1", "foo-"),
456						},
457					},
458					ChunkedValue: true,
459				},
460				{
461					Values: []*proto3.Value{
462						{
463							Kind: genProtoListValue("2"),
464						},
465					},
466				},
467				{
468					Values: []*proto3.Value{
469						{
470							Kind: genProtoListValue("bar-1", "bar-2"),
471						},
472					},
473				},
474			},
475			wantF: []*Row{
476				{
477					fields: kvListMeta.RowType.Fields,
478					vals: []*proto3.Value{
479						{
480							Kind: genProtoListValue("foo-1", "foo-2"),
481						},
482						{
483							Kind: genProtoListValue("bar-1", "bar-2"),
484						},
485					},
486				},
487			},
488			wantTxID: transactionID{5, 6, 7, 8, 9},
489			wantTs:   trxTs,
490			wantD:    true,
491		},
492		{
493			// Struct merging is also implemented by List merging. Note that
494			// Cloud Spanner uses proto.ListValue to encode Structs as well.
495			input: []*sppb.PartialResultSet{
496				{
497					Metadata: kvObjectMeta,
498					Values: []*proto3.Value{
499						{
500							Kind: &proto3.Value_ListValue{
501								ListValue: &proto3.ListValue{
502									Values: []*proto3.Value{
503										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
504										{Kind: genProtoListValue("foo-1", "fo")},
505									},
506								},
507							},
508						},
509					},
510					ChunkedValue: true,
511				},
512				{
513					Values: []*proto3.Value{
514						{
515							Kind: &proto3.Value_ListValue{
516								ListValue: &proto3.ListValue{
517									Values: []*proto3.Value{
518										{Kind: genProtoListValue("o-2", "f")},
519									},
520								},
521							},
522						},
523					},
524					ChunkedValue: true,
525				},
526				{
527					Values: []*proto3.Value{
528						{
529							Kind: &proto3.Value_ListValue{
530								ListValue: &proto3.ListValue{
531									Values: []*proto3.Value{
532										{Kind: genProtoListValue("oo-3")},
533									},
534								},
535							},
536						},
537						{
538							Kind: &proto3.Value_ListValue{
539								ListValue: &proto3.ListValue{
540									Values: []*proto3.Value{
541										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
542										{Kind: genProtoListValue("bar-1")},
543									},
544								},
545							},
546						},
547					},
548				},
549			},
550			wantF: []*Row{
551				{
552					fields: kvObjectMeta.RowType.Fields,
553					vals: []*proto3.Value{
554						{
555							Kind: &proto3.Value_ListValue{
556								ListValue: &proto3.ListValue{
557									Values: []*proto3.Value{
558										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
559										{Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
560									},
561								},
562							},
563						},
564						{
565							Kind: &proto3.Value_ListValue{
566								ListValue: &proto3.ListValue{
567									Values: []*proto3.Value{
568										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
569										{Kind: genProtoListValue("bar-1")},
570									},
571								},
572							},
573						},
574					},
575				},
576			},
577			wantTxID: transactionID{1, 2, 3, 4, 5},
578			wantD:    true,
579		},
580	}
581
582nextTest:
583	for i, test := range tests {
584		var rows []*Row
585		p := &partialResultSetDecoder{}
586		for j, v := range test.input {
587			rs, err := p.add(v)
588			if err != nil {
589				t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
590				continue nextTest
591			}
592			rows = append(rows, rs...)
593		}
594		if !testEqual(p.ts, test.wantTs) {
595			t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
596		}
597		if !testEqual(rows, test.wantF) {
598			t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
599		}
600		if got := p.done(); got != test.wantD {
601			t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
602		}
603	}
604}
605
606const (
607	// max number of PartialResultSets that will be buffered in tests.
608	maxBuffers = 16
609)
610
611// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to
612// a smaller value more suitable for tests. It returns a function which should
613// be called to restore the maxBytesBetweenResumeTokens to its old value.
614func setMaxBytesBetweenResumeTokens() func() {
615	o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
616	atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
617		Metadata: kvMeta,
618		Values: []*proto3.Value{
619			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
620			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
621		},
622	})))
623	return func() {
624		atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
625	}
626}
627
628// keyStr generates key string for kvMeta schema.
629func keyStr(i int) string {
630	return fmt.Sprintf("foo-%02d", i)
631}
632
633// valStr generates value string for kvMeta schema.
634func valStr(i int) string {
635	return fmt.Sprintf("bar-%02d", i)
636}
637
638// Test state transitions of resumableStreamDecoder where state machine ends up
639// to a non-blocking state(resumableStreamDecoder.Next returns on non-blocking
640// state).
641func TestRsdNonblockingStates(t *testing.T) {
642	restore := setMaxBytesBetweenResumeTokens()
643	defer restore()
644	tests := []struct {
645		name         string
646		resumeTokens [][]byte
647		prsErrors    []PartialResultSetExecutionTime
648		rpc          func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
649		sql          string
650		// Expected values
651		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
652		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
653		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
654		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
655		wantErr      error
656	}{
657		{
658			// unConnected->queueingRetryable->finished
659			name:         "unConnected->queueingRetryable->finished",
660			resumeTokens: make([][]byte, 2),
661			sql:          "SELECT t.key key, t.value value FROM t_mock t",
662			want: []*sppb.PartialResultSet{
663				{
664					Metadata: kvMeta,
665					Values: []*proto3.Value{
666						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
667						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
668					},
669				},
670			},
671			queue: []*sppb.PartialResultSet{
672				{
673					Metadata: kvMeta,
674					Values: []*proto3.Value{
675						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
676						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
677					},
678				},
679			},
680			stateHistory: []resumableStreamDecoderState{
681				queueingRetryable, // do RPC
682				queueingRetryable, // got foo-00
683				queueingRetryable, // got foo-01
684				finished,          // got EOF
685			},
686		},
687		{
688			// unConnected->queueingRetryable->aborted
689			name:         "unConnected->queueingRetryable->aborted",
690			resumeTokens: [][]byte{{}, EncodeResumeToken(1), {}, EncodeResumeToken(2)},
691			prsErrors: []PartialResultSetExecutionTime{{
692				ResumeToken: EncodeResumeToken(2),
693				Err:         status.Error(codes.Unknown, "I quit"),
694			}},
695			sql: "SELECT t.key key, t.value value FROM t_mock t",
696			want: []*sppb.PartialResultSet{
697				{
698					Metadata: kvMeta,
699					Values: []*proto3.Value{
700						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
701						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
702					},
703				},
704				{
705					Metadata: kvMeta,
706					Values: []*proto3.Value{
707						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
708						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
709					},
710					ResumeToken: EncodeResumeToken(1),
711				},
712			},
713			stateHistory: []resumableStreamDecoderState{
714				queueingRetryable, // do RPC
715				queueingRetryable, // got foo-00
716				queueingRetryable, // got foo-01
717				queueingRetryable, // foo-01, resume token
718				queueingRetryable, // got foo-02
719				aborted,           // got error
720			},
721			wantErr: status.Errorf(codes.Unknown, "I quit"),
722		},
723		{
724			// unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
725			name:         "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
726			resumeTokens: make([][]byte, maxBuffers+1),
727			sql:          "SELECT t.key key, t.value value FROM t_mock t",
728			want: func() (s []*sppb.PartialResultSet) {
729				for i := 0; i < maxBuffers+1; i++ {
730					s = append(s, &sppb.PartialResultSet{
731						Metadata: kvMeta,
732						Values: []*proto3.Value{
733							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
734							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
735						},
736					})
737				}
738				return s
739			}(),
740			stateHistory: func() (s []resumableStreamDecoderState) {
741				s = append(s, queueingRetryable) // RPC
742				for i := 0; i < maxBuffers; i++ {
743					s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
744				}
745				// the first item fills up the queue and triggers state transition;
746				// the second item is received under queueingUnretryable state.
747				s = append(s, queueingUnretryable)
748				s = append(s, queueingUnretryable)
749				return s
750			}(),
751		},
752		{
753			// unConnected->queueingRetryable->queueingUnretryable->aborted
754			name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
755			resumeTokens: func() (rts [][]byte) {
756				rts = make([][]byte, maxBuffers+1)
757				rts[maxBuffers] = EncodeResumeToken(1)
758				return rts
759			}(),
760			prsErrors: []PartialResultSetExecutionTime{{
761				ResumeToken: EncodeResumeToken(1),
762				Err:         status.Error(codes.Unknown, "Just Abort It"),
763			}},
764			sql: "SELECT t.key key, t.value value FROM t_mock t",
765			want: func() (s []*sppb.PartialResultSet) {
766				for i := 0; i < maxBuffers; i++ {
767					s = append(s, &sppb.PartialResultSet{
768						Metadata: kvMeta,
769						Values: []*proto3.Value{
770							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
771							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
772						},
773					})
774				}
775				return s
776			}(),
777			stateHistory: func() (s []resumableStreamDecoderState) {
778				s = append(s, queueingRetryable) // RPC
779				for i := 0; i < maxBuffers; i++ {
780					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
781				}
782				s = append(s, queueingUnretryable) // the last row triggers state change
783				s = append(s, aborted)             // Error happens
784				return s
785			}(),
786			wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
787		},
788	}
789	for _, test := range tests {
790		t.Run(test.name, func(t *testing.T) {
791			server, c, teardown := setupMockedTestServer(t)
792			defer teardown()
793			mc, err := c.sc.nextClient()
794			if err != nil {
795				t.Fatalf("failed to create a grpc client")
796			}
797
798			session, err := createSession(mc)
799			if err != nil {
800				t.Fatalf("failed to create a session")
801			}
802
803			if test.rpc == nil {
804				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
805					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
806						Session:     session.Name,
807						Sql:         test.sql,
808						ResumeToken: resumeToken,
809					})
810				}
811			}
812			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
813			defer cancel()
814			r := newResumableStreamDecoder(
815				ctx,
816				nil,
817				test.rpc,
818				nil,
819			)
820			st := []resumableStreamDecoderState{}
821			var lastErr error
822			// Once the expected number of state transitions are observed,
823			// send a signal by setting stateDone = true.
824			stateDone := false
825			// Set stateWitness to listen to state changes.
826			hl := len(test.stateHistory) // To avoid data race on test.
827			r.stateWitness = func(rs resumableStreamDecoderState) {
828				if !stateDone {
829					// Record state transitions.
830					st = append(st, rs)
831					if len(st) == hl {
832						lastErr = r.lastErr()
833						stateDone = true
834					}
835				}
836			}
837			// Let mock server stream given messages to resumableStreamDecoder.
838			err = setupStatementResult(t, server, test.sql, len(test.resumeTokens), test.resumeTokens)
839			if err != nil {
840				t.Fatalf("failed to set up a result for a statement: %v", err)
841			}
842
843			for _, et := range test.prsErrors {
844				server.TestSpanner.AddPartialResultSetError(
845					test.sql,
846					et,
847				)
848			}
849			var rs []*sppb.PartialResultSet
850			for {
851				select {
852				case <-ctx.Done():
853					t.Fatal("context cancelled or timeout during test")
854				default:
855				}
856				if stateDone {
857					// Check if resumableStreamDecoder carried out expected
858					// state transitions.
859					if !testEqual(st, test.stateHistory) {
860						t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
861					}
862					// Check if resumableStreamDecoder returns expected array of
863					// PartialResultSets.
864					if !testEqual(rs, test.want) {
865						t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
866					}
867					// Verify that resumableStreamDecoder's internal buffering is
868					// also correct.
869					var q []*sppb.PartialResultSet
870					for {
871						item := r.q.pop()
872						if item == nil {
873							break
874						}
875						q = append(q, item)
876					}
877					if !testEqual(q, test.queue) {
878						t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
879					}
880					// Verify resume token.
881					if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
882						t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
883					}
884					// Verify error message.
885					if !testEqual(lastErr, test.wantErr) {
886						t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
887					}
888					return
889				}
890				// Receive next decoded item.
891				if r.next() {
892					rs = append(rs, r.get())
893				}
894			}
895		})
896	}
897}
898
899// Test state transitions of resumableStreamDecoder where state machine
900// ends up to a blocking state(resumableStreamDecoder.Next blocks
901// on blocking state).
902func TestRsdBlockingStates(t *testing.T) {
903	restore := setMaxBytesBetweenResumeTokens()
904	defer restore()
905	for _, test := range []struct {
906		name         string
907		resumeTokens [][]byte
908		rpc          func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
909		sql          string
910		// Expected values
911		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
912		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
913		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
914		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
915		wantErr      error
916	}{
917		{
918			// unConnected -> unConnected
919			name: "unConnected -> unConnected",
920			rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
921				return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
922			},
923			sql:          "SELECT * from t_whatever",
924			stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
925			wantErr:      status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
926		},
927		{
928			// unConnected -> queueingRetryable
929			name:         "unConnected -> queueingRetryable",
930			sql:          "SELECT t.key key, t.value value FROM t_mock t",
931			stateHistory: []resumableStreamDecoderState{queueingRetryable},
932		},
933		{
934			// unConnected->queueingRetryable->queueingRetryable
935			name:         "unConnected->queueingRetryable->queueingRetryable",
936			resumeTokens: [][]byte{{}, EncodeResumeToken(1), EncodeResumeToken(2), {}},
937			sql:          "SELECT t.key key, t.value value FROM t_mock t",
938			want: []*sppb.PartialResultSet{
939				{
940					Metadata: kvMeta,
941					Values: []*proto3.Value{
942						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
943						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
944					},
945				},
946				{
947					Metadata: kvMeta,
948					Values: []*proto3.Value{
949						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
950						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
951					},
952					ResumeToken: EncodeResumeToken(1),
953				},
954				{
955					Metadata: kvMeta,
956					Values: []*proto3.Value{
957						{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
958						{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
959					},
960					ResumeToken: EncodeResumeToken(2),
961				},
962				// The server sends an io.EOF at last and the decoder will
963				// flush out all messages in the internal queue.
964				{
965					Metadata: kvMeta,
966					Values: []*proto3.Value{
967						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
968						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
969					},
970				},
971			},
972			queue: []*sppb.PartialResultSet{
973				{
974					Metadata: kvMeta,
975					Values: []*proto3.Value{
976						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
977						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
978					},
979				},
980			},
981			resumeToken: EncodeResumeToken(2),
982			stateHistory: []resumableStreamDecoderState{
983				queueingRetryable, // do RPC
984				queueingRetryable, // got foo-00
985				queueingRetryable, // got foo-01
986				queueingRetryable, // foo-01, resume token
987				queueingRetryable, // got foo-02
988				queueingRetryable, // foo-02, resume token
989				queueingRetryable, // got foo-03
990			},
991		},
992		{
993			// unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
994			name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
995			resumeTokens: func() (rts [][]byte) {
996				rts = make([][]byte, maxBuffers+3)
997				rts[maxBuffers+1] = EncodeResumeToken(maxBuffers + 1)
998				return rts
999			}(),
1000			sql: "SELECT t.key key, t.value value FROM t_mock t",
1001			want: func() (s []*sppb.PartialResultSet) {
1002				// The server sends an io.EOF at last and the decoder will
1003				// flush out all messages in the internal queue. Although the
1004				// last message is supposed to be queued and the decoder waits
1005				// for the next resume token, an io.EOF leads to a `finished`
1006				// state that the last message will be removed from the queue
1007				// and be read by the client side.
1008				for i := 0; i < maxBuffers+3; i++ {
1009					s = append(s, &sppb.PartialResultSet{
1010						Metadata: kvMeta,
1011						Values: []*proto3.Value{
1012							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1013							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1014						},
1015					})
1016				}
1017				s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1)
1018				return s
1019			}(),
1020			resumeToken: EncodeResumeToken(maxBuffers + 1),
1021			queue: []*sppb.PartialResultSet{
1022				{
1023					Metadata: kvMeta,
1024					Values: []*proto3.Value{
1025						{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
1026						{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
1027					},
1028				},
1029			},
1030			stateHistory: func() (s []resumableStreamDecoderState) {
1031				s = append(s, queueingRetryable) // RPC
1032				for i := 0; i < maxBuffers; i++ {
1033					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
1034				}
1035				for i := maxBuffers - 1; i < maxBuffers+1; i++ {
1036					// the first item fills up the queue and triggers state
1037					// change; the second item is received under
1038					// queueingUnretryable state.
1039					s = append(s, queueingUnretryable)
1040				}
1041				s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
1042				s = append(s, queueingRetryable)   // (maxBuffers+1)th row has resume token
1043				s = append(s, queueingRetryable)   // (maxBuffers+2)th row has no resume token
1044				return s
1045			}(),
1046		},
1047		{
1048			// unConnected->queueingRetryable->queueingUnretryable->finished
1049			name:         "unConnected->queueingRetryable->queueingUnretryable->finished",
1050			resumeTokens: make([][]byte, maxBuffers),
1051			sql:          "SELECT t.key key, t.value value FROM t_mock t",
1052			want: func() (s []*sppb.PartialResultSet) {
1053				for i := 0; i < maxBuffers; i++ {
1054					s = append(s, &sppb.PartialResultSet{
1055						Metadata: kvMeta,
1056						Values: []*proto3.Value{
1057							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1058							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1059						},
1060					})
1061				}
1062				return s
1063			}(),
1064			stateHistory: func() (s []resumableStreamDecoderState) {
1065				s = append(s, queueingRetryable) // RPC
1066				for i := 0; i < maxBuffers; i++ {
1067					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
1068				}
1069				s = append(s, queueingUnretryable) // last row triggers state change
1070				s = append(s, finished)            // query finishes
1071				return s
1072			}(),
1073		},
1074	} {
1075		t.Run(test.name, func(t *testing.T) {
1076			server, c, teardown := setupMockedTestServer(t)
1077			defer teardown()
1078			mc, err := c.sc.nextClient()
1079			if err != nil {
1080				t.Fatalf("failed to create a grpc client")
1081			}
1082
1083			session, err := createSession(mc)
1084			if err != nil {
1085				t.Fatalf("failed to create a session")
1086			}
1087
1088			if test.rpc == nil {
1089				// Avoid using test.sql directly in closure because for loop changes
1090				// test.
1091				sql := test.sql
1092				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1093					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1094						Session:     session.Name,
1095						Sql:         sql,
1096						ResumeToken: resumeToken,
1097					})
1098				}
1099			}
1100			ctx, cancel := context.WithCancel(context.Background())
1101			defer cancel()
1102			r := newResumableStreamDecoder(
1103				ctx,
1104				nil,
1105				test.rpc,
1106				nil,
1107			)
1108			// Override backoff to make the test run faster.
1109			r.backoff = gax.Backoff{
1110				Initial:    1 * time.Nanosecond,
1111				Max:        1 * time.Nanosecond,
1112				Multiplier: 1.3,
1113			}
1114			// st is the set of observed state transitions.
1115			st := []resumableStreamDecoderState{}
1116			// q is the content of the decoder's partial result queue when expected
1117			// number of state transitions are done.
1118			q := []*sppb.PartialResultSet{}
1119			var lastErr error
1120			// Once the expected number of state transitions are observed, send a
1121			// signal to channel stateDone.
1122			stateDone := make(chan int)
1123			// Set stateWitness to listen to state changes.
1124			hl := len(test.stateHistory) // To avoid data race on test.
1125			r.stateWitness = func(rs resumableStreamDecoderState) {
1126				select {
1127				case <-stateDone:
1128					// Noop after expected number of state transitions
1129				default:
1130					// Record state transitions.
1131					st = append(st, rs)
1132					if len(st) == hl {
1133						lastErr = r.lastErr()
1134						q = r.q.dump()
1135						close(stateDone)
1136					}
1137				}
1138			}
1139			// Let mock server stream given messages to resumableStreamDecoder.
1140			err = setupStatementResult(t, server, test.sql, len(test.resumeTokens), test.resumeTokens)
1141			if err != nil {
1142				t.Fatalf("failed to set up a result for a statement: %v", err)
1143			}
1144			var mutex = &sync.Mutex{}
1145			var rs []*sppb.PartialResultSet
1146			go func() {
1147				for {
1148					if !r.next() {
1149						// Note that r.Next also exits on context cancel/timeout.
1150						return
1151					}
1152					mutex.Lock()
1153					rs = append(rs, r.get())
1154					mutex.Unlock()
1155				}
1156			}()
1157			// Verify that resumableStreamDecoder reaches expected state.
1158			select {
1159			case <-stateDone: // Note that at this point, receiver is still blockingon r.next().
1160				// Check if resumableStreamDecoder carried out expected state
1161				// transitions.
1162				if !testEqual(st, test.stateHistory) {
1163					t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
1164				}
1165				// Check if resumableStreamDecoder returns expected array of
1166				// PartialResultSets.
1167				mutex.Lock()
1168				defer mutex.Unlock()
1169				if !testEqual(rs, test.want) {
1170					t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
1171				}
1172				// Verify that resumableStreamDecoder's internal buffering is also
1173				// correct.
1174				if !testEqual(q, test.queue) {
1175					t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
1176				}
1177				// Verify resume token.
1178				if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
1179					t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
1180				}
1181				// Verify error message.
1182				if !testEqual(lastErr, test.wantErr) {
1183					t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
1184				}
1185			case <-time.After(1 * time.Second):
1186				t.Fatal("Timeout in waiting for state change")
1187			}
1188		})
1189	}
1190}
1191
1192// sReceiver signals every receiving attempt through a channel, used by
1193// TestResumeToken to determine if the receiving of a certain PartialResultSet
1194// will be attempted next.
1195type sReceiver struct {
1196	c           chan int
1197	rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
1198}
1199
1200// Recv() implements streamingReceiver.Recv for sReceiver.
1201func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
1202	sr.c <- 1
1203	return sr.rpcReceiver.Recv()
1204}
1205
1206// waitn waits for nth receiving attempt from now on, until the signal for nth
1207// Recv() attempts is received or timeout. Note that because the way stream()
1208// works, the signal for the nth Recv() means that the previous n - 1
1209// PartialResultSets has already been returned to caller or queued, if no error
1210// happened.
1211func (sr *sReceiver) waitn(n int) error {
1212	for i := 0; i < n; i++ {
1213		select {
1214		case <-sr.c:
1215		case <-time.After(10 * time.Second):
1216			return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
1217		}
1218	}
1219	return nil
1220}
1221
1222// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
1223func TestQueueBytes(t *testing.T) {
1224	restore := setMaxBytesBetweenResumeTokens()
1225	defer restore()
1226
1227	server, c, teardown := setupMockedTestServer(t)
1228	defer teardown()
1229	mc, err := c.sc.nextClient()
1230	if err != nil {
1231		t.Fatalf("failed to create a grpc client")
1232	}
1233
1234	rt1 := EncodeResumeToken(1)
1235	rt2 := EncodeResumeToken(2)
1236	rt3 := EncodeResumeToken(3)
1237	resumeTokens := [][]byte{rt1, rt1, rt1, rt2, rt2, rt3}
1238	err = setupStatementResult(t, server, "SELECT t.key key, t.value value FROM t_mock t", len(resumeTokens), resumeTokens)
1239	if err != nil {
1240		t.Fatalf("failed to set up a result for a statement: %v", err)
1241	}
1242
1243	session, err := createSession(mc)
1244	if err != nil {
1245		t.Fatalf("failed to create a session")
1246	}
1247
1248	sr := &sReceiver{
1249		c: make(chan int, 1000), // will never block in this test
1250	}
1251	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1252	defer cancel()
1253	decoder := newResumableStreamDecoder(
1254		ctx,
1255		nil,
1256		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1257			r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1258				Session:     session.Name,
1259				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1260				ResumeToken: resumeToken,
1261			})
1262			sr.rpcReceiver = r
1263			return sr, err
1264		},
1265		nil,
1266	)
1267
1268	sizeOfPRS := proto.Size(&sppb.PartialResultSet{
1269		Metadata: kvMeta,
1270		Values: []*proto3.Value{
1271			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1272			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1273		},
1274		ResumeToken: rt1,
1275	})
1276
1277	decoder.next()
1278	decoder.next()
1279	decoder.next()
1280	if got, want := decoder.bytesBetweenResumeTokens, int32(2*sizeOfPRS); got != want {
1281		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want)
1282	}
1283
1284	decoder.next()
1285	if decoder.bytesBetweenResumeTokens != 0 {
1286		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens)
1287	}
1288
1289	decoder.next()
1290	if got, want := decoder.bytesBetweenResumeTokens, int32(sizeOfPRS); got != want {
1291		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want)
1292	}
1293
1294	decoder.next()
1295	if decoder.bytesBetweenResumeTokens != 0 {
1296		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens)
1297	}
1298}
1299
1300// Verify that client can deal with resume token correctly
1301func TestResumeToken(t *testing.T) {
1302	restore := setMaxBytesBetweenResumeTokens()
1303	defer restore()
1304	query := "SELECT t.key key, t.value value FROM t_mock t"
1305	server, c, teardown := setupMockedTestServer(t)
1306	defer teardown()
1307	mc, err := c.sc.nextClient()
1308	if err != nil {
1309		t.Fatalf("failed to create a grpc client")
1310	}
1311
1312	rt1 := EncodeResumeToken(1)
1313	rt2 := EncodeResumeToken(2)
1314	resumeTokens := make([][]byte, 3+maxBuffers)
1315	resumeTokens[1] = rt1
1316	resumeTokens[3+maxBuffers-1] = rt2
1317	err = setupStatementResult(t, server, query, len(resumeTokens), resumeTokens)
1318	if err != nil {
1319		t.Fatalf("failed to set up a result for a statement: %v", err)
1320	}
1321
1322	// The first error will be retried.
1323	server.TestSpanner.AddPartialResultSetError(
1324		query,
1325		PartialResultSetExecutionTime{
1326			ResumeToken: rt1,
1327			Err:         status.Error(codes.Unavailable, "mock server unavailable"),
1328		},
1329	)
1330	// The second error will not be retried because maxBytesBetweenResumeTokens
1331	// is reached and the state of resumableStreamDecoder:
1332	// queueingRetryable -> queueingUnretryable. The query will just fail.
1333	server.TestSpanner.AddPartialResultSetError(
1334		query,
1335		PartialResultSetExecutionTime{
1336			ResumeToken: rt2,
1337			Err:         status.Error(codes.Unavailable, "mock server wants some sleep"),
1338		},
1339	)
1340
1341	session, err := createSession(mc)
1342	if err != nil {
1343		t.Fatalf("failed to create a session")
1344	}
1345
1346	sr := &sReceiver{
1347		c: make(chan int, 1000), // will never block in this test
1348	}
1349	rows := []*Row{}
1350
1351	streaming := func() *RowIterator {
1352		return stream(context.Background(), nil,
1353			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1354				r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1355					Session:     session.Name,
1356					Sql:         query,
1357					ResumeToken: resumeToken,
1358				})
1359				sr.rpcReceiver = r
1360				return sr, err
1361			},
1362			nil,
1363			func(error) {})
1364	}
1365
1366	// Establish a stream to mock cloud spanner server.
1367	iter := streaming()
1368	defer iter.Stop()
1369	var row *Row
1370
1371	// Read first two rows.
1372	for i := 0; i < 3; i++ {
1373		row, err = iter.Next()
1374		if err != nil {
1375			t.Fatalf("failed to get next value: %v", err)
1376		}
1377		rows = append(rows, row)
1378	}
1379
1380	want := []*Row{
1381		{
1382			fields: kvMeta.RowType.Fields,
1383			vals: []*proto3.Value{
1384				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1385				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1386			},
1387		},
1388		{
1389			fields: kvMeta.RowType.Fields,
1390			vals: []*proto3.Value{
1391				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1392				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1393			},
1394		},
1395		{
1396			fields: kvMeta.RowType.Fields,
1397			vals: []*proto3.Value{
1398				{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
1399				{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
1400			},
1401		},
1402	}
1403	if !testEqual(rows, want) {
1404		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1405	}
1406
1407	// Trigger state change of resumableStreamDecoder:
1408	// queueingRetryable -> queueingUnretryable
1409	for i := 0; i < maxBuffers-1; i++ {
1410		row, err = iter.Next()
1411		if err != nil {
1412			t.Fatalf("failed to get next value: %v", err)
1413		}
1414		rows = append(rows, row)
1415	}
1416
1417	// Since resumableStreamDecoder is already at queueingUnretryable state,
1418	// query will just fail.
1419	_, err = iter.Next()
1420	if wantErr := spannerErrorf(codes.Unavailable, "mock server wants some sleep"); !testEqual(err, wantErr) {
1421		t.Fatalf("stream() returns error: %v, but want error: %v", err, wantErr)
1422	}
1423
1424	// Let server send two rows without resume token.
1425	resumeTokens = make([][]byte, 2)
1426	err = setupStatementResult(t, server, query, len(resumeTokens), resumeTokens)
1427	if err != nil {
1428		t.Fatalf("failed to set up a result for a statement: %v", err)
1429	}
1430
1431	// Reconnect to mock Cloud Spanner.
1432	rows = []*Row{}
1433	iter = streaming()
1434	defer iter.Stop()
1435
1436	for i := 0; i < 2; i++ {
1437		row, err = iter.Next()
1438		if err != nil {
1439			t.Fatalf("failed to get next value: %v", err)
1440		}
1441		rows = append(rows, row)
1442	}
1443
1444	// Verify if a normal server side EOF flushes all queued rows.
1445	want = []*Row{
1446		{
1447			fields: kvMeta.RowType.Fields,
1448			vals: []*proto3.Value{
1449				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1450				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1451			},
1452		},
1453		{
1454			fields: kvMeta.RowType.Fields,
1455			vals: []*proto3.Value{
1456				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1457				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1458			},
1459		},
1460	}
1461	if !testEqual(rows, want) {
1462		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1463	}
1464}
1465
1466// Verify that streaming query get retried upon real gRPC server transport
1467// failures.
1468func TestGrpcReconnect(t *testing.T) {
1469	restore := setMaxBytesBetweenResumeTokens()
1470	defer restore()
1471
1472	server, c, teardown := setupMockedTestServer(t)
1473	defer teardown()
1474	mc, err := c.sc.nextClient()
1475	if err != nil {
1476		t.Fatalf("failed to create a grpc client")
1477	}
1478
1479	session, err := createSession(mc)
1480	if err != nil {
1481		t.Fatalf("failed to create a session")
1482	}
1483
1484	// Simulate an unavailable error to interrupt the stream of PartialResultSet
1485	// in order to test the grpc retrying mechanism.
1486	server.TestSpanner.AddPartialResultSetError(
1487		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1488		PartialResultSetExecutionTime{
1489			ResumeToken: EncodeResumeToken(2),
1490			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
1491		},
1492	)
1493
1494	// The retry is counted from the second call.
1495	r := -1
1496	// Establish a stream to mock cloud spanner server.
1497	iter := stream(context.Background(), nil,
1498		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1499			r++
1500			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1501				Session:     session.Name,
1502				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1503				ResumeToken: resumeToken,
1504			})
1505
1506		},
1507		nil,
1508		func(error) {})
1509	defer iter.Stop()
1510	for {
1511		_, err := iter.Next()
1512		if err == iterator.Done {
1513			err = nil
1514			break
1515		}
1516		if err != nil {
1517			break
1518		}
1519	}
1520	if r != 1 {
1521		t.Errorf("retry count = %v, want 1", r)
1522	}
1523}
1524
1525// Test cancel/timeout for client operations.
1526func TestCancelTimeout(t *testing.T) {
1527	restore := setMaxBytesBetweenResumeTokens()
1528	defer restore()
1529	server, c, teardown := setupMockedTestServer(t)
1530	defer teardown()
1531	server.TestSpanner.PutExecutionTime(
1532		MethodExecuteStreamingSql,
1533		SimulatedExecutionTime{MinimumExecutionTime: 1 * time.Second},
1534	)
1535	mc, err := c.sc.nextClient()
1536	if err != nil {
1537		t.Fatalf("failed to create a grpc client")
1538	}
1539
1540	session, err := createSession(mc)
1541	if err != nil {
1542		t.Fatalf("failed to create a session")
1543	}
1544	done := make(chan int)
1545
1546	// Test cancelling query.
1547	ctx, cancel := context.WithCancel(context.Background())
1548	go func() {
1549		// Establish a stream to mock cloud spanner server.
1550		iter := stream(ctx, nil,
1551			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1552				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1553					Session:     session.Name,
1554					Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1555					ResumeToken: resumeToken,
1556				})
1557			},
1558			nil,
1559			func(error) {})
1560		defer iter.Stop()
1561		for {
1562			_, err = iter.Next()
1563			if err == iterator.Done {
1564				break
1565			}
1566			if err != nil {
1567				done <- 0
1568				break
1569			}
1570		}
1571	}()
1572	cancel()
1573	select {
1574	case <-done:
1575		if ErrCode(err) != codes.Canceled {
1576			t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
1577		}
1578	case <-time.After(1 * time.Second):
1579		t.Errorf("query doesn't exit timely after being cancelled")
1580	}
1581
1582	// Test query timeout.
1583	ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
1584	defer cancel()
1585	go func() {
1586		// Establish a stream to mock cloud spanner server.
1587		iter := stream(ctx, nil,
1588			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1589				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1590					Session:     session.Name,
1591					Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1592					ResumeToken: resumeToken,
1593				})
1594			},
1595			nil,
1596			func(error) {})
1597		defer iter.Stop()
1598		for {
1599			_, err = iter.Next()
1600			if err == iterator.Done {
1601				err = nil
1602				break
1603			}
1604			if err != nil {
1605				break
1606			}
1607		}
1608		done <- 0
1609	}()
1610	select {
1611	case <-done:
1612		if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
1613			t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
1614		}
1615	case <-time.After(2 * time.Second):
1616		t.Errorf("query doesn't timeout as expected")
1617	}
1618}
1619
1620func setupStatementResult(t *testing.T, server *MockedSpannerInMemTestServer, stmt string, rowCount int, resumeTokens [][]byte) error {
1621	selectValues := make([][]string, rowCount)
1622	for i := 0; i < rowCount; i++ {
1623		selectValues[i] = []string{keyStr(i), valStr(i)}
1624	}
1625
1626	rows := make([]*structpb.ListValue, len(selectValues))
1627	for i, values := range selectValues {
1628		rowValues := make([]*structpb.Value, len(kvMeta.RowType.Fields))
1629		for j, value := range values {
1630			rowValues[j] = &structpb.Value{
1631				Kind: &structpb.Value_StringValue{StringValue: value},
1632			}
1633		}
1634		rows[i] = &structpb.ListValue{
1635			Values: rowValues,
1636		}
1637	}
1638	resultSet := &sppb.ResultSet{
1639		Metadata: kvMeta,
1640		Rows:     rows,
1641	}
1642	result := &StatementResult{
1643		Type:         StatementResultResultSet,
1644		ResultSet:    resultSet,
1645		ResumeTokens: resumeTokens,
1646	}
1647	return server.TestSpanner.PutStatementResult(stmt, result)
1648}
1649
1650func TestRowIteratorDo(t *testing.T) {
1651	restore := setMaxBytesBetweenResumeTokens()
1652	defer restore()
1653
1654	_, c, teardown := setupMockedTestServer(t)
1655	defer teardown()
1656	mc, err := c.sc.nextClient()
1657	if err != nil {
1658		t.Fatalf("failed to create a grpc client")
1659	}
1660
1661	session, err := createSession(mc)
1662	if err != nil {
1663		t.Fatalf("failed to create a session")
1664	}
1665
1666	nRows := 0
1667	iter := stream(context.Background(), nil,
1668		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1669			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1670				Session:     session.Name,
1671				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1672				ResumeToken: resumeToken,
1673			})
1674		},
1675		nil,
1676		func(error) {})
1677	err = iter.Do(func(r *Row) error { nRows++; return nil })
1678	if err != nil {
1679		t.Errorf("Using Do: %v", err)
1680	}
1681	if nRows != 3 {
1682		t.Errorf("got %d rows, want 3", nRows)
1683	}
1684}
1685
1686func TestRowIteratorDoWithError(t *testing.T) {
1687	restore := setMaxBytesBetweenResumeTokens()
1688	defer restore()
1689
1690	_, c, teardown := setupMockedTestServer(t)
1691	defer teardown()
1692	mc, err := c.sc.nextClient()
1693	if err != nil {
1694		t.Fatalf("failed to create a grpc client")
1695	}
1696
1697	session, err := createSession(mc)
1698	if err != nil {
1699		t.Fatalf("failed to create a session")
1700	}
1701
1702	iter := stream(context.Background(), nil,
1703		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1704			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1705				Session:     session.Name,
1706				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1707				ResumeToken: resumeToken,
1708			})
1709		},
1710		nil,
1711		func(error) {})
1712	injected := errors.New("Failed iterator")
1713	err = iter.Do(func(r *Row) error { return injected })
1714	if err != injected {
1715		t.Errorf("got <%v>, want <%v>", err, injected)
1716	}
1717}
1718
1719func TestIteratorStopEarly(t *testing.T) {
1720	ctx := context.Background()
1721	restore := setMaxBytesBetweenResumeTokens()
1722	defer restore()
1723
1724	_, c, teardown := setupMockedTestServer(t)
1725	defer teardown()
1726	mc, err := c.sc.nextClient()
1727	if err != nil {
1728		t.Fatalf("failed to create a grpc client")
1729	}
1730
1731	session, err := createSession(mc)
1732	if err != nil {
1733		t.Fatalf("failed to create a session")
1734	}
1735
1736	iter := stream(ctx, nil,
1737		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1738			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1739				Session:     session.Name,
1740				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1741				ResumeToken: resumeToken,
1742			})
1743		},
1744		nil,
1745		func(error) {})
1746	_, err = iter.Next()
1747	if err != nil {
1748		t.Fatalf("before Stop: %v", err)
1749	}
1750	iter.Stop()
1751	// Stop sets r.err to the FailedPrecondition error "Next called after Stop".
1752	_, err = iter.Next()
1753	if g, w := ErrCode(err), codes.FailedPrecondition; g != w {
1754		t.Errorf("after Stop: got: %v, want: %v", g, w)
1755	}
1756}
1757
1758func TestIteratorWithError(t *testing.T) {
1759	injected := errors.New("Failed iterator")
1760	iter := RowIterator{err: injected}
1761	defer iter.Stop()
1762	if _, err := iter.Next(); err != injected {
1763		t.Fatalf("Expected error: %v, got %v", injected, err)
1764	}
1765}
1766
1767func createSession(client *vkit.Client) (*sppb.Session, error) {
1768	var formattedDatabase string = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
1769	var request = &sppb.CreateSessionRequest{
1770		Database: formattedDatabase,
1771	}
1772	return client.CreateSession(context.Background(), request)
1773}
1774