1// +build go1.13
2
3// Copyright (c) Microsoft Corporation. All rights reserved.
4// Licensed under the MIT License.
5
6package azcore
7
8import (
9	"bytes"
10	"context"
11	"encoding/base64"
12	"encoding/json"
13	"encoding/xml"
14	"errors"
15	"fmt"
16	"io"
17	"io/ioutil"
18	"mime/multipart"
19	"net/http"
20	"reflect"
21	"strconv"
22	"strings"
23
24	"golang.org/x/net/http/httpguts"
25)
26
27const (
28	contentTypeAppJSON = "application/json"
29	contentTypeAppXML  = "application/xml"
30)
31
32// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when
33// encoding/decoding a slice of bytes to/from a string.
34type Base64Encoding int
35
36const (
37	// Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads.
38	Base64StdFormat Base64Encoding = 0
39
40	// Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads.
41	Base64URLFormat Base64Encoding = 1
42)
43
44// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline.
45// Don't use this type directly, use NewRequest() instead.
46type Request struct {
47	*http.Request
48	body     ReadSeekCloser
49	policies []Policy
50	values   opValues
51}
52
53type opValues map[reflect.Type]interface{}
54
55// Set adds/changes a value
56func (ov opValues) set(value interface{}) {
57	ov[reflect.TypeOf(value)] = value
58}
59
60// Get looks for a value set by SetValue first
61func (ov opValues) get(value interface{}) bool {
62	v, ok := ov[reflect.ValueOf(value).Elem().Type()]
63	if ok {
64		reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v))
65	}
66	return ok
67}
68
69// JoinPaths concatenates multiple URL path segments into one path,
70// inserting path separation characters as required.
71func JoinPaths(paths ...string) string {
72	if len(paths) == 0 {
73		return ""
74	}
75	path := paths[0]
76	for i := 1; i < len(paths); i++ {
77		if path[len(path)-1] == '/' && paths[i][0] == '/' {
78			// strip off trailing '/' to avoid doubling up
79			path = path[:len(path)-1]
80		} else if path[len(path)-1] != '/' && paths[i][0] != '/' {
81			// add a trailing '/'
82			path = path + "/"
83		}
84		path += paths[i]
85	}
86	return path
87}
88
89// NewRequest creates a new Request with the specified input.
90func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) {
91	req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil)
92	if err != nil {
93		return nil, err
94	}
95	if req.URL.Host == "" {
96		return nil, errors.New("no Host in request URL")
97	}
98	if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") {
99		return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme)
100	}
101	return &Request{Request: req}, nil
102}
103
104// Next calls the next policy in the pipeline.
105// If there are no more policies, nil and ErrNoMorePolicies are returned.
106// This method is intended to be called from pipeline policies.
107// To send a request through a pipeline call Pipeline.Do().
108func (req *Request) Next() (*Response, error) {
109	if len(req.policies) == 0 {
110		return nil, ErrNoMorePolicies
111	}
112	nextPolicy := req.policies[0]
113	nextReq := *req
114	nextReq.policies = nextReq.policies[1:]
115	return nextPolicy.Do(&nextReq)
116}
117
118// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody.
119// The encoded value is treated as a JSON string.
120func (req *Request) MarshalAsByteArray(v []byte, format Base64Encoding) error {
121	// send as a JSON string
122	encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format))
123	return req.SetBody(NopCloser(strings.NewReader(encode)), contentTypeAppJSON)
124}
125
126// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody.
127func (req *Request) MarshalAsJSON(v interface{}) error {
128	v = cloneWithoutReadOnlyFields(v)
129	b, err := json.Marshal(v)
130	if err != nil {
131		return fmt.Errorf("error marshalling type %T: %s", v, err)
132	}
133	return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppJSON)
134}
135
136// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody.
137func (req *Request) MarshalAsXML(v interface{}) error {
138	b, err := xml.Marshal(v)
139	if err != nil {
140		return fmt.Errorf("error marshalling type %T: %s", v, err)
141	}
142	return req.SetBody(NopCloser(bytes.NewReader(b)), contentTypeAppXML)
143}
144
145// SetOperationValue adds/changes a mutable key/value associated with a single operation.
146func (req *Request) SetOperationValue(value interface{}) {
147	if req.values == nil {
148		req.values = opValues{}
149	}
150	req.values.set(value)
151}
152
153// OperationValue looks for a value set by SetOperationValue().
154func (req *Request) OperationValue(value interface{}) bool {
155	if req.values == nil {
156		return false
157	}
158	return req.values.get(value)
159}
160
161// SetBody sets the specified ReadSeekCloser as the HTTP request body.
162func (req *Request) SetBody(body ReadSeekCloser, contentType string) error {
163	// Set the body and content length.
164	size, err := body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size
165	if err != nil {
166		return err
167	}
168	if size == 0 {
169		body.Close()
170		return nil
171	}
172	_, err = body.Seek(0, io.SeekStart)
173	if err != nil {
174		return err
175	}
176	// keep a copy of the original body.  this is to handle cases
177	// where req.Body is replaced, e.g. httputil.DumpRequest and friends.
178	req.body = body
179	req.Request.Body = body
180	req.Request.ContentLength = size
181	req.Header.Set(HeaderContentType, contentType)
182	req.Header.Set(HeaderContentLength, strconv.FormatInt(size, 10))
183	return nil
184}
185
186// SetMultipartFormData writes the specified keys/values as multi-part form
187// fields with the specified value.  File content must be specified as a ReadSeekCloser.
188// All other values are treated as string values.
189func (req *Request) SetMultipartFormData(formData map[string]interface{}) error {
190	body := bytes.Buffer{}
191	writer := multipart.NewWriter(&body)
192	for k, v := range formData {
193		if rsc, ok := v.(ReadSeekCloser); ok {
194			// this is the body to upload, the key is its file name
195			fd, err := writer.CreateFormFile(k, k)
196			if err != nil {
197				return err
198			}
199			// copy the data to the form file
200			if _, err = io.Copy(fd, rsc); err != nil {
201				return err
202			}
203			continue
204		}
205		// ensure the value is in string format
206		s, ok := v.(string)
207		if !ok {
208			s = fmt.Sprintf("%v", v)
209		}
210		if err := writer.WriteField(k, s); err != nil {
211			return err
212		}
213	}
214	if err := writer.Close(); err != nil {
215		return err
216	}
217	req.body = NopCloser(bytes.NewReader(body.Bytes()))
218	req.Body = req.body
219	req.ContentLength = int64(body.Len())
220	req.Header.Set(HeaderContentType, writer.FormDataContentType())
221	req.Header.Set(HeaderContentLength, strconv.FormatInt(req.ContentLength, 10))
222	return nil
223}
224
225// SkipBodyDownload will disable automatic downloading of the response body.
226func (req *Request) SkipBodyDownload() {
227	req.SetOperationValue(bodyDownloadPolicyOpValues{skip: true})
228}
229
230// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation.
231func (req *Request) RewindBody() error {
232	if req.body != nil {
233		// Reset the stream back to the beginning and restore the body
234		_, err := req.body.Seek(0, io.SeekStart)
235		req.Body = req.body
236		return err
237	}
238	return nil
239}
240
241// Close closes the request body.
242func (req *Request) Close() error {
243	if req.Body == nil {
244		return nil
245	}
246	return req.Body.Close()
247}
248
249// Telemetry adds telemetry data to the request.
250// If telemetry reporting is disabled the value is discarded.
251func (req *Request) Telemetry(v string) {
252	req.SetOperationValue(requestTelemetry(v))
253}
254
255// clone returns a deep copy of the request with its context changed to ctx
256func (req *Request) clone(ctx context.Context) *Request {
257	r2 := Request{}
258	r2 = *req
259	r2.Request = req.Request.Clone(ctx)
260	return &r2
261}
262
263// valid returns nil if the underlying http.Request is well-formed.
264func (req *Request) valid() error {
265	// check copied from Transport.roundTrip()
266	for k, vv := range req.Header {
267		if !httpguts.ValidHeaderFieldName(k) {
268			req.Close()
269			return fmt.Errorf("invalid header field name %q", k)
270		}
271		for _, v := range vv {
272			if !httpguts.ValidHeaderFieldValue(v) {
273				req.Close()
274				return fmt.Errorf("invalid header field value %q for key %v", v, k)
275			}
276		}
277	}
278	return nil
279}
280
281// writes to a buffer, used for logging purposes
282func (req *Request) writeBody(b *bytes.Buffer) error {
283	if req.Body == nil {
284		fmt.Fprint(b, "   Request contained no body\n")
285		return nil
286	}
287	if ct := req.Header.Get(HeaderContentType); !shouldLogBody(b, ct) {
288		return nil
289	}
290	body, err := ioutil.ReadAll(req.Body)
291	if err != nil {
292		fmt.Fprintf(b, "   Failed to read request body: %s\n", err.Error())
293		return err
294	}
295	if err := req.RewindBody(); err != nil {
296		return err
297	}
298	logBody(b, body)
299	return nil
300}
301
302// EncodeByteArray will base-64 encode the byte slice v.
303func EncodeByteArray(v []byte, format Base64Encoding) string {
304	if format == Base64URLFormat {
305		return base64.RawURLEncoding.EncodeToString(v)
306	}
307	return base64.StdEncoding.EncodeToString(v)
308}
309
310// returns a clone of the object graph pointed to by v, omitting values of all read-only
311// fields. if there are no read-only fields in the object graph, no clone is created.
312func cloneWithoutReadOnlyFields(v interface{}) interface{} {
313	val := reflect.Indirect(reflect.ValueOf(v))
314	if val.Kind() != reflect.Struct {
315		// not a struct, skip
316		return v
317	}
318	// first walk the graph to find any R/O fields.
319	// if there aren't any, skip cloning the graph.
320	if !recursiveFindReadOnlyField(val) {
321		return v
322	}
323	return recursiveCloneWithoutReadOnlyFields(val)
324}
325
326// returns true if any field in the object graph of val contains the `azure:"ro"` tag value
327func recursiveFindReadOnlyField(val reflect.Value) bool {
328	t := val.Type()
329	// iterate over the fields, looking for the "azure" tag.
330	for i := 0; i < t.NumField(); i++ {
331		field := t.Field(i)
332		aztag := field.Tag.Get("azure")
333		if azureTagIsReadOnly(aztag) {
334			return true
335		} else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct && recursiveFindReadOnlyField(reflect.Indirect(val.Field(i))) {
336			return true
337		}
338	}
339	return false
340}
341
342// clones the object graph of val.  all non-R/O properties are copied to the clone
343func recursiveCloneWithoutReadOnlyFields(val reflect.Value) interface{} {
344	clone := reflect.New(val.Type())
345	t := val.Type()
346	// iterate over the fields, looking for the "azure" tag.
347	for i := 0; i < t.NumField(); i++ {
348		field := t.Field(i)
349		aztag := field.Tag.Get("azure")
350		if azureTagIsReadOnly(aztag) {
351			// omit from payload
352		} else if reflect.Indirect(val.Field(i)).Kind() == reflect.Struct {
353			// recursive case
354			v := recursiveCloneWithoutReadOnlyFields(reflect.Indirect(val.Field(i)))
355			if t.Field(i).Anonymous {
356				// NOTE: this does not handle the case of embedded fields of unexported struct types.
357				// this should be ok as we don't generate any code like this at present
358				reflect.Indirect(clone).Field(i).Set(reflect.Indirect(reflect.ValueOf(v)))
359			} else {
360				reflect.Indirect(clone).Field(i).Set(reflect.ValueOf(v))
361			}
362		} else {
363			// no azure RO tag, non-recursive case, include in payload
364			reflect.Indirect(clone).Field(i).Set(val.Field(i))
365		}
366	}
367	return clone.Interface()
368}
369
370// returns true if the "azure" tag contains the option "ro"
371func azureTagIsReadOnly(tag string) bool {
372	if tag == "" {
373		return false
374	}
375	parts := strings.Split(tag, ",")
376	for _, part := range parts {
377		if part == "ro" {
378			return true
379		}
380	}
381	return false
382}
383
384func logBody(b *bytes.Buffer, body []byte) {
385	fmt.Fprintln(b, "   --------------------------------------------------------------------------------")
386	fmt.Fprintln(b, string(body))
387	fmt.Fprintln(b, "   --------------------------------------------------------------------------------")
388}
389