1package jsonutil
2
3import (
4	"bytes"
5	"encoding/base64"
6	"encoding/json"
7	"fmt"
8	"io"
9	"math/big"
10	"reflect"
11	"strings"
12	"time"
13
14	"github.com/aws/aws-sdk-go/aws"
15	"github.com/aws/aws-sdk-go/aws/awserr"
16	"github.com/aws/aws-sdk-go/private/protocol"
17)
18
19var millisecondsFloat = new(big.Float).SetInt64(1e3)
20
21// UnmarshalJSONError unmarshal's the reader's JSON document into the passed in
22// type. The value to unmarshal the json document into must be a pointer to the
23// type.
24func UnmarshalJSONError(v interface{}, stream io.Reader) error {
25	var errBuf bytes.Buffer
26	body := io.TeeReader(stream, &errBuf)
27
28	err := json.NewDecoder(body).Decode(v)
29	if err != nil {
30		msg := "failed decoding error message"
31		if err == io.EOF {
32			msg = "error message missing"
33			err = nil
34		}
35		return awserr.NewUnmarshalError(err, msg, errBuf.Bytes())
36	}
37
38	return nil
39}
40
41// UnmarshalJSON reads a stream and unmarshals the results in object v.
42func UnmarshalJSON(v interface{}, stream io.Reader) error {
43	var out interface{}
44
45	decoder := json.NewDecoder(stream)
46	decoder.UseNumber()
47	err := decoder.Decode(&out)
48	if err == io.EOF {
49		return nil
50	} else if err != nil {
51		return err
52	}
53
54	return unmarshaler{}.unmarshalAny(reflect.ValueOf(v), out, "")
55}
56
57// UnmarshalJSONCaseInsensitive reads a stream and unmarshals the result into the
58// object v. Ignores casing for structure members.
59func UnmarshalJSONCaseInsensitive(v interface{}, stream io.Reader) error {
60	var out interface{}
61
62	decoder := json.NewDecoder(stream)
63	decoder.UseNumber()
64	err := decoder.Decode(&out)
65	if err == io.EOF {
66		return nil
67	} else if err != nil {
68		return err
69	}
70
71	return unmarshaler{
72		caseInsensitive: true,
73	}.unmarshalAny(reflect.ValueOf(v), out, "")
74}
75
76type unmarshaler struct {
77	caseInsensitive bool
78}
79
80func (u unmarshaler) unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
81	vtype := value.Type()
82	if vtype.Kind() == reflect.Ptr {
83		vtype = vtype.Elem() // check kind of actual element type
84	}
85
86	t := tag.Get("type")
87	if t == "" {
88		switch vtype.Kind() {
89		case reflect.Struct:
90			// also it can't be a time object
91			if _, ok := value.Interface().(*time.Time); !ok {
92				t = "structure"
93			}
94		case reflect.Slice:
95			// also it can't be a byte slice
96			if _, ok := value.Interface().([]byte); !ok {
97				t = "list"
98			}
99		case reflect.Map:
100			// cannot be a JSONValue map
101			if _, ok := value.Interface().(aws.JSONValue); !ok {
102				t = "map"
103			}
104		}
105	}
106
107	switch t {
108	case "structure":
109		if field, ok := vtype.FieldByName("_"); ok {
110			tag = field.Tag
111		}
112		return u.unmarshalStruct(value, data, tag)
113	case "list":
114		return u.unmarshalList(value, data, tag)
115	case "map":
116		return u.unmarshalMap(value, data, tag)
117	default:
118		return u.unmarshalScalar(value, data, tag)
119	}
120}
121
122func (u unmarshaler) unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
123	if data == nil {
124		return nil
125	}
126	mapData, ok := data.(map[string]interface{})
127	if !ok {
128		return fmt.Errorf("JSON value is not a structure (%#v)", data)
129	}
130
131	t := value.Type()
132	if value.Kind() == reflect.Ptr {
133		if value.IsNil() { // create the structure if it's nil
134			s := reflect.New(value.Type().Elem())
135			value.Set(s)
136			value = s
137		}
138
139		value = value.Elem()
140		t = t.Elem()
141	}
142
143	// unwrap any payloads
144	if payload := tag.Get("payload"); payload != "" {
145		field, _ := t.FieldByName(payload)
146		return u.unmarshalAny(value.FieldByName(payload), data, field.Tag)
147	}
148
149	for i := 0; i < t.NumField(); i++ {
150		field := t.Field(i)
151		if field.PkgPath != "" {
152			continue // ignore unexported fields
153		}
154
155		// figure out what this field is called
156		name := field.Name
157		if locName := field.Tag.Get("locationName"); locName != "" {
158			name = locName
159		}
160		if u.caseInsensitive {
161			if _, ok := mapData[name]; !ok {
162				// Fallback to uncased name search if the exact name didn't match.
163				for kn, v := range mapData {
164					if strings.EqualFold(kn, name) {
165						mapData[name] = v
166					}
167				}
168			}
169		}
170
171		member := value.FieldByIndex(field.Index)
172		err := u.unmarshalAny(member, mapData[name], field.Tag)
173		if err != nil {
174			return err
175		}
176	}
177	return nil
178}
179
180func (u unmarshaler) unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
181	if data == nil {
182		return nil
183	}
184	listData, ok := data.([]interface{})
185	if !ok {
186		return fmt.Errorf("JSON value is not a list (%#v)", data)
187	}
188
189	if value.IsNil() {
190		l := len(listData)
191		value.Set(reflect.MakeSlice(value.Type(), l, l))
192	}
193
194	for i, c := range listData {
195		err := u.unmarshalAny(value.Index(i), c, "")
196		if err != nil {
197			return err
198		}
199	}
200
201	return nil
202}
203
204func (u unmarshaler) unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
205	if data == nil {
206		return nil
207	}
208	mapData, ok := data.(map[string]interface{})
209	if !ok {
210		return fmt.Errorf("JSON value is not a map (%#v)", data)
211	}
212
213	if value.IsNil() {
214		value.Set(reflect.MakeMap(value.Type()))
215	}
216
217	for k, v := range mapData {
218		kvalue := reflect.ValueOf(k)
219		vvalue := reflect.New(value.Type().Elem()).Elem()
220
221		u.unmarshalAny(vvalue, v, "")
222		value.SetMapIndex(kvalue, vvalue)
223	}
224
225	return nil
226}
227
228func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
229
230	switch d := data.(type) {
231	case nil:
232		return nil // nothing to do here
233	case string:
234		switch value.Interface().(type) {
235		case *string:
236			value.Set(reflect.ValueOf(&d))
237		case []byte:
238			b, err := base64.StdEncoding.DecodeString(d)
239			if err != nil {
240				return err
241			}
242			value.Set(reflect.ValueOf(b))
243		case *time.Time:
244			format := tag.Get("timestampFormat")
245			if len(format) == 0 {
246				format = protocol.ISO8601TimeFormatName
247			}
248
249			t, err := protocol.ParseTime(format, d)
250			if err != nil {
251				return err
252			}
253			value.Set(reflect.ValueOf(&t))
254		case aws.JSONValue:
255			// No need to use escaping as the value is a non-quoted string.
256			v, err := protocol.DecodeJSONValue(d, protocol.NoEscape)
257			if err != nil {
258				return err
259			}
260			value.Set(reflect.ValueOf(v))
261		default:
262			return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
263		}
264	case json.Number:
265		switch value.Interface().(type) {
266		case *int64:
267			// Retain the old behavior where we would just truncate the float64
268			// calling d.Int64() here could cause an invalid syntax error due to the usage of strconv.ParseInt
269			f, err := d.Float64()
270			if err != nil {
271				return err
272			}
273			di := int64(f)
274			value.Set(reflect.ValueOf(&di))
275		case *float64:
276			f, err := d.Float64()
277			if err != nil {
278				return err
279			}
280			value.Set(reflect.ValueOf(&f))
281		case *time.Time:
282			float, ok := new(big.Float).SetString(d.String())
283			if !ok {
284				return fmt.Errorf("unsupported float time representation: %v", d.String())
285			}
286			float = float.Mul(float, millisecondsFloat)
287			ms, _ := float.Int64()
288			t := time.Unix(0, ms*1e6).UTC()
289			value.Set(reflect.ValueOf(&t))
290		default:
291			return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
292		}
293	case bool:
294		switch value.Interface().(type) {
295		case *bool:
296			value.Set(reflect.ValueOf(&d))
297		default:
298			return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
299		}
300	default:
301		return fmt.Errorf("unsupported JSON value (%v)", data)
302	}
303	return nil
304}
305