1package xmlutil
2
3import (
4	"bytes"
5	"encoding/base64"
6	"encoding/xml"
7	"fmt"
8	"io"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13
14	"github.com/aws/aws-sdk-go/aws/awserr"
15	"github.com/aws/aws-sdk-go/private/protocol"
16)
17
18// UnmarshalXMLError unmarshals the XML error from the stream into the value
19// type specified. The value must be a pointer. If the message fails to
20// unmarshal, the message content will be included in the returned error as a
21// awserr.UnmarshalError.
22func UnmarshalXMLError(v interface{}, stream io.Reader) error {
23	var errBuf bytes.Buffer
24	body := io.TeeReader(stream, &errBuf)
25
26	err := xml.NewDecoder(body).Decode(v)
27	if err != nil && err != io.EOF {
28		return awserr.NewUnmarshalError(err,
29			"failed to unmarshal error message", errBuf.Bytes())
30	}
31
32	return nil
33}
34
35// UnmarshalXML deserializes an xml.Decoder into the container v. V
36// needs to match the shape of the XML expected to be decoded.
37// If the shape doesn't match unmarshaling will fail.
38func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
39	n, err := XMLToStruct(d, nil)
40	if err != nil {
41		return err
42	}
43	if n.Children != nil {
44		for _, root := range n.Children {
45			for _, c := range root {
46				if wrappedChild, ok := c.Children[wrapper]; ok {
47					c = wrappedChild[0] // pull out wrapped element
48				}
49
50				err = parse(reflect.ValueOf(v), c, "")
51				if err != nil {
52					if err == io.EOF {
53						return nil
54					}
55					return err
56				}
57			}
58		}
59		return nil
60	}
61	return nil
62}
63
64// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
65// will be used to determine the type from r.
66func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
67	xml := tag.Get("xml")
68	if len(xml) != 0 {
69		name := strings.SplitAfterN(xml, ",", 2)[0]
70		if name == "-" {
71			return nil
72		}
73	}
74
75	rtype := r.Type()
76	if rtype.Kind() == reflect.Ptr {
77		rtype = rtype.Elem() // check kind of actual element type
78	}
79
80	t := tag.Get("type")
81	if t == "" {
82		switch rtype.Kind() {
83		case reflect.Struct:
84			// also it can't be a time object
85			if _, ok := r.Interface().(*time.Time); !ok {
86				t = "structure"
87			}
88		case reflect.Slice:
89			// also it can't be a byte slice
90			if _, ok := r.Interface().([]byte); !ok {
91				t = "list"
92			}
93		case reflect.Map:
94			t = "map"
95		}
96	}
97
98	switch t {
99	case "structure":
100		if field, ok := rtype.FieldByName("_"); ok {
101			tag = field.Tag
102		}
103		return parseStruct(r, node, tag)
104	case "list":
105		return parseList(r, node, tag)
106	case "map":
107		return parseMap(r, node, tag)
108	default:
109		return parseScalar(r, node, tag)
110	}
111}
112
113// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
114// types in the structure will also be deserialized.
115func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
116	t := r.Type()
117	if r.Kind() == reflect.Ptr {
118		if r.IsNil() { // create the structure if it's nil
119			s := reflect.New(r.Type().Elem())
120			r.Set(s)
121			r = s
122		}
123
124		r = r.Elem()
125		t = t.Elem()
126	}
127
128	// unwrap any payloads
129	if payload := tag.Get("payload"); payload != "" {
130		field, _ := t.FieldByName(payload)
131		return parseStruct(r.FieldByName(payload), node, field.Tag)
132	}
133
134	for i := 0; i < t.NumField(); i++ {
135		field := t.Field(i)
136		if c := field.Name[0:1]; strings.ToLower(c) == c {
137			continue // ignore unexported fields
138		}
139
140		// figure out what this field is called
141		name := field.Name
142		if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
143			name = field.Tag.Get("locationNameList")
144		} else if locName := field.Tag.Get("locationName"); locName != "" {
145			name = locName
146		}
147
148		// try to find the field by name in elements
149		elems := node.Children[name]
150
151		if elems == nil { // try to find the field in attributes
152			if val, ok := node.findElem(name); ok {
153				elems = []*XMLNode{{Text: val}}
154			}
155		}
156
157		member := r.FieldByName(field.Name)
158		for _, elem := range elems {
159			err := parse(member, elem, field.Tag)
160			if err != nil {
161				return err
162			}
163		}
164	}
165	return nil
166}
167
168// parseList deserializes a list of values from an XML node. Each list entry
169// will also be deserialized.
170func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
171	t := r.Type()
172
173	if tag.Get("flattened") == "" { // look at all item entries
174		mname := "member"
175		if name := tag.Get("locationNameList"); name != "" {
176			mname = name
177		}
178
179		if Children, ok := node.Children[mname]; ok {
180			if r.IsNil() {
181				r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
182			}
183
184			for i, c := range Children {
185				err := parse(r.Index(i), c, "")
186				if err != nil {
187					return err
188				}
189			}
190		}
191	} else { // flattened list means this is a single element
192		if r.IsNil() {
193			r.Set(reflect.MakeSlice(t, 0, 0))
194		}
195
196		childR := reflect.Zero(t.Elem())
197		r.Set(reflect.Append(r, childR))
198		err := parse(r.Index(r.Len()-1), node, "")
199		if err != nil {
200			return err
201		}
202	}
203
204	return nil
205}
206
207// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
208// will also be deserialized as map entries.
209func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
210	if r.IsNil() {
211		r.Set(reflect.MakeMap(r.Type()))
212	}
213
214	if tag.Get("flattened") == "" { // look at all child entries
215		for _, entry := range node.Children["entry"] {
216			parseMapEntry(r, entry, tag)
217		}
218	} else { // this element is itself an entry
219		parseMapEntry(r, node, tag)
220	}
221
222	return nil
223}
224
225// parseMapEntry deserializes a map entry from a XML node.
226func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
227	kname, vname := "key", "value"
228	if n := tag.Get("locationNameKey"); n != "" {
229		kname = n
230	}
231	if n := tag.Get("locationNameValue"); n != "" {
232		vname = n
233	}
234
235	keys, ok := node.Children[kname]
236	values := node.Children[vname]
237	if ok {
238		for i, key := range keys {
239			keyR := reflect.ValueOf(key.Text)
240			value := values[i]
241			valueR := reflect.New(r.Type().Elem()).Elem()
242
243			parse(valueR, value, "")
244			r.SetMapIndex(keyR, valueR)
245		}
246	}
247	return nil
248}
249
250// parseScaller deserializes an XMLNode value into a concrete type based on the
251// interface type of r.
252//
253// Error is returned if the deserialization fails due to invalid type conversion,
254// or unsupported interface type.
255func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
256	switch r.Interface().(type) {
257	case *string:
258		r.Set(reflect.ValueOf(&node.Text))
259		return nil
260	case []byte:
261		b, err := base64.StdEncoding.DecodeString(node.Text)
262		if err != nil {
263			return err
264		}
265		r.Set(reflect.ValueOf(b))
266	case *bool:
267		v, err := strconv.ParseBool(node.Text)
268		if err != nil {
269			return err
270		}
271		r.Set(reflect.ValueOf(&v))
272	case *int64:
273		v, err := strconv.ParseInt(node.Text, 10, 64)
274		if err != nil {
275			return err
276		}
277		r.Set(reflect.ValueOf(&v))
278	case *float64:
279		v, err := strconv.ParseFloat(node.Text, 64)
280		if err != nil {
281			return err
282		}
283		r.Set(reflect.ValueOf(&v))
284	case *time.Time:
285		format := tag.Get("timestampFormat")
286		if len(format) == 0 {
287			format = protocol.ISO8601TimeFormatName
288		}
289
290		t, err := protocol.ParseTime(format, node.Text)
291		if err != nil {
292			return err
293		}
294		r.Set(reflect.ValueOf(&t))
295	default:
296		return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
297	}
298	return nil
299}
300