1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"io"
24	"sync/atomic"
25	"testing"
26	"time"
27
28	"cloud.google.com/go/spanner/internal/backoff"
29	"cloud.google.com/go/spanner/internal/testutil"
30	"github.com/golang/protobuf/proto"
31	proto3 "github.com/golang/protobuf/ptypes/struct"
32	"google.golang.org/api/iterator"
33	sppb "google.golang.org/genproto/googleapis/spanner/v1"
34	"google.golang.org/grpc"
35	"google.golang.org/grpc/codes"
36	"google.golang.org/grpc/status"
37)
38
39var (
40	// Mocked transaction timestamp.
41	trxTs = time.Unix(1, 2)
42	// Metadata for mocked KV table, its rows are returned by SingleUse
43	// transactions.
44	kvMeta = func() *sppb.ResultSetMetadata {
45		meta := testutil.KvMeta
46		meta.Transaction = &sppb.Transaction{
47			ReadTimestamp: timestampProto(trxTs),
48		}
49		return &meta
50	}()
51	// Metadata for mocked ListKV table, which uses List for its key and value.
52	// Its rows are returned by snapshot readonly transactions, as indicated in
53	// the transaction metadata.
54	kvListMeta = &sppb.ResultSetMetadata{
55		RowType: &sppb.StructType{
56			Fields: []*sppb.StructType_Field{
57				{
58					Name: "Key",
59					Type: &sppb.Type{
60						Code: sppb.TypeCode_ARRAY,
61						ArrayElementType: &sppb.Type{
62							Code: sppb.TypeCode_STRING,
63						},
64					},
65				},
66				{
67					Name: "Value",
68					Type: &sppb.Type{
69						Code: sppb.TypeCode_ARRAY,
70						ArrayElementType: &sppb.Type{
71							Code: sppb.TypeCode_STRING,
72						},
73					},
74				},
75			},
76		},
77		Transaction: &sppb.Transaction{
78			Id:            transactionID{5, 6, 7, 8, 9},
79			ReadTimestamp: timestampProto(trxTs),
80		},
81	}
82	// Metadata for mocked schema of a query result set, which has two struct
83	// columns named "Col1" and "Col2", the struct's schema is like the
84	// following:
85	//
86	//	STRUCT {
87	//		INT
88	//		LIST<STRING>
89	//	}
90	//
91	// Its rows are returned in readwrite transaction, as indicated in the
92	// transaction metadata.
93	kvObjectMeta = &sppb.ResultSetMetadata{
94		RowType: &sppb.StructType{
95			Fields: []*sppb.StructType_Field{
96				{
97					Name: "Col1",
98					Type: &sppb.Type{
99						Code: sppb.TypeCode_STRUCT,
100						StructType: &sppb.StructType{
101							Fields: []*sppb.StructType_Field{
102								{
103									Name: "foo-f1",
104									Type: &sppb.Type{
105										Code: sppb.TypeCode_INT64,
106									},
107								},
108								{
109									Name: "foo-f2",
110									Type: &sppb.Type{
111										Code: sppb.TypeCode_ARRAY,
112										ArrayElementType: &sppb.Type{
113											Code: sppb.TypeCode_STRING,
114										},
115									},
116								},
117							},
118						},
119					},
120				},
121				{
122					Name: "Col2",
123					Type: &sppb.Type{
124						Code: sppb.TypeCode_STRUCT,
125						StructType: &sppb.StructType{
126							Fields: []*sppb.StructType_Field{
127								{
128									Name: "bar-f1",
129									Type: &sppb.Type{
130										Code: sppb.TypeCode_INT64,
131									},
132								},
133								{
134									Name: "bar-f2",
135									Type: &sppb.Type{
136										Code: sppb.TypeCode_ARRAY,
137										ArrayElementType: &sppb.Type{
138											Code: sppb.TypeCode_STRING,
139										},
140									},
141								},
142							},
143						},
144					},
145				},
146			},
147		},
148		Transaction: &sppb.Transaction{
149			Id: transactionID{1, 2, 3, 4, 5},
150		},
151	}
152)
153
154// String implements fmt.stringer.
155func (r *Row) String() string {
156	return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals)
157}
158
159func describeRows(l []*Row) string {
160	// generate a nice test failure description
161	var s = "["
162	for i, r := range l {
163		if i != 0 {
164			s += ",\n "
165		}
166		s += fmt.Sprint(r)
167	}
168	s += "]"
169	return s
170}
171
172// Helper for generating proto3 Value_ListValue instances, making test code
173// shorter and readable.
174func genProtoListValue(v ...string) *proto3.Value_ListValue {
175	r := &proto3.Value_ListValue{
176		ListValue: &proto3.ListValue{
177			Values: []*proto3.Value{},
178		},
179	}
180	for _, e := range v {
181		r.ListValue.Values = append(
182			r.ListValue.Values,
183			&proto3.Value{
184				Kind: &proto3.Value_StringValue{StringValue: e},
185			},
186		)
187	}
188	return r
189}
190
191// Test Row generation logics of partialResultSetDecoder.
192func TestPartialResultSetDecoder(t *testing.T) {
193	restore := setMaxBytesBetweenResumeTokens()
194	defer restore()
195	var tests = []struct {
196		input    []*sppb.PartialResultSet
197		wantF    []*Row
198		wantTxID transactionID
199		wantTs   time.Time
200		wantD    bool
201	}{
202		{
203			// Empty input.
204			wantD: true,
205		},
206		// String merging examples.
207		{
208			// Single KV result.
209			input: []*sppb.PartialResultSet{
210				{
211					Metadata: kvMeta,
212					Values: []*proto3.Value{
213						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
214						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
215					},
216				},
217			},
218			wantF: []*Row{
219				{
220					fields: kvMeta.RowType.Fields,
221					vals: []*proto3.Value{
222						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
223						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
224					},
225				},
226			},
227			wantTs: trxTs,
228			wantD:  true,
229		},
230		{
231			// Incomplete partial result.
232			input: []*sppb.PartialResultSet{
233				{
234					Metadata: kvMeta,
235					Values: []*proto3.Value{
236						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
237					},
238				},
239			},
240			wantTs: trxTs,
241			wantD:  false,
242		},
243		{
244			// Complete splitted result.
245			input: []*sppb.PartialResultSet{
246				{
247					Metadata: kvMeta,
248					Values: []*proto3.Value{
249						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
250					},
251				},
252				{
253					Values: []*proto3.Value{
254						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
255					},
256				},
257			},
258			wantF: []*Row{
259				{
260					fields: kvMeta.RowType.Fields,
261					vals: []*proto3.Value{
262						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
263						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
264					},
265				},
266			},
267			wantTs: trxTs,
268			wantD:  true,
269		},
270		{
271			// Multi-row example with splitted row in the middle.
272			input: []*sppb.PartialResultSet{
273				{
274					Metadata: kvMeta,
275					Values: []*proto3.Value{
276						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
277						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
278						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
279					},
280				},
281				{
282					Values: []*proto3.Value{
283						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
284						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
285						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
286					},
287				},
288			},
289			wantF: []*Row{
290				{
291					fields: kvMeta.RowType.Fields,
292					vals: []*proto3.Value{
293						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
294						{Kind: &proto3.Value_StringValue{StringValue: "bar"}},
295					},
296				},
297				{
298					fields: kvMeta.RowType.Fields,
299					vals: []*proto3.Value{
300						{Kind: &proto3.Value_StringValue{StringValue: "A"}},
301						{Kind: &proto3.Value_StringValue{StringValue: "1"}},
302					},
303				},
304				{
305					fields: kvMeta.RowType.Fields,
306					vals: []*proto3.Value{
307						{Kind: &proto3.Value_StringValue{StringValue: "B"}},
308						{Kind: &proto3.Value_StringValue{StringValue: "2"}},
309					},
310				},
311			},
312			wantTs: trxTs,
313			wantD:  true,
314		},
315		{
316			// Merging example in result_set.proto.
317			input: []*sppb.PartialResultSet{
318				{
319					Metadata: kvMeta,
320					Values: []*proto3.Value{
321						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
322						{Kind: &proto3.Value_StringValue{StringValue: "W"}},
323					},
324					ChunkedValue: true,
325				},
326				{
327					Values: []*proto3.Value{
328						{Kind: &proto3.Value_StringValue{StringValue: "orl"}},
329					},
330					ChunkedValue: true,
331				},
332				{
333					Values: []*proto3.Value{
334						{Kind: &proto3.Value_StringValue{StringValue: "d"}},
335					},
336				},
337			},
338			wantF: []*Row{
339				{
340					fields: kvMeta.RowType.Fields,
341					vals: []*proto3.Value{
342						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
343						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
344					},
345				},
346			},
347			wantTs: trxTs,
348			wantD:  true,
349		},
350		{
351			// More complex example showing completing a merge and
352			// starting a new merge in the same partialResultSet.
353			input: []*sppb.PartialResultSet{
354				{
355					Metadata: kvMeta,
356					Values: []*proto3.Value{
357						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
358						{Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
359					},
360					ChunkedValue: true,
361				},
362				{
363					Values: []*proto3.Value{
364						{Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
365						{Kind: &proto3.Value_StringValue{StringValue: "i"}},    // start split in key
366					},
367					ChunkedValue: true,
368				},
369				{
370					Values: []*proto3.Value{
371						{Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
372						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
373						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
374						{Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
375					},
376					ChunkedValue: true,
377				},
378				{
379					Values: []*proto3.Value{
380						{Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
381					},
382				},
383			},
384			wantF: []*Row{
385				{
386					fields: kvMeta.RowType.Fields,
387					vals: []*proto3.Value{
388						{Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
389						{Kind: &proto3.Value_StringValue{StringValue: "World"}},
390					},
391				},
392				{
393					fields: kvMeta.RowType.Fields,
394					vals: []*proto3.Value{
395						{Kind: &proto3.Value_StringValue{StringValue: "is"}},
396						{Kind: &proto3.Value_StringValue{StringValue: "not"}},
397					},
398				},
399				{
400					fields: kvMeta.RowType.Fields,
401					vals: []*proto3.Value{
402						{Kind: &proto3.Value_StringValue{StringValue: "a"}},
403						{Kind: &proto3.Value_StringValue{StringValue: "question"}},
404					},
405				},
406			},
407			wantTs: trxTs,
408			wantD:  true,
409		},
410		// List merging examples.
411		{
412			// Non-splitting Lists.
413			input: []*sppb.PartialResultSet{
414				{
415					Metadata: kvListMeta,
416					Values: []*proto3.Value{
417						{
418							Kind: genProtoListValue("foo-1", "foo-2"),
419						},
420					},
421				},
422				{
423					Values: []*proto3.Value{
424						{
425							Kind: genProtoListValue("bar-1", "bar-2"),
426						},
427					},
428				},
429			},
430			wantF: []*Row{
431				{
432					fields: kvListMeta.RowType.Fields,
433					vals: []*proto3.Value{
434						{
435							Kind: genProtoListValue("foo-1", "foo-2"),
436						},
437						{
438							Kind: genProtoListValue("bar-1", "bar-2"),
439						},
440					},
441				},
442			},
443			wantTxID: transactionID{5, 6, 7, 8, 9},
444			wantTs:   trxTs,
445			wantD:    true,
446		},
447		{
448			// Simple List merge case: splitted string element.
449			input: []*sppb.PartialResultSet{
450				{
451					Metadata: kvListMeta,
452					Values: []*proto3.Value{
453						{
454							Kind: genProtoListValue("foo-1", "foo-"),
455						},
456					},
457					ChunkedValue: true,
458				},
459				{
460					Values: []*proto3.Value{
461						{
462							Kind: genProtoListValue("2"),
463						},
464					},
465				},
466				{
467					Values: []*proto3.Value{
468						{
469							Kind: genProtoListValue("bar-1", "bar-2"),
470						},
471					},
472				},
473			},
474			wantF: []*Row{
475				{
476					fields: kvListMeta.RowType.Fields,
477					vals: []*proto3.Value{
478						{
479							Kind: genProtoListValue("foo-1", "foo-2"),
480						},
481						{
482							Kind: genProtoListValue("bar-1", "bar-2"),
483						},
484					},
485				},
486			},
487			wantTxID: transactionID{5, 6, 7, 8, 9},
488			wantTs:   trxTs,
489			wantD:    true,
490		},
491		{
492			// Struct merging is also implemented by List merging. Note that
493			// Cloud Spanner uses proto.ListValue to encode Structs as well.
494			input: []*sppb.PartialResultSet{
495				{
496					Metadata: kvObjectMeta,
497					Values: []*proto3.Value{
498						{
499							Kind: &proto3.Value_ListValue{
500								ListValue: &proto3.ListValue{
501									Values: []*proto3.Value{
502										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
503										{Kind: genProtoListValue("foo-1", "fo")},
504									},
505								},
506							},
507						},
508					},
509					ChunkedValue: true,
510				},
511				{
512					Values: []*proto3.Value{
513						{
514							Kind: &proto3.Value_ListValue{
515								ListValue: &proto3.ListValue{
516									Values: []*proto3.Value{
517										{Kind: genProtoListValue("o-2", "f")},
518									},
519								},
520							},
521						},
522					},
523					ChunkedValue: true,
524				},
525				{
526					Values: []*proto3.Value{
527						{
528							Kind: &proto3.Value_ListValue{
529								ListValue: &proto3.ListValue{
530									Values: []*proto3.Value{
531										{Kind: genProtoListValue("oo-3")},
532									},
533								},
534							},
535						},
536						{
537							Kind: &proto3.Value_ListValue{
538								ListValue: &proto3.ListValue{
539									Values: []*proto3.Value{
540										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
541										{Kind: genProtoListValue("bar-1")},
542									},
543								},
544							},
545						},
546					},
547				},
548			},
549			wantF: []*Row{
550				{
551					fields: kvObjectMeta.RowType.Fields,
552					vals: []*proto3.Value{
553						{
554							Kind: &proto3.Value_ListValue{
555								ListValue: &proto3.ListValue{
556									Values: []*proto3.Value{
557										{Kind: &proto3.Value_NumberValue{NumberValue: 23}},
558										{Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
559									},
560								},
561							},
562						},
563						{
564							Kind: &proto3.Value_ListValue{
565								ListValue: &proto3.ListValue{
566									Values: []*proto3.Value{
567										{Kind: &proto3.Value_NumberValue{NumberValue: 45}},
568										{Kind: genProtoListValue("bar-1")},
569									},
570								},
571							},
572						},
573					},
574				},
575			},
576			wantTxID: transactionID{1, 2, 3, 4, 5},
577			wantD:    true,
578		},
579	}
580
581nextTest:
582	for i, test := range tests {
583		var rows []*Row
584		p := &partialResultSetDecoder{}
585		for j, v := range test.input {
586			rs, err := p.add(v)
587			if err != nil {
588				t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
589				continue nextTest
590			}
591			rows = append(rows, rs...)
592		}
593		if !testEqual(p.ts, test.wantTs) {
594			t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
595		}
596		if !testEqual(rows, test.wantF) {
597			t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
598		}
599		if got := p.done(); got != test.wantD {
600			t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
601		}
602	}
603}
604
605const (
606	// max number of PartialResultSets that will be buffered in tests.
607	maxBuffers = 16
608)
609
610// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to
611// a smaller value more suitable for tests. It returns a function which should
612// be called to restore the maxBytesBetweenResumeTokens to its old value.
613func setMaxBytesBetweenResumeTokens() func() {
614	o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
615	atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
616		Metadata: kvMeta,
617		Values: []*proto3.Value{
618			{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
619			{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
620		},
621	})))
622	return func() {
623		atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
624	}
625}
626
627// keyStr generates key string for kvMeta schema.
628func keyStr(i int) string {
629	return fmt.Sprintf("foo-%02d", i)
630}
631
632// valStr generates value string for kvMeta schema.
633func valStr(i int) string {
634	return fmt.Sprintf("bar-%02d", i)
635}
636
637// Test state transitions of resumableStreamDecoder where state machine ends up
638// to a non-blocking state(resumableStreamDecoder.Next returns on non-blocking
639// state).
640func TestRsdNonblockingStates(t *testing.T) {
641	restore := setMaxBytesBetweenResumeTokens()
642	defer restore()
643	tests := []struct {
644		name string
645		msgs []testutil.MockCtlMsg
646		rpc  func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
647		sql  string
648		// Expected values
649		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
650		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
651		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
652		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
653		wantErr      error
654	}{
655		{
656			// unConnected->queueingRetryable->finished
657			name: "unConnected->queueingRetryable->finished",
658			msgs: []testutil.MockCtlMsg{
659				{},
660				{},
661				{Err: io.EOF, ResumeToken: false},
662			},
663			sql: "SELECT t.key key, t.value value FROM t_mock t",
664			want: []*sppb.PartialResultSet{
665				{
666					Metadata: kvMeta,
667					Values: []*proto3.Value{
668						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
669						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
670					},
671				},
672			},
673			queue: []*sppb.PartialResultSet{
674				{
675					Metadata: kvMeta,
676					Values: []*proto3.Value{
677						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
678						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
679					},
680				},
681			},
682			stateHistory: []resumableStreamDecoderState{
683				queueingRetryable, // do RPC
684				queueingRetryable, // got foo-00
685				queueingRetryable, // got foo-01
686				finished,          // got EOF
687			},
688		},
689		{
690			// unConnected->queueingRetryable->aborted
691			name: "unConnected->queueingRetryable->aborted",
692			msgs: []testutil.MockCtlMsg{
693				{},
694				{Err: nil, ResumeToken: true},
695				{},
696				{Err: errors.New("I quit"), ResumeToken: false},
697			},
698			sql: "SELECT t.key key, t.value value FROM t_mock t",
699			want: []*sppb.PartialResultSet{
700				{
701					Metadata: kvMeta,
702					Values: []*proto3.Value{
703						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
704						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
705					},
706				},
707				{
708					Metadata: kvMeta,
709					Values: []*proto3.Value{
710						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
711						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
712					},
713					ResumeToken: testutil.EncodeResumeToken(1),
714				},
715			},
716			stateHistory: []resumableStreamDecoderState{
717				queueingRetryable, // do RPC
718				queueingRetryable, // got foo-00
719				queueingRetryable, // got foo-01
720				queueingRetryable, // foo-01, resume token
721				queueingRetryable, // got foo-02
722				aborted,           // got error
723			},
724			wantErr: status.Errorf(codes.Unknown, "I quit"),
725		},
726		{
727			// unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
728			name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
729			msgs: func() (m []testutil.MockCtlMsg) {
730				for i := 0; i < maxBuffers+1; i++ {
731					m = append(m, testutil.MockCtlMsg{})
732				}
733				return m
734			}(),
735			sql: "SELECT t.key key, t.value value FROM t_mock t",
736			want: func() (s []*sppb.PartialResultSet) {
737				for i := 0; i < maxBuffers+1; i++ {
738					s = append(s, &sppb.PartialResultSet{
739						Metadata: kvMeta,
740						Values: []*proto3.Value{
741							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
742							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
743						},
744					})
745				}
746				return s
747			}(),
748			stateHistory: func() (s []resumableStreamDecoderState) {
749				s = append(s, queueingRetryable) // RPC
750				for i := 0; i < maxBuffers; i++ {
751					s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
752				}
753				// the first item fills up the queue and triggers state transition;
754				// the second item is received under queueingUnretryable state.
755				s = append(s, queueingUnretryable)
756				s = append(s, queueingUnretryable)
757				return s
758			}(),
759		},
760		{
761			// unConnected->queueingRetryable->queueingUnretryable->aborted
762			name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
763			msgs: func() (m []testutil.MockCtlMsg) {
764				for i := 0; i < maxBuffers; i++ {
765					m = append(m, testutil.MockCtlMsg{})
766				}
767				m = append(m, testutil.MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false})
768				return m
769			}(),
770			sql: "SELECT t.key key, t.value value FROM t_mock t",
771			want: func() (s []*sppb.PartialResultSet) {
772				for i := 0; i < maxBuffers; i++ {
773					s = append(s, &sppb.PartialResultSet{
774						Metadata: kvMeta,
775						Values: []*proto3.Value{
776							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
777							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
778						},
779					})
780				}
781				return s
782			}(),
783			stateHistory: func() (s []resumableStreamDecoderState) {
784				s = append(s, queueingRetryable) // RPC
785				for i := 0; i < maxBuffers; i++ {
786					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
787				}
788				s = append(s, queueingUnretryable) // the last row triggers state change
789				s = append(s, aborted)             // Error happens
790				return s
791			}(),
792			wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
793		},
794	}
795	for _, test := range tests {
796		t.Run(test.name, func(t *testing.T) {
797			ms := testutil.NewMockCloudSpanner(t, trxTs)
798			ms.Serve()
799			mc := sppb.NewSpannerClient(dialMock(t, ms))
800			if test.rpc == nil {
801				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
802					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
803						Sql:         test.sql,
804						ResumeToken: resumeToken,
805					})
806				}
807			}
808			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
809			defer cancel()
810			r := newResumableStreamDecoder(
811				ctx,
812				test.rpc,
813			)
814			st := []resumableStreamDecoderState{}
815			var lastErr error
816			// Once the expected number of state transitions are observed,
817			// send a signal by setting stateDone = true.
818			stateDone := false
819			// Set stateWitness to listen to state changes.
820			hl := len(test.stateHistory) // To avoid data race on test.
821			r.stateWitness = func(rs resumableStreamDecoderState) {
822				if !stateDone {
823					// Record state transitions.
824					st = append(st, rs)
825					if len(st) == hl {
826						lastErr = r.lastErr()
827						stateDone = true
828					}
829				}
830			}
831			// Let mock server stream given messages to resumableStreamDecoder.
832			for _, m := range test.msgs {
833				ms.AddMsg(m.Err, m.ResumeToken)
834			}
835			var rs []*sppb.PartialResultSet
836			for {
837				select {
838				case <-ctx.Done():
839					t.Fatal("context cancelled or timeout during test")
840				default:
841				}
842				if stateDone {
843					// Check if resumableStreamDecoder carried out expected
844					// state transitions.
845					if !testEqual(st, test.stateHistory) {
846						t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
847					}
848					// Check if resumableStreamDecoder returns expected array of
849					// PartialResultSets.
850					if !testEqual(rs, test.want) {
851						t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
852					}
853					// Verify that resumableStreamDecoder's internal buffering is
854					// also correct.
855					var q []*sppb.PartialResultSet
856					for {
857						item := r.q.pop()
858						if item == nil {
859							break
860						}
861						q = append(q, item)
862					}
863					if !testEqual(q, test.queue) {
864						t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
865					}
866					// Verify resume token.
867					if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
868						t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
869					}
870					// Verify error message.
871					if !testEqual(lastErr, test.wantErr) {
872						t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
873					}
874					return
875				}
876				// Receive next decoded item.
877				if r.next() {
878					rs = append(rs, r.get())
879				}
880			}
881		})
882	}
883}
884
885// Test state transitions of resumableStreamDecoder where state machine
886// ends up to a blocking state(resumableStreamDecoder.Next blocks
887// on blocking state).
888func TestRsdBlockingStates(t *testing.T) {
889	restore := setMaxBytesBetweenResumeTokens()
890	defer restore()
891	for _, test := range []struct {
892		name string
893		msgs []testutil.MockCtlMsg
894		rpc  func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
895		sql  string
896		// Expected values
897		want         []*sppb.PartialResultSet      // PartialResultSets that should be returned to caller
898		queue        []*sppb.PartialResultSet      // PartialResultSets that should be buffered
899		resumeToken  []byte                        // Resume token that is maintained by resumableStreamDecoder
900		stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
901		wantErr      error
902	}{
903		{
904			// unConnected -> unConnected
905			name: "unConnected -> unConnected",
906			rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
907				return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
908			},
909			sql:          "SELECT * from t_whatever",
910			stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
911			wantErr:      status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
912		},
913		{
914			// unConnected -> queueingRetryable
915			name:         "unConnected -> queueingRetryable",
916			sql:          "SELECT t.key key, t.value value FROM t_mock t",
917			stateHistory: []resumableStreamDecoderState{queueingRetryable},
918		},
919		{
920			// unConnected->queueingRetryable->queueingRetryable
921			name: "unConnected->queueingRetryable->queueingRetryable",
922			msgs: []testutil.MockCtlMsg{
923				{},
924				{Err: nil, ResumeToken: true},
925				{Err: nil, ResumeToken: true},
926				{},
927			},
928			sql: "SELECT t.key key, t.value value FROM t_mock t",
929			want: []*sppb.PartialResultSet{
930				{
931					Metadata: kvMeta,
932					Values: []*proto3.Value{
933						{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
934						{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
935					},
936				},
937				{
938					Metadata: kvMeta,
939					Values: []*proto3.Value{
940						{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
941						{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
942					},
943					ResumeToken: testutil.EncodeResumeToken(1),
944				},
945				{
946					Metadata: kvMeta,
947					Values: []*proto3.Value{
948						{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
949						{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
950					},
951					ResumeToken: testutil.EncodeResumeToken(2),
952				},
953			},
954			queue: []*sppb.PartialResultSet{
955				{
956					Metadata: kvMeta,
957					Values: []*proto3.Value{
958						{Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
959						{Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
960					},
961				},
962			},
963			resumeToken: testutil.EncodeResumeToken(2),
964			stateHistory: []resumableStreamDecoderState{
965				queueingRetryable, // do RPC
966				queueingRetryable, // got foo-00
967				queueingRetryable, // got foo-01
968				queueingRetryable, // foo-01, resume token
969				queueingRetryable, // got foo-02
970				queueingRetryable, // foo-02, resume token
971				queueingRetryable, // got foo-03
972			},
973		},
974		{
975			// unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
976			name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
977			msgs: func() (m []testutil.MockCtlMsg) {
978				for i := 0; i < maxBuffers+1; i++ {
979					m = append(m, testutil.MockCtlMsg{})
980				}
981				m = append(m, testutil.MockCtlMsg{Err: nil, ResumeToken: true})
982				m = append(m, testutil.MockCtlMsg{})
983				return m
984			}(),
985			sql: "SELECT t.key key, t.value value FROM t_mock t",
986			want: func() (s []*sppb.PartialResultSet) {
987				for i := 0; i < maxBuffers+2; i++ {
988					s = append(s, &sppb.PartialResultSet{
989						Metadata: kvMeta,
990						Values: []*proto3.Value{
991							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
992							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
993						},
994					})
995				}
996				s[maxBuffers+1].ResumeToken = testutil.EncodeResumeToken(maxBuffers + 1)
997				return s
998			}(),
999			resumeToken: testutil.EncodeResumeToken(maxBuffers + 1),
1000			queue: []*sppb.PartialResultSet{
1001				{
1002					Metadata: kvMeta,
1003					Values: []*proto3.Value{
1004						{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
1005						{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
1006					},
1007				},
1008			},
1009			stateHistory: func() (s []resumableStreamDecoderState) {
1010				s = append(s, queueingRetryable) // RPC
1011				for i := 0; i < maxBuffers; i++ {
1012					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
1013				}
1014				for i := maxBuffers - 1; i < maxBuffers+1; i++ {
1015					// the first item fills up the queue and triggers state
1016					// change; the second item is received under
1017					// queueingUnretryable state.
1018					s = append(s, queueingUnretryable)
1019				}
1020				s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
1021				s = append(s, queueingRetryable)   // (maxBuffers+1)th row has resume token
1022				s = append(s, queueingRetryable)   // (maxBuffers+2)th row has no resume token
1023				return s
1024			}(),
1025		},
1026		{
1027			// unConnected->queueingRetryable->queueingUnretryable->finished
1028			name: "unConnected->queueingRetryable->queueingUnretryable->finished",
1029			msgs: func() (m []testutil.MockCtlMsg) {
1030				for i := 0; i < maxBuffers; i++ {
1031					m = append(m, testutil.MockCtlMsg{})
1032				}
1033				m = append(m, testutil.MockCtlMsg{Err: io.EOF, ResumeToken: false})
1034				return m
1035			}(),
1036			sql: "SELECT t.key key, t.value value FROM t_mock t",
1037			want: func() (s []*sppb.PartialResultSet) {
1038				for i := 0; i < maxBuffers; i++ {
1039					s = append(s, &sppb.PartialResultSet{
1040						Metadata: kvMeta,
1041						Values: []*proto3.Value{
1042							{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1043							{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1044						},
1045					})
1046				}
1047				return s
1048			}(),
1049			stateHistory: func() (s []resumableStreamDecoderState) {
1050				s = append(s, queueingRetryable) // RPC
1051				for i := 0; i < maxBuffers; i++ {
1052					s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
1053				}
1054				s = append(s, queueingUnretryable) // last row triggers state change
1055				s = append(s, finished)            // query finishes
1056				return s
1057			}(),
1058		},
1059	} {
1060		t.Run(test.name, func(t *testing.T) {
1061			ms := testutil.NewMockCloudSpanner(t, trxTs)
1062			ms.Serve()
1063			cc := dialMock(t, ms)
1064			mc := sppb.NewSpannerClient(cc)
1065			if test.rpc == nil {
1066				// Avoid using test.sql directly in closure because for loop changes
1067				// test.
1068				sql := test.sql
1069				test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1070					return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1071						Sql:         sql,
1072						ResumeToken: resumeToken,
1073					})
1074				}
1075			}
1076			ctx, cancel := context.WithCancel(context.Background())
1077			defer cancel()
1078			r := newResumableStreamDecoder(
1079				ctx,
1080				test.rpc,
1081			)
1082			// Override backoff to make the test run faster.
1083			r.backoff = backoff.ExponentialBackoff{
1084				Min: 1 * time.Nanosecond,
1085				Max: 1 * time.Nanosecond,
1086			}
1087			// st is the set of observed state transitions.
1088			st := []resumableStreamDecoderState{}
1089			// q is the content of the decoder's partial result queue when expected
1090			// number of state transitions are done.
1091			q := []*sppb.PartialResultSet{}
1092			var lastErr error
1093			// Once the expected number of state transitions are observed, send a
1094			// signal to channel stateDone.
1095			stateDone := make(chan int)
1096			// Set stateWitness to listen to state changes.
1097			hl := len(test.stateHistory) // To avoid data race on test.
1098			r.stateWitness = func(rs resumableStreamDecoderState) {
1099				select {
1100				case <-stateDone:
1101					// Noop after expected number of state transitions
1102				default:
1103					// Record state transitions.
1104					st = append(st, rs)
1105					if len(st) == hl {
1106						lastErr = r.lastErr()
1107						q = r.q.dump()
1108						close(stateDone)
1109					}
1110				}
1111			}
1112			// Let mock server stream given messages to resumableStreamDecoder.
1113			for _, m := range test.msgs {
1114				ms.AddMsg(m.Err, m.ResumeToken)
1115			}
1116			var rs []*sppb.PartialResultSet
1117			go func() {
1118				for {
1119					if !r.next() {
1120						// Note that r.Next also exits on context cancel/timeout.
1121						return
1122					}
1123					rs = append(rs, r.get())
1124				}
1125			}()
1126			// Verify that resumableStreamDecoder reaches expected state.
1127			select {
1128			case <-stateDone: // Note that at this point, receiver is still blockingon r.next().
1129				// Check if resumableStreamDecoder carried out expected state
1130				// transitions.
1131				if !testEqual(st, test.stateHistory) {
1132					t.Fatalf("observed state transitions: \n%v\n, want \n%v\n", st, test.stateHistory)
1133				}
1134				// Check if resumableStreamDecoder returns expected array of
1135				// PartialResultSets.
1136				if !testEqual(rs, test.want) {
1137					t.Fatalf("received PartialResultSets: \n%v\n, want \n%v\n", rs, test.want)
1138				}
1139				// Verify that resumableStreamDecoder's internal buffering is also
1140				// correct.
1141				if !testEqual(q, test.queue) {
1142					t.Fatalf("PartialResultSets still queued: \n%v\n, want \n%v\n", q, test.queue)
1143				}
1144				// Verify resume token.
1145				if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
1146					t.Fatalf("Resume token is %v, want %v\n", r.resumeToken, test.resumeToken)
1147				}
1148				// Verify error message.
1149				if !testEqual(lastErr, test.wantErr) {
1150					t.Fatalf("got error %v, want %v", lastErr, test.wantErr)
1151				}
1152			case <-time.After(1 * time.Second):
1153				t.Fatal("Timeout in waiting for state change")
1154			}
1155			ms.Stop()
1156			if err := cc.Close(); err != nil {
1157				t.Fatal(err)
1158			}
1159		})
1160	}
1161}
1162
1163// sReceiver signals every receiving attempt through a channel, used by
1164// TestResumeToken to determine if the receiving of a certain PartialResultSet
1165// will be attempted next.
1166type sReceiver struct {
1167	c           chan int
1168	rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
1169}
1170
1171// Recv() implements streamingReceiver.Recv for sReceiver.
1172func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
1173	sr.c <- 1
1174	return sr.rpcReceiver.Recv()
1175}
1176
1177// waitn waits for nth receiving attempt from now on, until the signal for nth
1178// Recv() attempts is received or timeout. Note that because the way stream()
1179// works, the signal for the nth Recv() means that the previous n - 1
1180// PartialResultSets has already been returned to caller or queued, if no error
1181// happened.
1182func (sr *sReceiver) waitn(n int) error {
1183	for i := 0; i < n; i++ {
1184		select {
1185		case <-sr.c:
1186		case <-time.After(10 * time.Second):
1187			return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
1188		}
1189	}
1190	return nil
1191}
1192
1193// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
1194func TestQueueBytes(t *testing.T) {
1195	restore := setMaxBytesBetweenResumeTokens()
1196	defer restore()
1197	ms := testutil.NewMockCloudSpanner(t, trxTs)
1198	ms.Serve()
1199	defer ms.Stop()
1200	cc := dialMock(t, ms)
1201	defer cc.Close()
1202	mc := sppb.NewSpannerClient(cc)
1203	sr := &sReceiver{
1204		c: make(chan int, 1000), // will never block in this test
1205	}
1206	wantQueueBytes := 0
1207	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1208	defer cancel()
1209	r := newResumableStreamDecoder(
1210		ctx,
1211		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1212			r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1213				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1214				ResumeToken: resumeToken,
1215			})
1216			sr.rpcReceiver = r
1217			return sr, err
1218		},
1219	)
1220	go func() {
1221		for r.next() {
1222		}
1223	}()
1224	// Let server send maxBuffers / 2 rows.
1225	for i := 0; i < maxBuffers/2; i++ {
1226		wantQueueBytes += proto.Size(&sppb.PartialResultSet{
1227			Metadata: kvMeta,
1228			Values: []*proto3.Value{
1229				{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1230				{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1231			},
1232		})
1233		ms.AddMsg(nil, false)
1234	}
1235	if err := sr.waitn(maxBuffers/2 + 1); err != nil {
1236		t.Fatalf("failed to wait for the first %v recv() calls: %v", maxBuffers, err)
1237	}
1238	if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
1239		t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", r.bytesBetweenResumeTokens, wantQueueBytes)
1240	}
1241	// Now send a resume token to drain the queue.
1242	ms.AddMsg(nil, true)
1243	// Wait for all rows to be processes.
1244	if err := sr.waitn(1); err != nil {
1245		t.Fatalf("failed to wait for rows to be processed: %v", err)
1246	}
1247	if r.bytesBetweenResumeTokens != 0 {
1248		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
1249	}
1250	// Let server send maxBuffers - 1 rows.
1251	wantQueueBytes = 0
1252	for i := 0; i < maxBuffers-1; i++ {
1253		wantQueueBytes += proto.Size(&sppb.PartialResultSet{
1254			Metadata: kvMeta,
1255			Values: []*proto3.Value{
1256				{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1257				{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1258			},
1259		})
1260		ms.AddMsg(nil, false)
1261	}
1262	if err := sr.waitn(maxBuffers - 1); err != nil {
1263		t.Fatalf("failed to wait for %v rows to be processed: %v", maxBuffers-1, err)
1264	}
1265	if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
1266		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
1267	}
1268	// Trigger a state transition: queueingRetryable -> queueingUnretryable.
1269	ms.AddMsg(nil, false)
1270	if err := sr.waitn(1); err != nil {
1271		t.Fatalf("failed to wait for state transition: %v", err)
1272	}
1273	if r.bytesBetweenResumeTokens != 0 {
1274		t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
1275	}
1276}
1277
1278// Verify that client can deal with resume token correctly
1279func TestResumeToken(t *testing.T) {
1280	restore := setMaxBytesBetweenResumeTokens()
1281	defer restore()
1282	ms := testutil.NewMockCloudSpanner(t, trxTs)
1283	ms.Serve()
1284	defer ms.Stop()
1285	cc := dialMock(t, ms)
1286	defer cc.Close()
1287	mc := sppb.NewSpannerClient(cc)
1288	sr := &sReceiver{
1289		c: make(chan int, 1000), // will never block in this test
1290	}
1291	rows := []*Row{}
1292	done := make(chan error)
1293	streaming := func() {
1294		// Establish a stream to mock cloud spanner server.
1295		iter := stream(context.Background(),
1296			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1297				r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1298					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1299					ResumeToken: resumeToken,
1300				})
1301				sr.rpcReceiver = r
1302				return sr, err
1303			},
1304			nil,
1305			func(error) {})
1306		defer iter.Stop()
1307		var err error
1308		for {
1309			var row *Row
1310			row, err = iter.Next()
1311			if err == iterator.Done {
1312				err = nil
1313				break
1314			}
1315			if err != nil {
1316				break
1317			}
1318			rows = append(rows, row)
1319		}
1320		done <- err
1321	}
1322	go streaming()
1323	// Server streaming row 0 - 2, only row 1 has resume token.
1324	// Client will receive row 0 - 2, so it will try receiving for
1325	// 4 times (the last recv will block), and only row 0 - 1 will
1326	// be yielded.
1327	for i := 0; i < 3; i++ {
1328		if i == 1 {
1329			ms.AddMsg(nil, true)
1330		} else {
1331			ms.AddMsg(nil, false)
1332		}
1333	}
1334	// Wait for 4 receive attempts, as explained above.
1335	if err := sr.waitn(4); err != nil {
1336		t.Fatalf("failed to wait for row 0 - 2: %v", err)
1337	}
1338	want := []*Row{
1339		{
1340			fields: kvMeta.RowType.Fields,
1341			vals: []*proto3.Value{
1342				{Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
1343				{Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
1344			},
1345		},
1346		{
1347			fields: kvMeta.RowType.Fields,
1348			vals: []*proto3.Value{
1349				{Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
1350				{Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
1351			},
1352		},
1353	}
1354	if !testEqual(rows, want) {
1355		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1356	}
1357	// Inject resumable failure.
1358	ms.AddMsg(
1359		status.Errorf(codes.Unavailable, "mock server unavailable"),
1360		false,
1361	)
1362	// Test if client detects the resumable failure and retries.
1363	if err := sr.waitn(1); err != nil {
1364		t.Fatalf("failed to wait for client to retry: %v", err)
1365	}
1366	// Client has resumed the query, now server resend row 2.
1367	ms.AddMsg(nil, true)
1368	if err := sr.waitn(1); err != nil {
1369		t.Fatalf("failed to wait for resending row 2: %v", err)
1370	}
1371	// Now client should have received row 0 - 2.
1372	want = append(want, &Row{
1373		fields: kvMeta.RowType.Fields,
1374		vals: []*proto3.Value{
1375			{Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
1376			{Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
1377		},
1378	})
1379	if !testEqual(rows, want) {
1380		t.Errorf("received rows: \n%v\n, want\n%v\n", rows, want)
1381	}
1382	// Sending 3rd - (maxBuffers+1)th rows without resume tokens, client should buffer them.
1383	for i := 3; i < maxBuffers+2; i++ {
1384		ms.AddMsg(nil, false)
1385	}
1386	if err := sr.waitn(maxBuffers - 1); err != nil {
1387		t.Fatalf("failed to wait for row 3-%v: %v", maxBuffers+1, err)
1388	}
1389	// Received rows should be unchanged.
1390	if !testEqual(rows, want) {
1391		t.Errorf("receive rows: \n%v\n, want\n%v\n", rows, want)
1392	}
1393	// Send (maxBuffers+2)th row to trigger state change of resumableStreamDecoder:
1394	// queueingRetryable -> queueingUnretryable
1395	ms.AddMsg(nil, false)
1396	if err := sr.waitn(1); err != nil {
1397		t.Fatalf("failed to wait for row %v: %v", maxBuffers+2, err)
1398	}
1399	// Client should yield row 3rd - (maxBuffers+2)th to application. Therefore,
1400	// application should see row 0 - (maxBuffers+2)th so far.
1401	for i := 3; i < maxBuffers+3; i++ {
1402		want = append(want, &Row{
1403			fields: kvMeta.RowType.Fields,
1404			vals: []*proto3.Value{
1405				{Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
1406				{Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
1407			},
1408		})
1409	}
1410	if !testEqual(rows, want) {
1411		t.Errorf("received rows: \n%v\n; want\n%v\n", rows, want)
1412	}
1413	// Inject resumable error, but since resumableStreamDecoder is already at
1414	// queueingUnretryable state, query will just fail.
1415	ms.AddMsg(
1416		status.Errorf(codes.Unavailable, "mock server wants some sleep"),
1417		false,
1418	)
1419	var gotErr error
1420	select {
1421	case gotErr = <-done:
1422	case <-time.After(10 * time.Second):
1423		t.Fatalf("timeout in waiting for failed query to return.")
1424	}
1425	if wantErr := toSpannerError(status.Errorf(codes.Unavailable, "mock server wants some sleep")); !testEqual(gotErr, wantErr) {
1426		t.Fatalf("stream() returns error: %v, but want error: %v", gotErr, wantErr)
1427	}
1428
1429	// Reconnect to mock Cloud Spanner.
1430	rows = []*Row{}
1431	go streaming()
1432	// Let server send two rows without resume token.
1433	for i := maxBuffers + 3; i < maxBuffers+5; i++ {
1434		ms.AddMsg(nil, false)
1435	}
1436	if err := sr.waitn(3); err != nil {
1437		t.Fatalf("failed to wait for row %v - %v: %v", maxBuffers+3, maxBuffers+5, err)
1438	}
1439	if len(rows) > 0 {
1440		t.Errorf("client received some rows unexpectedly: %v, want nothing", rows)
1441	}
1442	// Let server end the query.
1443	ms.AddMsg(io.EOF, false)
1444	select {
1445	case gotErr = <-done:
1446	case <-time.After(10 * time.Second):
1447		t.Fatalf("timeout in waiting for failed query to return")
1448	}
1449	if gotErr != nil {
1450		t.Fatalf("stream() returns unexpected error: %v, but want no error", gotErr)
1451	}
1452	// Verify if a normal server side EOF flushes all queued rows.
1453	want = []*Row{
1454		{
1455			fields: kvMeta.RowType.Fields,
1456			vals: []*proto3.Value{
1457				{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 3)}},
1458				{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 3)}},
1459			},
1460		},
1461		{
1462			fields: kvMeta.RowType.Fields,
1463			vals: []*proto3.Value{
1464				{Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 4)}},
1465				{Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 4)}},
1466			},
1467		},
1468	}
1469	if !testEqual(rows, want) {
1470		t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
1471	}
1472}
1473
1474// Verify that streaming query get retried upon real gRPC server transport
1475// failures.
1476func TestGrpcReconnect(t *testing.T) {
1477	restore := setMaxBytesBetweenResumeTokens()
1478	defer restore()
1479	ms := testutil.NewMockCloudSpanner(t, trxTs)
1480	ms.Serve()
1481	defer ms.Stop()
1482	cc := dialMock(t, ms)
1483	defer cc.Close()
1484	mc := sppb.NewSpannerClient(cc)
1485	retry := make(chan int)
1486	row := make(chan int)
1487	var err error
1488	go func() {
1489		r := 0
1490		// Establish a stream to mock cloud spanner server.
1491		iter := stream(context.Background(),
1492			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1493				if r > 0 {
1494					// This RPC attempt is a retry, signal it.
1495					retry <- r
1496				}
1497				r++
1498				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1499					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1500					ResumeToken: resumeToken,
1501				})
1502
1503			},
1504			nil,
1505			func(error) {})
1506		defer iter.Stop()
1507		for {
1508			_, err = iter.Next()
1509			if err == iterator.Done {
1510				err = nil
1511				break
1512			}
1513			if err != nil {
1514				break
1515			}
1516			row <- 0
1517		}
1518	}()
1519	// Add a message and wait for the receipt.
1520	ms.AddMsg(nil, true)
1521	select {
1522	case <-row:
1523	case <-time.After(10 * time.Second):
1524		t.Fatalf("expect stream to be established within 10 seconds, but it didn't")
1525	}
1526	// Error injection: force server to close all connections.
1527	ms.Stop()
1528	// Test to see if client respond to the real RPC failure correctly by
1529	// retrying RPC.
1530	select {
1531	case r, ok := <-retry:
1532		if ok && r == 1 {
1533			break
1534		}
1535		t.Errorf("retry count = %v, want 1", r)
1536	case <-time.After(10 * time.Second):
1537		t.Errorf("client library failed to respond after 10 seconds, aborting")
1538		return
1539	}
1540}
1541
1542// Test cancel/timeout for client operations.
1543func TestCancelTimeout(t *testing.T) {
1544	restore := setMaxBytesBetweenResumeTokens()
1545	defer restore()
1546	ms := testutil.NewMockCloudSpanner(t, trxTs)
1547	ms.Serve()
1548	defer ms.Stop()
1549	cc := dialMock(t, ms)
1550	defer cc.Close()
1551	mc := sppb.NewSpannerClient(cc)
1552	done := make(chan int)
1553	go func() {
1554		for {
1555			ms.AddMsg(nil, true)
1556		}
1557	}()
1558	// Test cancelling query.
1559	ctx, cancel := context.WithCancel(context.Background())
1560	var err error
1561	go func() {
1562		// Establish a stream to mock cloud spanner server.
1563		iter := stream(ctx,
1564			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1565				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1566					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1567					ResumeToken: resumeToken,
1568				})
1569			},
1570			nil,
1571			func(error) {})
1572		defer iter.Stop()
1573		for {
1574			_, err = iter.Next()
1575			if err == iterator.Done {
1576				break
1577			}
1578			if err != nil {
1579				done <- 0
1580				break
1581			}
1582		}
1583	}()
1584	cancel()
1585	select {
1586	case <-done:
1587		if ErrCode(err) != codes.Canceled {
1588			t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
1589		}
1590	case <-time.After(1 * time.Second):
1591		t.Errorf("query doesn't exit timely after being cancelled")
1592	}
1593	// Test query timeout.
1594	ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
1595	defer cancel()
1596	go func() {
1597		// Establish a stream to mock cloud spanner server.
1598		iter := stream(ctx,
1599			func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1600				return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1601					Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1602					ResumeToken: resumeToken,
1603				})
1604			},
1605			nil,
1606			func(error) {})
1607		defer iter.Stop()
1608		for {
1609			_, err = iter.Next()
1610			if err == iterator.Done {
1611				err = nil
1612				break
1613			}
1614			if err != nil {
1615				break
1616			}
1617		}
1618		done <- 0
1619	}()
1620	select {
1621	case <-done:
1622		if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
1623			t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
1624		}
1625	case <-time.After(2 * time.Second):
1626		t.Errorf("query doesn't timeout as expected")
1627	}
1628}
1629
1630func TestRowIteratorDo(t *testing.T) {
1631	restore := setMaxBytesBetweenResumeTokens()
1632	defer restore()
1633	ms := testutil.NewMockCloudSpanner(t, trxTs)
1634	ms.Serve()
1635	defer ms.Stop()
1636	cc := dialMock(t, ms)
1637	defer cc.Close()
1638	mc := sppb.NewSpannerClient(cc)
1639
1640	for i := 0; i < 3; i++ {
1641		ms.AddMsg(nil, false)
1642	}
1643	ms.AddMsg(io.EOF, true)
1644	nRows := 0
1645	iter := stream(context.Background(),
1646		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1647			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1648				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1649				ResumeToken: resumeToken,
1650			})
1651		},
1652		nil,
1653		func(error) {})
1654	err := iter.Do(func(r *Row) error { nRows++; return nil })
1655	if err != nil {
1656		t.Errorf("Using Do: %v", err)
1657	}
1658	if nRows != 3 {
1659		t.Errorf("got %d rows, want 3", nRows)
1660	}
1661}
1662
1663func TestRowIteratorDoWithError(t *testing.T) {
1664	restore := setMaxBytesBetweenResumeTokens()
1665	defer restore()
1666	ms := testutil.NewMockCloudSpanner(t, trxTs)
1667	ms.Serve()
1668	defer ms.Stop()
1669	cc := dialMock(t, ms)
1670	defer cc.Close()
1671	mc := sppb.NewSpannerClient(cc)
1672
1673	for i := 0; i < 3; i++ {
1674		ms.AddMsg(nil, false)
1675	}
1676	ms.AddMsg(io.EOF, true)
1677	iter := stream(context.Background(),
1678		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1679			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1680				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1681				ResumeToken: resumeToken,
1682			})
1683		},
1684		nil,
1685		func(error) {})
1686	injected := errors.New("Failed iterator")
1687	err := iter.Do(func(r *Row) error { return injected })
1688	if err != injected {
1689		t.Errorf("got <%v>, want <%v>", err, injected)
1690	}
1691}
1692
1693func TestIteratorStopEarly(t *testing.T) {
1694	ctx := context.Background()
1695	restore := setMaxBytesBetweenResumeTokens()
1696	defer restore()
1697	ms := testutil.NewMockCloudSpanner(t, trxTs)
1698	ms.Serve()
1699	defer ms.Stop()
1700	cc := dialMock(t, ms)
1701	defer cc.Close()
1702	mc := sppb.NewSpannerClient(cc)
1703
1704	ms.AddMsg(nil, false)
1705	ms.AddMsg(nil, false)
1706	ms.AddMsg(io.EOF, true)
1707
1708	iter := stream(ctx,
1709		func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
1710			return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
1711				Sql:         "SELECT t.key key, t.value value FROM t_mock t",
1712				ResumeToken: resumeToken,
1713			})
1714		},
1715		nil,
1716		func(error) {})
1717	_, err := iter.Next()
1718	if err != nil {
1719		t.Fatalf("before Stop: %v", err)
1720	}
1721	iter.Stop()
1722	// Stop sets r.err to the FailedPrecondition error "Next called after Stop".
1723	// Override that here so this test can observe the Canceled error from the
1724	// stream.
1725	iter.err = nil
1726	iter.Next()
1727	if ErrCode(iter.streamd.lastErr()) != codes.Canceled {
1728		t.Errorf("after Stop: got %v, wanted Canceled", err)
1729	}
1730}
1731
1732func TestIteratorWithError(t *testing.T) {
1733	injected := errors.New("Failed iterator")
1734	iter := RowIterator{err: injected}
1735	defer iter.Stop()
1736	if _, err := iter.Next(); err != injected {
1737		t.Fatalf("Expected error: %v, got %v", injected, err)
1738	}
1739}
1740
1741func dialMock(t *testing.T, ms *testutil.MockCloudSpanner) *grpc.ClientConn {
1742	cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock())
1743	if err != nil {
1744		t.Fatalf("Dial(%q) = %v", ms.Addr(), err)
1745	}
1746	return cc
1747}
1748