1package queryutil
2
3import (
4	"encoding/base64"
5	"fmt"
6	"net/url"
7	"reflect"
8	"sort"
9	"strconv"
10	"strings"
11	"time"
12
13	"github.com/aws/aws-sdk-go/private/protocol"
14)
15
16// Parse parses an object i and fills a url.Values object. The isEC2 flag
17// indicates if this is the EC2 Query sub-protocol.
18func Parse(body url.Values, i interface{}, isEC2 bool) error {
19	q := queryParser{isEC2: isEC2}
20	return q.parseValue(body, reflect.ValueOf(i), "", "")
21}
22
23func elemOf(value reflect.Value) reflect.Value {
24	for value.Kind() == reflect.Ptr {
25		value = value.Elem()
26	}
27	return value
28}
29
30type queryParser struct {
31	isEC2 bool
32}
33
34func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
35	value = elemOf(value)
36
37	// no need to handle zero values
38	if !value.IsValid() {
39		return nil
40	}
41
42	t := tag.Get("type")
43	if t == "" {
44		switch value.Kind() {
45		case reflect.Struct:
46			t = "structure"
47		case reflect.Slice:
48			t = "list"
49		case reflect.Map:
50			t = "map"
51		}
52	}
53
54	switch t {
55	case "structure":
56		return q.parseStruct(v, value, prefix)
57	case "list":
58		return q.parseList(v, value, prefix, tag)
59	case "map":
60		return q.parseMap(v, value, prefix, tag)
61	default:
62		return q.parseScalar(v, value, prefix, tag)
63	}
64}
65
66func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error {
67	if !value.IsValid() {
68		return nil
69	}
70
71	t := value.Type()
72	for i := 0; i < value.NumField(); i++ {
73		elemValue := elemOf(value.Field(i))
74		field := t.Field(i)
75
76		if field.PkgPath != "" {
77			continue // ignore unexported fields
78		}
79		if field.Tag.Get("ignore") != "" {
80			continue
81		}
82
83		if protocol.CanSetIdempotencyToken(value.Field(i), field) {
84			token := protocol.GetIdempotencyToken()
85			elemValue = reflect.ValueOf(token)
86		}
87
88		var name string
89		if q.isEC2 {
90			name = field.Tag.Get("queryName")
91		}
92		if name == "" {
93			if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
94				name = field.Tag.Get("locationNameList")
95			} else if locName := field.Tag.Get("locationName"); locName != "" {
96				name = locName
97			}
98			if name != "" && q.isEC2 {
99				name = strings.ToUpper(name[0:1]) + name[1:]
100			}
101		}
102		if name == "" {
103			name = field.Name
104		}
105
106		if prefix != "" {
107			name = prefix + "." + name
108		}
109
110		if err := q.parseValue(v, elemValue, name, field.Tag); err != nil {
111			return err
112		}
113	}
114	return nil
115}
116
117func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
118	// If it's empty, generate an empty value
119	if !value.IsNil() && value.Len() == 0 {
120		v.Set(prefix, "")
121		return nil
122	}
123
124	if _, ok := value.Interface().([]byte); ok {
125		return q.parseScalar(v, value, prefix, tag)
126	}
127
128	// check for unflattened list member
129	if !q.isEC2 && tag.Get("flattened") == "" {
130		if listName := tag.Get("locationNameList"); listName == "" {
131			prefix += ".member"
132		} else {
133			prefix += "." + listName
134		}
135	}
136
137	for i := 0; i < value.Len(); i++ {
138		slicePrefix := prefix
139		if slicePrefix == "" {
140			slicePrefix = strconv.Itoa(i + 1)
141		} else {
142			slicePrefix = slicePrefix + "." + strconv.Itoa(i+1)
143		}
144		if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil {
145			return err
146		}
147	}
148	return nil
149}
150
151func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
152	// If it's empty, generate an empty value
153	if !value.IsNil() && value.Len() == 0 {
154		v.Set(prefix, "")
155		return nil
156	}
157
158	// check for unflattened list member
159	if !q.isEC2 && tag.Get("flattened") == "" {
160		prefix += ".entry"
161	}
162
163	// sort keys for improved serialization consistency.
164	// this is not strictly necessary for protocol support.
165	mapKeyValues := value.MapKeys()
166	mapKeys := map[string]reflect.Value{}
167	mapKeyNames := make([]string, len(mapKeyValues))
168	for i, mapKey := range mapKeyValues {
169		name := mapKey.String()
170		mapKeys[name] = mapKey
171		mapKeyNames[i] = name
172	}
173	sort.Strings(mapKeyNames)
174
175	for i, mapKeyName := range mapKeyNames {
176		mapKey := mapKeys[mapKeyName]
177		mapValue := value.MapIndex(mapKey)
178
179		kname := tag.Get("locationNameKey")
180		if kname == "" {
181			kname = "key"
182		}
183		vname := tag.Get("locationNameValue")
184		if vname == "" {
185			vname = "value"
186		}
187
188		// serialize key
189		var keyName string
190		if prefix == "" {
191			keyName = strconv.Itoa(i+1) + "." + kname
192		} else {
193			keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname
194		}
195
196		if err := q.parseValue(v, mapKey, keyName, ""); err != nil {
197			return err
198		}
199
200		// serialize value
201		var valueName string
202		if prefix == "" {
203			valueName = strconv.Itoa(i+1) + "." + vname
204		} else {
205			valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname
206		}
207
208		if err := q.parseValue(v, mapValue, valueName, ""); err != nil {
209			return err
210		}
211	}
212
213	return nil
214}
215
216func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error {
217	switch value := r.Interface().(type) {
218	case string:
219		v.Set(name, value)
220	case []byte:
221		if !r.IsNil() {
222			v.Set(name, base64.StdEncoding.EncodeToString(value))
223		}
224	case bool:
225		v.Set(name, strconv.FormatBool(value))
226	case int64:
227		v.Set(name, strconv.FormatInt(value, 10))
228	case int:
229		v.Set(name, strconv.Itoa(value))
230	case float64:
231		v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
232	case float32:
233		v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
234	case time.Time:
235		const ISO8601UTC = "2006-01-02T15:04:05Z"
236		format := tag.Get("timestampFormat")
237		if len(format) == 0 {
238			format = protocol.ISO8601TimeFormatName
239		}
240
241		v.Set(name, protocol.FormatTime(format, value))
242	default:
243		return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name())
244	}
245	return nil
246}
247