1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package testutil
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"sync"
24	"testing"
25	"time"
26
27	"github.com/golang/protobuf/proto"
28	"github.com/golang/protobuf/ptypes/empty"
29	proto3 "github.com/golang/protobuf/ptypes/struct"
30	pbt "github.com/golang/protobuf/ptypes/timestamp"
31	pbs "google.golang.org/genproto/googleapis/rpc/status"
32	sppb "google.golang.org/genproto/googleapis/spanner/v1"
33	"google.golang.org/grpc"
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/status"
36)
37
38// MockCloudSpannerClient is a mock implementation of sppb.SpannerClient.
39type MockCloudSpannerClient struct {
40	sppb.SpannerClient
41
42	mu sync.Mutex
43	t  *testing.T
44	// Live sessions on the client.
45	sessions map[string]bool
46	// Session ping history.
47	pings []string
48	// Client will stall on any requests.
49	freezed chan struct{}
50
51	// Expected set of actions that have been executed by the client. These
52	// interfaces should be type reflected against with *Request types in sppb,
53	// such as sppb.GetSessionRequest. Buffered to a large degree.
54	ReceivedRequests chan interface{}
55}
56
57// NewMockCloudSpannerClient creates new MockCloudSpannerClient instance.
58func NewMockCloudSpannerClient(t *testing.T) *MockCloudSpannerClient {
59	mc := &MockCloudSpannerClient{
60		t:                t,
61		sessions:         map[string]bool{},
62		ReceivedRequests: make(chan interface{}, 100000),
63	}
64
65	// Produce a closed channel, so the default action of ready is to not block.
66	mc.Freeze()
67	mc.Unfreeze()
68
69	return mc
70}
71
72// DumpPings dumps the ping history.
73func (m *MockCloudSpannerClient) DumpPings() []string {
74	m.mu.Lock()
75	defer m.mu.Unlock()
76	return append([]string(nil), m.pings...)
77}
78
79// DumpSessions dumps the internal session table.
80func (m *MockCloudSpannerClient) DumpSessions() map[string]bool {
81	m.mu.Lock()
82	defer m.mu.Unlock()
83	st := map[string]bool{}
84	for s, v := range m.sessions {
85		st[s] = v
86	}
87	return st
88}
89
90// CreateSession is a placeholder for SpannerClient.CreateSession.
91func (m *MockCloudSpannerClient) CreateSession(ctx context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
92	m.ready()
93	m.ReceivedRequests <- r
94
95	m.mu.Lock()
96	defer m.mu.Unlock()
97	s := &sppb.Session{}
98	if r.Database != "mockdb" {
99		// Reject other databases
100		return s, status.Errorf(codes.NotFound, fmt.Sprintf("database not found: %v", r.Database))
101	}
102	// Generate & record session name.
103	s.Name = fmt.Sprintf("mockdb-%v", time.Now().UnixNano())
104	m.sessions[s.Name] = true
105	return s, nil
106}
107
108// GetSession is a placeholder for SpannerClient.GetSession.
109func (m *MockCloudSpannerClient) GetSession(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
110	m.ready()
111	m.ReceivedRequests <- r
112
113	m.mu.Lock()
114	defer m.mu.Unlock()
115	m.pings = append(m.pings, r.Name)
116	if _, ok := m.sessions[r.Name]; !ok {
117		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Session not found: %v", r.Name))
118	}
119	return &sppb.Session{Name: r.Name}, nil
120}
121
122// DeleteSession is a placeholder for SpannerClient.DeleteSession.
123func (m *MockCloudSpannerClient) DeleteSession(ctx context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
124	m.ready()
125	m.ReceivedRequests <- r
126
127	m.mu.Lock()
128	defer m.mu.Unlock()
129	if _, ok := m.sessions[r.Name]; !ok {
130		// Session not found.
131		return &empty.Empty{}, status.Errorf(codes.NotFound, fmt.Sprintf("Session not found: %v", r.Name))
132	}
133	// Delete session from in-memory table.
134	delete(m.sessions, r.Name)
135	return &empty.Empty{}, nil
136}
137
138// ExecuteSql is a placeholder for SpannerClient.ExecuteSql.
139func (m *MockCloudSpannerClient) ExecuteSql(ctx context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (*sppb.ResultSet, error) {
140	m.ready()
141	m.ReceivedRequests <- r
142
143	m.mu.Lock()
144	defer m.mu.Unlock()
145	return &sppb.ResultSet{Stats: &sppb.ResultSetStats{RowCount: &sppb.ResultSetStats_RowCountExact{7}}}, nil
146}
147
148// ExecuteBatchDml is a placeholder for SpannerClient.ExecuteBatchDml.
149func (m *MockCloudSpannerClient) ExecuteBatchDml(ctx context.Context, r *sppb.ExecuteBatchDmlRequest, opts ...grpc.CallOption) (*sppb.ExecuteBatchDmlResponse, error) {
150	m.ready()
151	m.ReceivedRequests <- r
152
153	m.mu.Lock()
154	defer m.mu.Unlock()
155	return &sppb.ExecuteBatchDmlResponse{Status: &pbs.Status{Code: 0}, ResultSets: []*sppb.ResultSet{}}, nil
156}
157
158// ExecuteStreamingSql is a mock implementation of SpannerClient.ExecuteStreamingSql.
159func (m *MockCloudSpannerClient) ExecuteStreamingSql(ctx context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) {
160	m.ready()
161	m.ReceivedRequests <- r
162
163	m.mu.Lock()
164	defer m.mu.Unlock()
165	wantReq := &sppb.ExecuteSqlRequest{
166		Session: "mocksession",
167		Transaction: &sppb.TransactionSelector{
168			Selector: &sppb.TransactionSelector_SingleUse{
169				SingleUse: &sppb.TransactionOptions{
170					Mode: &sppb.TransactionOptions_ReadOnly_{
171						ReadOnly: &sppb.TransactionOptions_ReadOnly{
172							TimestampBound: &sppb.TransactionOptions_ReadOnly_Strong{
173								Strong: true,
174							},
175							ReturnReadTimestamp: false,
176						},
177					},
178				},
179			},
180		},
181		Sql: "mockquery",
182		Params: &proto3.Struct{
183			Fields: map[string]*proto3.Value{"var1": {Kind: &proto3.Value_StringValue{StringValue: "abc"}}},
184		},
185		ParamTypes: map[string]*sppb.Type{"var1": {Code: sppb.TypeCode_STRING}},
186	}
187	if !proto.Equal(r, wantReq) {
188		return nil, fmt.Errorf("got query request: %v, want: %v", r, wantReq)
189	}
190	return nil, errors.New("query never succeeds on mock client")
191}
192
193// StreamingRead is a placeholder for SpannerClient.StreamingRead.
194func (m *MockCloudSpannerClient) StreamingRead(ctx context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) {
195	m.ready()
196	m.ReceivedRequests <- r
197
198	m.mu.Lock()
199	defer m.mu.Unlock()
200	wantReq := &sppb.ReadRequest{
201		Session: "mocksession",
202		Transaction: &sppb.TransactionSelector{
203			Selector: &sppb.TransactionSelector_SingleUse{
204				SingleUse: &sppb.TransactionOptions{
205					Mode: &sppb.TransactionOptions_ReadOnly_{
206						ReadOnly: &sppb.TransactionOptions_ReadOnly{
207							TimestampBound: &sppb.TransactionOptions_ReadOnly_Strong{
208								Strong: true,
209							},
210							ReturnReadTimestamp: false,
211						},
212					},
213				},
214			},
215		},
216		Table:   "t_mock",
217		Columns: []string{"col1", "col2"},
218		KeySet: &sppb.KeySet{
219			Keys: []*proto3.ListValue{
220				{
221					Values: []*proto3.Value{
222						{Kind: &proto3.Value_StringValue{StringValue: "foo"}},
223					},
224				},
225			},
226			Ranges: []*sppb.KeyRange{},
227			All:    false,
228		},
229	}
230	if !proto.Equal(r, wantReq) {
231		return nil, fmt.Errorf("got query request: %v, want: %v", r, wantReq)
232	}
233	return nil, errors.New("read never succeeds on mock client")
234}
235
236// BeginTransaction is a placeholder for SpannerClient.BeginTransaction.
237func (m *MockCloudSpannerClient) BeginTransaction(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
238	m.ready()
239	m.ReceivedRequests <- r
240
241	m.mu.Lock()
242	defer m.mu.Unlock()
243	resp := &sppb.Transaction{Id: []byte("transaction-1")}
244	if _, ok := r.Options.Mode.(*sppb.TransactionOptions_ReadOnly_); ok {
245		resp.ReadTimestamp = &pbt.Timestamp{Seconds: 3, Nanos: 4}
246	}
247	return resp, nil
248}
249
250// Commit is a placeholder for SpannerClient.Commit.
251func (m *MockCloudSpannerClient) Commit(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
252	m.ready()
253	m.ReceivedRequests <- r
254
255	m.mu.Lock()
256	defer m.mu.Unlock()
257	return &sppb.CommitResponse{CommitTimestamp: &pbt.Timestamp{Seconds: 1, Nanos: 2}}, nil
258}
259
260// Rollback is a placeholder for SpannerClient.Rollback.
261func (m *MockCloudSpannerClient) Rollback(ctx context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
262	m.ready()
263	m.ReceivedRequests <- r
264
265	m.mu.Lock()
266	defer m.mu.Unlock()
267	return nil, nil
268}
269
270// PartitionQuery is a placeholder for SpannerServer.PartitionQuery.
271func (m *MockCloudSpannerClient) PartitionQuery(ctx context.Context, r *sppb.PartitionQueryRequest, opts ...grpc.CallOption) (*sppb.PartitionResponse, error) {
272	m.ready()
273	m.ReceivedRequests <- r
274
275	return nil, errors.New("Unimplemented")
276}
277
278// PartitionRead is a placeholder for SpannerServer.PartitionRead.
279func (m *MockCloudSpannerClient) PartitionRead(ctx context.Context, r *sppb.PartitionReadRequest, opts ...grpc.CallOption) (*sppb.PartitionResponse, error) {
280	m.ready()
281	m.ReceivedRequests <- r
282
283	return nil, errors.New("Unimplemented")
284}
285
286// Freeze stalls all requests.
287func (m *MockCloudSpannerClient) Freeze() {
288	m.mu.Lock()
289	defer m.mu.Unlock()
290	m.freezed = make(chan struct{})
291}
292
293// Unfreeze restores processing requests.
294func (m *MockCloudSpannerClient) Unfreeze() {
295	m.mu.Lock()
296	defer m.mu.Unlock()
297	close(m.freezed)
298}
299
300// ready checks conditions before executing requests
301// TODO: add checks for injected errors, actions
302func (m *MockCloudSpannerClient) ready() {
303	m.mu.Lock()
304	freezed := m.freezed
305	m.mu.Unlock()
306	// check if client should be freezed
307	<-freezed
308}
309