1// Package rest provides RESTful serialization of AWS requests and responses.
2package rest
3
4import (
5	"bytes"
6	"encoding/base64"
7	"fmt"
8	"io"
9	"net/http"
10	"net/url"
11	"path"
12	"reflect"
13	"strconv"
14	"strings"
15	"time"
16
17	"github.com/aws/aws-sdk-go/aws"
18	"github.com/aws/aws-sdk-go/aws/awserr"
19	"github.com/aws/aws-sdk-go/aws/request"
20	"github.com/aws/aws-sdk-go/private/protocol"
21)
22
23// RFC822 returns an RFC822 formatted timestamp for AWS protocols
24const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
25
26// Whether the byte value can be sent without escaping in AWS URLs
27var noEscape [256]bool
28
29var errValueNotSet = fmt.Errorf("value not set")
30
31func init() {
32	for i := 0; i < len(noEscape); i++ {
33		// AWS expects every character except these to be escaped
34		noEscape[i] = (i >= 'A' && i <= 'Z') ||
35			(i >= 'a' && i <= 'z') ||
36			(i >= '0' && i <= '9') ||
37			i == '-' ||
38			i == '.' ||
39			i == '_' ||
40			i == '~'
41	}
42}
43
44// BuildHandler is a named request handler for building rest protocol requests
45var BuildHandler = request.NamedHandler{Name: "awssdk.rest.Build", Fn: Build}
46
47// Build builds the REST component of a service request.
48func Build(r *request.Request) {
49	if r.ParamsFilled() {
50		v := reflect.ValueOf(r.Params).Elem()
51		buildLocationElements(r, v, false)
52		buildBody(r, v)
53	}
54}
55
56// BuildAsGET builds the REST component of a service request with the ability to hoist
57// data from the body.
58func BuildAsGET(r *request.Request) {
59	if r.ParamsFilled() {
60		v := reflect.ValueOf(r.Params).Elem()
61		buildLocationElements(r, v, true)
62		buildBody(r, v)
63	}
64}
65
66func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bool) {
67	query := r.HTTPRequest.URL.Query()
68
69	// Setup the raw path to match the base path pattern. This is needed
70	// so that when the path is mutated a custom escaped version can be
71	// stored in RawPath that will be used by the Go client.
72	r.HTTPRequest.URL.RawPath = r.HTTPRequest.URL.Path
73
74	for i := 0; i < v.NumField(); i++ {
75		m := v.Field(i)
76		if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
77			continue
78		}
79
80		if m.IsValid() {
81			field := v.Type().Field(i)
82			name := field.Tag.Get("locationName")
83			if name == "" {
84				name = field.Name
85			}
86			if kind := m.Kind(); kind == reflect.Ptr {
87				m = m.Elem()
88			} else if kind == reflect.Interface {
89				if !m.Elem().IsValid() {
90					continue
91				}
92			}
93			if !m.IsValid() {
94				continue
95			}
96			if field.Tag.Get("ignore") != "" {
97				continue
98			}
99
100			var err error
101			switch field.Tag.Get("location") {
102			case "headers": // header maps
103				err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag)
104			case "header":
105				err = buildHeader(&r.HTTPRequest.Header, m, name, field.Tag)
106			case "uri":
107				err = buildURI(r.HTTPRequest.URL, m, name, field.Tag)
108			case "querystring":
109				err = buildQueryString(query, m, name, field.Tag)
110			default:
111				if buildGETQuery {
112					err = buildQueryString(query, m, name, field.Tag)
113				}
114			}
115			r.Error = err
116		}
117		if r.Error != nil {
118			return
119		}
120	}
121
122	r.HTTPRequest.URL.RawQuery = query.Encode()
123	if !aws.BoolValue(r.Config.DisableRestProtocolURICleaning) {
124		cleanPath(r.HTTPRequest.URL)
125	}
126}
127
128func buildBody(r *request.Request, v reflect.Value) {
129	if field, ok := v.Type().FieldByName("_"); ok {
130		if payloadName := field.Tag.Get("payload"); payloadName != "" {
131			pfield, _ := v.Type().FieldByName(payloadName)
132			if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
133				payload := reflect.Indirect(v.FieldByName(payloadName))
134				if payload.IsValid() && payload.Interface() != nil {
135					switch reader := payload.Interface().(type) {
136					case io.ReadSeeker:
137						r.SetReaderBody(reader)
138					case []byte:
139						r.SetBufferBody(reader)
140					case string:
141						r.SetStringBody(reader)
142					default:
143						r.Error = awserr.New("SerializationError",
144							"failed to encode REST request",
145							fmt.Errorf("unknown payload type %s", payload.Type()))
146					}
147				}
148			}
149		}
150	}
151}
152
153func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect.StructTag) error {
154	str, err := convertType(v, tag)
155	if err == errValueNotSet {
156		return nil
157	} else if err != nil {
158		return awserr.New("SerializationError", "failed to encode REST request", err)
159	}
160
161	header.Add(name, str)
162
163	return nil
164}
165
166func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) error {
167	prefix := tag.Get("locationName")
168	for _, key := range v.MapKeys() {
169		str, err := convertType(v.MapIndex(key), tag)
170		if err == errValueNotSet {
171			continue
172		} else if err != nil {
173			return awserr.New("SerializationError", "failed to encode REST request", err)
174
175		}
176
177		header.Add(prefix+key.String(), str)
178	}
179	return nil
180}
181
182func buildURI(u *url.URL, v reflect.Value, name string, tag reflect.StructTag) error {
183	value, err := convertType(v, tag)
184	if err == errValueNotSet {
185		return nil
186	} else if err != nil {
187		return awserr.New("SerializationError", "failed to encode REST request", err)
188	}
189
190	u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
191	u.Path = strings.Replace(u.Path, "{"+name+"+}", value, -1)
192
193	u.RawPath = strings.Replace(u.RawPath, "{"+name+"}", EscapePath(value, true), -1)
194	u.RawPath = strings.Replace(u.RawPath, "{"+name+"+}", EscapePath(value, false), -1)
195
196	return nil
197}
198
199func buildQueryString(query url.Values, v reflect.Value, name string, tag reflect.StructTag) error {
200	switch value := v.Interface().(type) {
201	case []*string:
202		for _, item := range value {
203			query.Add(name, *item)
204		}
205	case map[string]*string:
206		for key, item := range value {
207			query.Add(key, *item)
208		}
209	case map[string][]*string:
210		for key, items := range value {
211			for _, item := range items {
212				query.Add(key, *item)
213			}
214		}
215	default:
216		str, err := convertType(v, tag)
217		if err == errValueNotSet {
218			return nil
219		} else if err != nil {
220			return awserr.New("SerializationError", "failed to encode REST request", err)
221		}
222		query.Set(name, str)
223	}
224
225	return nil
226}
227
228func cleanPath(u *url.URL) {
229	hasSlash := strings.HasSuffix(u.Path, "/")
230
231	// clean up path, removing duplicate `/`
232	u.Path = path.Clean(u.Path)
233	u.RawPath = path.Clean(u.RawPath)
234
235	if hasSlash && !strings.HasSuffix(u.Path, "/") {
236		u.Path += "/"
237		u.RawPath += "/"
238	}
239}
240
241// EscapePath escapes part of a URL path in Amazon style
242func EscapePath(path string, encodeSep bool) string {
243	var buf bytes.Buffer
244	for i := 0; i < len(path); i++ {
245		c := path[i]
246		if noEscape[c] || (c == '/' && !encodeSep) {
247			buf.WriteByte(c)
248		} else {
249			fmt.Fprintf(&buf, "%%%02X", c)
250		}
251	}
252	return buf.String()
253}
254
255func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error) {
256	v = reflect.Indirect(v)
257	if !v.IsValid() {
258		return "", errValueNotSet
259	}
260
261	switch value := v.Interface().(type) {
262	case string:
263		str = value
264	case []byte:
265		str = base64.StdEncoding.EncodeToString(value)
266	case bool:
267		str = strconv.FormatBool(value)
268	case int64:
269		str = strconv.FormatInt(value, 10)
270	case float64:
271		str = strconv.FormatFloat(value, 'f', -1, 64)
272	case time.Time:
273		str = value.UTC().Format(RFC822)
274	case aws.JSONValue:
275		if len(value) == 0 {
276			return "", errValueNotSet
277		}
278		escaping := protocol.NoEscape
279		if tag.Get("location") == "header" {
280			escaping = protocol.Base64Escape
281		}
282		str, err = protocol.EncodeJSONValue(value, escaping)
283		if err != nil {
284			return "", fmt.Errorf("unable to encode JSONValue, %v", err)
285		}
286	default:
287		err := fmt.Errorf("unsupported value for param %v (%s)", v.Interface(), v.Type())
288		return "", err
289	}
290	return str, nil
291}
292