1// Copyright 2019 The Go Cloud Development Kit Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package gcpfirestore
16
17// Encoding and decoding between supported docstore types and Firestore protos.
18
19import (
20	"errors"
21	"fmt"
22	"path"
23	"reflect"
24	"time"
25
26	"github.com/golang/protobuf/ptypes"
27	ts "github.com/golang/protobuf/ptypes/timestamp"
28	"gocloud.dev/docstore/driver"
29	pb "google.golang.org/genproto/googleapis/firestore/v1"
30	"google.golang.org/genproto/googleapis/type/latlng"
31)
32
33// encodeDoc encodes a driver.Document into Firestore's representation.
34// A Firestore document (*pb.Document) is just a Go map from strings to *pb.Values.
35func encodeDoc(doc driver.Document, nameField string) (*pb.Document, error) {
36	var e encoder
37	if err := doc.Encode(&e); err != nil {
38		return nil, err
39	}
40	fields := e.pv.GetMapValue().Fields
41	// Do not put the name field in the document itself.
42	if nameField != "" {
43		delete(fields, nameField)
44	}
45	return &pb.Document{Fields: fields}, nil
46}
47
48// encodeValue encodes a Go value as a Firestore Value.
49// The Firestore proto definition for Value is a oneof of various types,
50// including basic types like string as well as lists and maps.
51func encodeValue(x interface{}) (*pb.Value, error) {
52	var e encoder
53	if err := driver.Encode(reflect.ValueOf(x), &e); err != nil {
54		return nil, err
55	}
56	return e.pv, nil
57}
58
59// encoder implements driver.Encoder. Its job is to encode a single Firestore value.
60type encoder struct {
61	pv *pb.Value
62}
63
64var nullValue = &pb.Value{ValueType: &pb.Value_NullValue{}}
65
66func (e *encoder) EncodeNil()            { e.pv = nullValue }
67func (e *encoder) EncodeBool(x bool)     { e.pv = &pb.Value{ValueType: &pb.Value_BooleanValue{x}} }
68func (e *encoder) EncodeInt(x int64)     { e.pv = &pb.Value{ValueType: &pb.Value_IntegerValue{x}} }
69func (e *encoder) EncodeUint(x uint64)   { e.pv = &pb.Value{ValueType: &pb.Value_IntegerValue{int64(x)}} }
70func (e *encoder) EncodeBytes(x []byte)  { e.pv = &pb.Value{ValueType: &pb.Value_BytesValue{x}} }
71func (e *encoder) EncodeFloat(x float64) { e.pv = floatval(x) }
72func (e *encoder) EncodeString(x string) { e.pv = &pb.Value{ValueType: &pb.Value_StringValue{x}} }
73
74func (e *encoder) ListIndex(int) { panic("impossible") }
75func (e *encoder) MapKey(string) { panic("impossible") }
76
77func (e *encoder) EncodeList(n int) driver.Encoder {
78	s := make([]*pb.Value, n)
79	e.pv = &pb.Value{ValueType: &pb.Value_ArrayValue{&pb.ArrayValue{Values: s}}}
80	return &listEncoder{s: s}
81}
82
83func (e *encoder) EncodeMap(n int) driver.Encoder {
84	m := make(map[string]*pb.Value, n)
85	e.pv = &pb.Value{ValueType: &pb.Value_MapValue{&pb.MapValue{Fields: m}}}
86	return &mapEncoder{m: m}
87}
88
89var (
90	typeOfGoTime         = reflect.TypeOf(time.Time{})
91	typeOfProtoTimestamp = reflect.TypeOf((*ts.Timestamp)(nil))
92	typeOfLatLng         = reflect.TypeOf((*latlng.LatLng)(nil))
93)
94
95// Encode time.Time, latlng.LatLng, and ts.Timestamp specially, because the Go Firestore
96// client does.
97func (e *encoder) EncodeSpecial(v reflect.Value) (bool, error) {
98	switch v.Type() {
99	case typeOfGoTime:
100		ts, err := ptypes.TimestampProto(v.Interface().(time.Time))
101		if err != nil {
102			return false, err
103		}
104		e.pv = &pb.Value{ValueType: &pb.Value_TimestampValue{ts}}
105		return true, nil
106	case typeOfProtoTimestamp:
107		if v.IsNil() {
108			e.pv = nullValue
109		} else {
110			e.pv = &pb.Value{ValueType: &pb.Value_TimestampValue{v.Interface().(*ts.Timestamp)}}
111		}
112		return true, nil
113	case typeOfLatLng:
114		if v.IsNil() {
115			e.pv = nullValue
116		} else {
117			e.pv = &pb.Value{ValueType: &pb.Value_GeoPointValue{v.Interface().(*latlng.LatLng)}}
118		}
119		return true, nil
120	default:
121		return false, nil
122	}
123}
124
125type listEncoder struct {
126	s []*pb.Value
127	encoder
128}
129
130func (e *listEncoder) ListIndex(i int) { e.s[i] = e.pv }
131
132type mapEncoder struct {
133	m map[string]*pb.Value
134	encoder
135}
136
137func (e *mapEncoder) MapKey(k string) { e.m[k] = e.pv }
138
139func floatval(x float64) *pb.Value { return &pb.Value{ValueType: &pb.Value_DoubleValue{x}} }
140
141////////////////////////////////////////////////////////////////
142
143// decodeDoc decodes a Firestore document into a driver.Document.
144func decodeDoc(pdoc *pb.Document, ddoc driver.Document, nameField, revField string) error {
145	if pdoc.Fields == nil {
146		pdoc.Fields = map[string]*pb.Value{}
147	}
148	if nameField != "" {
149		pdoc.Fields[nameField] = &pb.Value{ValueType: &pb.Value_StringValue{StringValue: path.Base(pdoc.Name)}}
150	}
151	mv := &pb.Value{ValueType: &pb.Value_MapValue{&pb.MapValue{Fields: pdoc.Fields}}}
152	if err := ddoc.Decode(decoder{mv}); err != nil {
153		return err
154	}
155	// Set the revision field in the document, if it exists, to the update time.
156	if ddoc.HasField(revField) && pdoc.UpdateTime != nil {
157		return ddoc.SetField(revField, pdoc.UpdateTime)
158	}
159	return nil
160}
161
162type decoder struct {
163	pv *pb.Value
164}
165
166func (d decoder) String() string { // for debugging
167	return fmt.Sprint(d.pv)
168}
169
170func (d decoder) AsNull() bool {
171	_, ok := d.pv.ValueType.(*pb.Value_NullValue)
172	return ok
173}
174
175func (d decoder) AsBool() (bool, bool) {
176	if b, ok := d.pv.ValueType.(*pb.Value_BooleanValue); ok {
177		return b.BooleanValue, true
178	}
179	return false, false
180}
181
182func (d decoder) AsString() (string, bool) {
183	if s, ok := d.pv.ValueType.(*pb.Value_StringValue); ok {
184		return s.StringValue, true
185	}
186	return "", false
187}
188
189func (d decoder) AsInt() (int64, bool) {
190	if i, ok := d.pv.ValueType.(*pb.Value_IntegerValue); ok {
191		return i.IntegerValue, true
192	}
193	return 0, false
194}
195
196func (d decoder) AsUint() (uint64, bool) {
197	if i, ok := d.pv.ValueType.(*pb.Value_IntegerValue); ok {
198		return uint64(i.IntegerValue), true
199	}
200	return 0, false
201}
202
203func (d decoder) AsFloat() (float64, bool) {
204	if f, ok := d.pv.ValueType.(*pb.Value_DoubleValue); ok {
205		return f.DoubleValue, true
206	}
207	return 0, false
208}
209
210func (d decoder) AsBytes() ([]byte, bool) {
211	if bs, ok := d.pv.ValueType.(*pb.Value_BytesValue); ok {
212		return bs.BytesValue, true
213	}
214	return nil, false
215}
216
217// AsInterface decodes the value in d into the most appropriate Go type.
218func (d decoder) AsInterface() (interface{}, error) {
219	return decodeValue(d.pv)
220}
221
222func decodeValue(v *pb.Value) (interface{}, error) {
223	switch v := v.ValueType.(type) {
224	case *pb.Value_NullValue:
225		return nil, nil
226	case *pb.Value_BooleanValue:
227		return v.BooleanValue, nil
228	case *pb.Value_IntegerValue:
229		return v.IntegerValue, nil
230	case *pb.Value_DoubleValue:
231		return v.DoubleValue, nil
232	case *pb.Value_StringValue:
233		return v.StringValue, nil
234	case *pb.Value_BytesValue:
235		return v.BytesValue, nil
236	case *pb.Value_TimestampValue:
237		// Return TimestampValue as time.Time.
238		t, err := ptypes.Timestamp(v.TimestampValue)
239		if err != nil {
240			return nil, err
241		}
242		return t, nil
243	case *pb.Value_ReferenceValue:
244		// TODO(jba): support references
245		return nil, errors.New("references are not currently supported")
246	case *pb.Value_GeoPointValue:
247		// Return GeoPointValue as *latlng.LatLng.
248		return v.GeoPointValue, nil
249	case *pb.Value_ArrayValue:
250		s := make([]interface{}, len(v.ArrayValue.Values))
251		for i, pv := range v.ArrayValue.Values {
252			e, err := decodeValue(pv)
253			if err != nil {
254				return nil, err
255			}
256			s[i] = e
257		}
258		return s, nil
259	case *pb.Value_MapValue:
260		m := make(map[string]interface{}, len(v.MapValue.Fields))
261		for k, pv := range v.MapValue.Fields {
262			e, err := decodeValue(pv)
263			if err != nil {
264				return nil, err
265			}
266			m[k] = e
267		}
268		return m, nil
269	}
270	return nil, fmt.Errorf("unknown firestore value type %T", v)
271}
272
273func (d decoder) ListLen() (int, bool) {
274	a := d.pv.GetArrayValue()
275	if a == nil {
276		return 0, false
277	}
278	return len(a.Values), true
279}
280
281func (d decoder) DecodeList(f func(int, driver.Decoder) bool) {
282	for i, e := range d.pv.GetArrayValue().Values {
283		if !f(i, decoder{e}) {
284			return
285		}
286	}
287}
288func (d decoder) MapLen() (int, bool) {
289	m := d.pv.GetMapValue()
290	if m == nil {
291		return 0, false
292	}
293	return len(m.Fields), true
294}
295func (d decoder) DecodeMap(f func(string, driver.Decoder, bool) bool) {
296	for k, v := range d.pv.GetMapValue().Fields {
297		if !f(k, decoder{v}, true) {
298			return
299		}
300	}
301}
302
303func (d decoder) AsSpecial(v reflect.Value) (bool, interface{}, error) {
304	switch v.Type() {
305	case typeOfGoTime:
306		if ts, ok := d.pv.ValueType.(*pb.Value_TimestampValue); ok {
307			if ts.TimestampValue == nil {
308				return true, time.Time{}, nil
309			}
310			t, err := ptypes.Timestamp(ts.TimestampValue)
311			return true, t, err
312		}
313		return true, nil, fmt.Errorf("expected TimestampValue for time.Time, got %+v", d.pv.ValueType)
314	case typeOfProtoTimestamp:
315		if ts, ok := d.pv.ValueType.(*pb.Value_TimestampValue); ok {
316			return true, ts.TimestampValue, nil
317		}
318		return true, nil, fmt.Errorf("expected TimestampValue for *ts.Timestamp, got %+v", d.pv.ValueType)
319
320	case typeOfLatLng:
321		if ll, ok := d.pv.ValueType.(*pb.Value_GeoPointValue); ok {
322			return true, ll.GeoPointValue, nil
323		}
324		return true, nil, fmt.Errorf("expected GeoPointValue for *latlng.LatLng, got %+v", d.pv.ValueType)
325
326	default:
327		return false, nil, nil
328	}
329}
330