1// +build go1.13
2
3// Copyright (c) Microsoft Corporation. All rights reserved.
4// Licensed under the MIT License.
5
6package azcore
7
8import (
9	"bytes"
10	"encoding/base64"
11	"encoding/json"
12	"encoding/xml"
13	"fmt"
14	"io"
15	"io/ioutil"
16	"net/http"
17	"sort"
18	"strconv"
19	"strings"
20	"time"
21)
22
23// Response represents the response from an HTTP request.
24type Response struct {
25	*http.Response
26}
27
28func (r *Response) payload() ([]byte, error) {
29	// r.Body won't be a nopClosingBytesReader if downloading was skipped
30	if buf, ok := r.Body.(*nopClosingBytesReader); ok {
31		return buf.Bytes(), nil
32	}
33	bytesBody, err := ioutil.ReadAll(r.Body)
34	r.Body.Close()
35	if err != nil {
36		return nil, err
37	}
38	r.Body = &nopClosingBytesReader{s: bytesBody, i: 0}
39	return bytesBody, nil
40}
41
42// HasStatusCode returns true if the Response's status code is one of the specified values.
43func (r *Response) HasStatusCode(statusCodes ...int) bool {
44	if r == nil {
45		return false
46	}
47	for _, sc := range statusCodes {
48		if r.StatusCode == sc {
49			return true
50		}
51	}
52	return false
53}
54
55// UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v.
56func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error {
57	p, err := r.payload()
58	if err != nil {
59		return err
60	}
61	if len(p) == 0 {
62		return nil
63	}
64	payload := string(p)
65	if payload[0] == '"' {
66		// remove surrounding quotes
67		payload = payload[1 : len(payload)-1]
68	}
69	switch format {
70	case Base64StdFormat:
71		decoded, err := base64.StdEncoding.DecodeString(payload)
72		if err == nil {
73			*v = &decoded
74			return nil
75		}
76		return err
77	case Base64URLFormat:
78		// use raw encoding as URL format should not contain any '=' characters
79		decoded, err := base64.RawURLEncoding.DecodeString(payload)
80		if err == nil {
81			*v = &decoded
82			return nil
83		}
84		return err
85	default:
86		return fmt.Errorf("unrecognized byte array format: %d", format)
87	}
88}
89
90// UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v.
91func (r *Response) UnmarshalAsJSON(v interface{}) error {
92	payload, err := r.payload()
93	if err != nil {
94		return err
95	}
96	// TODO: verify early exit is correct
97	if len(payload) == 0 {
98		return nil
99	}
100	err = r.removeBOM()
101	if err != nil {
102		return err
103	}
104	err = json.Unmarshal(payload, v)
105	if err != nil {
106		err = fmt.Errorf("unmarshalling type %T: %s", v, err)
107	}
108	return err
109}
110
111// UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v.
112func (r *Response) UnmarshalAsXML(v interface{}) error {
113	payload, err := r.payload()
114	if err != nil {
115		return err
116	}
117	// TODO: verify early exit is correct
118	if len(payload) == 0 {
119		return nil
120	}
121	err = r.removeBOM()
122	if err != nil {
123		return err
124	}
125	err = xml.Unmarshal(payload, v)
126	if err != nil {
127		err = fmt.Errorf("unmarshalling type %T: %s", v, err)
128	}
129	return err
130}
131
132// Drain reads the response body to completion then closes it.  The bytes read are discarded.
133func (r *Response) Drain() {
134	if r != nil && r.Body != nil {
135		_, _ = io.Copy(ioutil.Discard, r.Body)
136		r.Body.Close()
137	}
138}
139
140// removeBOM removes any byte-order mark prefix from the payload if present.
141func (r *Response) removeBOM() error {
142	payload, err := r.payload()
143	if err != nil {
144		return err
145	}
146	// UTF8
147	trimmed := bytes.TrimPrefix(payload, []byte("\xef\xbb\xbf"))
148	if len(trimmed) < len(payload) {
149		r.Body.(*nopClosingBytesReader).Set(trimmed)
150	}
151	return nil
152}
153
154// helper to reduce nil Response checks
155func (r *Response) retryAfter() time.Duration {
156	if r == nil {
157		return 0
158	}
159	return RetryAfter(r.Response)
160}
161
162// writes to a buffer, used for logging purposes
163func (r *Response) writeBody(b *bytes.Buffer) error {
164	if ct := r.Header.Get(HeaderContentType); !shouldLogBody(b, ct) {
165		return nil
166	}
167	body, err := r.payload()
168	if err != nil {
169		fmt.Fprintf(b, "   Failed to read response body: %s\n", err.Error())
170		return err
171	}
172	if len(body) > 0 {
173		logBody(b, body)
174	} else {
175		fmt.Fprint(b, "   Response contained no body\n")
176	}
177	return nil
178}
179
180// RetryAfter returns non-zero if the response contains a Retry-After header value.
181func RetryAfter(resp *http.Response) time.Duration {
182	if resp == nil {
183		return 0
184	}
185	ra := resp.Header.Get(HeaderRetryAfter)
186	if ra == "" {
187		return 0
188	}
189	// retry-after values are expressed in either number of
190	// seconds or an HTTP-date indicating when to try again
191	if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 {
192		return time.Duration(retryAfter) * time.Second
193	} else if t, err := time.Parse(time.RFC1123, ra); err == nil {
194		return time.Until(t)
195	}
196	return 0
197}
198
199// writeRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are
200// not nil, then these are also written into the Buffer.
201func writeRequestWithResponse(b *bytes.Buffer, request *Request, response *Response, err error) {
202	// Write the request into the buffer.
203	fmt.Fprint(b, "   "+request.Method+" "+request.URL.String()+"\n")
204	writeHeader(b, request.Header)
205	if response != nil {
206		fmt.Fprintln(b, "   --------------------------------------------------------------------------------")
207		fmt.Fprint(b, "   RESPONSE Status: "+response.Status+"\n")
208		writeHeader(b, response.Header)
209	}
210	if err != nil {
211		fmt.Fprintln(b, "   --------------------------------------------------------------------------------")
212		fmt.Fprint(b, "   ERROR:\n"+err.Error()+"\n")
213	}
214}
215
216// formatHeaders appends an HTTP request's or response's header into a Buffer.
217func writeHeader(b *bytes.Buffer, header http.Header) {
218	if len(header) == 0 {
219		b.WriteString("   (no headers)\n")
220		return
221	}
222	keys := make([]string, 0, len(header))
223	// Alphabetize the headers
224	for k := range header {
225		keys = append(keys, k)
226	}
227	sort.Strings(keys)
228	for _, k := range keys {
229		// Redact the value of any Authorization header to prevent security information from persisting in logs
230		value := interface{}("REDACTED")
231		if !strings.EqualFold(k, "Authorization") {
232			value = header[k]
233		}
234		fmt.Fprintf(b, "   %s: %+v\n", k, value)
235	}
236}
237