1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"io"
24	"sync/atomic"
25	"testing"
26	"time"
27
28	. "cloud.google.com/go/spanner/internal/testutil"
29	"github.com/golang/protobuf/proto"
30	proto3 "github.com/golang/protobuf/ptypes/struct"
31	"github.com/googleapis/gax-go/v2"
32	"google.golang.org/api/iterator"
33	sppb "google.golang.org/genproto/googleapis/spanner/v1"
34	"google.golang.org/grpc"
35	"google.golang.org/grpc/codes"
36	"google.golang.org/grpc/status"
37)
38
39var (
40	// Mocked transaction timestamp.
41	trxTs = time.Unix(1, 2)
42	// Metadata for mocked KV table, its rows are returned by SingleUse
43	// transactions.
44	kvMeta = func() *sppb.ResultSetMetadata {
45		meta := KvMeta
46		meta.Transaction = &sppb.Transaction{
47			ReadTimestamp: timestampProto(trxTs),
48		}
49		return &meta
50	}()
51	// Metadata for mocked ListKV table, which uses List for its key and value.
52	// Its rows are returned by snapshot readonly transactions, as indicated in
53	// the transaction metadata.
54	kvListMeta = &sppb.ResultSetMetadata{
55		RowType: &sppb.StructType{
56			Fields: []*sppb.StructType_Field{
57				{
58					Name: "Key",
59					Type: &sppb.Type{
60						Code: sppb.TypeCode_ARRAY,
61						ArrayElementType: &sppb.Type{
62							Code: sppb.TypeCode_STRING,
63						},
64					},
65				},
66				{
67					Name: "Value",
68					Type: &sppb.Type{
69						Code: sppb.TypeCode_ARRAY,
70						ArrayElementType: &sppb.Type{
71							Code: sppb.TypeCode_STRING,
72						},
73					},
74				},
75			},
76		},
77		Transaction: &sppb.Transaction{
78			Id:            transactionID{5, 6, 7, 8, 9},
79			ReadTimestamp: timestampProto(trxTs),
80		},
81	}
82	// Metadata for mocked schema of a query result set, which has two struct
83	// columns named "Col1" and "Col2", the struct's schema is like the
84	// following:
85	//
86	//	STRUCT {
87	//		INT
88	//		LIST<STRING>
89	//	}
90	//
91	// Its rows are returned in readwrite transaction, as indicated in the
92	// transaction metadata.
93	kvObjectMeta = &sppb.ResultSetMetadata{
94		RowType: &sppb.StructType{
95			Fields: []*sppb.StructType_Field{
96				{
97					Name: "Col1",
98					Type: &sppb.Type{
99						Code: sppb.TypeCode_STRUCT,
100						StructType: &sppb.StructType{
101							Fields: []*sppb.StructType_Field{
102								{
103									Name: "foo-f1",
104									Type: &sppb.Type{
105										Code: sppb.TypeCode_INT64,
106									},
107								},
108								{
109									Name: "foo-f2",
110									Type: &sppb.Type{
111										Code: sppb.TypeCode_ARRAY,
112										ArrayElementType: &sppb.Type{
113											Code: sppb.TypeCode_STRING,
114										},
115									},
116								},
117							},
118						},
119					},
120				},
121				{
122					Name: "Col2",
123					Type: &sppb.Type{
124						Code: sppb.TypeCode_STRUCT,
125						StructType: &sppb.StructType{
126							Fields: []*sppb.StructType_Field{
127								{
128									Name: "bar-f1",
129									Type: &sppb.Type{
130										Code: sppb.TypeCode_INT64,
131									},
132								},
133								{
134									Name: "bar-f2",
135									Type: &sppb.Type{
136										Code: sppb.TypeCode_ARRAY,
137										ArrayElementType: &sppb.Type{
138											Code: sppb.TypeCode_STRING,
139										},
140									},
141								},
142							},
143						},
144					},
145				},
146			},
147		},
148		Transaction: &sppb.Transaction{
149			Id: transactionID{1, 2, 3, 4, 5},
150		},
151	}
152)
153
154// String implements fmt.stringer.
155func (r *Row) String() string {
156	return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals)
157}
158
159func describeRows(l []*Row) string {
160	// generate a nice test failure description
161	var s = "["
162	for i, r := range l {
163		if i != 0 {
164			s += ",\n "
165		}
166		s += fmt.Sprint(r)
167	}
168	s += "]"
169	return s
170}
171
172// Helper for generating proto3 Value_ListValue instances, making test code
173// shorter and readable.
174func genProtoListValue(v ...string) *proto3.Value_ListValue {
175	r := &proto3.Value_ListValue{
176		ListValue: &proto3.ListValue{
177			Values: []*proto3.Value{},
178		},
179	}
180	for _, e := range v {
181		r.ListValue.Values = append(
182			r.ListValue.Values,
183			&proto3.Value{
184				Kind: &proto3.Value_StringValue{StringValue: e},
185			},
186		)
187	}
188	return r
189}
190
191// Test Row generation logics of partialResultSetDecoder.
192func TestPartialResultSetDecoder(t *testing.T) {
193	restore := setMaxBytesBetweenResumeTokens()
194	defer restore()
195	var tests = []struct {
196		input    []*sppb.PartialResultSet
197		wantF    []*Row
198		wantTxID transactionID
199		wantTs   time.Time
200		wantD    bool
201	}{
202		{
203			// Empty input.
204			wantD: true,
205		},
206		// String merging examples.
207		{
208			// Single KV result.
209			input: []*sppb.PartialResultSet{
210				{
211					Metadata: kvMeta,
212					Values: []*proto3.Value{
213						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
214						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
215					},
216				},
217			},
218			wantF: []*Row{
219				{
220					fields: kvMeta.RowType.Fields,
221					vals: []*proto3.Value{
222						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
223						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
224					},
225				},
226			},
227			wantTs: trxTs,
228			wantD:  true,
229		},
230		{
231			// Incomplete partial result.
232			input: []*sppb.PartialResultSet{
233				{
234					Metadata: kvMeta,
235					Values: []*proto3.Value{
236						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
237					},
238				},
239			},
240			wantTs: trxTs,
241			wantD:  false,
242		},
243		{
244			// Complete splitted result.
245			input: []*sppb.PartialResultSet{
246				{
247					Metadata: kvMeta,
248					Values: []*proto3.Value{
249						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
250					},
251				},
252				{
253					Values: []*proto3.Value{
254						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
255					},
256				},
257			},
258			wantF: []*Row{
259				{
260					fields: kvMeta.RowType.Fields,
261					vals: []*proto3.Value{
262						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
263						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
264					},
265				},
266			},
267			wantTs: trxTs,
268			wantD:  true,
269		},
270		{
271			// Multi-row example with splitted row in the middle.
272			input: []*sppb.PartialResultSet{
273				{
274					Metadata: kvMeta,
275					Values: []*proto3.Value{
276						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
277						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
278						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
279					},
280				},
281				{
282					Values: []*proto3.Value{
283						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
284						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
285						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
286					},
287				},
288			},
289			wantF: []*Row{
290				{
291					fields: kvMeta.RowType.Fields,
292					vals: []*proto3.Value{
293						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
294						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
295					},
296				},
297				{
298					fields: kvMeta.RowType.Fields,
299					vals: []*proto3.Value{
300						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
301						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
302					},
303				},
304				{
305					fields: kvMeta.RowType.Fields,
306					vals: []*proto3.Value{
307						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
308						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
309					},
310				},
311			},
312			wantTs: trxTs,
313			wantD:  true,
314		},
315		{
316			// Merging example in result_set.proto.
317			input: []*sppb.PartialResultSet{
318				{
319					Metadata: kvMeta,
320					Values: []*proto3.Value{
321						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
322						{Kind: &proto3.Value_StringValue{StringValue: "W"}},
323					},
324					ChunkedValue: true,
325				},
326				{
327					Values: []*proto3.Value{
328						{Kind: &proto3.Value_StringValue{StringValue: "orl"}},
329					},
330					ChunkedValue: true,
331				},
332				{
333					Values: []*proto3.Value{
334						{Kind: &proto3.Value_StringValue{StringValue: "d"}},
335					},
336				},
337			},
338			wantF: []*Row{
339				{
340					fields: kvMeta.RowType.Fields,
341					vals: []*proto3.Value{
342						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
343						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
344					},
345				},
346			},
347			wantTs: trxTs,
348			wantD:  true,
349		},
350		{
351			// More complex example showing completing a merge and
352			// starting a new merge in the same partialResultSet.
353			input: []*sppb.PartialResultSet{
354				{
355					Metadata: kvMeta,
356					Values: []*proto3.Value{
357						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
358						{Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
359					},
360					ChunkedValue: true,
361				},
362				{
363					Values: []*proto3.Value{
364						{Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
365						{Kind: &proto3.Value_StringValue{StringValue: "i"}},    // start split in key
366					},
367					ChunkedValue: true,
368				},
369				{
370					Values: []*proto3.Value{
371						{Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
372						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
373						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
374						{Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
375					},
376					ChunkedValue: true,
377				},
378				{
379					Values: []*proto3.Value{
380						{Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
381					},
382				},
383			},
384			wantF: []*Row{
385				{
386					fields: kvMeta.RowType.Fields,
387					vals: []*proto3.Value{
388						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
389						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
390					},
391				},
392				{
393					fields: kvMeta.RowType.Fields,
394					vals: []*proto3.Value{
395						{Kind: &proto3.Value_StringValue{StringValue: "is"}},
396						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
397					},
398				},
399				{
400					fields: kvMeta.RowType.Fields,
401					vals: []*proto3.Value{
402						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
403						{Kind: &proto3.Value_StringValue{StringValue: "question"}},
404					},
405				},
406			},
407			wantTs: trxTs,
408			wantD:  true,
409		},
410		// List merging examples.
411		{
412			// Non-splitting Lists.
413			input: []*sppb.PartialResultSet{
414				{
415					Metadata: kvListMeta,
416					Values: []*proto3.Value{
417						{
418							Kind: genProtoListValue("foo-1", "foo-2"),
419						},
420					},
421				},
422				{
423					Values: []*proto3.Value{
424						{
425							Kind: genProtoListValue("bar-1", "bar-2"),
426						},
427					},
428				},
429			},
430			wantF: []*Row{
431				{
432					fields: kvListMeta.RowType.Fields,
433					vals: []*proto3.Value{
434						{
435							Kind: genProtoListValue("foo-1", "foo-2"),
436						},
437						{
438							Kind: genProtoListValue("bar-1", "bar-2"),
439						},
440					},
441				},
442			},
443			wantTxID: transactionID{5, 6, 7, 8, 9},
444			wantTs:   trxTs,
445			wantD:    true,
446		},
447		{
448			// Simple List merge case: splitted string element.
449			input: []*sppb.PartialResultSet{
450				{
451					Metadata: kvListMeta,
452					Values: []*proto3.Value{
453						{
454							Kind: genProtoListValue("foo-1", "foo-"),
455						},
456					},
457					ChunkedValue: true,
458				},
459				{
460					Values: []*proto3.Value{
461						{
462							Kind: genProtoListValue("2"),
463						},
464					},
465				},
466				{
467					Values: []*proto3.Value{
468						{
469							Kind: genProtoListValue("bar-1", "bar-2"),
470						},
471					},
472				},
473			},
474			wantF: []*Row{
475				{
476					fields: kvListMeta.RowType.Fields,
477					vals: []*proto3.Value{
478						{
479							Kind: genProtoListValue("foo-1", "foo-2"),
480						},
481						{
482							Kind: genProtoListValue("bar-1", "bar-2"),
483						},
484					},
485				},
486			},
487			wantTxID: transactionID{5, 6, 7, 8, 9},
488			wantTs:   trxTs,
489			wantD:    true,
490		},
491		{
492			// Struct merging is also implemented by List merging. Note that
493			// Cloud Spanner uses proto.ListValue to encode Structs as well.
494			input: []*sppb.PartialResultSet{
495				{
496					Metadata: kvObjectMeta,
497					Values: []*proto3.Value{
498						{
499							Kind: &proto3.Value_ListValue{
500								ListValue: &proto3.ListValue{
501									Values: []*proto3.Value{
502										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
503										{Kind: genProtoListValue("foo-1", "fo")},
504									},
505								},
506							},
507						},
508					},
509					ChunkedValue: true,
510				},
511				{
512					Values: []*proto3.Value{
513						{
514							Kind: &proto3.Value_ListValue{
515								ListValue: &proto3.ListValue{
516									Values: []*proto3.Value{
517										{Kind: genProtoListValue("o-2", "f")},
518									},
519								},
520							},
521						},
522					},
523					ChunkedValue: true,
524				},
525				{
526					Values: []*proto3.Value{
527						{
528							Kind: &proto3.Value_ListValue{
529								ListValue: &proto3.ListValue{
530									Values: []*proto3.Value{
531										{Kind: genProtoListValue("oo-3")},
532									},
533								},
534							},
535						},
536						{
537							Kind: &proto3.Value_ListValue{
538								ListValue: &proto3.ListValue{
539									Values: []*proto3.Value{
540										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
541										{Kind: genProtoListValue("bar-1")},
542									},
543								},
544							},
545						},
546					},
547				},
548			},
549			wantF: []*Row{
550				{
551					fields: kvObjectMeta.RowType.Fields,
552					vals: []*proto3.Value{
553						{
554							Kind: &proto3.Value_ListValue{
555								ListValue: &proto3.ListValue{
556									Values: []*proto3.Value{
557										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
558										{Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
559									},
560								},
561							},
562						},
563						{
564							Kind: &proto3.Value_ListValue{
565								ListValue: &proto3.ListValue{
566									Values: []*proto3.Value{
567										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
568										{Kind: genProtoListValue("bar-1")},
569									},
570								},
571							},
572						},
573					},
574				},
575			},
576			wantTxID: transactionID{1, 2, 3, 4, 5},
577			wantD:    true,
578		},
579	}
580
581nextTest:
582	for i, test := range tests {
583		var rows []*Row
584		p := &partialResultSetDecoder{}
585		for j, v := range test.input {
586			rs, err := p.add(v)
587			if err != nil {
588				t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
589				continue nextTest
590			}
591			rows = append(rows, rs...)
592		}
593		if !testEqual(p.ts, test.wantTs) {
594			t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
595		}
596		if !testEqual(rows, test.wantF) {
597			t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
598		}
599		if got := p.done(); got != test.wantD {
600			t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
601		}
602	}
603}
604
605const (
606	// max number of PartialResultSets that will be buffered in tests.
607	maxBuffers = 16
608)
609
610// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to
611// a smaller value more suitable for tests. It returns a function which should
612// be called to restore the maxBytesBetweenResumeTokens to its old value.
613func setMaxBytesBetweenResumeTokens() func() {
614	o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
615	atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
616		Metadata: kvMeta,
617		Values: []*proto3.Value{
618			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
619			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
620		},
621	})))
622	return func() {
623		atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
624	}
625}
626
627// keyStr generates key string for kvMeta schema.
628func keyStr(i int) string {
629	return fmt.Sprintf("foo-%02d", i)
630}
631
632// valStr generates value string for kvMeta schema.
633func valStr(i int) string {
634	return fmt.Sprintf("bar-%02d", i)
635}
636
637// Test state transitions of resumableStreamDecoder where state machine ends up
638// to a non-blocking state(resumableStreamDecoder.Next returns on non-blocking
639// state).
640func TestRsdNonblockingStates(t *testing.T) {
641	restore := setMaxBytesBetweenResumeTokens()
642	defer restore()
643	tests := []struct {
644		name string
645		msgs []MockCtlMsg
646		rpc  func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
647		sql  string
648		// Expected values
649		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
650		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
651		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
652		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
653		wantErr      error
654	}{
655		{
656			// unConnected->queueingRetryable->finished
657			name: "unConnected->queueingRetryable->finished",
658			msgs: []MockCtlMsg{
659				{},
660				{},
661				{Err: io.EOF, ResumeToken: false},
662			},
663			sql: "SELECT t.key key, t.value value FROM t_mock t",
664			want: []*sppb.PartialResultSet{
665				{
666					Metadata: kvMeta,
667					Values: []*proto3.Value{
668						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
669						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
670					},
671				},
672			},
673			queue: []*sppb.PartialResultSet{
674				{
675					Metadata: kvMeta,
676					Values: []*proto3.Value{
677						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
678						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
679					},
680				},
681			},
682			stateHistory: []resumableStreamDecoderState{
683				queueingRetryable, // do RPC
684				queueingRetryable, // got foo-00
685				queueingRetryable, // got foo-01
686				finished,          // got EOF
687			},
688		},
689		{
690			// unConnected->queueingRetryable->aborted
691			name: "unConnected->queueingRetryable->aborted",
692			msgs: []MockCtlMsg{
693				{},
694				{Err: nil, ResumeToken: true},
695				{},
696				{Err: errors.New("I quit"), ResumeToken: false},
697			},
698			sql: "SELECT t.key key, t.value value FROM t_mock t",
699			want: []*sppb.PartialResultSet{
700				{
701					Metadata: kvMeta,
702					Values: []*proto3.Value{
703						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
704						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
705					},
706				},
707				{
708					Metadata: kvMeta,
709					Values: []*proto3.Value{
710						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
711						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
712					},
713					ResumeToken: EncodeResumeToken(1),
714				},
715			},
716			stateHistory: []resumableStreamDecoderState{
717				queueingRetryable, // do RPC
718				queueingRetryable, // got foo-00
719				queueingRetryable, // got foo-01
720				queueingRetryable, // foo-01, resume token
721				queueingRetryable, // got foo-02
722				aborted,           // got error
723			},
724			wantErr: status.Errorf(codes.Unknown, "I quit"),
725		},
726		{
727			// unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
728			name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
729			msgs: func() (m []MockCtlMsg) {
730				for i := 0; i < maxBuffers+1; i++ {
731					m = append(m, MockCtlMsg{})
732				}
733				return m
734			}(),
735			sql: "SELECT t.key key, t.value value FROM t_mock t",
736			want: func() (s []*sppb.PartialResultSet) {
737				for i := 0; i < maxBuffers+1; i++ {
738					s = append(s, &sppb.PartialResultSet{
739						Metadata: kvMeta,
740						Values: []*proto3.Value{
741							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
742							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
743						},
744					})
745				}
746				return s
747			}(),
748			stateHistory: func() (s []resumableStreamDecoderState) {
749				s = append(s, queueingRetryable) // RPC
750				for i := 0; i < maxBuffers; i++ {
751					s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
752				}
753				// the first item fills up the queue and triggers state transition;
754				// the second item is received under queueingUnretryable state.
755				s = append(s, queueingUnretryable)
756				s = append(s, queueingUnretryable)
757				return s
758			}(),
759		},
760		{
761			// unConnected->queueingRetryable->queueingUnretryable->aborted
762			name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
763			msgs: func() (m []MockCtlMsg) {
764				for i := 0; i < maxBuffers; i++ {
765					m = append(m, MockCtlMsg{})
766				}
767				m = append(m, MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false})
768				return m
769			}(),
770			sql: "SELECT t.key key, t.value value FROM t_mock t",
771			want: func() (s []*sppb.PartialResultSet) {
772				for i := 0; i < maxBuffers; i++ {
773					s = append(s, &sppb.PartialResultSet{
774						Metadata: kvMeta,
775						Values: []*proto3.Value{
776							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
777							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
778						},
779					})
780				}
781				return s
782			}(),
783			stateHistory: func() (s []resumableStreamDecoderState) {
784				s = append(s, queueingRetryable) // RPC
785				for i := 0; i < maxBuffers; i++ {
786					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
787				}
788				s = append(s, queueingUnretryable) // the last row triggers state change
789				s = append(s, aborted)             // Error happens
790				return s
791			}(),
792			wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
793		},
794	}
795	for _, test := range tests {
796		t.Run(test.name, func(t *testing.T) {
797			ms := NewMockCloudSpanner(t, trxTs)
798			ms.Serve()
799			mc := sppb.NewSpannerClient(dialMock(t, ms))
800			if test.rpc == nil {
801				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
802					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
803						Sql:         test.sql,
804						ResumeToken: resumeToken,
805					})
806				}
807			}
808			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
809			defer cancel()
810			r := newResumableStreamDecoder(
811				ctx,
812				nil,
813				test.rpc,
814				nil,
815			)
816			st := []resumableStreamDecoderState{}
817			var lastErr error
818			// Once the expected number of state transitions are observed,
819			// send a signal by setting stateDone = true.
820			stateDone := false
821			// Set stateWitness to listen to state changes.
822			hl := len(test.stateHistory) // To avoid data race on test.
823			r.stateWitness = func(rs resumableStreamDecoderState) {
824				if !stateDone {
825					// Record state transitions.
826					st = append(st, rs)
827					if len(st) == hl {
828						lastErr = r.lastErr()
829						stateDone = true
830					}
831				}
832			}
833			// Let mock server stream given messages to resumableStreamDecoder.
834			for _, m := range test.msgs {
835				ms.AddMsg(m.Err, m.ResumeToken)
836			}
837			var rs []*sppb.PartialResultSet
838			for {
839				select {
840				case <-ctx.Done():
841					t.Fatal("context cancelled or timeout during test")
842				default:
843				}
844				if stateDone {
845					// Check if resumableStreamDecoder carried out expected
846					// state transitions.
847					if !testEqual(st, test.stateHistory) {
848						t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
849					}
850					// Check if resumableStreamDecoder returns expected array of
851					// PartialResultSets.
852					if !testEqual(rs, test.want) {
853						t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
854					}
855					// Verify that resumableStreamDecoder's internal buffering is
856					// also correct.
857					var q []*sppb.PartialResultSet
858					for {
859						item := r.q.pop()
860						if item == nil {
861							break
862						}
863						q = append(q, item)
864					}
865					if !testEqual(q, test.queue) {
866						t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
867					}
868					// Verify resume token.
869					if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
870						t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
871					}
872					// Verify error message.
873					if !testEqual(lastErr, test.wantErr) {
874						t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
875					}
876					return
877				}
878				// Receive next decoded item.
879				if r.next() {
880					rs = append(rs, r.get())
881				}
882			}
883		})
884	}
885}
886
887// Test state transitions of resumableStreamDecoder where state machine
888// ends up to a blocking state(resumableStreamDecoder.Next blocks
889// on blocking state).
890func TestRsdBlockingStates(t *testing.T) {
891	restore := setMaxBytesBetweenResumeTokens()
892	defer restore()
893	for _, test := range []struct {
894		name string
895		msgs []MockCtlMsg
896		rpc  func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
897		sql  string
898		// Expected values
899		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
900		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
901		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
902		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
903		wantErr      error
904	}{
905		{
906			// unConnected -> unConnected
907			name: "unConnected -> unConnected",
908			rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
909				return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
910			},
911			sql:          "SELECT * from t_whatever",
912			stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
913			wantErr:      status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
914		},
915		{
916			// unConnected -> queueingRetryable
917			name:         "unConnected -> queueingRetryable",
918			sql:          "SELECT t.key key, t.value value FROM t_mock t",
919			stateHistory: []resumableStreamDecoderState{queueingRetryable},
920		},
921		{
922			// unConnected->queueingRetryable->queueingRetryable
923			name: "unConnected->queueingRetryable->queueingRetryable",
924			msgs: []MockCtlMsg{
925				{},
926				{Err: nil, ResumeToken: true},
927				{Err: nil, ResumeToken: true},
928				{},
929			},
930			sql: "SELECT t.key key, t.value value FROM t_mock t",
931			want: []*sppb.PartialResultSet{
932				{
933					Metadata: kvMeta,
934					Values: []*proto3.Value{
935						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
936						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
937					},
938				},
939				{
940					Metadata: kvMeta,
941					Values: []*proto3.Value{
942						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
943						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
944					},
945					ResumeToken: EncodeResumeToken(1),
946				},
947				{
948					Metadata: kvMeta,
949					Values: []*proto3.Value{
950						{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
951						{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
952					},
953					ResumeToken: EncodeResumeToken(2),
954				},
955			},
956			queue: []*sppb.PartialResultSet{
957				{
958					Metadata: kvMeta,
959					Values: []*proto3.Value{
960						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
961						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
962					},
963				},
964			},
965			resumeToken: EncodeResumeToken(2),
966			stateHistory: []resumableStreamDecoderState{
967				queueingRetryable, // do RPC
968				queueingRetryable, // got foo-00
969				queueingRetryable, // got foo-01
970				queueingRetryable, // foo-01, resume token
971				queueingRetryable, // got foo-02
972				queueingRetryable, // foo-02, resume token
973				queueingRetryable, // got foo-03
974			},
975		},
976		{
977			// unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
978			name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
979			msgs: func() (m []MockCtlMsg) {
980				for i := 0; i < maxBuffers+1; i++ {
981					m = append(m, MockCtlMsg{})
982				}
983				m = append(m, MockCtlMsg{Err: nil, ResumeToken: true})
984				m = append(m, MockCtlMsg{})
985				return m
986			}(),
987			sql: "SELECT t.key key, t.value value FROM t_mock t",
988			want: func() (s []*sppb.PartialResultSet) {
989				for i := 0; i < maxBuffers+2; i++ {
990					s = append(s, &sppb.PartialResultSet{
991						Metadata: kvMeta,
992						Values: []*proto3.Value{
993							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
994							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
995						},
996					})
997				}
998				s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1)
999				return s
1000			}(),
1001			resumeToken: EncodeResumeToken(maxBuffers + 1),
1002			queue: []*sppb.PartialResultSet{
1003				{
1004					Metadata: kvMeta,
1005					Values: []*proto3.Value{
1006						{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
1007						{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
1008					},
1009				},
1010			},
1011			stateHistory: func() (s []resumableStreamDecoderState) {
1012				s = append(s, queueingRetryable) // RPC
1013				for i := 0; i < maxBuffers; i++ {
1014					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
1015				}
1016				for i := maxBuffers - 1; i < maxBuffers+1; i++ {
1017					// the first item fills up the queue and triggers state
1018					// change; the second item is received under
1019					// queueingUnretryable state.
1020					s = append(s, queueingUnretryable)
1021				}
1022				s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
1023				s = append(s, queueingRetryable)   // (maxBuffers+1)th row has resume token
1024				s = append(s, queueingRetryable)   // (maxBuffers+2)th row has no resume token
1025				return s
1026			}(),
1027		},
1028		{
1029			// unConnected->queueingRetryable->queueingUnretryable->finished
1030			name: "unConnected->queueingRetryable->queueingUnretryable->finished",
1031			msgs: func() (m []MockCtlMsg) {
1032				for i := 0; i < maxBuffers; i++ {
1033					m = append(m, MockCtlMsg{})
1034				}
1035				m = append(m, MockCtlMsg{Err: io.EOF, ResumeToken: false})
1036				return m
1037			}(),
1038			sql: "SELECT t.key key, t.value value FROM t_mock t",
1039			want: func() (s []*sppb.PartialResultSet) {
1040				for i := 0; i < maxBuffers; i++ {
1041					s = append(s, &sppb.PartialResultSet{
1042						Metadata: kvMeta,
1043						Values: []*proto3.Value{
1044							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1045							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1046						},
1047					})
1048				}
1049				return s
1050			}(),
1051			stateHistory: func() (s []resumableStreamDecoderState) {
1052				s = append(s, queueingRetryable) // RPC
1053				for i := 0; i < maxBuffers; i++ {
1054					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
1055				}
1056				s = append(s, queueingUnretryable) // last row triggers state change
1057				s = append(s, finished)            // query finishes
1058				return s
1059			}(),
1060		},
1061	} {
1062		t.Run(test.name, func(t *testing.T) {
1063			ms := NewMockCloudSpanner(t, trxTs)
1064			ms.Serve()
1065			cc := dialMock(t, ms)
1066			mc := sppb.NewSpannerClient(cc)
1067			if test.rpc == nil {
1068				// Avoid using test.sql directly in closure because for loop changes
1069				// test.
1070				sql := test.sql
1071				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1072					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1073						Sql:         sql,
1074						ResumeToken: resumeToken,
1075					})
1076				}
1077			}
1078			ctx, cancel := context.WithCancel(context.Background())
1079			defer cancel()
1080			r := newResumableStreamDecoder(
1081				ctx,
1082				nil,
1083				test.rpc,
1084				nil,
1085			)
1086			// Override backoff to make the test run faster.
1087			r.backoff = gax.Backoff{
1088				Initial:    1 * time.Nanosecond,
1089				Max:        1 * time.Nanosecond,
1090				Multiplier: 1.3,
1091			}
1092			// st is the set of observed state transitions.
1093			st := []resumableStreamDecoderState{}
1094			// q is the content of the decoder's partial result queue when expected
1095			// number of state transitions are done.
1096			q := []*sppb.PartialResultSet{}
1097			var lastErr error
1098			// Once the expected number of state transitions are observed, send a
1099			// signal to channel stateDone.
1100			stateDone := make(chan int)
1101			// Set stateWitness to listen to state changes.
1102			hl := len(test.stateHistory) // To avoid data race on test.
1103			r.stateWitness = func(rs resumableStreamDecoderState) {
1104				select {
1105				case <-stateDone:
1106					// Noop after expected number of state transitions
1107				default:
1108					// Record state transitions.
1109					st = append(st, rs)
1110					if len(st) == hl {
1111						lastErr = r.lastErr()
1112						q = r.q.dump()
1113						close(stateDone)
1114					}
1115				}
1116			}
1117			// Let mock server stream given messages to resumableStreamDecoder.
1118			for _, m := range test.msgs {
1119				ms.AddMsg(m.Err, m.ResumeToken)
1120			}
1121			var rs []*sppb.PartialResultSet
1122			go func() {
1123				for {
1124					if !r.next() {
1125						// Note that r.Next also exits on context cancel/timeout.
1126						return
1127					}
1128					rs = append(rs, r.get())
1129				}
1130			}()
1131			// Verify that resumableStreamDecoder reaches expected state.
1132			select {
1133			case <-stateDone: // Note that at this point, receiver is still blockingon r.next().
1134				// Check if resumableStreamDecoder carried out expected state
1135				// transitions.
1136				if !testEqual(st, test.stateHistory) {
1137					t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
1138				}
1139				// Check if resumableStreamDecoder returns expected array of
1140				// PartialResultSets.
1141				if !testEqual(rs, test.want) {
1142					t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
1143				}
1144				// Verify that resumableStreamDecoder's internal buffering is also
1145				// correct.
1146				if !testEqual(q, test.queue) {
1147					t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
1148				}
1149				// Verify resume token.
1150				if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
1151					t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
1152				}
1153				// Verify error message.
1154				if !testEqual(lastErr, test.wantErr) {
1155					t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
1156				}
1157			case <-time.After(1 * time.Second):
1158				t.Fatal("Timeout in waiting for state change")
1159			}
1160			ms.Stop()
1161			if err := cc.Close(); err != nil {
1162				t.Fatal(err)
1163			}
1164		})
1165	}
1166}
1167
1168// sReceiver signals every receiving attempt through a channel, used by
1169// TestResumeToken to determine if the receiving of a certain PartialResultSet
1170// will be attempted next.
1171type sReceiver struct {
1172	c           chan int
1173	rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
1174}
1175
1176// Recv() implements streamingReceiver.Recv for sReceiver.
1177func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
1178	sr.c <- 1
1179	return sr.rpcReceiver.Recv()
1180}
1181
1182// waitn waits for nth receiving attempt from now on, until the signal for nth
1183// Recv() attempts is received or timeout. Note that because the way stream()
1184// works, the signal for the nth Recv() means that the previous n - 1
1185// PartialResultSets has already been returned to caller or queued, if no error
1186// happened.
1187func (sr *sReceiver) waitn(n int) error {
1188	for i := 0; i < n; i++ {
1189		select {
1190		case <-sr.c:
1191		case <-time.After(10 * time.Second):
1192			return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
1193		}
1194	}
1195	return nil
1196}
1197
1198// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
1199func TestQueueBytes(t *testing.T) {
1200	restore := setMaxBytesBetweenResumeTokens()
1201	defer restore()
1202	ms := NewMockCloudSpanner(t, trxTs)
1203	ms.Serve()
1204	defer ms.Stop()
1205	cc := dialMock(t, ms)
1206	defer cc.Close()
1207	mc := sppb.NewSpannerClient(cc)
1208	sr := &sReceiver{
1209		c: make(chan int, 1000), // will never block in this test
1210	}
1211	wantQueueBytes := 0
1212	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1213	defer cancel()
1214	r := newResumableStreamDecoder(
1215		ctx,
1216		nil,
1217		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1218			r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1219				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1220				ResumeToken: resumeToken,
1221			})
1222			sr.rpcReceiver = r
1223			return sr, err
1224		},
1225		nil,
1226	)
1227	go func() {
1228		for r.next() {
1229		}
1230	}()
1231	// Let server send maxBuffers / 2 rows.
1232	for i := 0; i < maxBuffers/2; i++ {
1233		wantQueueBytes += proto.Size(&sppb.PartialResultSet{
1234			Metadata: kvMeta,
1235			Values: []*proto3.Value{
1236				{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1237				{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1238			},
1239		})
1240		ms.AddMsg(nil, false)
1241	}
1242	if err := sr.waitn(maxBuffers/2 + 1); err != nil {
1243		t.Fatalf("failed to wait for the first %v recv() calls: %v", maxBuffers, err)
1244	}
1245	if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
1246		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", r.bytesBetweenResumeTokens, wantQueueBytes)
1247	}
1248	// Now send a resume token to drain the queue.
1249	ms.AddMsg(nil, true)
1250	// Wait for all rows to be processes.
1251	if err := sr.waitn(1); err != nil {
1252		t.Fatalf("failed to wait for rows to be processed: %v", err)
1253	}
1254	if r.bytesBetweenResumeTokens != 0 {
1255		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
1256	}
1257	// Let server send maxBuffers - 1 rows.
1258	wantQueueBytes = 0
1259	for i := 0; i < maxBuffers-1; i++ {
1260		wantQueueBytes += proto.Size(&sppb.PartialResultSet{
1261			Metadata: kvMeta,
1262			Values: []*proto3.Value{
1263				{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1264				{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1265			},
1266		})
1267		ms.AddMsg(nil, false)
1268	}
1269	if err := sr.waitn(maxBuffers - 1); err != nil {
1270		t.Fatalf("failed to wait for %v rows to be processed: %v", maxBuffers-1, err)
1271	}
1272	if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
1273		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
1274	}
1275	// Trigger a state transition: queueingRetryable -> queueingUnretryable.
1276	ms.AddMsg(nil, false)
1277	if err := sr.waitn(1); err != nil {
1278		t.Fatalf("failed to wait for state transition: %v", err)
1279	}
1280	if r.bytesBetweenResumeTokens != 0 {
1281		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
1282	}
1283}
1284
1285// Verify that client can deal with resume token correctly
1286func TestResumeToken(t *testing.T) {
1287	restore := setMaxBytesBetweenResumeTokens()
1288	defer restore()
1289	ms := NewMockCloudSpanner(t, trxTs)
1290	ms.Serve()
1291	defer ms.Stop()
1292	cc := dialMock(t, ms)
1293	defer cc.Close()
1294	mc := sppb.NewSpannerClient(cc)
1295	sr := &sReceiver{
1296		c: make(chan int, 1000), // will never block in this test
1297	}
1298	rows := []*Row{}
1299	done := make(chan error)
1300	streaming := func() {
1301		// Establish a stream to mock cloud spanner server.
1302		iter := stream(context.Background(), nil,
1303			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1304				r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1305					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1306					ResumeToken: resumeToken,
1307				})
1308				sr.rpcReceiver = r
1309				return sr, err
1310			},
1311			nil,
1312			func(error) {})
1313		defer iter.Stop()
1314		var err error
1315		for {
1316			var row *Row
1317			row, err = iter.Next()
1318			if err == iterator.Done {
1319				err = nil
1320				break
1321			}
1322			if err != nil {
1323				break
1324			}
1325			rows = append(rows, row)
1326		}
1327		done <- err
1328	}
1329	go streaming()
1330	// Server streaming row 0 - 2, only row 1 has resume token.
1331	// Client will receive row 0 - 2, so it will try receiving for
1332	// 4 times (the last recv will block), and only row 0 - 1 will
1333	// be yielded.
1334	for i := 0; i < 3; i++ {
1335		if i == 1 {
1336			ms.AddMsg(nil, true)
1337		} else {
1338			ms.AddMsg(nil, false)
1339		}
1340	}
1341	// Wait for 4 receive attempts, as explained above.
1342	if err := sr.waitn(4); err != nil {
1343		t.Fatalf("failed to wait for row 0 - 2: %v", err)
1344	}
1345	want := []*Row{
1346		{
1347			fields: kvMeta.RowType.Fields,
1348			vals: []*proto3.Value{
1349				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1350				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1351			},
1352		},
1353		{
1354			fields: kvMeta.RowType.Fields,
1355			vals: []*proto3.Value{
1356				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1357				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1358			},
1359		},
1360	}
1361	if !testEqual(rows, want) {
1362		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1363	}
1364	// Inject resumable failure.
1365	ms.AddMsg(
1366		status.Errorf(codes.Unavailable, "mock server unavailable"),
1367		false,
1368	)
1369	// Test if client detects the resumable failure and retries.
1370	if err := sr.waitn(1); err != nil {
1371		t.Fatalf("failed to wait for client to retry: %v", err)
1372	}
1373	// Client has resumed the query, now server resend row 2.
1374	ms.AddMsg(nil, true)
1375	if err := sr.waitn(1); err != nil {
1376		t.Fatalf("failed to wait for resending row 2: %v", err)
1377	}
1378	// Now client should have received row 0 - 2.
1379	want = append(want, &Row{
1380		fields: kvMeta.RowType.Fields,
1381		vals: []*proto3.Value{
1382			{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
1383			{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
1384		},
1385	})
1386	if !testEqual(rows, want) {
1387		t.Errorf("received rows: \n%v\n, want\n%v\n", rows, want)
1388	}
1389	// Sending 3rd - (maxBuffers+1)th rows without resume tokens, client should buffer them.
1390	for i := 3; i < maxBuffers+2; i++ {
1391		ms.AddMsg(nil, false)
1392	}
1393	if err := sr.waitn(maxBuffers - 1); err != nil {
1394		t.Fatalf("failed to wait for row 3-%v: %v", maxBuffers+1, err)
1395	}
1396	// Received rows should be unchanged.
1397	if !testEqual(rows, want) {
1398		t.Errorf("receive rows: \n%v\n, want\n%v\n", rows, want)
1399	}
1400	// Send (maxBuffers+2)th row to trigger state change of resumableStreamDecoder:
1401	// queueingRetryable -> queueingUnretryable
1402	ms.AddMsg(nil, false)
1403	if err := sr.waitn(1); err != nil {
1404		t.Fatalf("failed to wait for row %v: %v", maxBuffers+2, err)
1405	}
1406	// Client should yield row 3rd - (maxBuffers+2)th to application. Therefore,
1407	// application should see row 0 - (maxBuffers+2)th so far.
1408	for i := 3; i < maxBuffers+3; i++ {
1409		want = append(want, &Row{
1410			fields: kvMeta.RowType.Fields,
1411			vals: []*proto3.Value{
1412				{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1413				{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1414			},
1415		})
1416	}
1417	if !testEqual(rows, want) {
1418		t.Errorf("received rows: \n%v\n; want\n%v\n", rows, want)
1419	}
1420	// Inject resumable error, but since resumableStreamDecoder is already at
1421	// queueingUnretryable state, query will just fail.
1422	ms.AddMsg(
1423		status.Errorf(codes.Unavailable, "mock server wants some sleep"),
1424		false,
1425	)
1426	var gotErr error
1427	select {
1428	case gotErr = <-done:
1429	case <-time.After(10 * time.Second):
1430		t.Fatalf("timeout in waiting for failed query to return.")
1431	}
1432	if wantErr := spannerErrorf(codes.Unavailable, "mock server wants some sleep"); !testEqual(gotErr, wantErr) {
1433		t.Fatalf("stream() returns error: %v, but want error: %v", gotErr, wantErr)
1434	}
1435
1436	// Reconnect to mock Cloud Spanner.
1437	rows = []*Row{}
1438	go streaming()
1439	// Let server send two rows without resume token.
1440	for i := maxBuffers + 3; i < maxBuffers+5; i++ {
1441		ms.AddMsg(nil, false)
1442	}
1443	if err := sr.waitn(3); err != nil {
1444		t.Fatalf("failed to wait for row %v - %v: %v", maxBuffers+3, maxBuffers+5, err)
1445	}
1446	if len(rows) > 0 {
1447		t.Errorf("client received some rows unexpectedly: %v, want nothing", rows)
1448	}
1449	// Let server end the query.
1450	ms.AddMsg(io.EOF, false)
1451	select {
1452	case gotErr = <-done:
1453	case <-time.After(10 * time.Second):
1454		t.Fatalf("timeout in waiting for failed query to return")
1455	}
1456	if gotErr != nil {
1457		t.Fatalf("stream() returns unexpected error: %v, but want no error", gotErr)
1458	}
1459	// Verify if a normal server side EOF flushes all queued rows.
1460	want = []*Row{
1461		{
1462			fields: kvMeta.RowType.Fields,
1463			vals: []*proto3.Value{
1464				{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 3)}},
1465				{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 3)}},
1466			},
1467		},
1468		{
1469			fields: kvMeta.RowType.Fields,
1470			vals: []*proto3.Value{
1471				{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 4)}},
1472				{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 4)}},
1473			},
1474		},
1475	}
1476	if !testEqual(rows, want) {
1477		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1478	}
1479}
1480
1481// Verify that streaming query get retried upon real gRPC server transport
1482// failures.
1483func TestGrpcReconnect(t *testing.T) {
1484	restore := setMaxBytesBetweenResumeTokens()
1485	defer restore()
1486	ms := NewMockCloudSpanner(t, trxTs)
1487	ms.Serve()
1488	defer ms.Stop()
1489	cc := dialMock(t, ms)
1490	defer cc.Close()
1491	mc := sppb.NewSpannerClient(cc)
1492	retry := make(chan int)
1493	row := make(chan int)
1494	var err error
1495	go func() {
1496		r := 0
1497		// Establish a stream to mock cloud spanner server.
1498		iter := stream(context.Background(), nil,
1499			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1500				if r > 0 {
1501					// This RPC attempt is a retry, signal it.
1502					retry <- r
1503				}
1504				r++
1505				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1506					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1507					ResumeToken: resumeToken,
1508				})
1509
1510			},
1511			nil,
1512			func(error) {})
1513		defer iter.Stop()
1514		for {
1515			_, err = iter.Next()
1516			if err == iterator.Done {
1517				err = nil
1518				break
1519			}
1520			if err != nil {
1521				break
1522			}
1523			row <- 0
1524		}
1525	}()
1526	// Add a message and wait for the receipt.
1527	ms.AddMsg(nil, true)
1528	select {
1529	case <-row:
1530	case <-time.After(10 * time.Second):
1531		t.Fatalf("expect stream to be established within 10 seconds, but it didn't")
1532	}
1533	// Error injection: force server to close all connections.
1534	ms.Stop()
1535	// Test to see if client respond to the real RPC failure correctly by
1536	// retrying RPC.
1537	select {
1538	case r, ok := <-retry:
1539		if ok && r == 1 {
1540			break
1541		}
1542		t.Errorf("retry count = %v, want 1", r)
1543	case <-time.After(10 * time.Second):
1544		t.Errorf("client library failed to respond after 10 seconds, aborting")
1545		return
1546	}
1547}
1548
1549// Test cancel/timeout for client operations.
1550func TestCancelTimeout(t *testing.T) {
1551	restore := setMaxBytesBetweenResumeTokens()
1552	defer restore()
1553	ms := NewMockCloudSpanner(t, trxTs)
1554	ms.Serve()
1555	defer ms.Stop()
1556	cc := dialMock(t, ms)
1557	defer cc.Close()
1558	mc := sppb.NewSpannerClient(cc)
1559	done := make(chan int)
1560	go func() {
1561		for {
1562			ms.AddMsg(nil, true)
1563		}
1564	}()
1565	// Test cancelling query.
1566	ctx, cancel := context.WithCancel(context.Background())
1567	var err error
1568	go func() {
1569		// Establish a stream to mock cloud spanner server.
1570		iter := stream(ctx, nil,
1571			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1572				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1573					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1574					ResumeToken: resumeToken,
1575				})
1576			},
1577			nil,
1578			func(error) {})
1579		defer iter.Stop()
1580		for {
1581			_, err = iter.Next()
1582			if err == iterator.Done {
1583				break
1584			}
1585			if err != nil {
1586				done <- 0
1587				break
1588			}
1589		}
1590	}()
1591	cancel()
1592	select {
1593	case <-done:
1594		if ErrCode(err) != codes.Canceled {
1595			t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
1596		}
1597	case <-time.After(1 * time.Second):
1598		t.Errorf("query doesn't exit timely after being cancelled")
1599	}
1600	// Test query timeout.
1601	ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
1602	defer cancel()
1603	go func() {
1604		// Establish a stream to mock cloud spanner server.
1605		iter := stream(ctx, nil,
1606			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1607				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1608					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1609					ResumeToken: resumeToken,
1610				})
1611			},
1612			nil,
1613			func(error) {})
1614		defer iter.Stop()
1615		for {
1616			_, err = iter.Next()
1617			if err == iterator.Done {
1618				err = nil
1619				break
1620			}
1621			if err != nil {
1622				break
1623			}
1624		}
1625		done <- 0
1626	}()
1627	select {
1628	case <-done:
1629		if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
1630			t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
1631		}
1632	case <-time.After(2 * time.Second):
1633		t.Errorf("query doesn't timeout as expected")
1634	}
1635}
1636
1637func TestRowIteratorDo(t *testing.T) {
1638	restore := setMaxBytesBetweenResumeTokens()
1639	defer restore()
1640	ms := NewMockCloudSpanner(t, trxTs)
1641	ms.Serve()
1642	defer ms.Stop()
1643	cc := dialMock(t, ms)
1644	defer cc.Close()
1645	mc := sppb.NewSpannerClient(cc)
1646
1647	for i := 0; i < 3; i++ {
1648		ms.AddMsg(nil, false)
1649	}
1650	ms.AddMsg(io.EOF, true)
1651	nRows := 0
1652	iter := stream(context.Background(), nil,
1653		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1654			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1655				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1656				ResumeToken: resumeToken,
1657			})
1658		},
1659		nil,
1660		func(error) {})
1661	err := iter.Do(func(r *Row) error { nRows++; return nil })
1662	if err != nil {
1663		t.Errorf("Using Do: %v", err)
1664	}
1665	if nRows != 3 {
1666		t.Errorf("got %d rows, want 3", nRows)
1667	}
1668}
1669
1670func TestRowIteratorDoWithError(t *testing.T) {
1671	restore := setMaxBytesBetweenResumeTokens()
1672	defer restore()
1673	ms := NewMockCloudSpanner(t, trxTs)
1674	ms.Serve()
1675	defer ms.Stop()
1676	cc := dialMock(t, ms)
1677	defer cc.Close()
1678	mc := sppb.NewSpannerClient(cc)
1679
1680	for i := 0; i < 3; i++ {
1681		ms.AddMsg(nil, false)
1682	}
1683	ms.AddMsg(io.EOF, true)
1684	iter := stream(context.Background(), nil,
1685		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1686			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1687				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1688				ResumeToken: resumeToken,
1689			})
1690		},
1691		nil,
1692		func(error) {})
1693	injected := errors.New("Failed iterator")
1694	err := iter.Do(func(r *Row) error { return injected })
1695	if err != injected {
1696		t.Errorf("got <%v>, want <%v>", err, injected)
1697	}
1698}
1699
1700func TestIteratorStopEarly(t *testing.T) {
1701	ctx := context.Background()
1702	restore := setMaxBytesBetweenResumeTokens()
1703	defer restore()
1704	ms := NewMockCloudSpanner(t, trxTs)
1705	ms.Serve()
1706	defer ms.Stop()
1707	cc := dialMock(t, ms)
1708	defer cc.Close()
1709	mc := sppb.NewSpannerClient(cc)
1710
1711	ms.AddMsg(nil, false)
1712	ms.AddMsg(nil, false)
1713	ms.AddMsg(io.EOF, true)
1714
1715	iter := stream(ctx, nil,
1716		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1717			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1718				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1719				ResumeToken: resumeToken,
1720			})
1721		},
1722		nil,
1723		func(error) {})
1724	_, err := iter.Next()
1725	if err != nil {
1726		t.Fatalf("before Stop: %v", err)
1727	}
1728	iter.Stop()
1729	// Stop sets r.err to the FailedPrecondition error "Next called after Stop".
1730	_, err = iter.Next()
1731	if g, w := ErrCode(err), codes.FailedPrecondition; g != w {
1732		t.Errorf("after Stop: got: %v, want: %v", g, w)
1733	}
1734}
1735
1736func TestIteratorWithError(t *testing.T) {
1737	injected := errors.New("Failed iterator")
1738	iter := RowIterator{err: injected}
1739	defer iter.Stop()
1740	if _, err := iter.Next(); err != injected {
1741		t.Fatalf("Expected error: %v, got %v", injected, err)
1742	}
1743}
1744
1745func dialMock(t *testing.T, ms *MockCloudSpanner) *grpc.ClientConn {
1746	cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock())
1747	if err != nil {
1748		t.Fatalf("Dial(%q) = %v", ms.Addr(), err)
1749	}
1750	return cc
1751}
1752