1package rest
2
3import (
4	"encoding/base64"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"net/http"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13
14	"github.com/aws/aws-sdk-go/aws"
15	"github.com/aws/aws-sdk-go/aws/awserr"
16	"github.com/aws/aws-sdk-go/aws/request"
17)
18
19// UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
20var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
21
22// UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
23var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
24
25// Unmarshal unmarshals the REST component of a response in a REST service.
26func Unmarshal(r *request.Request) {
27	if r.DataFilled() {
28		v := reflect.Indirect(reflect.ValueOf(r.Data))
29		unmarshalBody(r, v)
30	}
31}
32
33// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
34func UnmarshalMeta(r *request.Request) {
35	r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
36	if r.RequestID == "" {
37		// Alternative version of request id in the header
38		r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
39	}
40	if r.DataFilled() {
41		v := reflect.Indirect(reflect.ValueOf(r.Data))
42		unmarshalLocationElements(r, v)
43	}
44}
45
46func unmarshalBody(r *request.Request, v reflect.Value) {
47	if field, ok := v.Type().FieldByName("_"); ok {
48		if payloadName := field.Tag.Get("payload"); payloadName != "" {
49			pfield, _ := v.Type().FieldByName(payloadName)
50			if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
51				payload := v.FieldByName(payloadName)
52				if payload.IsValid() {
53					switch payload.Interface().(type) {
54					case []byte:
55						defer r.HTTPResponse.Body.Close()
56						b, err := ioutil.ReadAll(r.HTTPResponse.Body)
57						if err != nil {
58							r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
59						} else {
60							payload.Set(reflect.ValueOf(b))
61						}
62					case *string:
63						defer r.HTTPResponse.Body.Close()
64						b, err := ioutil.ReadAll(r.HTTPResponse.Body)
65						if err != nil {
66							r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
67						} else {
68							str := string(b)
69							payload.Set(reflect.ValueOf(&str))
70						}
71					default:
72						switch payload.Type().String() {
73						case "io.ReadSeeker":
74							payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
75						case "aws.ReadSeekCloser", "io.ReadCloser":
76							payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
77						default:
78							io.Copy(ioutil.Discard, r.HTTPResponse.Body)
79							defer r.HTTPResponse.Body.Close()
80							r.Error = awserr.New("SerializationError",
81								"failed to decode REST response",
82								fmt.Errorf("unknown payload type %s", payload.Type()))
83						}
84					}
85				}
86			}
87		}
88	}
89}
90
91func unmarshalLocationElements(r *request.Request, v reflect.Value) {
92	for i := 0; i < v.NumField(); i++ {
93		m, field := v.Field(i), v.Type().Field(i)
94		if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
95			continue
96		}
97
98		if m.IsValid() {
99			name := field.Tag.Get("locationName")
100			if name == "" {
101				name = field.Name
102			}
103
104			switch field.Tag.Get("location") {
105			case "statusCode":
106				unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
107			case "header":
108				err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
109				if err != nil {
110					r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
111					break
112				}
113			case "headers":
114				prefix := field.Tag.Get("locationName")
115				err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
116				if err != nil {
117					r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
118					break
119				}
120			}
121		}
122		if r.Error != nil {
123			return
124		}
125	}
126}
127
128func unmarshalStatusCode(v reflect.Value, statusCode int) {
129	if !v.IsValid() {
130		return
131	}
132
133	switch v.Interface().(type) {
134	case *int64:
135		s := int64(statusCode)
136		v.Set(reflect.ValueOf(&s))
137	}
138}
139
140func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
141	switch r.Interface().(type) {
142	case map[string]*string: // we only support string map value types
143		out := map[string]*string{}
144		for k, v := range headers {
145			k = http.CanonicalHeaderKey(k)
146			if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
147				out[k[len(prefix):]] = &v[0]
148			}
149		}
150		r.Set(reflect.ValueOf(out))
151	}
152	return nil
153}
154
155func unmarshalHeader(v reflect.Value, header string) error {
156	if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
157		return nil
158	}
159
160	switch v.Interface().(type) {
161	case *string:
162		v.Set(reflect.ValueOf(&header))
163	case []byte:
164		b, err := base64.StdEncoding.DecodeString(header)
165		if err != nil {
166			return err
167		}
168		v.Set(reflect.ValueOf(&b))
169	case *bool:
170		b, err := strconv.ParseBool(header)
171		if err != nil {
172			return err
173		}
174		v.Set(reflect.ValueOf(&b))
175	case *int64:
176		i, err := strconv.ParseInt(header, 10, 64)
177		if err != nil {
178			return err
179		}
180		v.Set(reflect.ValueOf(&i))
181	case *float64:
182		f, err := strconv.ParseFloat(header, 64)
183		if err != nil {
184			return err
185		}
186		v.Set(reflect.ValueOf(&f))
187	case *time.Time:
188		t, err := time.Parse(RFC822, header)
189		if err != nil {
190			return err
191		}
192		v.Set(reflect.ValueOf(&t))
193	default:
194		err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
195		return err
196	}
197	return nil
198}
199