1// Copyright 2014 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package datastore
16
17import (
18	"context"
19	"errors"
20	"fmt"
21	"reflect"
22	"sort"
23	"testing"
24
25	"cloud.google.com/go/internal/testutil"
26	"github.com/golang/protobuf/proto"
27	"github.com/google/go-cmp/cmp"
28	pb "google.golang.org/genproto/googleapis/datastore/v1"
29	"google.golang.org/grpc"
30)
31
32var (
33	key1 = &pb.Key{
34		Path: []*pb.Key_PathElement{
35			{
36				Kind:   "Gopher",
37				IdType: &pb.Key_PathElement_Id{Id: 6},
38			},
39		},
40	}
41	key2 = &pb.Key{
42		Path: []*pb.Key_PathElement{
43			{
44				Kind:   "Gopher",
45				IdType: &pb.Key_PathElement_Id{Id: 6},
46			},
47			{
48				Kind:   "Gopher",
49				IdType: &pb.Key_PathElement_Id{Id: 8},
50			},
51		},
52	}
53)
54
55type fakeClient struct {
56	pb.DatastoreClient
57	queryFn  func(*pb.RunQueryRequest) (*pb.RunQueryResponse, error)
58	commitFn func(*pb.CommitRequest) (*pb.CommitResponse, error)
59}
60
61func (c *fakeClient) RunQuery(_ context.Context, req *pb.RunQueryRequest, _ ...grpc.CallOption) (*pb.RunQueryResponse, error) {
62	return c.queryFn(req)
63}
64
65func (c *fakeClient) Commit(_ context.Context, req *pb.CommitRequest, _ ...grpc.CallOption) (*pb.CommitResponse, error) {
66	return c.commitFn(req)
67}
68
69func fakeRunQuery(in *pb.RunQueryRequest) (*pb.RunQueryResponse, error) {
70	expectedIn := &pb.RunQueryRequest{
71		QueryType: &pb.RunQueryRequest_Query{Query: &pb.Query{
72			Kind: []*pb.KindExpression{{Name: "Gopher"}},
73		}},
74	}
75	if !proto.Equal(in, expectedIn) {
76		return nil, fmt.Errorf("unsupported argument: got %v want %v", in, expectedIn)
77	}
78	return &pb.RunQueryResponse{
79		Batch: &pb.QueryResultBatch{
80			MoreResults:      pb.QueryResultBatch_NO_MORE_RESULTS,
81			EntityResultType: pb.EntityResult_FULL,
82			EntityResults: []*pb.EntityResult{
83				{
84					Entity: &pb.Entity{
85						Key: key1,
86						Properties: map[string]*pb.Value{
87							"Name":   {ValueType: &pb.Value_StringValue{StringValue: "George"}},
88							"Height": {ValueType: &pb.Value_IntegerValue{IntegerValue: 32}},
89						},
90					},
91				},
92				{
93					Entity: &pb.Entity{
94						Key: key2,
95						Properties: map[string]*pb.Value{
96							"Name": {ValueType: &pb.Value_StringValue{StringValue: "Rufus"}},
97							// No height for Rufus.
98						},
99					},
100				},
101			},
102		},
103	}, nil
104}
105
106type StructThatImplementsPLS struct{}
107
108func (StructThatImplementsPLS) Load(p []Property) error   { return nil }
109func (StructThatImplementsPLS) Save() ([]Property, error) { return nil, nil }
110
111var _ PropertyLoadSaver = StructThatImplementsPLS{}
112
113type StructPtrThatImplementsPLS struct{}
114
115func (*StructPtrThatImplementsPLS) Load(p []Property) error   { return nil }
116func (*StructPtrThatImplementsPLS) Save() ([]Property, error) { return nil, nil }
117
118var _ PropertyLoadSaver = &StructPtrThatImplementsPLS{}
119
120type PropertyMap map[string]Property
121
122func (m PropertyMap) Load(props []Property) error {
123	for _, p := range props {
124		m[p.Name] = p
125	}
126	return nil
127}
128
129func (m PropertyMap) Save() ([]Property, error) {
130	props := make([]Property, 0, len(m))
131	for _, p := range m {
132		props = append(props, p)
133	}
134	return props, nil
135}
136
137var _ PropertyLoadSaver = PropertyMap{}
138
139type Gopher struct {
140	Name   string
141	Height int
142}
143
144// typeOfEmptyInterface is the type of interface{}, but we can't use
145// reflect.TypeOf((interface{})(nil)) directly because TypeOf takes an
146// interface{}.
147var typeOfEmptyInterface = reflect.TypeOf((*interface{})(nil)).Elem()
148
149func TestCheckMultiArg(t *testing.T) {
150	testCases := []struct {
151		v        interface{}
152		mat      multiArgType
153		elemType reflect.Type
154	}{
155		// Invalid cases.
156		{nil, multiArgTypeInvalid, nil},
157		{Gopher{}, multiArgTypeInvalid, nil},
158		{&Gopher{}, multiArgTypeInvalid, nil},
159		{PropertyList{}, multiArgTypeInvalid, nil}, // This is a special case.
160		{PropertyMap{}, multiArgTypeInvalid, nil},
161		{[]*PropertyList(nil), multiArgTypeInvalid, nil},
162		{[]*PropertyMap(nil), multiArgTypeInvalid, nil},
163		{[]**Gopher(nil), multiArgTypeInvalid, nil},
164		{[]*interface{}(nil), multiArgTypeInvalid, nil},
165		// Valid cases.
166		{
167			[]PropertyList(nil),
168			multiArgTypePropertyLoadSaver,
169			reflect.TypeOf(PropertyList{}),
170		},
171		{
172			[]PropertyMap(nil),
173			multiArgTypePropertyLoadSaver,
174			reflect.TypeOf(PropertyMap{}),
175		},
176		{
177			[]StructThatImplementsPLS(nil),
178			multiArgTypePropertyLoadSaver,
179			reflect.TypeOf(StructThatImplementsPLS{}),
180		},
181		{
182			[]StructPtrThatImplementsPLS(nil),
183			multiArgTypePropertyLoadSaver,
184			reflect.TypeOf(StructPtrThatImplementsPLS{}),
185		},
186		{
187			[]Gopher(nil),
188			multiArgTypeStruct,
189			reflect.TypeOf(Gopher{}),
190		},
191		{
192			[]*Gopher(nil),
193			multiArgTypeStructPtr,
194			reflect.TypeOf(Gopher{}),
195		},
196		{
197			[]interface{}(nil),
198			multiArgTypeInterface,
199			typeOfEmptyInterface,
200		},
201	}
202	for _, tc := range testCases {
203		mat, elemType := checkMultiArg(reflect.ValueOf(tc.v))
204		if mat != tc.mat || elemType != tc.elemType {
205			t.Errorf("checkMultiArg(%T): got %v, %v want %v, %v",
206				tc.v, mat, elemType, tc.mat, tc.elemType)
207		}
208	}
209}
210
211func TestSimpleQuery(t *testing.T) {
212	struct1 := Gopher{Name: "George", Height: 32}
213	struct2 := Gopher{Name: "Rufus"}
214	pList1 := PropertyList{
215		{
216			Name:  "Height",
217			Value: int64(32),
218		},
219		{
220			Name:  "Name",
221			Value: "George",
222		},
223	}
224	pList2 := PropertyList{
225		{
226			Name:  "Name",
227			Value: "Rufus",
228		},
229	}
230	pMap1 := PropertyMap{
231		"Name": Property{
232			Name:  "Name",
233			Value: "George",
234		},
235		"Height": Property{
236			Name:  "Height",
237			Value: int64(32),
238		},
239	}
240	pMap2 := PropertyMap{
241		"Name": Property{
242			Name:  "Name",
243			Value: "Rufus",
244		},
245	}
246
247	testCases := []struct {
248		dst  interface{}
249		want interface{}
250	}{
251		// The destination must have type *[]P, *[]S or *[]*S, for some non-interface
252		// type P such that *P implements PropertyLoadSaver, or for some struct type S.
253		{new([]Gopher), &[]Gopher{struct1, struct2}},
254		{new([]*Gopher), &[]*Gopher{&struct1, &struct2}},
255		{new([]PropertyList), &[]PropertyList{pList1, pList2}},
256		{new([]PropertyMap), &[]PropertyMap{pMap1, pMap2}},
257
258		// Any other destination type is invalid.
259		{0, nil},
260		{Gopher{}, nil},
261		{PropertyList{}, nil},
262		{PropertyMap{}, nil},
263		{[]int{}, nil},
264		{[]Gopher{}, nil},
265		{[]PropertyList{}, nil},
266		{new(int), nil},
267		{new(Gopher), nil},
268		{new(PropertyList), nil}, // This is a special case.
269		{new(PropertyMap), nil},
270		{new([]int), nil},
271		{new([]map[int]int), nil},
272		{new([]map[string]Property), nil},
273		{new([]map[string]interface{}), nil},
274		{new([]*int), nil},
275		{new([]*map[int]int), nil},
276		{new([]*map[string]Property), nil},
277		{new([]*map[string]interface{}), nil},
278		{new([]**Gopher), nil},
279		{new([]*PropertyList), nil},
280		{new([]*PropertyMap), nil},
281	}
282	for _, tc := range testCases {
283		nCall := 0
284		client := &Client{
285			client: &fakeClient{
286				queryFn: func(req *pb.RunQueryRequest) (*pb.RunQueryResponse, error) {
287					nCall++
288					return fakeRunQuery(req)
289				},
290			},
291		}
292		ctx := context.Background()
293
294		var (
295			expectedErr   error
296			expectedNCall int
297		)
298		if tc.want == nil {
299			expectedErr = ErrInvalidEntityType
300		} else {
301			expectedNCall = 1
302		}
303		keys, err := client.GetAll(ctx, NewQuery("Gopher"), tc.dst)
304		if err != expectedErr {
305			t.Errorf("dst type %T: got error %v, want %v", tc.dst, err, expectedErr)
306			continue
307		}
308		if nCall != expectedNCall {
309			t.Errorf("dst type %T: Context.Call was called an incorrect number of times: got %d want %d", tc.dst, nCall, expectedNCall)
310			continue
311		}
312		if err != nil {
313			continue
314		}
315
316		key1 := IDKey("Gopher", 6, nil)
317		expectedKeys := []*Key{
318			key1,
319			IDKey("Gopher", 8, key1),
320		}
321		if l1, l2 := len(keys), len(expectedKeys); l1 != l2 {
322			t.Errorf("dst type %T: got %d keys, want %d keys", tc.dst, l1, l2)
323			continue
324		}
325		for i, key := range keys {
326			if !keysEqual(key, expectedKeys[i]) {
327				t.Errorf("dst type %T: got key #%d %v, want %v", tc.dst, i, key, expectedKeys[i])
328				continue
329			}
330		}
331
332		// Make sure we sort any PropertyList items (the order is not deterministic).
333		if pLists, ok := tc.dst.(*[]PropertyList); ok {
334			for _, p := range *pLists {
335				sort.Sort(byName(p))
336			}
337		}
338
339		if !testutil.Equal(tc.dst, tc.want) {
340			t.Errorf("dst type %T: Entities\ngot  %+v\nwant %+v", tc.dst, tc.dst, tc.want)
341			continue
342		}
343	}
344}
345
346// keysEqual is like (*Key).Equal, but ignores the App ID.
347func keysEqual(a, b *Key) bool {
348	for a != nil && b != nil {
349		if a.Kind != b.Kind || a.Name != b.Name || a.ID != b.ID {
350			return false
351		}
352		a, b = a.Parent, b.Parent
353	}
354	return a == b
355}
356
357func TestQueriesAreImmutable(t *testing.T) {
358	// Test that deriving q2 from q1 does not modify q1.
359	q0 := NewQuery("foo")
360	q1 := NewQuery("foo")
361	q2 := q1.Offset(2)
362	if !testutil.Equal(q0, q1, cmp.AllowUnexported(Query{})) {
363		t.Errorf("q0 and q1 were not equal")
364	}
365	if testutil.Equal(q1, q2, cmp.AllowUnexported(Query{})) {
366		t.Errorf("q1 and q2 were equal")
367	}
368
369	// Test that deriving from q4 twice does not conflict, even though
370	// q4 has a long list of order clauses. This tests that the arrays
371	// backed by a query's slice of orders are not shared.
372	f := func() *Query {
373		q := NewQuery("bar")
374		// 47 is an ugly number that is unlikely to be near a re-allocation
375		// point in repeated append calls. For example, it's not near a power
376		// of 2 or a multiple of 10.
377		for i := 0; i < 47; i++ {
378			q = q.Order(fmt.Sprintf("x%d", i))
379		}
380		return q
381	}
382	q3 := f().Order("y")
383	q4 := f()
384	q5 := q4.Order("y")
385	q6 := q4.Order("z")
386	if !testutil.Equal(q3, q5, cmp.AllowUnexported(Query{})) {
387		t.Errorf("q3 and q5 were not equal")
388	}
389	if testutil.Equal(q5, q6, cmp.AllowUnexported(Query{})) {
390		t.Errorf("q5 and q6 were equal")
391	}
392}
393
394func TestFilterParser(t *testing.T) {
395	testCases := []struct {
396		filterStr     string
397		wantOK        bool
398		wantFieldName string
399		wantOp        operator
400	}{
401		// Supported ops.
402		{"x<", true, "x", lessThan},
403		{"x <", true, "x", lessThan},
404		{"x  <", true, "x", lessThan},
405		{"   x   <  ", true, "x", lessThan},
406		{"x <=", true, "x", lessEq},
407		{"x =", true, "x", equal},
408		{"x >=", true, "x", greaterEq},
409		{"x >", true, "x", greaterThan},
410		{"in >", true, "in", greaterThan},
411		{"in>", true, "in", greaterThan},
412		// Valid but (currently) unsupported ops.
413		{"x!=", false, "", 0},
414		{"x !=", false, "", 0},
415		{" x  !=  ", false, "", 0},
416		{"x IN", false, "", 0},
417		{"x in", false, "", 0},
418		// Invalid ops.
419		{"x EQ", false, "", 0},
420		{"x lt", false, "", 0},
421		{"x <>", false, "", 0},
422		{"x >>", false, "", 0},
423		{"x ==", false, "", 0},
424		{"x =<", false, "", 0},
425		{"x =>", false, "", 0},
426		{"x !", false, "", 0},
427		{"x ", false, "", 0},
428		{"x", false, "", 0},
429		// Quoted and interesting field names.
430		{"x > y =", true, "x > y", equal},
431		{"` x ` =", true, " x ", equal},
432		{`" x " =`, true, " x ", equal},
433		{`" \"x " =`, true, ` "x `, equal},
434		{`" x =`, false, "", 0},
435		{`" x ="`, false, "", 0},
436		{"` x \" =", false, "", 0},
437	}
438	for _, tc := range testCases {
439		q := NewQuery("foo").Filter(tc.filterStr, 42)
440		if ok := q.err == nil; ok != tc.wantOK {
441			t.Errorf("%q: ok=%t, want %t", tc.filterStr, ok, tc.wantOK)
442			continue
443		}
444		if !tc.wantOK {
445			continue
446		}
447		if len(q.filter) != 1 {
448			t.Errorf("%q: len=%d, want %d", tc.filterStr, len(q.filter), 1)
449			continue
450		}
451		got, want := q.filter[0], filter{tc.wantFieldName, tc.wantOp, 42}
452		if got != want {
453			t.Errorf("%q: got %v, want %v", tc.filterStr, got, want)
454			continue
455		}
456	}
457}
458
459func TestNamespaceQuery(t *testing.T) {
460	gotNamespace := make(chan string, 1)
461	ctx := context.Background()
462	client := &Client{
463		client: &fakeClient{
464			queryFn: func(req *pb.RunQueryRequest) (*pb.RunQueryResponse, error) {
465				if part := req.PartitionId; part != nil {
466					gotNamespace <- part.NamespaceId
467				} else {
468					gotNamespace <- ""
469				}
470				return nil, errors.New("not implemented")
471			},
472		},
473	}
474
475	var gs []Gopher
476
477	// Ignore errors for the rest of this test.
478	client.GetAll(ctx, NewQuery("gopher"), &gs)
479	if got, want := <-gotNamespace, ""; got != want {
480		t.Errorf("GetAll: got namespace %q, want %q", got, want)
481	}
482	client.Count(ctx, NewQuery("gopher"))
483	if got, want := <-gotNamespace, ""; got != want {
484		t.Errorf("Count: got namespace %q, want %q", got, want)
485	}
486
487	const ns = "not_default"
488	client.GetAll(ctx, NewQuery("gopher").Namespace(ns), &gs)
489	if got, want := <-gotNamespace, ns; got != want {
490		t.Errorf("GetAll: got namespace %q, want %q", got, want)
491	}
492	client.Count(ctx, NewQuery("gopher").Namespace(ns))
493	if got, want := <-gotNamespace, ns; got != want {
494		t.Errorf("Count: got namespace %q, want %q", got, want)
495	}
496}
497
498func TestReadOptions(t *testing.T) {
499	tid := []byte{1}
500	for _, test := range []struct {
501		q    *Query
502		want *pb.ReadOptions
503	}{
504		{
505			q:    NewQuery(""),
506			want: nil,
507		},
508		{
509			q:    NewQuery("").Transaction(nil),
510			want: nil,
511		},
512		{
513			q: NewQuery("").Transaction(&Transaction{id: tid}),
514			want: &pb.ReadOptions{
515				ConsistencyType: &pb.ReadOptions_Transaction{
516					Transaction: tid,
517				},
518			},
519		},
520		{
521			q: NewQuery("").EventualConsistency(),
522			want: &pb.ReadOptions{
523				ConsistencyType: &pb.ReadOptions_ReadConsistency_{
524					ReadConsistency: pb.ReadOptions_EVENTUAL,
525				},
526			},
527		},
528	} {
529		req := &pb.RunQueryRequest{}
530		if err := test.q.toProto(req); err != nil {
531			t.Fatalf("%+v: got %v, want no error", test.q, err)
532		}
533		if got := req.ReadOptions; !proto.Equal(got, test.want) {
534			t.Errorf("%+v:\ngot  %+v\nwant %+v", test.q, got, test.want)
535		}
536	}
537	// Test errors.
538	for _, q := range []*Query{
539		NewQuery("").Transaction(&Transaction{id: nil}),
540		NewQuery("").Transaction(&Transaction{id: tid}).EventualConsistency(),
541	} {
542		req := &pb.RunQueryRequest{}
543		if err := q.toProto(req); err == nil {
544			t.Errorf("%+v: got nil, wanted error", q)
545		}
546	}
547}
548
549func TestInvalidFilters(t *testing.T) {
550	client := &Client{
551		client: &fakeClient{
552			queryFn: func(req *pb.RunQueryRequest) (*pb.RunQueryResponse, error) {
553				return fakeRunQuery(req)
554			},
555		},
556	}
557
558	// Used for an invalid type
559	type MyType int
560	var v MyType = 1
561
562	for _, q := range []*Query{
563		NewQuery("SomeKey").Filter("", 0),
564		NewQuery("SomeKey").Filter("fld=", v),
565	} {
566		if _, err := client.Count(context.Background(), q); err == nil {
567			t.Errorf("%+v: got nil, wanted error", q)
568		}
569	}
570}
571