1// Copyright 2014 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package datastore
16
17import (
18	"context"
19	"errors"
20	"fmt"
21	"log"
22	"os"
23	"reflect"
24
25	"cloud.google.com/go/internal/trace"
26	"google.golang.org/api/option"
27	"google.golang.org/api/transport"
28	gtransport "google.golang.org/api/transport/grpc"
29	pb "google.golang.org/genproto/googleapis/datastore/v1"
30	"google.golang.org/grpc"
31)
32
33const (
34	prodAddr  = "datastore.googleapis.com:443"
35	userAgent = "gcloud-golang-datastore/20160401"
36)
37
38// ScopeDatastore grants permissions to view and/or manage datastore entities
39const ScopeDatastore = "https://www.googleapis.com/auth/datastore"
40
41// DetectProjectID is a sentinel value that instructs NewClient to detect the
42// project ID. It is given in place of the projectID argument. NewClient will
43// use the project ID from the given credentials or the default credentials
44// (https://developers.google.com/accounts/docs/application-default-credentials)
45// if no credentials were provided. When providing credentials, not all
46// options will allow NewClient to extract the project ID. Specifically a JWT
47// does not have the project ID encoded.
48const DetectProjectID = "*detect-project-id*"
49
50// resourcePrefixHeader is the name of the metadata header used to indicate
51// the resource being operated on.
52const resourcePrefixHeader = "google-cloud-resource-prefix"
53
54// Client is a client for reading and writing data in a datastore dataset.
55type Client struct {
56	connPool gtransport.ConnPool
57	client   pb.DatastoreClient
58	dataset  string // Called dataset by the datastore API, synonym for project ID.
59}
60
61// NewClient creates a new Client for a given dataset.  If the project ID is
62// empty, it is derived from the DATASTORE_PROJECT_ID environment variable.
63// If the DATASTORE_EMULATOR_HOST environment variable is set, client will use
64// its value to connect to a locally-running datastore emulator.
65// DetectProjectID can be passed as the projectID argument to instruct
66// NewClient to detect the project ID from the credentials.
67// Call (*Client).Close() when done with the client.
68func NewClient(ctx context.Context, projectID string, opts ...option.ClientOption) (*Client, error) {
69	var o []option.ClientOption
70	// Environment variables for gcd emulator:
71	// https://cloud.google.com/datastore/docs/tools/datastore-emulator
72	// If the emulator is available, dial it without passing any credentials.
73	if addr := os.Getenv("DATASTORE_EMULATOR_HOST"); addr != "" {
74		o = []option.ClientOption{
75			option.WithEndpoint(addr),
76			option.WithoutAuthentication(),
77			option.WithGRPCDialOption(grpc.WithInsecure()),
78		}
79		if projectID == DetectProjectID {
80			projectID, _ = detectProjectID(ctx, opts...)
81			if projectID == "" {
82				projectID = "dummy-emulator-datastore-project"
83			}
84		}
85	} else {
86		o = []option.ClientOption{
87			option.WithEndpoint(prodAddr),
88			option.WithScopes(ScopeDatastore),
89			option.WithUserAgent(userAgent),
90		}
91	}
92	// Warn if we see the legacy emulator environment variables.
93	if os.Getenv("DATASTORE_HOST") != "" && os.Getenv("DATASTORE_EMULATOR_HOST") == "" {
94		log.Print("WARNING: legacy environment variable DATASTORE_HOST is ignored. Use DATASTORE_EMULATOR_HOST instead.")
95	}
96	if os.Getenv("DATASTORE_DATASET") != "" && os.Getenv("DATASTORE_PROJECT_ID") == "" {
97		log.Print("WARNING: legacy environment variable DATASTORE_DATASET is ignored. Use DATASTORE_PROJECT_ID instead.")
98	}
99	if projectID == "" {
100		projectID = os.Getenv("DATASTORE_PROJECT_ID")
101	}
102
103	o = append(o, opts...)
104
105	if projectID == DetectProjectID {
106		detected, err := detectProjectID(ctx, opts...)
107		if err != nil {
108			return nil, err
109		}
110		projectID = detected
111	}
112
113	if projectID == "" {
114		return nil, errors.New("datastore: missing project/dataset id")
115	}
116	connPool, err := gtransport.DialPool(ctx, o...)
117	if err != nil {
118		return nil, fmt.Errorf("dialing: %v", err)
119	}
120	return &Client{
121		connPool: connPool,
122		client:   newDatastoreClient(connPool, projectID),
123		dataset:  projectID,
124	}, nil
125}
126
127func detectProjectID(ctx context.Context, opts ...option.ClientOption) (string, error) {
128	creds, err := transport.Creds(ctx, opts...)
129	if err != nil {
130		return "", fmt.Errorf("fetching creds: %v", err)
131	}
132	if creds.ProjectID == "" {
133		return "", errors.New("datastore: see the docs on DetectProjectID")
134	}
135	return creds.ProjectID, nil
136}
137
138var (
139	// ErrInvalidEntityType is returned when functions like Get or Next are
140	// passed a dst or src argument of invalid type.
141	ErrInvalidEntityType = errors.New("datastore: invalid entity type")
142	// ErrInvalidKey is returned when an invalid key is presented.
143	ErrInvalidKey = errors.New("datastore: invalid key")
144	// ErrNoSuchEntity is returned when no entity was found for a given key.
145	ErrNoSuchEntity = errors.New("datastore: no such entity")
146)
147
148type multiArgType int
149
150const (
151	multiArgTypeInvalid multiArgType = iota
152	multiArgTypePropertyLoadSaver
153	multiArgTypeStruct
154	multiArgTypeStructPtr
155	multiArgTypeInterface
156)
157
158// ErrFieldMismatch is returned when a field is to be loaded into a different
159// type than the one it was stored from, or when a field is missing or
160// unexported in the destination struct.
161// StructType is the type of the struct pointed to by the destination argument
162// passed to Get or to Iterator.Next.
163type ErrFieldMismatch struct {
164	StructType reflect.Type
165	FieldName  string
166	Reason     string
167}
168
169func (e *ErrFieldMismatch) Error() string {
170	return fmt.Sprintf("datastore: cannot load field %q into a %q: %s",
171		e.FieldName, e.StructType, e.Reason)
172}
173
174// GeoPoint represents a location as latitude/longitude in degrees.
175type GeoPoint struct {
176	Lat, Lng float64
177}
178
179// Valid returns whether a GeoPoint is within [-90, 90] latitude and [-180, 180] longitude.
180func (g GeoPoint) Valid() bool {
181	return -90 <= g.Lat && g.Lat <= 90 && -180 <= g.Lng && g.Lng <= 180
182}
183
184func keyToProto(k *Key) *pb.Key {
185	if k == nil {
186		return nil
187	}
188
189	var path []*pb.Key_PathElement
190	for {
191		el := &pb.Key_PathElement{Kind: k.Kind}
192		if k.ID != 0 {
193			el.IdType = &pb.Key_PathElement_Id{Id: k.ID}
194		} else if k.Name != "" {
195			el.IdType = &pb.Key_PathElement_Name{Name: k.Name}
196		}
197		path = append(path, el)
198		if k.Parent == nil {
199			break
200		}
201		k = k.Parent
202	}
203
204	// The path should be in order [grandparent, parent, child]
205	// We did it backward above, so reverse back.
206	for i := 0; i < len(path)/2; i++ {
207		path[i], path[len(path)-i-1] = path[len(path)-i-1], path[i]
208	}
209
210	key := &pb.Key{Path: path}
211	if k.Namespace != "" {
212		key.PartitionId = &pb.PartitionId{
213			NamespaceId: k.Namespace,
214		}
215	}
216	return key
217}
218
219// protoToKey decodes a protocol buffer representation of a key into an
220// equivalent *Key object. If the key is invalid, protoToKey will return the
221// invalid key along with ErrInvalidKey.
222func protoToKey(p *pb.Key) (*Key, error) {
223	var key *Key
224	var namespace string
225	if partition := p.PartitionId; partition != nil {
226		namespace = partition.NamespaceId
227	}
228	for _, el := range p.Path {
229		key = &Key{
230			Namespace: namespace,
231			Kind:      el.Kind,
232			ID:        el.GetId(),
233			Name:      el.GetName(),
234			Parent:    key,
235		}
236	}
237	if !key.valid() { // Also detects key == nil.
238		return key, ErrInvalidKey
239	}
240	return key, nil
241}
242
243// multiKeyToProto is a batch version of keyToProto.
244func multiKeyToProto(keys []*Key) []*pb.Key {
245	ret := make([]*pb.Key, len(keys))
246	for i, k := range keys {
247		ret[i] = keyToProto(k)
248	}
249	return ret
250}
251
252// multiKeyToProto is a batch version of keyToProto.
253func multiProtoToKey(keys []*pb.Key) ([]*Key, error) {
254	hasErr := false
255	ret := make([]*Key, len(keys))
256	err := make(MultiError, len(keys))
257	for i, k := range keys {
258		ret[i], err[i] = protoToKey(k)
259		if err[i] != nil {
260			hasErr = true
261		}
262	}
263	if hasErr {
264		return nil, err
265	}
266	return ret, nil
267}
268
269// multiValid is a batch version of Key.valid. It returns an error, not a
270// []bool.
271func multiValid(key []*Key) error {
272	invalid := false
273	for _, k := range key {
274		if !k.valid() {
275			invalid = true
276			break
277		}
278	}
279	if !invalid {
280		return nil
281	}
282	err := make(MultiError, len(key))
283	for i, k := range key {
284		if !k.valid() {
285			err[i] = ErrInvalidKey
286		}
287	}
288	return err
289}
290
291// checkMultiArg checks that v has type []S, []*S, []I, or []P, for some struct
292// type S, for some interface type I, or some non-interface non-pointer type P
293// such that P or *P implements PropertyLoadSaver.
294//
295// It returns what category the slice's elements are, and the reflect.Type
296// that represents S, I or P.
297//
298// As a special case, PropertyList is an invalid type for v.
299func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
300	// TODO(djd): multiArg is very confusing. Fold this logic into the
301	// relevant Put/Get methods to make the logic less opaque.
302	if v.Kind() != reflect.Slice {
303		return multiArgTypeInvalid, nil
304	}
305	if v.Type() == typeOfPropertyList {
306		return multiArgTypeInvalid, nil
307	}
308	elemType = v.Type().Elem()
309	if reflect.PtrTo(elemType).Implements(typeOfPropertyLoadSaver) {
310		return multiArgTypePropertyLoadSaver, elemType
311	}
312	switch elemType.Kind() {
313	case reflect.Struct:
314		return multiArgTypeStruct, elemType
315	case reflect.Interface:
316		return multiArgTypeInterface, elemType
317	case reflect.Ptr:
318		elemType = elemType.Elem()
319		if elemType.Kind() == reflect.Struct {
320			return multiArgTypeStructPtr, elemType
321		}
322	}
323	return multiArgTypeInvalid, nil
324}
325
326// Close closes the Client. Call Close to clean up resources when done with the
327// Client.
328func (c *Client) Close() error {
329	return c.connPool.Close()
330}
331
332// Get loads the entity stored for key into dst, which must be a struct pointer
333// or implement PropertyLoadSaver. If there is no such entity for the key, Get
334// returns ErrNoSuchEntity.
335//
336// The values of dst's unmatched struct fields are not modified, and matching
337// slice-typed fields are not reset before appending to them. In particular, it
338// is recommended to pass a pointer to a zero valued struct on each Get call.
339//
340// ErrFieldMismatch is returned when a field is to be loaded into a different
341// type than the one it was stored from, or when a field is missing or
342// unexported in the destination struct. ErrFieldMismatch is only returned if
343// dst is a struct pointer.
344func (c *Client) Get(ctx context.Context, key *Key, dst interface{}) (err error) {
345	ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Get")
346	defer func() { trace.EndSpan(ctx, err) }()
347
348	if dst == nil { // get catches nil interfaces; we need to catch nil ptr here
349		return ErrInvalidEntityType
350	}
351	err = c.get(ctx, []*Key{key}, []interface{}{dst}, nil)
352	if me, ok := err.(MultiError); ok {
353		return me[0]
354	}
355	return err
356}
357
358// GetMulti is a batch version of Get.
359//
360// dst must be a []S, []*S, []I or []P, for some struct type S, some interface
361// type I, or some non-interface non-pointer type P such that P or *P
362// implements PropertyLoadSaver. If an []I, each element must be a valid dst
363// for Get: it must be a struct pointer or implement PropertyLoadSaver.
364//
365// As a special case, PropertyList is an invalid type for dst, even though a
366// PropertyList is a slice of structs. It is treated as invalid to avoid being
367// mistakenly passed when []PropertyList was intended.
368//
369// err may be a MultiError. See ExampleMultiError to check it.
370func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst interface{}) (err error) {
371	ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.GetMulti")
372	defer func() { trace.EndSpan(ctx, err) }()
373
374	return c.get(ctx, keys, dst, nil)
375}
376
377func (c *Client) get(ctx context.Context, keys []*Key, dst interface{}, opts *pb.ReadOptions) error {
378	v := reflect.ValueOf(dst)
379	multiArgType, _ := checkMultiArg(v)
380
381	// Confidence checks
382	if multiArgType == multiArgTypeInvalid {
383		return errors.New("datastore: dst has invalid type")
384	}
385	if len(keys) != v.Len() {
386		return errors.New("datastore: keys and dst slices have different length")
387	}
388	if len(keys) == 0 {
389		return nil
390	}
391
392	// Go through keys, validate them, serialize then, and create a dict mapping them to their indices.
393	// Equal keys are deduped.
394	multiErr, any := make(MultiError, len(keys)), false
395	keyMap := make(map[string][]int, len(keys))
396	pbKeys := make([]*pb.Key, 0, len(keys))
397	for i, k := range keys {
398		if !k.valid() {
399			multiErr[i] = ErrInvalidKey
400			any = true
401		} else if k.Incomplete() {
402			multiErr[i] = fmt.Errorf("datastore: can't get the incomplete key: %v", k)
403			any = true
404		} else {
405			ks := k.String()
406			if _, ok := keyMap[ks]; !ok {
407				pbKeys = append(pbKeys, keyToProto(k))
408			}
409			keyMap[ks] = append(keyMap[ks], i)
410		}
411	}
412	if any {
413		return multiErr
414	}
415	req := &pb.LookupRequest{
416		ProjectId:   c.dataset,
417		Keys:        pbKeys,
418		ReadOptions: opts,
419	}
420	resp, err := c.client.Lookup(ctx, req)
421	if err != nil {
422		return err
423	}
424	found := resp.Found
425	missing := resp.Missing
426	// Upper bound 1000 iterations to prevent infinite loop. This matches the max
427	// number of Entities you can request from Datastore.
428	// Note that if ctx has a deadline, the deadline will probably
429	// be hit before we reach 1000 iterations.
430	for i := 0; len(resp.Deferred) > 0 && i < 1000; i++ {
431		req.Keys = resp.Deferred
432		resp, err = c.client.Lookup(ctx, req)
433		if err != nil {
434			return err
435		}
436		found = append(found, resp.Found...)
437		missing = append(missing, resp.Missing...)
438	}
439
440	filled := 0
441	for _, e := range found {
442		k, err := protoToKey(e.Entity.Key)
443		if err != nil {
444			return errors.New("datastore: internal error: server returned an invalid key")
445		}
446		filled += len(keyMap[k.String()])
447		for _, index := range keyMap[k.String()] {
448			elem := v.Index(index)
449			if multiArgType == multiArgTypePropertyLoadSaver || multiArgType == multiArgTypeStruct {
450				elem = elem.Addr()
451			}
452			if multiArgType == multiArgTypeStructPtr && elem.IsNil() {
453				elem.Set(reflect.New(elem.Type().Elem()))
454			}
455			if err := loadEntityProto(elem.Interface(), e.Entity); err != nil {
456				multiErr[index] = err
457				any = true
458			}
459		}
460	}
461	for _, e := range missing {
462		k, err := protoToKey(e.Entity.Key)
463		if err != nil {
464			return errors.New("datastore: internal error: server returned an invalid key")
465		}
466		filled += len(keyMap[k.String()])
467		for _, index := range keyMap[k.String()] {
468			multiErr[index] = ErrNoSuchEntity
469		}
470		any = true
471	}
472
473	if filled != len(keys) {
474		return errors.New("datastore: internal error: server returned the wrong number of entities")
475	}
476
477	if any {
478		return multiErr
479	}
480	return nil
481}
482
483// Put saves the entity src into the datastore with the given key. src must be
484// a struct pointer or implement PropertyLoadSaver; if the struct pointer has
485// any unexported fields they will be skipped. If the key is incomplete, the
486// returned key will be a unique key generated by the datastore.
487func (c *Client) Put(ctx context.Context, key *Key, src interface{}) (*Key, error) {
488	k, err := c.PutMulti(ctx, []*Key{key}, []interface{}{src})
489	if err != nil {
490		if me, ok := err.(MultiError); ok {
491			return nil, me[0]
492		}
493		return nil, err
494	}
495	return k[0], nil
496}
497
498// PutMulti is a batch version of Put.
499//
500// src must satisfy the same conditions as the dst argument to GetMulti.
501// err may be a MultiError. See ExampleMultiError to check it.
502func (c *Client) PutMulti(ctx context.Context, keys []*Key, src interface{}) (ret []*Key, err error) {
503	// TODO(jba): rewrite in terms of Mutate.
504	ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.PutMulti")
505	defer func() { trace.EndSpan(ctx, err) }()
506
507	mutations, err := putMutations(keys, src)
508	if err != nil {
509		return nil, err
510	}
511
512	// Make the request.
513	req := &pb.CommitRequest{
514		ProjectId: c.dataset,
515		Mutations: mutations,
516		Mode:      pb.CommitRequest_NON_TRANSACTIONAL,
517	}
518	resp, err := c.client.Commit(ctx, req)
519	if err != nil {
520		return nil, err
521	}
522
523	// Copy any newly minted keys into the returned keys.
524	ret = make([]*Key, len(keys))
525	for i, key := range keys {
526		if key.Incomplete() {
527			// This key is in the mutation results.
528			ret[i], err = protoToKey(resp.MutationResults[i].Key)
529			if err != nil {
530				return nil, errors.New("datastore: internal error: server returned an invalid key")
531			}
532		} else {
533			ret[i] = key
534		}
535	}
536	return ret, nil
537}
538
539func putMutations(keys []*Key, src interface{}) ([]*pb.Mutation, error) {
540	v := reflect.ValueOf(src)
541	multiArgType, _ := checkMultiArg(v)
542	if multiArgType == multiArgTypeInvalid {
543		return nil, errors.New("datastore: src has invalid type")
544	}
545	if len(keys) != v.Len() {
546		return nil, errors.New("datastore: key and src slices have different length")
547	}
548	if len(keys) == 0 {
549		return nil, nil
550	}
551	if err := multiValid(keys); err != nil {
552		return nil, err
553	}
554	mutations := make([]*pb.Mutation, 0, len(keys))
555	multiErr := make(MultiError, len(keys))
556	hasErr := false
557	for i, k := range keys {
558		elem := v.Index(i)
559		// Two cases where we need to take the address:
560		// 1) multiArgTypePropertyLoadSaver => &elem implements PLS
561		// 2) multiArgTypeStruct => saveEntity needs *struct
562		if multiArgType == multiArgTypePropertyLoadSaver || multiArgType == multiArgTypeStruct {
563			elem = elem.Addr()
564		}
565		p, err := saveEntity(k, elem.Interface())
566		if err != nil {
567			multiErr[i] = err
568			hasErr = true
569		}
570		var mut *pb.Mutation
571		if k.Incomplete() {
572			mut = &pb.Mutation{Operation: &pb.Mutation_Insert{Insert: p}}
573		} else {
574			mut = &pb.Mutation{Operation: &pb.Mutation_Upsert{Upsert: p}}
575		}
576		mutations = append(mutations, mut)
577	}
578	if hasErr {
579		return nil, multiErr
580	}
581	return mutations, nil
582}
583
584// Delete deletes the entity for the given key.
585func (c *Client) Delete(ctx context.Context, key *Key) error {
586	err := c.DeleteMulti(ctx, []*Key{key})
587	if me, ok := err.(MultiError); ok {
588		return me[0]
589	}
590	return err
591}
592
593// DeleteMulti is a batch version of Delete.
594//
595// err may be a MultiError. See ExampleMultiError to check it.
596func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) (err error) {
597	// TODO(jba): rewrite in terms of Mutate.
598	ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.DeleteMulti")
599	defer func() { trace.EndSpan(ctx, err) }()
600
601	mutations, err := deleteMutations(keys)
602	if err != nil {
603		return err
604	}
605
606	req := &pb.CommitRequest{
607		ProjectId: c.dataset,
608		Mutations: mutations,
609		Mode:      pb.CommitRequest_NON_TRANSACTIONAL,
610	}
611	_, err = c.client.Commit(ctx, req)
612	return err
613}
614
615func deleteMutations(keys []*Key) ([]*pb.Mutation, error) {
616	mutations := make([]*pb.Mutation, 0, len(keys))
617	set := make(map[string]bool, len(keys))
618	multiErr := make(MultiError, len(keys))
619	hasErr := false
620	for i, k := range keys {
621		if !k.valid() {
622			multiErr[i] = ErrInvalidKey
623			hasErr = true
624		} else if k.Incomplete() {
625			multiErr[i] = fmt.Errorf("datastore: can't delete the incomplete key: %v", k)
626			hasErr = true
627		} else {
628			ks := k.String()
629			if !set[ks] {
630				mutations = append(mutations, &pb.Mutation{
631					Operation: &pb.Mutation_Delete{Delete: keyToProto(k)},
632				})
633			}
634			set[ks] = true
635		}
636	}
637	if hasErr {
638		return nil, multiErr
639	}
640	return mutations, nil
641}
642
643// Mutate applies one or more mutations. Mutations are applied in
644// non-transactional mode. If you need atomicity, use Transaction.Mutate.
645// It returns the keys of the argument Mutations, in the same order.
646//
647// If any of the mutations are invalid, Mutate returns a MultiError with the errors.
648// Mutate returns a MultiError in this case even if there is only one Mutation.
649// See ExampleMultiError to check it.
650func (c *Client) Mutate(ctx context.Context, muts ...*Mutation) (ret []*Key, err error) {
651	ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Mutate")
652	defer func() { trace.EndSpan(ctx, err) }()
653
654	pmuts, err := mutationProtos(muts)
655	if err != nil {
656		return nil, err
657	}
658	req := &pb.CommitRequest{
659		ProjectId: c.dataset,
660		Mutations: pmuts,
661		Mode:      pb.CommitRequest_NON_TRANSACTIONAL,
662	}
663	resp, err := c.client.Commit(ctx, req)
664	if err != nil {
665		return nil, err
666	}
667	// Copy any newly minted keys into the returned keys.
668	ret = make([]*Key, len(muts))
669	for i, mut := range muts {
670		if mut.key.Incomplete() {
671			// This key is in the mutation results.
672			ret[i], err = protoToKey(resp.MutationResults[i].Key)
673			if err != nil {
674				return nil, errors.New("datastore: internal error: server returned an invalid key")
675			}
676		} else {
677			ret[i] = mut.key
678		}
679	}
680	return ret, nil
681}
682