1// Copyright 2014 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package datastore
16
17import (
18	"fmt"
19	"reflect"
20	"strings"
21	"time"
22
23	"cloud.google.com/go/civil"
24	"cloud.google.com/go/internal/fields"
25	pb "google.golang.org/genproto/googleapis/datastore/v1"
26)
27
28var (
29	typeOfByteSlice     = reflect.TypeOf([]byte(nil))
30	typeOfTime          = reflect.TypeOf(time.Time{})
31	typeOfCivilDate     = reflect.TypeOf(civil.Date{})
32	typeOfCivilDateTime = reflect.TypeOf(civil.DateTime{})
33	typeOfCivilTime     = reflect.TypeOf(civil.Time{})
34	typeOfGeoPoint      = reflect.TypeOf(GeoPoint{})
35	typeOfKeyPtr        = reflect.TypeOf(&Key{})
36)
37
38// typeMismatchReason returns a string explaining why the property p could not
39// be stored in an entity field of type v.Type().
40func typeMismatchReason(p Property, v reflect.Value) string {
41	entityType := "empty"
42	switch p.Value.(type) {
43	case int64:
44		entityType = "int"
45	case bool:
46		entityType = "bool"
47	case string:
48		entityType = "string"
49	case float64:
50		entityType = "float"
51	case *Key:
52		entityType = "*datastore.Key"
53	case *Entity:
54		entityType = "*datastore.Entity"
55	case GeoPoint:
56		entityType = "GeoPoint"
57	case time.Time:
58		entityType = "time.Time"
59	case []byte:
60		entityType = "[]byte"
61	}
62
63	return fmt.Sprintf("type mismatch: %s versus %v", entityType, v.Type())
64}
65
66func overflowReason(x interface{}, v reflect.Value) string {
67	return fmt.Sprintf("value %v overflows struct field of type %v", x, v.Type())
68}
69
70type propertyLoader struct {
71	// m holds the number of times a substruct field like "Foo.Bar.Baz" has
72	// been seen so far. The map is constructed lazily.
73	m map[string]int
74}
75
76func (l *propertyLoader) load(codec fields.List, structValue reflect.Value, p Property, prev map[string]struct{}) string {
77	sl, ok := p.Value.([]interface{})
78	if !ok {
79		return l.loadOneElement(codec, structValue, p, prev)
80	}
81	for _, val := range sl {
82		p.Value = val
83		if errStr := l.loadOneElement(codec, structValue, p, prev); errStr != "" {
84			return errStr
85		}
86	}
87	return ""
88}
89
90// loadOneElement loads the value of Property p into structValue based on the provided
91// codec. codec is used to find the field in structValue into which p should be loaded.
92// prev is the set of property names already seen for structValue.
93func (l *propertyLoader) loadOneElement(codec fields.List, structValue reflect.Value, p Property, prev map[string]struct{}) string {
94	var sliceOk bool
95	var sliceIndex int
96	var v reflect.Value
97
98	name := p.Name
99	fieldNames := strings.Split(name, ".")
100
101	for len(fieldNames) > 0 {
102		var field *fields.Field
103
104		// Start by trying to find a field with name. If none found,
105		// cut off the last field (delimited by ".") and find its parent
106		// in the codec.
107		// eg. for name "A.B.C.D", split off "A.B.C" and try to
108		// find a field in the codec with this name.
109		// Loop again with "A.B", etc.
110		for i := len(fieldNames); i > 0; i-- {
111			parent := strings.Join(fieldNames[:i], ".")
112			field = codec.Match(parent)
113			if field != nil {
114				fieldNames = fieldNames[i:]
115				break
116			}
117		}
118
119		// If we never found a matching field in the codec, return
120		// error message.
121		if field == nil {
122			return "no such struct field"
123		}
124
125		v = initField(structValue, field.Index)
126		if !v.IsValid() {
127			return "no such struct field"
128		}
129		if !v.CanSet() {
130			return "cannot set struct field"
131		}
132
133		// If field implements PLS, we delegate loading to the PLS's Load early,
134		// and stop iterating through fields.
135		ok, err := plsFieldLoad(v, p, fieldNames)
136		if err != nil {
137			return err.Error()
138		}
139		if ok {
140			return ""
141		}
142
143		if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
144			codec, err = structCache.Fields(field.Type.Elem())
145			if err != nil {
146				return err.Error()
147			}
148
149			// Init value if its nil
150			if v.IsNil() {
151				v.Set(reflect.New(field.Type.Elem()))
152			}
153			structValue = v.Elem()
154		}
155
156		if field.Type.Kind() == reflect.Struct {
157			codec, err = structCache.Fields(field.Type)
158			if err != nil {
159				return err.Error()
160			}
161			structValue = v
162		}
163
164		// If the element is a slice, we need to accommodate it.
165		if v.Kind() == reflect.Slice && v.Type() != typeOfByteSlice {
166			if l.m == nil {
167				l.m = make(map[string]int)
168			}
169			sliceIndex = l.m[p.Name]
170			l.m[p.Name] = sliceIndex + 1
171			for v.Len() <= sliceIndex {
172				v.Set(reflect.Append(v, reflect.New(v.Type().Elem()).Elem()))
173			}
174			structValue = v.Index(sliceIndex)
175
176			// If structValue implements PLS, we delegate loading to the PLS's
177			// Load early, and stop iterating through fields.
178			ok, err := plsFieldLoad(structValue, p, fieldNames)
179			if err != nil {
180				return err.Error()
181			}
182			if ok {
183				return ""
184			}
185
186			if structValue.Type().Kind() == reflect.Struct {
187				codec, err = structCache.Fields(structValue.Type())
188				if err != nil {
189					return err.Error()
190				}
191			}
192			sliceOk = true
193		}
194	}
195
196	var slice reflect.Value
197	if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 {
198		slice = v
199		v = reflect.New(v.Type().Elem()).Elem()
200	} else if _, ok := prev[p.Name]; ok && !sliceOk {
201		// Zero the field back out that was set previously, turns out
202		// it's a slice and we don't know what to do with it
203		v.Set(reflect.Zero(v.Type()))
204		return "multiple-valued property requires a slice field type"
205	}
206
207	prev[p.Name] = struct{}{}
208
209	if errReason := setVal(v, p); errReason != "" {
210		// Set the slice back to its zero value.
211		if slice.IsValid() {
212			slice.Set(reflect.Zero(slice.Type()))
213		}
214		return errReason
215	}
216
217	if slice.IsValid() {
218		slice.Index(sliceIndex).Set(v)
219	}
220
221	return ""
222}
223
224// plsFieldLoad first tries to converts v's value to a PLS, then v's addressed
225// value to a PLS. If neither succeeds, plsFieldLoad returns false for first return
226// value. Otherwise, the first return value will be true.
227// If v is successfully converted to a PLS, plsFieldLoad will then try to Load
228// the property p into v (by way of the PLS's Load method).
229//
230// If the field v has been flattened, the Property's name must be altered
231// before calling Load to reflect the field v.
232// For example, if our original field name was "A.B.C.D",
233// and at this point in iteration we had initialized the field
234// corresponding to "A" and have moved into the struct, so that now
235// v corresponds to the field named "B", then we want to let the
236// PLS handle this field (B)'s subfields ("C", "D"),
237// so we send the property to the PLS's Load, renamed to "C.D".
238//
239// If subfields are present, the field v has been flattened.
240func plsFieldLoad(v reflect.Value, p Property, subfields []string) (ok bool, err error) {
241	vpls, err := plsForLoad(v)
242	if err != nil {
243		return false, err
244	}
245
246	if vpls == nil {
247		return false, nil
248	}
249
250	// If Entity, load properties as well as key.
251	if e, ok := p.Value.(*Entity); ok {
252		err = loadEntity(vpls, e)
253		return true, err
254	}
255
256	// If flattened, we must alter the property's name to reflect
257	// the field v.
258	if len(subfields) > 0 {
259		p.Name = strings.Join(subfields, ".")
260	}
261
262	return true, vpls.Load([]Property{p})
263}
264
265// setVal sets 'v' to the value of the Property 'p'.
266func setVal(v reflect.Value, p Property) (s string) {
267	pValue := p.Value
268	switch v.Kind() {
269	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
270		x, ok := pValue.(int64)
271		if !ok && pValue != nil {
272			return typeMismatchReason(p, v)
273		}
274		if v.OverflowInt(x) {
275			return overflowReason(x, v)
276		}
277		v.SetInt(x)
278	case reflect.Bool:
279		x, ok := pValue.(bool)
280		if !ok && pValue != nil {
281			return typeMismatchReason(p, v)
282		}
283		v.SetBool(x)
284	case reflect.String:
285		x, ok := pValue.(string)
286		if !ok && pValue != nil {
287			return typeMismatchReason(p, v)
288		}
289		v.SetString(x)
290	case reflect.Float32, reflect.Float64:
291		x, ok := pValue.(float64)
292		if !ok && pValue != nil {
293			return typeMismatchReason(p, v)
294		}
295		if v.OverflowFloat(x) {
296			return overflowReason(x, v)
297		}
298		v.SetFloat(x)
299
300	case reflect.Interface:
301		if !v.CanSet() {
302			return fmt.Sprintf("%v is unsettable", v.Type())
303		}
304
305		rpValue := reflect.ValueOf(pValue)
306		if !rpValue.Type().AssignableTo(v.Type()) {
307			return fmt.Sprintf("%q is not assignable to %q", rpValue.Type(), v.Type())
308		}
309		v.Set(rpValue)
310
311	case reflect.Ptr:
312		// v must be a pointer to either a Key, an Entity, or one of the supported basic types.
313		if v.Type() != typeOfKeyPtr && v.Type().Elem().Kind() != reflect.Struct && !isValidPointerType(v.Type().Elem()) {
314			return typeMismatchReason(p, v)
315		}
316
317		if pValue == nil {
318			// If v is populated already, set it to nil.
319			if !v.IsNil() {
320				v.Set(reflect.New(v.Type()).Elem())
321			}
322			return ""
323		}
324
325		if x, ok := p.Value.(*Key); ok {
326			if _, ok := v.Interface().(*Key); !ok {
327				return typeMismatchReason(p, v)
328			}
329			v.Set(reflect.ValueOf(x))
330			return ""
331		}
332		if v.IsNil() {
333			v.Set(reflect.New(v.Type().Elem()))
334		}
335		switch x := pValue.(type) {
336		case *Entity:
337			err := loadEntity(v.Interface(), x)
338			if err != nil {
339				return err.Error()
340			}
341		case int64:
342			if v.Elem().OverflowInt(x) {
343				return overflowReason(x, v.Elem())
344			}
345			v.Elem().SetInt(x)
346		case float64:
347			if v.Elem().OverflowFloat(x) {
348				return overflowReason(x, v.Elem())
349			}
350			v.Elem().SetFloat(x)
351		case bool:
352			v.Elem().SetBool(x)
353		case string:
354			v.Elem().SetString(x)
355		case GeoPoint, time.Time:
356			v.Elem().Set(reflect.ValueOf(x))
357		default:
358			return typeMismatchReason(p, v)
359		}
360	case reflect.Struct:
361		switch v.Type() {
362		case typeOfTime:
363			// Some time values are converted into microsecond integer values
364			// (for example when used with projects). So, here we check first
365			// whether this value is an int64, and next whether it's time.
366			//
367			// See more at https://cloud.google.com/datastore/docs/concepts/queries#limitations_on_projections
368			micros, ok := pValue.(int64)
369			if ok {
370				s := micros / 1e6
371				ns := micros % 1e6
372				v.Set(reflect.ValueOf(time.Unix(s, ns).In(time.UTC)))
373				break
374			}
375			x, ok := pValue.(time.Time)
376			if !ok && pValue != nil {
377				return typeMismatchReason(p, v)
378			}
379			v.Set(reflect.ValueOf(x))
380		case typeOfGeoPoint:
381			x, ok := pValue.(GeoPoint)
382			if !ok && pValue != nil {
383				return typeMismatchReason(p, v)
384			}
385			v.Set(reflect.ValueOf(x))
386		case typeOfCivilDate:
387			date := civil.DateOf(pValue.(time.Time).In(time.UTC))
388			v.Set(reflect.ValueOf(date))
389		case typeOfCivilDateTime:
390			dateTime := civil.DateTimeOf(pValue.(time.Time).In(time.UTC))
391			v.Set(reflect.ValueOf(dateTime))
392		case typeOfCivilTime:
393			timeVal := civil.TimeOf(pValue.(time.Time).In(time.UTC))
394			v.Set(reflect.ValueOf(timeVal))
395		default:
396			ent, ok := pValue.(*Entity)
397			if !ok {
398				return typeMismatchReason(p, v)
399			}
400			err := loadEntity(v.Addr().Interface(), ent)
401			if err != nil {
402				return err.Error()
403			}
404		}
405	case reflect.Slice:
406		x, ok := pValue.([]byte)
407		if !ok && pValue != nil {
408			return typeMismatchReason(p, v)
409		}
410		if v.Type().Elem().Kind() != reflect.Uint8 {
411			return typeMismatchReason(p, v)
412		}
413		v.SetBytes(x)
414	default:
415		return typeMismatchReason(p, v)
416	}
417	return ""
418}
419
420// initField is similar to reflect's Value.FieldByIndex, in that it
421// returns the nested struct field corresponding to index, but it
422// initialises any nil pointers encountered when traversing the structure.
423func initField(val reflect.Value, index []int) reflect.Value {
424	for _, i := range index[:len(index)-1] {
425		val = val.Field(i)
426		if val.Kind() == reflect.Ptr {
427			if val.IsNil() {
428				val.Set(reflect.New(val.Type().Elem()))
429			}
430			val = val.Elem()
431		}
432	}
433	return val.Field(index[len(index)-1])
434}
435
436// loadEntityProto loads an EntityProto into PropertyLoadSaver or struct pointer.
437func loadEntityProto(dst interface{}, src *pb.Entity) error {
438	ent, err := protoToEntity(src)
439	if err != nil {
440		return err
441	}
442	return loadEntity(dst, ent)
443}
444
445func loadEntity(dst interface{}, ent *Entity) error {
446	if pls, ok := dst.(PropertyLoadSaver); ok {
447		// Load both key and properties. Try to load as much as possible, even
448		// if an error occurs during loading either the key or the
449		// properties.
450		var keyLoadErr error
451		if e, ok := dst.(KeyLoader); ok {
452			keyLoadErr = e.LoadKey(ent.Key)
453		}
454		loadErr := pls.Load(ent.Properties)
455		// Let any error returned by LoadKey prevail above any error from Load.
456		if keyLoadErr != nil {
457			return keyLoadErr
458		}
459		return loadErr
460	}
461	return loadEntityToStruct(dst, ent)
462}
463
464func loadEntityToStruct(dst interface{}, ent *Entity) error {
465	pls, err := newStructPLS(dst)
466	if err != nil {
467		return err
468	}
469
470	// Try and load key.
471	keyField := pls.codec.Match(keyFieldName)
472	if keyField != nil && ent.Key != nil {
473		pls.v.FieldByIndex(keyField.Index).Set(reflect.ValueOf(ent.Key))
474	}
475
476	// Load properties.
477	return pls.Load(ent.Properties)
478}
479
480func (s structPLS) Load(props []Property) error {
481	var fieldName, errReason string
482	var l propertyLoader
483
484	prev := make(map[string]struct{})
485	for _, p := range props {
486		if errStr := l.load(s.codec, s.v, p, prev); errStr != "" {
487			// We don't return early, as we try to load as many properties as possible.
488			// It is valid to load an entity into a struct that cannot fully represent it.
489			// That case returns an error, but the caller is free to ignore it.
490			fieldName, errReason = p.Name, errStr
491		}
492	}
493	if errReason != "" {
494		return &ErrFieldMismatch{
495			StructType: s.v.Type(),
496			FieldName:  fieldName,
497			Reason:     errReason,
498		}
499	}
500	return nil
501}
502
503func protoToEntity(src *pb.Entity) (*Entity, error) {
504	props := make([]Property, 0, len(src.Properties))
505	for name, val := range src.Properties {
506		v, err := propToValue(val)
507		if err != nil {
508			return nil, err
509		}
510		props = append(props, Property{
511			Name:    name,
512			Value:   v,
513			NoIndex: val.ExcludeFromIndexes,
514		})
515	}
516	var key *Key
517	if src.Key != nil {
518		// Ignore any error, since nested entity values
519		// are allowed to have an invalid key.
520		key, _ = protoToKey(src.Key)
521	}
522
523	return &Entity{key, props}, nil
524}
525
526// propToValue returns a Go value that represents the PropertyValue. For
527// example, a TimestampValue becomes a time.Time.
528func propToValue(v *pb.Value) (interface{}, error) {
529	switch v := v.ValueType.(type) {
530	case *pb.Value_NullValue:
531		return nil, nil
532	case *pb.Value_BooleanValue:
533		return v.BooleanValue, nil
534	case *pb.Value_IntegerValue:
535		return v.IntegerValue, nil
536	case *pb.Value_DoubleValue:
537		return v.DoubleValue, nil
538	case *pb.Value_TimestampValue:
539		return time.Unix(v.TimestampValue.Seconds, int64(v.TimestampValue.Nanos)).In(time.UTC), nil
540	case *pb.Value_KeyValue:
541		return protoToKey(v.KeyValue)
542	case *pb.Value_StringValue:
543		return v.StringValue, nil
544	case *pb.Value_BlobValue:
545		return []byte(v.BlobValue), nil
546	case *pb.Value_GeoPointValue:
547		return GeoPoint{Lat: v.GeoPointValue.Latitude, Lng: v.GeoPointValue.Longitude}, nil
548	case *pb.Value_EntityValue:
549		return protoToEntity(v.EntityValue)
550	case *pb.Value_ArrayValue:
551		arr := make([]interface{}, 0, len(v.ArrayValue.Values))
552		for _, v := range v.ArrayValue.Values {
553			vv, err := propToValue(v)
554			if err != nil {
555				return nil, err
556			}
557			arr = append(arr, vv)
558		}
559		return arr, nil
560	default:
561		return nil, nil
562	}
563}
564