1/*
2Copyright 2018 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"os"
22	"sync"
23	"testing"
24	"time"
25
26	. "cloud.google.com/go/spanner/internal/testutil"
27	sppb "google.golang.org/genproto/googleapis/spanner/v1"
28)
29
30func TestPartitionRoundTrip(t *testing.T) {
31	t.Parallel()
32	for i, want := range []Partition{
33		{rreq: &sppb.ReadRequest{Table: "t"}},
34		{qreq: &sppb.ExecuteSqlRequest{Sql: "sql"}},
35	} {
36		got := serdesPartition(t, i, &want)
37		if !testEqual(got, want) {
38			t.Errorf("got: %#v\nwant:%#v", got, want)
39		}
40	}
41}
42
43func TestBROTIDRoundTrip(t *testing.T) {
44	t.Parallel()
45	tm := time.Now()
46	want := BatchReadOnlyTransactionID{
47		tid: []byte("tid"),
48		sid: "sid",
49		rts: tm,
50	}
51	data, err := want.MarshalBinary()
52	if err != nil {
53		t.Fatal(err)
54	}
55	var got BatchReadOnlyTransactionID
56	if err := got.UnmarshalBinary(data); err != nil {
57		t.Fatal(err)
58	}
59	if !testEqual(got, want) {
60		t.Errorf("got: %#v\nwant:%#v", got, want)
61	}
62}
63
64// serdesPartition is a helper that serialize a Partition then deserialize it.
65func serdesPartition(t *testing.T, i int, p1 *Partition) (p2 Partition) {
66	var (
67		data []byte
68		err  error
69	)
70	if data, err = p1.MarshalBinary(); err != nil {
71		t.Fatalf("#%d: encoding failed %v", i, err)
72	}
73	if err = p2.UnmarshalBinary(data); err != nil {
74		t.Fatalf("#%d: decoding failed %v", i, err)
75	}
76	return p2
77}
78
79func TestPartitionQuery_QueryOptions(t *testing.T) {
80	for _, tt := range queryOptionsTestCases() {
81		t.Run(tt.name, func(t *testing.T) {
82			if tt.env.Options != nil {
83				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
84				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
85			}
86
87			ctx := context.Background()
88			_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
89			defer teardown()
90
91			var (
92				err  error
93				txn  *BatchReadOnlyTransaction
94				ps   []*Partition
95				stmt = NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)
96			)
97
98			if txn, err = client.BatchReadOnlyTransaction(ctx, StrongRead()); err != nil {
99				t.Fatal(err)
100			}
101			defer txn.Cleanup(ctx)
102
103			if tt.query.Options == nil {
104				ps, err = txn.PartitionQuery(ctx, stmt, PartitionOptions{0, 3})
105			} else {
106				ps, err = txn.PartitionQueryWithOptions(ctx, stmt, PartitionOptions{0, 3}, tt.query)
107			}
108			if err != nil {
109				t.Fatal(err)
110			}
111
112			for _, p := range ps {
113				if got, want := p.qreq.QueryOptions.OptimizerVersion, tt.want.Options.OptimizerVersion; got != want {
114					t.Fatalf("Incorrect optimizer version: got %v, want %v", got, want)
115				}
116			}
117		})
118	}
119}
120
121func TestPartitionQuery_Parallel(t *testing.T) {
122	ctx := context.Background()
123	server, client, teardown := setupMockedTestServer(t)
124	defer teardown()
125
126	txn, err := client.BatchReadOnlyTransaction(ctx, StrongRead())
127	if err != nil {
128		t.Fatal(err)
129	}
130	defer txn.Cleanup(ctx)
131	ps, err := txn.PartitionQuery(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), PartitionOptions{0, 10})
132	if err != nil {
133		t.Fatal(err)
134	}
135	for i, p := range ps {
136		server.TestSpanner.PutPartitionResult(p.pt, server.CreateSingleRowSingersResult(int64(i)))
137	}
138
139	wg := &sync.WaitGroup{}
140	mu := sync.Mutex{}
141	var total int64
142
143	for _, p := range ps {
144		p := p
145		go func() {
146			iter := txn.Execute(context.Background(), p)
147			defer iter.Stop()
148
149			var count int64
150			err := iter.Do(func(row *Row) error {
151				count++
152				return nil
153			})
154			if err != nil {
155				return
156			}
157
158			mu.Lock()
159			total += count
160			mu.Unlock()
161			wg.Done()
162		}()
163		wg.Add(1)
164	}
165
166	wg.Wait()
167	if g, w := total, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; g != w {
168		t.Errorf("Row count mismatch\nGot: %d\nWant: %d", g, w)
169	}
170}
171