1package xmlutil 2 3import ( 4 "bytes" 5 "encoding/base64" 6 "encoding/xml" 7 "fmt" 8 "io" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/aws/aws-sdk-go/aws/awserr" 15 "github.com/aws/aws-sdk-go/private/protocol" 16) 17 18// UnmarshalXMLError unmarshals the XML error from the stream into the value 19// type specified. The value must be a pointer. If the message fails to 20// unmarshal, the message content will be included in the returned error as a 21// awserr.UnmarshalError. 22func UnmarshalXMLError(v interface{}, stream io.Reader) error { 23 var errBuf bytes.Buffer 24 body := io.TeeReader(stream, &errBuf) 25 26 err := xml.NewDecoder(body).Decode(v) 27 if err != nil && err != io.EOF { 28 return awserr.NewUnmarshalError(err, 29 "failed to unmarshal error message", errBuf.Bytes()) 30 } 31 32 return nil 33} 34 35// UnmarshalXML deserializes an xml.Decoder into the container v. V 36// needs to match the shape of the XML expected to be decoded. 37// If the shape doesn't match unmarshaling will fail. 38func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error { 39 n, err := XMLToStruct(d, nil) 40 if err != nil { 41 return err 42 } 43 if n.Children != nil { 44 for _, root := range n.Children { 45 for _, c := range root { 46 if wrappedChild, ok := c.Children[wrapper]; ok { 47 c = wrappedChild[0] // pull out wrapped element 48 } 49 50 err = parse(reflect.ValueOf(v), c, "") 51 if err != nil { 52 if err == io.EOF { 53 return nil 54 } 55 return err 56 } 57 } 58 } 59 return nil 60 } 61 return nil 62} 63 64// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect 65// will be used to determine the type from r. 66func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { 67 xml := tag.Get("xml") 68 if len(xml) != 0 { 69 name := strings.SplitAfterN(xml, ",", 2)[0] 70 if name == "-" { 71 return nil 72 } 73 } 74 75 rtype := r.Type() 76 if rtype.Kind() == reflect.Ptr { 77 rtype = rtype.Elem() // check kind of actual element type 78 } 79 80 t := tag.Get("type") 81 if t == "" { 82 switch rtype.Kind() { 83 case reflect.Struct: 84 // also it can't be a time object 85 if _, ok := r.Interface().(*time.Time); !ok { 86 t = "structure" 87 } 88 case reflect.Slice: 89 // also it can't be a byte slice 90 if _, ok := r.Interface().([]byte); !ok { 91 t = "list" 92 } 93 case reflect.Map: 94 t = "map" 95 } 96 } 97 98 switch t { 99 case "structure": 100 if field, ok := rtype.FieldByName("_"); ok { 101 tag = field.Tag 102 } 103 return parseStruct(r, node, tag) 104 case "list": 105 return parseList(r, node, tag) 106 case "map": 107 return parseMap(r, node, tag) 108 default: 109 return parseScalar(r, node, tag) 110 } 111} 112 113// parseStruct deserializes a structure and its fields from an XMLNode. Any nested 114// types in the structure will also be deserialized. 115func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { 116 t := r.Type() 117 if r.Kind() == reflect.Ptr { 118 if r.IsNil() { // create the structure if it's nil 119 s := reflect.New(r.Type().Elem()) 120 r.Set(s) 121 r = s 122 } 123 124 r = r.Elem() 125 t = t.Elem() 126 } 127 128 // unwrap any payloads 129 if payload := tag.Get("payload"); payload != "" { 130 field, _ := t.FieldByName(payload) 131 return parseStruct(r.FieldByName(payload), node, field.Tag) 132 } 133 134 for i := 0; i < t.NumField(); i++ { 135 field := t.Field(i) 136 if c := field.Name[0:1]; strings.ToLower(c) == c { 137 continue // ignore unexported fields 138 } 139 140 // figure out what this field is called 141 name := field.Name 142 if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" { 143 name = field.Tag.Get("locationNameList") 144 } else if locName := field.Tag.Get("locationName"); locName != "" { 145 name = locName 146 } 147 148 // try to find the field by name in elements 149 elems := node.Children[name] 150 151 if elems == nil { // try to find the field in attributes 152 if val, ok := node.findElem(name); ok { 153 elems = []*XMLNode{{Text: val}} 154 } 155 } 156 157 member := r.FieldByName(field.Name) 158 for _, elem := range elems { 159 err := parse(member, elem, field.Tag) 160 if err != nil { 161 return err 162 } 163 } 164 } 165 return nil 166} 167 168// parseList deserializes a list of values from an XML node. Each list entry 169// will also be deserialized. 170func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { 171 t := r.Type() 172 173 if tag.Get("flattened") == "" { // look at all item entries 174 mname := "member" 175 if name := tag.Get("locationNameList"); name != "" { 176 mname = name 177 } 178 179 if Children, ok := node.Children[mname]; ok { 180 if r.IsNil() { 181 r.Set(reflect.MakeSlice(t, len(Children), len(Children))) 182 } 183 184 for i, c := range Children { 185 err := parse(r.Index(i), c, "") 186 if err != nil { 187 return err 188 } 189 } 190 } 191 } else { // flattened list means this is a single element 192 if r.IsNil() { 193 r.Set(reflect.MakeSlice(t, 0, 0)) 194 } 195 196 childR := reflect.Zero(t.Elem()) 197 r.Set(reflect.Append(r, childR)) 198 err := parse(r.Index(r.Len()-1), node, "") 199 if err != nil { 200 return err 201 } 202 } 203 204 return nil 205} 206 207// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode 208// will also be deserialized as map entries. 209func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { 210 if r.IsNil() { 211 r.Set(reflect.MakeMap(r.Type())) 212 } 213 214 if tag.Get("flattened") == "" { // look at all child entries 215 for _, entry := range node.Children["entry"] { 216 parseMapEntry(r, entry, tag) 217 } 218 } else { // this element is itself an entry 219 parseMapEntry(r, node, tag) 220 } 221 222 return nil 223} 224 225// parseMapEntry deserializes a map entry from a XML node. 226func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { 227 kname, vname := "key", "value" 228 if n := tag.Get("locationNameKey"); n != "" { 229 kname = n 230 } 231 if n := tag.Get("locationNameValue"); n != "" { 232 vname = n 233 } 234 235 keys, ok := node.Children[kname] 236 values := node.Children[vname] 237 if ok { 238 for i, key := range keys { 239 keyR := reflect.ValueOf(key.Text) 240 value := values[i] 241 valueR := reflect.New(r.Type().Elem()).Elem() 242 243 parse(valueR, value, "") 244 r.SetMapIndex(keyR, valueR) 245 } 246 } 247 return nil 248} 249 250// parseScaller deserializes an XMLNode value into a concrete type based on the 251// interface type of r. 252// 253// Error is returned if the deserialization fails due to invalid type conversion, 254// or unsupported interface type. 255func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { 256 switch r.Interface().(type) { 257 case *string: 258 r.Set(reflect.ValueOf(&node.Text)) 259 return nil 260 case []byte: 261 b, err := base64.StdEncoding.DecodeString(node.Text) 262 if err != nil { 263 return err 264 } 265 r.Set(reflect.ValueOf(b)) 266 case *bool: 267 v, err := strconv.ParseBool(node.Text) 268 if err != nil { 269 return err 270 } 271 r.Set(reflect.ValueOf(&v)) 272 case *int64: 273 v, err := strconv.ParseInt(node.Text, 10, 64) 274 if err != nil { 275 return err 276 } 277 r.Set(reflect.ValueOf(&v)) 278 case *float64: 279 v, err := strconv.ParseFloat(node.Text, 64) 280 if err != nil { 281 return err 282 } 283 r.Set(reflect.ValueOf(&v)) 284 case *time.Time: 285 format := tag.Get("timestampFormat") 286 if len(format) == 0 { 287 format = protocol.ISO8601TimeFormatName 288 } 289 290 t, err := protocol.ParseTime(format, node.Text) 291 if err != nil { 292 return err 293 } 294 r.Set(reflect.ValueOf(&t)) 295 default: 296 return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type()) 297 } 298 return nil 299} 300