1// Package rest provides RESTful serialization of AWS requests and responses.
2package rest
3
4import (
5	"bytes"
6	"encoding/base64"
7	"fmt"
8	"io"
9	"net/http"
10	"net/url"
11	"path"
12	"reflect"
13	"strconv"
14	"strings"
15	"time"
16
17	"github.com/aws/aws-sdk-go/aws"
18	"github.com/aws/aws-sdk-go/aws/awserr"
19	"github.com/aws/aws-sdk-go/aws/request"
20)
21
22// RFC822 returns an RFC822 formatted timestamp for AWS protocols
23const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
24
25// Whether the byte value can be sent without escaping in AWS URLs
26var noEscape [256]bool
27
28var errValueNotSet = fmt.Errorf("value not set")
29
30func init() {
31	for i := 0; i < len(noEscape); i++ {
32		// AWS expects every character except these to be escaped
33		noEscape[i] = (i >= 'A' && i <= 'Z') ||
34			(i >= 'a' && i <= 'z') ||
35			(i >= '0' && i <= '9') ||
36			i == '-' ||
37			i == '.' ||
38			i == '_' ||
39			i == '~'
40	}
41}
42
43// BuildHandler is a named request handler for building rest protocol requests
44var BuildHandler = request.NamedHandler{Name: "awssdk.rest.Build", Fn: Build}
45
46// Build builds the REST component of a service request.
47func Build(r *request.Request) {
48	if r.ParamsFilled() {
49		v := reflect.ValueOf(r.Params).Elem()
50		buildLocationElements(r, v, false)
51		buildBody(r, v)
52	}
53}
54
55// BuildAsGET builds the REST component of a service request with the ability to hoist
56// data from the body.
57func BuildAsGET(r *request.Request) {
58	if r.ParamsFilled() {
59		v := reflect.ValueOf(r.Params).Elem()
60		buildLocationElements(r, v, true)
61		buildBody(r, v)
62	}
63}
64
65func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bool) {
66	query := r.HTTPRequest.URL.Query()
67
68	// Setup the raw path to match the base path pattern. This is needed
69	// so that when the path is mutated a custom escaped version can be
70	// stored in RawPath that will be used by the Go client.
71	r.HTTPRequest.URL.RawPath = r.HTTPRequest.URL.Path
72
73	for i := 0; i < v.NumField(); i++ {
74		m := v.Field(i)
75		if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
76			continue
77		}
78
79		if m.IsValid() {
80			field := v.Type().Field(i)
81			name := field.Tag.Get("locationName")
82			if name == "" {
83				name = field.Name
84			}
85			if m.Kind() == reflect.Ptr {
86				m = m.Elem()
87			}
88			if !m.IsValid() {
89				continue
90			}
91			if field.Tag.Get("ignore") != "" {
92				continue
93			}
94
95			var err error
96			switch field.Tag.Get("location") {
97			case "headers": // header maps
98				err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag.Get("locationName"))
99			case "header":
100				err = buildHeader(&r.HTTPRequest.Header, m, name)
101			case "uri":
102				err = buildURI(r.HTTPRequest.URL, m, name)
103			case "querystring":
104				err = buildQueryString(query, m, name)
105			default:
106				if buildGETQuery {
107					err = buildQueryString(query, m, name)
108				}
109			}
110			r.Error = err
111		}
112		if r.Error != nil {
113			return
114		}
115	}
116
117	r.HTTPRequest.URL.RawQuery = query.Encode()
118	if !aws.BoolValue(r.Config.DisableRestProtocolURICleaning) {
119		cleanPath(r.HTTPRequest.URL)
120	}
121}
122
123func buildBody(r *request.Request, v reflect.Value) {
124	if field, ok := v.Type().FieldByName("_"); ok {
125		if payloadName := field.Tag.Get("payload"); payloadName != "" {
126			pfield, _ := v.Type().FieldByName(payloadName)
127			if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
128				payload := reflect.Indirect(v.FieldByName(payloadName))
129				if payload.IsValid() && payload.Interface() != nil {
130					switch reader := payload.Interface().(type) {
131					case io.ReadSeeker:
132						r.SetReaderBody(reader)
133					case []byte:
134						r.SetBufferBody(reader)
135					case string:
136						r.SetStringBody(reader)
137					default:
138						r.Error = awserr.New("SerializationError",
139							"failed to encode REST request",
140							fmt.Errorf("unknown payload type %s", payload.Type()))
141					}
142				}
143			}
144		}
145	}
146}
147
148func buildHeader(header *http.Header, v reflect.Value, name string) error {
149	str, err := convertType(v)
150	if err == errValueNotSet {
151		return nil
152	} else if err != nil {
153		return awserr.New("SerializationError", "failed to encode REST request", err)
154	}
155
156	header.Add(name, str)
157
158	return nil
159}
160
161func buildHeaderMap(header *http.Header, v reflect.Value, prefix string) error {
162	for _, key := range v.MapKeys() {
163		str, err := convertType(v.MapIndex(key))
164		if err == errValueNotSet {
165			continue
166		} else if err != nil {
167			return awserr.New("SerializationError", "failed to encode REST request", err)
168
169		}
170
171		header.Add(prefix+key.String(), str)
172	}
173	return nil
174}
175
176func buildURI(u *url.URL, v reflect.Value, name string) error {
177	value, err := convertType(v)
178	if err == errValueNotSet {
179		return nil
180	} else if err != nil {
181		return awserr.New("SerializationError", "failed to encode REST request", err)
182	}
183
184	u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
185	u.Path = strings.Replace(u.Path, "{"+name+"+}", value, -1)
186
187	u.RawPath = strings.Replace(u.RawPath, "{"+name+"}", EscapePath(value, true), -1)
188	u.RawPath = strings.Replace(u.RawPath, "{"+name+"+}", EscapePath(value, false), -1)
189
190	return nil
191}
192
193func buildQueryString(query url.Values, v reflect.Value, name string) error {
194	switch value := v.Interface().(type) {
195	case []*string:
196		for _, item := range value {
197			query.Add(name, *item)
198		}
199	case map[string]*string:
200		for key, item := range value {
201			query.Add(key, *item)
202		}
203	case map[string][]*string:
204		for key, items := range value {
205			for _, item := range items {
206				query.Add(key, *item)
207			}
208		}
209	default:
210		str, err := convertType(v)
211		if err == errValueNotSet {
212			return nil
213		} else if err != nil {
214			return awserr.New("SerializationError", "failed to encode REST request", err)
215		}
216		query.Set(name, str)
217	}
218
219	return nil
220}
221
222func cleanPath(u *url.URL) {
223	hasSlash := strings.HasSuffix(u.Path, "/")
224
225	// clean up path, removing duplicate `/`
226	u.Path = path.Clean(u.Path)
227	u.RawPath = path.Clean(u.RawPath)
228
229	if hasSlash && !strings.HasSuffix(u.Path, "/") {
230		u.Path += "/"
231		u.RawPath += "/"
232	}
233}
234
235// EscapePath escapes part of a URL path in Amazon style
236func EscapePath(path string, encodeSep bool) string {
237	var buf bytes.Buffer
238	for i := 0; i < len(path); i++ {
239		c := path[i]
240		if noEscape[c] || (c == '/' && !encodeSep) {
241			buf.WriteByte(c)
242		} else {
243			fmt.Fprintf(&buf, "%%%02X", c)
244		}
245	}
246	return buf.String()
247}
248
249func convertType(v reflect.Value) (string, error) {
250	v = reflect.Indirect(v)
251	if !v.IsValid() {
252		return "", errValueNotSet
253	}
254
255	var str string
256	switch value := v.Interface().(type) {
257	case string:
258		str = value
259	case []byte:
260		str = base64.StdEncoding.EncodeToString(value)
261	case bool:
262		str = strconv.FormatBool(value)
263	case int64:
264		str = strconv.FormatInt(value, 10)
265	case float64:
266		str = strconv.FormatFloat(value, 'f', -1, 64)
267	case time.Time:
268		str = value.UTC().Format(RFC822)
269	default:
270		err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
271		return "", err
272	}
273	return str, nil
274}
275