1// Copyright 2019 The Go Cloud Development Kit Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package awsdynamodb
16
17import (
18	"errors"
19	"fmt"
20	"reflect"
21	"strconv"
22	"time"
23
24	dyn "github.com/aws/aws-sdk-go/service/dynamodb"
25	"gocloud.dev/docstore/driver"
26)
27
28var nullValue = new(dyn.AttributeValue).SetNULL(true)
29
30type encoder struct {
31	av *dyn.AttributeValue
32}
33
34func (e *encoder) EncodeNil()        { e.av = nullValue }
35func (e *encoder) EncodeBool(x bool) { e.av = new(dyn.AttributeValue).SetBOOL(x) }
36func (e *encoder) EncodeInt(x int64) { e.av = new(dyn.AttributeValue).SetN(strconv.FormatInt(x, 10)) }
37func (e *encoder) EncodeUint(x uint64) {
38	e.av = new(dyn.AttributeValue).SetN(strconv.FormatUint(x, 10))
39}
40func (e *encoder) EncodeBytes(x []byte)  { e.av = new(dyn.AttributeValue).SetB(x) }
41func (e *encoder) EncodeFloat(x float64) { e.av = encodeFloat(x) }
42
43func (e *encoder) ListIndex(int) { panic("impossible") }
44func (e *encoder) MapKey(string) { panic("impossible") }
45
46func (e *encoder) EncodeString(x string) {
47	if len(x) == 0 {
48		e.av = nullValue
49	} else {
50		e.av = new(dyn.AttributeValue).SetS(x)
51	}
52}
53
54func (e *encoder) EncodeComplex(x complex128) {
55	e.av = new(dyn.AttributeValue).SetL([]*dyn.AttributeValue{encodeFloat(real(x)), encodeFloat(imag(x))})
56}
57
58func (e *encoder) EncodeList(n int) driver.Encoder {
59	s := make([]*dyn.AttributeValue, n)
60	e.av = new(dyn.AttributeValue).SetL(s)
61	return &listEncoder{s: s}
62}
63
64func (e *encoder) EncodeMap(n int) driver.Encoder {
65	m := make(map[string]*dyn.AttributeValue, n)
66	e.av = new(dyn.AttributeValue).SetM(m)
67	return &mapEncoder{m: m}
68}
69
70var typeOfGoTime = reflect.TypeOf(time.Time{})
71
72// EncodeSpecial encodes time.Time specially.
73func (e *encoder) EncodeSpecial(v reflect.Value) (bool, error) {
74	switch v.Type() {
75	case typeOfGoTime:
76		ts := v.Interface().(time.Time).Format(time.RFC3339Nano)
77		e.EncodeString(ts)
78	default:
79		return false, nil
80	}
81	return true, nil
82}
83
84type listEncoder struct {
85	s []*dyn.AttributeValue
86	encoder
87}
88
89func (e *listEncoder) ListIndex(i int) { e.s[i] = e.av }
90
91type mapEncoder struct {
92	m map[string]*dyn.AttributeValue
93	encoder
94}
95
96func (e *mapEncoder) MapKey(k string) { e.m[k] = e.av }
97
98func encodeDoc(doc driver.Document) (*dyn.AttributeValue, error) {
99	var e encoder
100	if err := doc.Encode(&e); err != nil {
101		return nil, err
102	}
103	return e.av, nil
104}
105
106// Encode the key fields of the given document into a map AttributeValue.
107// pkey and skey are the names of the partition key field and the sort key field.
108// pkey must always be non-empty, but skey may be empty if the collection has no sort key.
109func encodeDocKeyFields(doc driver.Document, pkey, skey string) (*dyn.AttributeValue, error) {
110	m := map[string]*dyn.AttributeValue{}
111
112	set := func(fieldName string) error {
113		fieldVal, err := doc.GetField(fieldName)
114		if err != nil {
115			return err
116		}
117		attrVal, err := encodeValue(fieldVal)
118		if err != nil {
119			return err
120		}
121		m[fieldName] = attrVal
122		return nil
123	}
124
125	if err := set(pkey); err != nil {
126		return nil, err
127	}
128	if skey != "" {
129		if err := set(skey); err != nil {
130			return nil, err
131		}
132	}
133	return new(dyn.AttributeValue).SetM(m), nil
134}
135
136func encodeValue(v interface{}) (*dyn.AttributeValue, error) {
137	var e encoder
138	if err := driver.Encode(reflect.ValueOf(v), &e); err != nil {
139		return nil, err
140	}
141	return e.av, nil
142}
143
144func encodeFloat(f float64) *dyn.AttributeValue {
145	return new(dyn.AttributeValue).SetN(strconv.FormatFloat(f, 'f', -1, 64))
146}
147
148////////////////////////////////////////////////////////////////
149
150func decodeDoc(item *dyn.AttributeValue, doc driver.Document) error {
151	return doc.Decode(decoder{av: item})
152}
153
154type decoder struct {
155	av *dyn.AttributeValue
156}
157
158func (d decoder) String() string {
159	return d.av.String()
160}
161
162func (d decoder) AsBool() (bool, bool) {
163	if d.av.BOOL == nil {
164		return false, false
165	}
166	return *d.av.BOOL, true
167}
168
169func (d decoder) AsNull() bool {
170	return d.av.NULL != nil
171}
172
173func (d decoder) AsString() (string, bool) {
174	// Empty string is represented by NULL.
175	if d.av.NULL != nil {
176		return "", true
177	}
178	if d.av.S == nil {
179		return "", false
180	}
181	return *d.av.S, true
182}
183
184func (d decoder) AsInt() (int64, bool) {
185	if d.av.N == nil {
186		return 0, false
187	}
188	i, err := strconv.ParseInt(*d.av.N, 10, 64)
189	if err != nil {
190		return 0, false
191	}
192	return i, true
193}
194
195func (d decoder) AsUint() (uint64, bool) {
196	if d.av.N == nil {
197		return 0, false
198	}
199	u, err := strconv.ParseUint(*d.av.N, 10, 64)
200	if err != nil {
201		return 0, false
202	}
203	return u, true
204}
205
206func (d decoder) AsFloat() (float64, bool) {
207	if d.av.N == nil {
208		return 0, false
209	}
210	f, err := strconv.ParseFloat(*d.av.N, 64)
211	if err != nil {
212		return 0, false
213	}
214	return f, true
215
216}
217
218func (d decoder) AsComplex() (complex128, bool) {
219	if d.av.L == nil {
220		return 0, false
221	}
222	if len(d.av.L) != 2 {
223		return 0, false
224	}
225	r, ok := decoder{d.av.L[0]}.AsFloat()
226	if !ok {
227		return 0, false
228	}
229	i, ok := decoder{d.av.L[1]}.AsFloat()
230	if !ok {
231		return 0, false
232	}
233	return complex(r, i), true
234}
235
236func (d decoder) AsBytes() ([]byte, bool) {
237	if d.av.B == nil {
238		return nil, false
239	}
240	return d.av.B, true
241}
242
243func (d decoder) ListLen() (int, bool) {
244	if d.av.L == nil {
245		return 0, false
246	}
247	return len(d.av.L), true
248}
249
250func (d decoder) DecodeList(f func(i int, vd driver.Decoder) bool) {
251	for i, el := range d.av.L {
252		if !f(i, decoder{el}) {
253			break
254		}
255	}
256}
257
258func (d decoder) MapLen() (int, bool) {
259	if d.av.M == nil {
260		return 0, false
261	}
262	return len(d.av.M), true
263}
264
265func (d decoder) DecodeMap(f func(key string, vd driver.Decoder, exactMatch bool) bool) {
266	for k, av := range d.av.M {
267		if !f(k, decoder{av}, true) {
268			break
269		}
270	}
271}
272
273func (d decoder) AsInterface() (interface{}, error) {
274	return toGoValue(d.av)
275}
276
277func toGoValue(av *dyn.AttributeValue) (interface{}, error) {
278	switch {
279	case av.NULL != nil:
280		return nil, nil
281	case av.BOOL != nil:
282		return *av.BOOL, nil
283	case av.N != nil:
284		f, err := strconv.ParseFloat(*av.N, 64)
285		if err != nil {
286			return nil, err
287		}
288		i := int64(f)
289		if float64(i) == f {
290			return i, nil
291		}
292		u := uint64(f)
293		if float64(u) == f {
294			return u, nil
295		}
296		return f, nil
297
298	case av.B != nil:
299		return av.B, nil
300	case av.S != nil:
301		return *av.S, nil
302
303	case av.L != nil:
304		s := make([]interface{}, len(av.L))
305		for i, v := range av.L {
306			x, err := toGoValue(v)
307			if err != nil {
308				return nil, err
309			}
310			s[i] = x
311		}
312		return s, nil
313
314	case av.M != nil:
315		m := make(map[string]interface{}, len(av.M))
316		for k, v := range av.M {
317			x, err := toGoValue(v)
318			if err != nil {
319				return nil, err
320			}
321			m[k] = x
322		}
323		return m, nil
324
325	default:
326		return nil, fmt.Errorf("awsdynamodb: AttributeValue %s not supported", av)
327	}
328}
329
330func (d decoder) AsSpecial(v reflect.Value) (bool, interface{}, error) {
331	unsupportedTypes := `unsupported type, the docstore driver for DynamoDB does
332	not decode DynamoDB set types, such as string set, number set and binary set`
333	if d.av.SS != nil || d.av.NS != nil || d.av.BS != nil {
334		return true, nil, errors.New(unsupportedTypes)
335	}
336	switch v.Type() {
337	case typeOfGoTime:
338		if d.av.S == nil {
339			return false, nil, errors.New("expected string field for time.Time")
340		}
341		t, err := time.Parse(time.RFC3339Nano, *d.av.S)
342		return true, t, err
343	}
344	return false, nil, nil
345}
346