1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"sync"
24	"sync/atomic"
25	"testing"
26	"time"
27
28	vkit "cloud.google.com/go/spanner/apiv1"
29	. "cloud.google.com/go/spanner/internal/testutil"
30	"github.com/golang/protobuf/proto"
31	proto3 "github.com/golang/protobuf/ptypes/struct"
32	structpb "github.com/golang/protobuf/ptypes/struct"
33	"github.com/googleapis/gax-go/v2"
34	"google.golang.org/api/iterator"
35	sppb "google.golang.org/genproto/googleapis/spanner/v1"
36	"google.golang.org/grpc/codes"
37	"google.golang.org/grpc/status"
38)
39
40var (
41	// Mocked transaction timestamp.
42	trxTs = time.Unix(1, 2)
43	// Metadata for mocked KV table, its rows are returned by SingleUse
44	// transactions.
45	kvMeta = func() *sppb.ResultSetMetadata {
46		meta := KvMeta
47		meta.Transaction = &sppb.Transaction{
48			ReadTimestamp: timestampProto(trxTs),
49		}
50		return &meta
51	}()
52	// Metadata for mocked ListKV table, which uses List for its key and value.
53	// Its rows are returned by snapshot readonly transactions, as indicated in
54	// the transaction metadata.
55	kvListMeta = &sppb.ResultSetMetadata{
56		RowType: &sppb.StructType{
57			Fields: []*sppb.StructType_Field{
58				{
59					Name: "Key",
60					Type: &sppb.Type{
61						Code: sppb.TypeCode_ARRAY,
62						ArrayElementType: &sppb.Type{
63							Code: sppb.TypeCode_STRING,
64						},
65					},
66				},
67				{
68					Name: "Value",
69					Type: &sppb.Type{
70						Code: sppb.TypeCode_ARRAY,
71						ArrayElementType: &sppb.Type{
72							Code: sppb.TypeCode_STRING,
73						},
74					},
75				},
76			},
77		},
78		Transaction: &sppb.Transaction{
79			Id:            transactionID{5, 6, 7, 8, 9},
80			ReadTimestamp: timestampProto(trxTs),
81		},
82	}
83	// Metadata for mocked schema of a query result set, which has two struct
84	// columns named "Col1" and "Col2", the struct's schema is like the
85	// following:
86	//
87	//	STRUCT {
88	//		INT
89	//		LIST<STRING>
90	//	}
91	//
92	// Its rows are returned in readwrite transaction, as indicated in the
93	// transaction metadata.
94	kvObjectMeta = &sppb.ResultSetMetadata{
95		RowType: &sppb.StructType{
96			Fields: []*sppb.StructType_Field{
97				{
98					Name: "Col1",
99					Type: &sppb.Type{
100						Code: sppb.TypeCode_STRUCT,
101						StructType: &sppb.StructType{
102							Fields: []*sppb.StructType_Field{
103								{
104									Name: "foo-f1",
105									Type: &sppb.Type{
106										Code: sppb.TypeCode_INT64,
107									},
108								},
109								{
110									Name: "foo-f2",
111									Type: &sppb.Type{
112										Code: sppb.TypeCode_ARRAY,
113										ArrayElementType: &sppb.Type{
114											Code: sppb.TypeCode_STRING,
115										},
116									},
117								},
118							},
119						},
120					},
121				},
122				{
123					Name: "Col2",
124					Type: &sppb.Type{
125						Code: sppb.TypeCode_STRUCT,
126						StructType: &sppb.StructType{
127							Fields: []*sppb.StructType_Field{
128								{
129									Name: "bar-f1",
130									Type: &sppb.Type{
131										Code: sppb.TypeCode_INT64,
132									},
133								},
134								{
135									Name: "bar-f2",
136									Type: &sppb.Type{
137										Code: sppb.TypeCode_ARRAY,
138										ArrayElementType: &sppb.Type{
139											Code: sppb.TypeCode_STRING,
140										},
141									},
142								},
143							},
144						},
145					},
146				},
147			},
148		},
149		Transaction: &sppb.Transaction{
150			Id: transactionID{1, 2, 3, 4, 5},
151		},
152	}
153)
154
155func describeRows(l []*Row) string {
156	// generate a nice test failure description
157	var s = "["
158	for i, r := range l {
159		if i != 0 {
160			s += ",\n "
161		}
162		s += fmt.Sprint(r)
163	}
164	s += "]"
165	return s
166}
167
168// Helper for generating proto3 Value_ListValue instances, making test code
169// shorter and readable.
170func genProtoListValue(v ...string) *proto3.Value_ListValue {
171	r := &proto3.Value_ListValue{
172		ListValue: &proto3.ListValue{
173			Values: []*proto3.Value{},
174		},
175	}
176	for _, e := range v {
177		r.ListValue.Values = append(
178			r.ListValue.Values,
179			&proto3.Value{
180				Kind: &proto3.Value_StringValue{StringValue: e},
181			},
182		)
183	}
184	return r
185}
186
187// Test Row generation logics of partialResultSetDecoder.
188func TestPartialResultSetDecoder(t *testing.T) {
189	restore := setMaxBytesBetweenResumeTokens()
190	defer restore()
191	var tests = []struct {
192		input    []*sppb.PartialResultSet
193		wantF    []*Row
194		wantTxID transactionID
195		wantTs   time.Time
196		wantD    bool
197	}{
198		{
199			// Empty input.
200			wantD: true,
201		},
202		// String merging examples.
203		{
204			// Single KV result.
205			input: []*sppb.PartialResultSet{
206				{
207					Metadata: kvMeta,
208					Values: []*proto3.Value{
209						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
210						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
211					},
212				},
213			},
214			wantF: []*Row{
215				{
216					fields: kvMeta.RowType.Fields,
217					vals: []*proto3.Value{
218						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
219						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
220					},
221				},
222			},
223			wantTs: trxTs,
224			wantD:  true,
225		},
226		{
227			// Incomplete partial result.
228			input: []*sppb.PartialResultSet{
229				{
230					Metadata: kvMeta,
231					Values: []*proto3.Value{
232						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
233					},
234				},
235			},
236			wantTs: trxTs,
237			wantD:  false,
238		},
239		{
240			// Complete splitted result.
241			input: []*sppb.PartialResultSet{
242				{
243					Metadata: kvMeta,
244					Values: []*proto3.Value{
245						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
246					},
247				},
248				{
249					Values: []*proto3.Value{
250						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
251					},
252				},
253			},
254			wantF: []*Row{
255				{
256					fields: kvMeta.RowType.Fields,
257					vals: []*proto3.Value{
258						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
259						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
260					},
261				},
262			},
263			wantTs: trxTs,
264			wantD:  true,
265		},
266		{
267			// Multi-row example with splitted row in the middle.
268			input: []*sppb.PartialResultSet{
269				{
270					Metadata: kvMeta,
271					Values: []*proto3.Value{
272						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
273						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
274						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
275					},
276				},
277				{
278					Values: []*proto3.Value{
279						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
280						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
281						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
282					},
283				},
284			},
285			wantF: []*Row{
286				{
287					fields: kvMeta.RowType.Fields,
288					vals: []*proto3.Value{
289						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
290						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
291					},
292				},
293				{
294					fields: kvMeta.RowType.Fields,
295					vals: []*proto3.Value{
296						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
297						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
298					},
299				},
300				{
301					fields: kvMeta.RowType.Fields,
302					vals: []*proto3.Value{
303						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
304						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
305					},
306				},
307			},
308			wantTs: trxTs,
309			wantD:  true,
310		},
311		{
312			// Merging example in result_set.proto.
313			input: []*sppb.PartialResultSet{
314				{
315					Metadata: kvMeta,
316					Values: []*proto3.Value{
317						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
318						{Kind: &proto3.Value_StringValue{StringValue: "W"}},
319					},
320					ChunkedValue: true,
321				},
322				{
323					Values: []*proto3.Value{
324						{Kind: &proto3.Value_StringValue{StringValue: "orl"}},
325					},
326					ChunkedValue: true,
327				},
328				{
329					Values: []*proto3.Value{
330						{Kind: &proto3.Value_StringValue{StringValue: "d"}},
331					},
332				},
333			},
334			wantF: []*Row{
335				{
336					fields: kvMeta.RowType.Fields,
337					vals: []*proto3.Value{
338						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
339						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
340					},
341				},
342			},
343			wantTs: trxTs,
344			wantD:  true,
345		},
346		{
347			// More complex example showing completing a merge and
348			// starting a new merge in the same partialResultSet.
349			input: []*sppb.PartialResultSet{
350				{
351					Metadata: kvMeta,
352					Values: []*proto3.Value{
353						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
354						{Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
355					},
356					ChunkedValue: true,
357				},
358				{
359					Values: []*proto3.Value{
360						{Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
361						{Kind: &proto3.Value_StringValue{StringValue: "i"}},    // start split in key
362					},
363					ChunkedValue: true,
364				},
365				{
366					Values: []*proto3.Value{
367						{Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
368						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
369						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
370						{Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
371					},
372					ChunkedValue: true,
373				},
374				{
375					Values: []*proto3.Value{
376						{Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
377					},
378				},
379			},
380			wantF: []*Row{
381				{
382					fields: kvMeta.RowType.Fields,
383					vals: []*proto3.Value{
384						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
385						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
386					},
387				},
388				{
389					fields: kvMeta.RowType.Fields,
390					vals: []*proto3.Value{
391						{Kind: &proto3.Value_StringValue{StringValue: "is"}},
392						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
393					},
394				},
395				{
396					fields: kvMeta.RowType.Fields,
397					vals: []*proto3.Value{
398						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
399						{Kind: &proto3.Value_StringValue{StringValue: "question"}},
400					},
401				},
402			},
403			wantTs: trxTs,
404			wantD:  true,
405		},
406		// List merging examples.
407		{
408			// Non-splitting Lists.
409			input: []*sppb.PartialResultSet{
410				{
411					Metadata: kvListMeta,
412					Values: []*proto3.Value{
413						{
414							Kind: genProtoListValue("foo-1", "foo-2"),
415						},
416					},
417				},
418				{
419					Values: []*proto3.Value{
420						{
421							Kind: genProtoListValue("bar-1", "bar-2"),
422						},
423					},
424				},
425			},
426			wantF: []*Row{
427				{
428					fields: kvListMeta.RowType.Fields,
429					vals: []*proto3.Value{
430						{
431							Kind: genProtoListValue("foo-1", "foo-2"),
432						},
433						{
434							Kind: genProtoListValue("bar-1", "bar-2"),
435						},
436					},
437				},
438			},
439			wantTxID: transactionID{5, 6, 7, 8, 9},
440			wantTs:   trxTs,
441			wantD:    true,
442		},
443		{
444			// Simple List merge case: splitted string element.
445			input: []*sppb.PartialResultSet{
446				{
447					Metadata: kvListMeta,
448					Values: []*proto3.Value{
449						{
450							Kind: genProtoListValue("foo-1", "foo-"),
451						},
452					},
453					ChunkedValue: true,
454				},
455				{
456					Values: []*proto3.Value{
457						{
458							Kind: genProtoListValue("2"),
459						},
460					},
461				},
462				{
463					Values: []*proto3.Value{
464						{
465							Kind: genProtoListValue("bar-1", "bar-2"),
466						},
467					},
468				},
469			},
470			wantF: []*Row{
471				{
472					fields: kvListMeta.RowType.Fields,
473					vals: []*proto3.Value{
474						{
475							Kind: genProtoListValue("foo-1", "foo-2"),
476						},
477						{
478							Kind: genProtoListValue("bar-1", "bar-2"),
479						},
480					},
481				},
482			},
483			wantTxID: transactionID{5, 6, 7, 8, 9},
484			wantTs:   trxTs,
485			wantD:    true,
486		},
487		{
488			// Struct merging is also implemented by List merging. Note that
489			// Cloud Spanner uses proto.ListValue to encode Structs as well.
490			input: []*sppb.PartialResultSet{
491				{
492					Metadata: kvObjectMeta,
493					Values: []*proto3.Value{
494						{
495							Kind: &proto3.Value_ListValue{
496								ListValue: &proto3.ListValue{
497									Values: []*proto3.Value{
498										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
499										{Kind: genProtoListValue("foo-1", "fo")},
500									},
501								},
502							},
503						},
504					},
505					ChunkedValue: true,
506				},
507				{
508					Values: []*proto3.Value{
509						{
510							Kind: &proto3.Value_ListValue{
511								ListValue: &proto3.ListValue{
512									Values: []*proto3.Value{
513										{Kind: genProtoListValue("o-2", "f")},
514									},
515								},
516							},
517						},
518					},
519					ChunkedValue: true,
520				},
521				{
522					Values: []*proto3.Value{
523						{
524							Kind: &proto3.Value_ListValue{
525								ListValue: &proto3.ListValue{
526									Values: []*proto3.Value{
527										{Kind: genProtoListValue("oo-3")},
528									},
529								},
530							},
531						},
532						{
533							Kind: &proto3.Value_ListValue{
534								ListValue: &proto3.ListValue{
535									Values: []*proto3.Value{
536										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
537										{Kind: genProtoListValue("bar-1")},
538									},
539								},
540							},
541						},
542					},
543				},
544			},
545			wantF: []*Row{
546				{
547					fields: kvObjectMeta.RowType.Fields,
548					vals: []*proto3.Value{
549						{
550							Kind: &proto3.Value_ListValue{
551								ListValue: &proto3.ListValue{
552									Values: []*proto3.Value{
553										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
554										{Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
555									},
556								},
557							},
558						},
559						{
560							Kind: &proto3.Value_ListValue{
561								ListValue: &proto3.ListValue{
562									Values: []*proto3.Value{
563										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
564										{Kind: genProtoListValue("bar-1")},
565									},
566								},
567							},
568						},
569					},
570				},
571			},
572			wantTxID: transactionID{1, 2, 3, 4, 5},
573			wantD:    true,
574		},
575	}
576
577nextTest:
578	for i, test := range tests {
579		var rows []*Row
580		p := &partialResultSetDecoder{}
581		for j, v := range test.input {
582			rs, _, err := p.add(v)
583			if err != nil {
584				t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
585				continue nextTest
586			}
587			rows = append(rows, rs...)
588		}
589		if !testEqual(p.ts, test.wantTs) {
590			t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
591		}
592		if !testEqual(rows, test.wantF) {
593			t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
594		}
595		if got := p.done(); got != test.wantD {
596			t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
597		}
598	}
599}
600
601const (
602	// max number of PartialResultSets that will be buffered in tests.
603	maxBuffers = 16
604)
605
606// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to
607// a smaller value more suitable for tests. It returns a function which should
608// be called to restore the maxBytesBetweenResumeTokens to its old value.
609func setMaxBytesBetweenResumeTokens() func() {
610	o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
611	atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
612		Metadata: kvMeta,
613		Values: []*proto3.Value{
614			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
615			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
616		},
617	})))
618	return func() {
619		atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
620	}
621}
622
623// keyStr generates key string for kvMeta schema.
624func keyStr(i int) string {
625	return fmt.Sprintf("foo-%02d", i)
626}
627
628// valStr generates value string for kvMeta schema.
629func valStr(i int) string {
630	return fmt.Sprintf("bar-%02d", i)
631}
632
633// Test state transitions of resumableStreamDecoder where state machine ends up
634// to a non-blocking state(resumableStreamDecoder.Next returns on non-blocking
635// state).
636func TestRsdNonblockingStates(t *testing.T) {
637	restore := setMaxBytesBetweenResumeTokens()
638	defer restore()
639	tests := []struct {
640		name         string
641		resumeTokens [][]byte
642		prsErrors    []PartialResultSetExecutionTime
643		rpc          func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
644		sql          string
645		// Expected values
646		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
647		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
648		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
649		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
650		wantErr      error
651	}{
652		{
653			// unConnected->queueingRetryable->finished
654			name:         "unConnected->queueingRetryable->finished",
655			resumeTokens: make([][]byte, 2),
656			sql:          "SELECT t.key key, t.value value FROM t_mock t",
657			want: []*sppb.PartialResultSet{
658				{
659					Metadata: kvMeta,
660					Values: []*proto3.Value{
661						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
662						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
663					},
664				},
665			},
666			queue: []*sppb.PartialResultSet{
667				{
668					Metadata: kvMeta,
669					Values: []*proto3.Value{
670						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
671						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
672					},
673				},
674			},
675			stateHistory: []resumableStreamDecoderState{
676				queueingRetryable, // do RPC
677				queueingRetryable, // got foo-00
678				queueingRetryable, // got foo-01
679				finished,          // got EOF
680			},
681		},
682		{
683			// unConnected->queueingRetryable->aborted
684			name:         "unConnected->queueingRetryable->aborted",
685			resumeTokens: [][]byte{{}, EncodeResumeToken(1), {}, EncodeResumeToken(2)},
686			prsErrors: []PartialResultSetExecutionTime{{
687				ResumeToken: EncodeResumeToken(2),
688				Err:         status.Error(codes.Unknown, "I quit"),
689			}},
690			sql: "SELECT t.key key, t.value value FROM t_mock t",
691			want: []*sppb.PartialResultSet{
692				{
693					Metadata: kvMeta,
694					Values: []*proto3.Value{
695						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
696						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
697					},
698				},
699				{
700					Metadata: kvMeta,
701					Values: []*proto3.Value{
702						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
703						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
704					},
705					ResumeToken: EncodeResumeToken(1),
706				},
707			},
708			stateHistory: []resumableStreamDecoderState{
709				queueingRetryable, // do RPC
710				queueingRetryable, // got foo-00
711				queueingRetryable, // got foo-01
712				queueingRetryable, // foo-01, resume token
713				queueingRetryable, // got foo-02
714				aborted,           // got error
715			},
716			wantErr: status.Errorf(codes.Unknown, "I quit"),
717		},
718		{
719			// unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
720			name:         "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
721			resumeTokens: make([][]byte, maxBuffers+1),
722			sql:          "SELECT t.key key, t.value value FROM t_mock t",
723			want: func() (s []*sppb.PartialResultSet) {
724				for i := 0; i < maxBuffers+1; i++ {
725					s = append(s, &sppb.PartialResultSet{
726						Metadata: kvMeta,
727						Values: []*proto3.Value{
728							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
729							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
730						},
731					})
732				}
733				return s
734			}(),
735			stateHistory: func() (s []resumableStreamDecoderState) {
736				s = append(s, queueingRetryable) // RPC
737				for i := 0; i < maxBuffers; i++ {
738					s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
739				}
740				// the first item fills up the queue and triggers state transition;
741				// the second item is received under queueingUnretryable state.
742				s = append(s, queueingUnretryable)
743				s = append(s, queueingUnretryable)
744				return s
745			}(),
746		},
747		{
748			// unConnected->queueingRetryable->queueingUnretryable->aborted
749			name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
750			resumeTokens: func() (rts [][]byte) {
751				rts = make([][]byte, maxBuffers+1)
752				rts[maxBuffers] = EncodeResumeToken(1)
753				return rts
754			}(),
755			prsErrors: []PartialResultSetExecutionTime{{
756				ResumeToken: EncodeResumeToken(1),
757				Err:         status.Error(codes.Unknown, "Just Abort It"),
758			}},
759			sql: "SELECT t.key key, t.value value FROM t_mock t",
760			want: func() (s []*sppb.PartialResultSet) {
761				for i := 0; i < maxBuffers; i++ {
762					s = append(s, &sppb.PartialResultSet{
763						Metadata: kvMeta,
764						Values: []*proto3.Value{
765							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
766							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
767						},
768					})
769				}
770				return s
771			}(),
772			stateHistory: func() (s []resumableStreamDecoderState) {
773				s = append(s, queueingRetryable) // RPC
774				for i := 0; i < maxBuffers; i++ {
775					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
776				}
777				s = append(s, queueingUnretryable) // the last row triggers state change
778				s = append(s, aborted)             // Error happens
779				return s
780			}(),
781			wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
782		},
783	}
784	for _, test := range tests {
785		t.Run(test.name, func(t *testing.T) {
786			server, c, teardown := setupMockedTestServer(t)
787			defer teardown()
788			mc, err := c.sc.nextClient()
789			if err != nil {
790				t.Fatalf("failed to create a grpc client")
791			}
792
793			session, err := createSession(mc)
794			if err != nil {
795				t.Fatalf("failed to create a session")
796			}
797
798			if test.rpc == nil {
799				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
800					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
801						Session:     session.Name,
802						Sql:         test.sql,
803						ResumeToken: resumeToken,
804					})
805				}
806			}
807			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
808			defer cancel()
809			r := newResumableStreamDecoder(
810				ctx,
811				nil,
812				test.rpc,
813				nil,
814			)
815			st := []resumableStreamDecoderState{}
816			var lastErr error
817			// Once the expected number of state transitions are observed,
818			// send a signal by setting stateDone = true.
819			stateDone := false
820			// Set stateWitness to listen to state changes.
821			hl := len(test.stateHistory) // To avoid data race on test.
822			r.stateWitness = func(rs resumableStreamDecoderState) {
823				if !stateDone {
824					// Record state transitions.
825					st = append(st, rs)
826					if len(st) == hl {
827						lastErr = r.lastErr()
828						stateDone = true
829					}
830				}
831			}
832			// Let mock server stream given messages to resumableStreamDecoder.
833			err = setupStatementResult(t, server, test.sql, len(test.resumeTokens), test.resumeTokens)
834			if err != nil {
835				t.Fatalf("failed to set up a result for a statement: %v", err)
836			}
837
838			for _, et := range test.prsErrors {
839				server.TestSpanner.AddPartialResultSetError(
840					test.sql,
841					et,
842				)
843			}
844			var rs []*sppb.PartialResultSet
845			for {
846				select {
847				case <-ctx.Done():
848					t.Fatal("context cancelled or timeout during test")
849				default:
850				}
851				if stateDone {
852					// Check if resumableStreamDecoder carried out expected
853					// state transitions.
854					if !testEqual(st, test.stateHistory) {
855						t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
856					}
857					// Check if resumableStreamDecoder returns expected array of
858					// PartialResultSets.
859					if !testEqual(rs, test.want) {
860						t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
861					}
862					// Verify that resumableStreamDecoder's internal buffering is
863					// also correct.
864					var q []*sppb.PartialResultSet
865					for {
866						item := r.q.pop()
867						if item == nil {
868							break
869						}
870						q = append(q, item)
871					}
872					if !testEqual(q, test.queue) {
873						t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
874					}
875					// Verify resume token.
876					if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
877						t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
878					}
879					// Verify error message.
880					if !testEqual(lastErr, test.wantErr) {
881						t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
882					}
883					return
884				}
885				// Receive next decoded item.
886				if r.next() {
887					rs = append(rs, r.get())
888				}
889			}
890		})
891	}
892}
893
894// Test state transitions of resumableStreamDecoder where state machine
895// ends up to a blocking state(resumableStreamDecoder.Next blocks
896// on blocking state).
897func TestRsdBlockingStates(t *testing.T) {
898	restore := setMaxBytesBetweenResumeTokens()
899	defer restore()
900	for _, test := range []struct {
901		name         string
902		resumeTokens [][]byte
903		rpc          func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
904		sql          string
905		// Expected values
906		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
907		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
908		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
909		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
910		wantErr      error
911	}{
912		{
913			// unConnected -> unConnected
914			name: "unConnected -> unConnected",
915			rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
916				return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
917			},
918			sql:          "SELECT * from t_whatever",
919			stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
920			wantErr:      status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
921		},
922		{
923			// unConnected -> queueingRetryable
924			name:         "unConnected -> queueingRetryable",
925			sql:          "SELECT t.key key, t.value value FROM t_mock t",
926			stateHistory: []resumableStreamDecoderState{queueingRetryable},
927			want: []*sppb.PartialResultSet{
928				{
929					Metadata: kvMeta,
930				},
931			},
932		},
933		{
934			// unConnected->queueingRetryable->queueingRetryable
935			name:         "unConnected->queueingRetryable->queueingRetryable",
936			resumeTokens: [][]byte{{}, EncodeResumeToken(1), EncodeResumeToken(2), {}},
937			sql:          "SELECT t.key key, t.value value FROM t_mock t",
938			want: []*sppb.PartialResultSet{
939				{
940					Metadata: kvMeta,
941					Values: []*proto3.Value{
942						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
943						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
944					},
945				},
946				{
947					Metadata: kvMeta,
948					Values: []*proto3.Value{
949						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
950						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
951					},
952					ResumeToken: EncodeResumeToken(1),
953				},
954				{
955					Metadata: kvMeta,
956					Values: []*proto3.Value{
957						{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
958						{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
959					},
960					ResumeToken: EncodeResumeToken(2),
961				},
962				// The server sends an io.EOF at last and the decoder will
963				// flush out all messages in the internal queue.
964				{
965					Metadata: kvMeta,
966					Values: []*proto3.Value{
967						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
968						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
969					},
970				},
971			},
972			queue: []*sppb.PartialResultSet{
973				{
974					Metadata: kvMeta,
975					Values: []*proto3.Value{
976						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
977						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
978					},
979				},
980			},
981			resumeToken: EncodeResumeToken(2),
982			stateHistory: []resumableStreamDecoderState{
983				queueingRetryable, // do RPC
984				queueingRetryable, // got foo-00
985				queueingRetryable, // got foo-01
986				queueingRetryable, // foo-01, resume token
987				queueingRetryable, // got foo-02
988				queueingRetryable, // foo-02, resume token
989				queueingRetryable, // got foo-03
990			},
991		},
992		{
993			// unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
994			name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
995			resumeTokens: func() (rts [][]byte) {
996				rts = make([][]byte, maxBuffers+3)
997				rts[maxBuffers+1] = EncodeResumeToken(maxBuffers + 1)
998				return rts
999			}(),
1000			sql: "SELECT t.key key, t.value value FROM t_mock t",
1001			want: func() (s []*sppb.PartialResultSet) {
1002				// The server sends an io.EOF at last and the decoder will
1003				// flush out all messages in the internal queue. Although the
1004				// last message is supposed to be queued and the decoder waits
1005				// for the next resume token, an io.EOF leads to a `finished`
1006				// state that the last message will be removed from the queue
1007				// and be read by the client side.
1008				for i := 0; i < maxBuffers+3; i++ {
1009					s = append(s, &sppb.PartialResultSet{
1010						Metadata: kvMeta,
1011						Values: []*proto3.Value{
1012							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1013							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1014						},
1015					})
1016				}
1017				s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1)
1018				return s
1019			}(),
1020			resumeToken: EncodeResumeToken(maxBuffers + 1),
1021			queue: []*sppb.PartialResultSet{
1022				{
1023					Metadata: kvMeta,
1024					Values: []*proto3.Value{
1025						{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
1026						{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
1027					},
1028				},
1029			},
1030			stateHistory: func() (s []resumableStreamDecoderState) {
1031				s = append(s, queueingRetryable) // RPC
1032				for i := 0; i < maxBuffers; i++ {
1033					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
1034				}
1035				for i := maxBuffers - 1; i < maxBuffers+1; i++ {
1036					// the first item fills up the queue and triggers state
1037					// change; the second item is received under
1038					// queueingUnretryable state.
1039					s = append(s, queueingUnretryable)
1040				}
1041				s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
1042				s = append(s, queueingRetryable)   // (maxBuffers+1)th row has resume token
1043				s = append(s, queueingRetryable)   // (maxBuffers+2)th row has no resume token
1044				return s
1045			}(),
1046		},
1047		{
1048			// unConnected->queueingRetryable->queueingUnretryable->finished
1049			name:         "unConnected->queueingRetryable->queueingUnretryable->finished",
1050			resumeTokens: make([][]byte, maxBuffers),
1051			sql:          "SELECT t.key key, t.value value FROM t_mock t",
1052			want: func() (s []*sppb.PartialResultSet) {
1053				for i := 0; i < maxBuffers; i++ {
1054					s = append(s, &sppb.PartialResultSet{
1055						Metadata: kvMeta,
1056						Values: []*proto3.Value{
1057							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1058							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1059						},
1060					})
1061				}
1062				return s
1063			}(),
1064			stateHistory: func() (s []resumableStreamDecoderState) {
1065				s = append(s, queueingRetryable) // RPC
1066				for i := 0; i < maxBuffers; i++ {
1067					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
1068				}
1069				s = append(s, queueingUnretryable) // last row triggers state change
1070				s = append(s, finished)            // query finishes
1071				return s
1072			}(),
1073		},
1074	} {
1075		t.Run(test.name, func(t *testing.T) {
1076			server, c, teardown := setupMockedTestServer(t)
1077			defer teardown()
1078			mc, err := c.sc.nextClient()
1079			if err != nil {
1080				t.Fatalf("failed to create a grpc client")
1081			}
1082
1083			session, err := createSession(mc)
1084			if err != nil {
1085				t.Fatalf("failed to create a session")
1086			}
1087
1088			if test.rpc == nil {
1089				// Avoid using test.sql directly in closure because for loop changes
1090				// test.
1091				sql := test.sql
1092				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1093					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1094						Session:     session.Name,
1095						Sql:         sql,
1096						ResumeToken: resumeToken,
1097					})
1098				}
1099			}
1100			ctx, cancel := context.WithCancel(context.Background())
1101			defer cancel()
1102			r := newResumableStreamDecoder(
1103				ctx,
1104				nil,
1105				test.rpc,
1106				nil,
1107			)
1108			// Override backoff to make the test run faster.
1109			r.backoff = gax.Backoff{
1110				Initial:    1 * time.Nanosecond,
1111				Max:        1 * time.Nanosecond,
1112				Multiplier: 1.3,
1113			}
1114			// st is the set of observed state transitions.
1115			st := []resumableStreamDecoderState{}
1116			// q is the content of the decoder's partial result queue when expected
1117			// number of state transitions are done.
1118			q := []*sppb.PartialResultSet{}
1119			var lastErr error
1120			// Once the expected number of state transitions are observed, send a
1121			// signal to channel stateDone.
1122			stateDone := make(chan int)
1123			// Set stateWitness to listen to state changes.
1124			hl := len(test.stateHistory) // To avoid data race on test.
1125			r.stateWitness = func(rs resumableStreamDecoderState) {
1126				select {
1127				case <-stateDone:
1128					// Noop after expected number of state transitions
1129				default:
1130					// Record state transitions.
1131					st = append(st, rs)
1132					if len(st) == hl {
1133						lastErr = r.lastErr()
1134						q = r.q.dump()
1135						close(stateDone)
1136					}
1137				}
1138			}
1139			// Let mock server stream given messages to resumableStreamDecoder.
1140			err = setupStatementResult(t, server, test.sql, len(test.resumeTokens), test.resumeTokens)
1141			if err != nil {
1142				t.Fatalf("failed to set up a result for a statement: %v", err)
1143			}
1144			var mutex = &sync.Mutex{}
1145			var rs []*sppb.PartialResultSet
1146			rowsFetched := make(chan int)
1147			go func() {
1148				for {
1149					if !r.next() {
1150						// Note that r.Next also exits on context cancel/timeout.
1151						close(rowsFetched)
1152						return
1153					}
1154					mutex.Lock()
1155					rs = append(rs, r.get())
1156					mutex.Unlock()
1157				}
1158			}()
1159			// Wait until all rows have been fetched.
1160			if len(test.want) > 0 {
1161				select {
1162				case <-rowsFetched:
1163				case <-time.After(1 * time.Second):
1164					t.Fatal("Timeout in waiting for rows to be fetched")
1165				}
1166			}
1167			// Verify that resumableStreamDecoder reaches expected state.
1168			select {
1169			case <-stateDone: // Note that at this point, receiver is still blocking on r.next().
1170				// Check if resumableStreamDecoder carried out expected state
1171				// transitions.
1172				if !testEqual(st, test.stateHistory) {
1173					t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
1174				}
1175				// Check if resumableStreamDecoder returns expected array of
1176				// PartialResultSets.
1177				mutex.Lock()
1178				defer mutex.Unlock()
1179				if !testEqual(rs, test.want) {
1180					t.Fatalf("%s: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want)
1181				}
1182				// Verify that resumableStreamDecoder's internal buffering is also
1183				// correct.
1184				if !testEqual(q, test.queue) {
1185					t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
1186				}
1187				// Verify resume token.
1188				if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
1189					t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
1190				}
1191				// Verify error message.
1192				if !testEqual(lastErr, test.wantErr) {
1193					t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
1194				}
1195			case <-time.After(1 * time.Second):
1196				t.Fatal("Timeout in waiting for state change")
1197			}
1198		})
1199	}
1200}
1201
1202// sReceiver signals every receiving attempt through a channel, used by
1203// TestResumeToken to determine if the receiving of a certain PartialResultSet
1204// will be attempted next.
1205type sReceiver struct {
1206	c           chan int
1207	rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
1208}
1209
1210// Recv() implements streamingReceiver.Recv for sReceiver.
1211func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
1212	sr.c <- 1
1213	return sr.rpcReceiver.Recv()
1214}
1215
1216// waitn waits for nth receiving attempt from now on, until the signal for nth
1217// Recv() attempts is received or timeout. Note that because the way stream()
1218// works, the signal for the nth Recv() means that the previous n - 1
1219// PartialResultSets has already been returned to caller or queued, if no error
1220// happened.
1221func (sr *sReceiver) waitn(n int) error {
1222	for i := 0; i < n; i++ {
1223		select {
1224		case <-sr.c:
1225		case <-time.After(10 * time.Second):
1226			return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
1227		}
1228	}
1229	return nil
1230}
1231
1232// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
1233func TestQueueBytes(t *testing.T) {
1234	restore := setMaxBytesBetweenResumeTokens()
1235	defer restore()
1236
1237	server, c, teardown := setupMockedTestServer(t)
1238	defer teardown()
1239	mc, err := c.sc.nextClient()
1240	if err != nil {
1241		t.Fatalf("failed to create a grpc client")
1242	}
1243
1244	rt1 := EncodeResumeToken(1)
1245	rt2 := EncodeResumeToken(2)
1246	rt3 := EncodeResumeToken(3)
1247	resumeTokens := [][]byte{rt1, rt1, rt1, rt2, rt2, rt3}
1248	err = setupStatementResult(t, server, "SELECT t.key key, t.value value FROM t_mock t", len(resumeTokens), resumeTokens)
1249	if err != nil {
1250		t.Fatalf("failed to set up a result for a statement: %v", err)
1251	}
1252
1253	session, err := createSession(mc)
1254	if err != nil {
1255		t.Fatalf("failed to create a session")
1256	}
1257
1258	sr := &sReceiver{
1259		c: make(chan int, 1000), // will never block in this test
1260	}
1261	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1262	defer cancel()
1263	decoder := newResumableStreamDecoder(
1264		ctx,
1265		nil,
1266		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1267			r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1268				Session:     session.Name,
1269				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1270				ResumeToken: resumeToken,
1271			})
1272			sr.rpcReceiver = r
1273			return sr, err
1274		},
1275		nil,
1276	)
1277
1278	sizeOfPRS := proto.Size(&sppb.PartialResultSet{
1279		Metadata: kvMeta,
1280		Values: []*proto3.Value{
1281			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1282			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1283		},
1284		ResumeToken: rt1,
1285	})
1286
1287	decoder.next()
1288	decoder.next()
1289	decoder.next()
1290	if got, want := decoder.bytesBetweenResumeTokens, int32(2*sizeOfPRS); got != want {
1291		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want)
1292	}
1293
1294	decoder.next()
1295	if decoder.bytesBetweenResumeTokens != 0 {
1296		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens)
1297	}
1298
1299	decoder.next()
1300	if got, want := decoder.bytesBetweenResumeTokens, int32(sizeOfPRS); got != want {
1301		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want)
1302	}
1303
1304	decoder.next()
1305	if decoder.bytesBetweenResumeTokens != 0 {
1306		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens)
1307	}
1308}
1309
1310// Verify that client can deal with resume token correctly
1311func TestResumeToken(t *testing.T) {
1312	restore := setMaxBytesBetweenResumeTokens()
1313	defer restore()
1314	query := "SELECT t.key key, t.value value FROM t_mock t"
1315	server, c, teardown := setupMockedTestServer(t)
1316	defer teardown()
1317	mc, err := c.sc.nextClient()
1318	if err != nil {
1319		t.Fatalf("failed to create a grpc client")
1320	}
1321
1322	rt1 := EncodeResumeToken(1)
1323	rt2 := EncodeResumeToken(2)
1324	resumeTokens := make([][]byte, 3+maxBuffers)
1325	resumeTokens[1] = rt1
1326	resumeTokens[3+maxBuffers-1] = rt2
1327	err = setupStatementResult(t, server, query, len(resumeTokens), resumeTokens)
1328	if err != nil {
1329		t.Fatalf("failed to set up a result for a statement: %v", err)
1330	}
1331
1332	// The first error will be retried.
1333	server.TestSpanner.AddPartialResultSetError(
1334		query,
1335		PartialResultSetExecutionTime{
1336			ResumeToken: rt1,
1337			Err:         status.Error(codes.Unavailable, "mock server unavailable"),
1338		},
1339	)
1340	// The second error will not be retried because maxBytesBetweenResumeTokens
1341	// is reached and the state of resumableStreamDecoder:
1342	// queueingRetryable -> queueingUnretryable. The query will just fail.
1343	server.TestSpanner.AddPartialResultSetError(
1344		query,
1345		PartialResultSetExecutionTime{
1346			ResumeToken: rt2,
1347			Err:         status.Error(codes.Unavailable, "mock server wants some sleep"),
1348		},
1349	)
1350
1351	session, err := createSession(mc)
1352	if err != nil {
1353		t.Fatalf("failed to create a session")
1354	}
1355
1356	sr := &sReceiver{
1357		c: make(chan int, 1000), // will never block in this test
1358	}
1359	rows := []*Row{}
1360
1361	streaming := func() *RowIterator {
1362		return stream(context.Background(), nil,
1363			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1364				r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1365					Session:     session.Name,
1366					Sql:         query,
1367					ResumeToken: resumeToken,
1368				})
1369				sr.rpcReceiver = r
1370				return sr, err
1371			},
1372			nil,
1373			func(error) {})
1374	}
1375
1376	// Establish a stream to mock cloud spanner server.
1377	iter := streaming()
1378	defer iter.Stop()
1379	var row *Row
1380
1381	// Read first two rows.
1382	for i := 0; i < 3; i++ {
1383		row, err = iter.Next()
1384		if err != nil {
1385			t.Fatalf("failed to get next value: %v", err)
1386		}
1387		rows = append(rows, row)
1388	}
1389
1390	want := []*Row{
1391		{
1392			fields: kvMeta.RowType.Fields,
1393			vals: []*proto3.Value{
1394				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1395				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1396			},
1397		},
1398		{
1399			fields: kvMeta.RowType.Fields,
1400			vals: []*proto3.Value{
1401				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1402				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1403			},
1404		},
1405		{
1406			fields: kvMeta.RowType.Fields,
1407			vals: []*proto3.Value{
1408				{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
1409				{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
1410			},
1411		},
1412	}
1413	if !testEqual(rows, want) {
1414		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1415	}
1416
1417	// Trigger state change of resumableStreamDecoder:
1418	// queueingRetryable -> queueingUnretryable
1419	for i := 0; i < maxBuffers-1; i++ {
1420		row, err = iter.Next()
1421		if err != nil {
1422			t.Fatalf("failed to get next value: %v", err)
1423		}
1424		rows = append(rows, row)
1425	}
1426
1427	// Since resumableStreamDecoder is already at queueingUnretryable state,
1428	// query will just fail.
1429	_, err = iter.Next()
1430	if wantErr := spannerErrorf(codes.Unavailable, "mock server wants some sleep"); !testEqual(err, wantErr) {
1431		t.Fatalf("stream() returns error: %v, but want error: %v", err, wantErr)
1432	}
1433
1434	// Let server send two rows without resume token.
1435	resumeTokens = make([][]byte, 2)
1436	err = setupStatementResult(t, server, query, len(resumeTokens), resumeTokens)
1437	if err != nil {
1438		t.Fatalf("failed to set up a result for a statement: %v", err)
1439	}
1440
1441	// Reconnect to mock Cloud Spanner.
1442	rows = []*Row{}
1443	iter = streaming()
1444	defer iter.Stop()
1445
1446	for i := 0; i < 2; i++ {
1447		row, err = iter.Next()
1448		if err != nil {
1449			t.Fatalf("failed to get next value: %v", err)
1450		}
1451		rows = append(rows, row)
1452	}
1453
1454	// Verify if a normal server side EOF flushes all queued rows.
1455	want = []*Row{
1456		{
1457			fields: kvMeta.RowType.Fields,
1458			vals: []*proto3.Value{
1459				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1460				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1461			},
1462		},
1463		{
1464			fields: kvMeta.RowType.Fields,
1465			vals: []*proto3.Value{
1466				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1467				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1468			},
1469		},
1470	}
1471	if !testEqual(rows, want) {
1472		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1473	}
1474}
1475
1476// Verify that streaming query get retried upon real gRPC server transport
1477// failures.
1478func TestGrpcReconnect(t *testing.T) {
1479	restore := setMaxBytesBetweenResumeTokens()
1480	defer restore()
1481
1482	server, c, teardown := setupMockedTestServer(t)
1483	defer teardown()
1484	mc, err := c.sc.nextClient()
1485	if err != nil {
1486		t.Fatalf("failed to create a grpc client")
1487	}
1488
1489	session, err := createSession(mc)
1490	if err != nil {
1491		t.Fatalf("failed to create a session")
1492	}
1493
1494	// Simulate an unavailable error to interrupt the stream of PartialResultSet
1495	// in order to test the grpc retrying mechanism.
1496	server.TestSpanner.AddPartialResultSetError(
1497		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1498		PartialResultSetExecutionTime{
1499			ResumeToken: EncodeResumeToken(2),
1500			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
1501		},
1502	)
1503
1504	// The retry is counted from the second call.
1505	r := -1
1506	// Establish a stream to mock cloud spanner server.
1507	iter := stream(context.Background(), nil,
1508		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1509			r++
1510			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1511				Session:     session.Name,
1512				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1513				ResumeToken: resumeToken,
1514			})
1515
1516		},
1517		nil,
1518		func(error) {})
1519	defer iter.Stop()
1520	for {
1521		_, err := iter.Next()
1522		if err == iterator.Done {
1523			err = nil
1524			break
1525		}
1526		if err != nil {
1527			break
1528		}
1529	}
1530	if r != 1 {
1531		t.Errorf("retry count = %v, want 1", r)
1532	}
1533}
1534
1535// Test cancel/timeout for client operations.
1536func TestCancelTimeout(t *testing.T) {
1537	restore := setMaxBytesBetweenResumeTokens()
1538	defer restore()
1539	server, c, teardown := setupMockedTestServer(t)
1540	defer teardown()
1541	server.TestSpanner.PutExecutionTime(
1542		MethodExecuteStreamingSql,
1543		SimulatedExecutionTime{MinimumExecutionTime: 1 * time.Second},
1544	)
1545	mc, err := c.sc.nextClient()
1546	if err != nil {
1547		t.Fatalf("failed to create a grpc client")
1548	}
1549
1550	session, err := createSession(mc)
1551	if err != nil {
1552		t.Fatalf("failed to create a session")
1553	}
1554	done := make(chan int)
1555
1556	// Test cancelling query.
1557	ctx, cancel := context.WithCancel(context.Background())
1558	go func() {
1559		// Establish a stream to mock cloud spanner server.
1560		iter := stream(ctx, nil,
1561			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1562				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1563					Session:     session.Name,
1564					Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1565					ResumeToken: resumeToken,
1566				})
1567			},
1568			nil,
1569			func(error) {})
1570		defer iter.Stop()
1571		for {
1572			_, err = iter.Next()
1573			if err == iterator.Done {
1574				break
1575			}
1576			if err != nil {
1577				done <- 0
1578				break
1579			}
1580		}
1581	}()
1582	cancel()
1583	select {
1584	case <-done:
1585		if ErrCode(err) != codes.Canceled {
1586			t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
1587		}
1588	case <-time.After(1 * time.Second):
1589		t.Errorf("query doesn't exit timely after being cancelled")
1590	}
1591
1592	// Test query timeout.
1593	ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
1594	defer cancel()
1595	go func() {
1596		// Establish a stream to mock cloud spanner server.
1597		iter := stream(ctx, nil,
1598			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1599				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1600					Session:     session.Name,
1601					Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1602					ResumeToken: resumeToken,
1603				})
1604			},
1605			nil,
1606			func(error) {})
1607		defer iter.Stop()
1608		for {
1609			_, err = iter.Next()
1610			if err == iterator.Done {
1611				err = nil
1612				break
1613			}
1614			if err != nil {
1615				break
1616			}
1617		}
1618		done <- 0
1619	}()
1620	select {
1621	case <-done:
1622		if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
1623			t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
1624		}
1625	case <-time.After(2 * time.Second):
1626		t.Errorf("query doesn't timeout as expected")
1627	}
1628}
1629
1630func setupStatementResult(t *testing.T, server *MockedSpannerInMemTestServer, stmt string, rowCount int, resumeTokens [][]byte) error {
1631	selectValues := make([][]string, rowCount)
1632	for i := 0; i < rowCount; i++ {
1633		selectValues[i] = []string{keyStr(i), valStr(i)}
1634	}
1635
1636	rows := make([]*structpb.ListValue, len(selectValues))
1637	for i, values := range selectValues {
1638		rowValues := make([]*structpb.Value, len(kvMeta.RowType.Fields))
1639		for j, value := range values {
1640			rowValues[j] = &structpb.Value{
1641				Kind: &structpb.Value_StringValue{StringValue: value},
1642			}
1643		}
1644		rows[i] = &structpb.ListValue{
1645			Values: rowValues,
1646		}
1647	}
1648	resultSet := &sppb.ResultSet{
1649		Metadata: kvMeta,
1650		Rows:     rows,
1651	}
1652	result := &StatementResult{
1653		Type:         StatementResultResultSet,
1654		ResultSet:    resultSet,
1655		ResumeTokens: resumeTokens,
1656	}
1657	return server.TestSpanner.PutStatementResult(stmt, result)
1658}
1659
1660func TestRowIteratorDo(t *testing.T) {
1661	restore := setMaxBytesBetweenResumeTokens()
1662	defer restore()
1663
1664	_, c, teardown := setupMockedTestServer(t)
1665	defer teardown()
1666	mc, err := c.sc.nextClient()
1667	if err != nil {
1668		t.Fatalf("failed to create a grpc client")
1669	}
1670
1671	session, err := createSession(mc)
1672	if err != nil {
1673		t.Fatalf("failed to create a session")
1674	}
1675
1676	nRows := 0
1677	iter := stream(context.Background(), nil,
1678		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1679			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1680				Session:     session.Name,
1681				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1682				ResumeToken: resumeToken,
1683			})
1684		},
1685		nil,
1686		func(error) {})
1687	err = iter.Do(func(r *Row) error { nRows++; return nil })
1688	if err != nil {
1689		t.Errorf("Using Do: %v", err)
1690	}
1691	if nRows != 3 {
1692		t.Errorf("got %d rows, want 3", nRows)
1693	}
1694}
1695
1696func TestRowIteratorDoWithError(t *testing.T) {
1697	restore := setMaxBytesBetweenResumeTokens()
1698	defer restore()
1699
1700	_, c, teardown := setupMockedTestServer(t)
1701	defer teardown()
1702	mc, err := c.sc.nextClient()
1703	if err != nil {
1704		t.Fatalf("failed to create a grpc client")
1705	}
1706
1707	session, err := createSession(mc)
1708	if err != nil {
1709		t.Fatalf("failed to create a session")
1710	}
1711
1712	iter := stream(context.Background(), nil,
1713		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1714			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1715				Session:     session.Name,
1716				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1717				ResumeToken: resumeToken,
1718			})
1719		},
1720		nil,
1721		func(error) {})
1722	injected := errors.New("Failed iterator")
1723	err = iter.Do(func(r *Row) error { return injected })
1724	if err != injected {
1725		t.Errorf("got <%v>, want <%v>", err, injected)
1726	}
1727}
1728
1729func TestIteratorStopEarly(t *testing.T) {
1730	ctx := context.Background()
1731	restore := setMaxBytesBetweenResumeTokens()
1732	defer restore()
1733
1734	_, c, teardown := setupMockedTestServer(t)
1735	defer teardown()
1736	mc, err := c.sc.nextClient()
1737	if err != nil {
1738		t.Fatalf("failed to create a grpc client")
1739	}
1740
1741	session, err := createSession(mc)
1742	if err != nil {
1743		t.Fatalf("failed to create a session")
1744	}
1745
1746	iter := stream(ctx, nil,
1747		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1748			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1749				Session:     session.Name,
1750				Sql:         SelectSingerIDAlbumIDAlbumTitleFromAlbums,
1751				ResumeToken: resumeToken,
1752			})
1753		},
1754		nil,
1755		func(error) {})
1756	_, err = iter.Next()
1757	if err != nil {
1758		t.Fatalf("before Stop: %v", err)
1759	}
1760	iter.Stop()
1761	// Stop sets r.err to the FailedPrecondition error "Next called after Stop".
1762	_, err = iter.Next()
1763	if g, w := ErrCode(err), codes.FailedPrecondition; g != w {
1764		t.Errorf("after Stop: got: %v, want: %v", g, w)
1765	}
1766}
1767
1768func TestIteratorWithError(t *testing.T) {
1769	injected := errors.New("Failed iterator")
1770	iter := RowIterator{err: injected}
1771	defer iter.Stop()
1772	if _, err := iter.Next(); err != injected {
1773		t.Fatalf("Expected error: %v, got %v", injected, err)
1774	}
1775}
1776
1777func createSession(client *vkit.Client) (*sppb.Session, error) {
1778	var formattedDatabase string = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
1779	var request = &sppb.CreateSessionRequest{
1780		Database: formattedDatabase,
1781	}
1782	return client.CreateSession(context.Background(), request)
1783}
1784