1// Copyright (c) 2015-2019 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
2// resty source code and usage is governed by a MIT style
3// license that can be found in the LICENSE file.
4
5package resty
6
7import (
8	"bytes"
9	"encoding/xml"
10	"errors"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"mime/multipart"
15	"net/http"
16	"net/url"
17	"os"
18	"path/filepath"
19	"reflect"
20	"strings"
21	"time"
22)
23
24const debugRequestLogKey = "__restyDebugRequestLog"
25
26//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
27// Request Middleware(s)
28//_______________________________________________________________________
29
30func parseRequestURL(c *Client, r *Request) error {
31	// GitHub #103 Path Params
32	if len(r.pathParams) > 0 {
33		for p, v := range r.pathParams {
34			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
35		}
36	}
37	if len(c.pathParams) > 0 {
38		for p, v := range c.pathParams {
39			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
40		}
41	}
42
43	// Parsing request URL
44	reqURL, err := url.Parse(r.URL)
45	if err != nil {
46		return err
47	}
48
49	// If Request.URL is relative path then added c.HostURL into
50	// the request URL otherwise Request.URL will be used as-is
51	if !reqURL.IsAbs() {
52		r.URL = reqURL.String()
53		if len(r.URL) > 0 && r.URL[0] != '/' {
54			r.URL = "/" + r.URL
55		}
56
57		reqURL, err = url.Parse(c.HostURL + r.URL)
58		if err != nil {
59			return err
60		}
61	}
62
63	// Adding Query Param
64	query := make(url.Values)
65	for k, v := range c.QueryParam {
66		for _, iv := range v {
67			query.Add(k, iv)
68		}
69	}
70
71	for k, v := range r.QueryParam {
72		// remove query param from client level by key
73		// since overrides happens for that key in the request
74		query.Del(k)
75
76		for _, iv := range v {
77			query.Add(k, iv)
78		}
79	}
80
81	// GitHub #123 Preserve query string order partially.
82	// Since not feasible in `SetQuery*` resty methods, because
83	// standard package `url.Encode(...)` sorts the query params
84	// alphabetically
85	if len(query) > 0 {
86		if IsStringEmpty(reqURL.RawQuery) {
87			reqURL.RawQuery = query.Encode()
88		} else {
89			reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode()
90		}
91	}
92
93	r.URL = reqURL.String()
94
95	return nil
96}
97
98func parseRequestHeader(c *Client, r *Request) error {
99	hdr := make(http.Header)
100	for k := range c.Header {
101		hdr[k] = append(hdr[k], c.Header[k]...)
102	}
103
104	for k := range r.Header {
105		hdr.Del(k)
106		hdr[k] = append(hdr[k], r.Header[k]...)
107	}
108
109	if IsStringEmpty(hdr.Get(hdrUserAgentKey)) {
110		hdr.Set(hdrUserAgentKey, hdrUserAgentValue)
111	}
112
113	ct := hdr.Get(hdrContentTypeKey)
114	if IsStringEmpty(hdr.Get(hdrAcceptKey)) && !IsStringEmpty(ct) &&
115		(IsJSONType(ct) || IsXMLType(ct)) {
116		hdr.Set(hdrAcceptKey, hdr.Get(hdrContentTypeKey))
117	}
118
119	r.Header = hdr
120
121	return nil
122}
123
124func parseRequestBody(c *Client, r *Request) (err error) {
125	if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
126		// Handling Multipart
127		if r.isMultiPart && !(r.Method == MethodPatch) {
128			if err = handleMultipart(c, r); err != nil {
129				return
130			}
131
132			goto CL
133		}
134
135		// Handling Form Data
136		if len(c.FormData) > 0 || len(r.FormData) > 0 {
137			handleFormData(c, r)
138
139			goto CL
140		}
141
142		// Handling Request body
143		if r.Body != nil {
144			handleContentType(c, r)
145
146			if err = handleRequestBody(c, r); err != nil {
147				return
148			}
149		}
150	}
151
152CL:
153	// by default resty won't set content length, you can if you want to :)
154	if (c.setContentLength || r.setContentLength) && r.bodyBuf != nil {
155		r.Header.Set(hdrContentLengthKey, fmt.Sprintf("%d", r.bodyBuf.Len()))
156	}
157
158	return
159}
160
161func createHTTPRequest(c *Client, r *Request) (err error) {
162	if r.bodyBuf == nil {
163		if reader, ok := r.Body.(io.Reader); ok {
164			r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
165		} else {
166			r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
167		}
168	} else {
169		r.RawRequest, err = http.NewRequest(r.Method, r.URL, r.bodyBuf)
170	}
171
172	if err != nil {
173		return
174	}
175
176	// Assign close connection option
177	r.RawRequest.Close = c.closeConnection
178
179	// Add headers into http request
180	r.RawRequest.Header = r.Header
181
182	// Add cookies from client instance into http request
183	for _, cookie := range c.Cookies {
184		r.RawRequest.AddCookie(cookie)
185	}
186
187	// Add cookies from request instance into http request
188	for _, cookie := range r.Cookies {
189		r.RawRequest.AddCookie(cookie)
190	}
191
192	// it's for non-http scheme option
193	if r.RawRequest.URL != nil && r.RawRequest.URL.Scheme == "" {
194		r.RawRequest.URL.Scheme = c.scheme
195		r.RawRequest.URL.Host = r.URL
196	}
197
198	// Enable trace
199	if c.trace || r.trace {
200		r.clientTrace = &clientTrace{}
201		r.ctx = r.clientTrace.createContext(r.Context())
202	}
203
204	// Use context if it was specified
205	if r.ctx != nil {
206		r.RawRequest = r.RawRequest.WithContext(r.ctx)
207	}
208
209	// assign get body func for the underlying raw request instance
210	r.RawRequest.GetBody = func() (io.ReadCloser, error) {
211		// If r.bodyBuf present, return the copy
212		if r.bodyBuf != nil {
213			return ioutil.NopCloser(bytes.NewReader(r.bodyBuf.Bytes())), nil
214		}
215
216		// Maybe body is `io.Reader`.
217		// Note: Resty user have to watchout for large body size of `io.Reader`
218		if r.RawRequest.Body != nil {
219			b, err := ioutil.ReadAll(r.RawRequest.Body)
220			if err != nil {
221				return nil, err
222			}
223
224			// Restore the Body
225			closeq(r.RawRequest.Body)
226			r.RawRequest.Body = ioutil.NopCloser(bytes.NewBuffer(b))
227
228			// Return the Body bytes
229			return ioutil.NopCloser(bytes.NewBuffer(b)), nil
230		}
231
232		return nil, nil
233	}
234
235	return
236}
237
238func addCredentials(c *Client, r *Request) error {
239	var isBasicAuth bool
240	// Basic Auth
241	if r.UserInfo != nil { // takes precedence
242		r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
243		isBasicAuth = true
244	} else if c.UserInfo != nil {
245		r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
246		isBasicAuth = true
247	}
248
249	if !c.DisableWarn {
250		if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
251			c.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
252		}
253	}
254
255	// Token Auth
256	if !IsStringEmpty(r.Token) { // takes precedence
257		r.RawRequest.Header.Set(hdrAuthorizationKey, "Bearer "+r.Token)
258	} else if !IsStringEmpty(c.Token) {
259		r.RawRequest.Header.Set(hdrAuthorizationKey, "Bearer "+c.Token)
260	}
261
262	return nil
263}
264
265func requestLogger(c *Client, r *Request) error {
266	if c.Debug {
267		rr := r.RawRequest
268		rl := &RequestLog{Header: copyHeaders(rr.Header), Body: r.fmtBodyString()}
269		if c.requestLog != nil {
270			if err := c.requestLog(rl); err != nil {
271				return err
272			}
273		}
274		// fmt.Sprintf("COOKIES:\n%s\n", composeCookies(c.GetClient().Jar, *rr.URL)) +
275
276		reqLog := "\n==============================================================================\n" +
277			"~~~ REQUEST ~~~\n" +
278			fmt.Sprintf("%s  %s  %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) +
279			fmt.Sprintf("HOST   : %s\n", rr.URL.Host) +
280			fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(c, r, rl.Header)) +
281			fmt.Sprintf("BODY   :\n%v\n", rl.Body) +
282			"------------------------------------------------------------------------------\n"
283
284		r.initValuesMap()
285		r.values[debugRequestLogKey] = reqLog
286	}
287
288	return nil
289}
290
291//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
292// Response Middleware(s)
293//_______________________________________________________________________
294
295func responseLogger(c *Client, res *Response) error {
296	if c.Debug {
297		rl := &ResponseLog{Header: copyHeaders(res.Header()), Body: res.fmtBodyString(c.debugBodySizeLimit)}
298		if c.responseLog != nil {
299			if err := c.responseLog(rl); err != nil {
300				return err
301			}
302		}
303
304		debugLog := res.Request.values[debugRequestLogKey].(string)
305		debugLog += "~~~ RESPONSE ~~~\n" +
306			fmt.Sprintf("STATUS       : %s\n", res.Status()) +
307			fmt.Sprintf("RECEIVED AT  : %v\n", res.ReceivedAt().Format(time.RFC3339Nano)) +
308			fmt.Sprintf("TIME DURATION: %v\n", res.Time()) +
309			"HEADERS      :\n" +
310			composeHeaders(c, res.Request, rl.Header) + "\n"
311		if res.Request.isSaveResponse {
312			debugLog += fmt.Sprintf("BODY         :\n***** RESPONSE WRITTEN INTO FILE *****\n")
313		} else {
314			debugLog += fmt.Sprintf("BODY         :\n%v\n", rl.Body)
315		}
316		debugLog += "==============================================================================\n"
317
318		c.log.Debugf(debugLog)
319	}
320
321	return nil
322}
323
324func parseResponseBody(c *Client, res *Response) (err error) {
325	if res.StatusCode() == http.StatusNoContent {
326		return
327	}
328	// Handles only JSON or XML content type
329	ct := firstNonEmpty(res.Header().Get(hdrContentTypeKey), res.Request.fallbackContentType)
330	if IsJSONType(ct) || IsXMLType(ct) {
331		// HTTP status code > 199 and < 300, considered as Result
332		if res.IsSuccess() {
333			res.Request.Error = nil
334			if res.Request.Result != nil {
335				err = Unmarshalc(c, ct, res.body, res.Request.Result)
336				return
337			}
338		}
339
340		// HTTP status code > 399, considered as Error
341		if res.IsError() {
342			// global error interface
343			if res.Request.Error == nil && c.Error != nil {
344				res.Request.Error = reflect.New(c.Error).Interface()
345			}
346
347			if res.Request.Error != nil {
348				err = Unmarshalc(c, ct, res.body, res.Request.Error)
349			}
350		}
351	}
352
353	return
354}
355
356func handleMultipart(c *Client, r *Request) (err error) {
357	r.bodyBuf = acquireBuffer()
358	w := multipart.NewWriter(r.bodyBuf)
359
360	for k, v := range c.FormData {
361		for _, iv := range v {
362			if err = w.WriteField(k, iv); err != nil {
363				return err
364			}
365		}
366	}
367
368	for k, v := range r.FormData {
369		for _, iv := range v {
370			if strings.HasPrefix(k, "@") { // file
371				err = addFile(w, k[1:], iv)
372				if err != nil {
373					return
374				}
375			} else { // form value
376				if err = w.WriteField(k, iv); err != nil {
377					return err
378				}
379			}
380		}
381	}
382
383	// #21 - adding io.Reader support
384	if len(r.multipartFiles) > 0 {
385		for _, f := range r.multipartFiles {
386			err = addFileReader(w, f)
387			if err != nil {
388				return
389			}
390		}
391	}
392
393	// GitHub #130 adding multipart field support with content type
394	if len(r.multipartFields) > 0 {
395		for _, mf := range r.multipartFields {
396			if err = addMultipartFormField(w, mf); err != nil {
397				return
398			}
399		}
400	}
401
402	r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
403	err = w.Close()
404
405	return
406}
407
408func handleFormData(c *Client, r *Request) {
409	formData := url.Values{}
410
411	for k, v := range c.FormData {
412		for _, iv := range v {
413			formData.Add(k, iv)
414		}
415	}
416
417	for k, v := range r.FormData {
418		// remove form data field from client level by key
419		// since overrides happens for that key in the request
420		formData.Del(k)
421
422		for _, iv := range v {
423			formData.Add(k, iv)
424		}
425	}
426
427	r.bodyBuf = bytes.NewBuffer([]byte(formData.Encode()))
428	r.Header.Set(hdrContentTypeKey, formContentType)
429	r.isFormData = true
430}
431
432func handleContentType(c *Client, r *Request) {
433	contentType := r.Header.Get(hdrContentTypeKey)
434	if IsStringEmpty(contentType) {
435		contentType = DetectContentType(r.Body)
436		r.Header.Set(hdrContentTypeKey, contentType)
437	}
438}
439
440func handleRequestBody(c *Client, r *Request) (err error) {
441	var bodyBytes []byte
442	contentType := r.Header.Get(hdrContentTypeKey)
443	kind := kindOf(r.Body)
444	r.bodyBuf = nil
445
446	if reader, ok := r.Body.(io.Reader); ok {
447		if c.setContentLength || r.setContentLength { // keep backward compability
448			r.bodyBuf = acquireBuffer()
449			_, err = r.bodyBuf.ReadFrom(reader)
450			r.Body = nil
451		} else {
452			// Otherwise buffer less processing for `io.Reader`, sounds good.
453			return
454		}
455	} else if b, ok := r.Body.([]byte); ok {
456		bodyBytes = b
457	} else if s, ok := r.Body.(string); ok {
458		bodyBytes = []byte(s)
459	} else if IsJSONType(contentType) &&
460		(kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
461		bodyBytes, err = jsonMarshal(c, r, r.Body)
462	} else if IsXMLType(contentType) && (kind == reflect.Struct) {
463		bodyBytes, err = xml.Marshal(r.Body)
464	}
465
466	if bodyBytes == nil && r.bodyBuf == nil {
467		err = errors.New("unsupported 'Body' type/value")
468	}
469
470	// if any errors during body bytes handling, return it
471	if err != nil {
472		return
473	}
474
475	// []byte into Buffer
476	if bodyBytes != nil && r.bodyBuf == nil {
477		r.bodyBuf = acquireBuffer()
478		_, _ = r.bodyBuf.Write(bodyBytes)
479	}
480
481	return
482}
483
484func saveResponseIntoFile(c *Client, res *Response) error {
485	if res.Request.isSaveResponse {
486		file := ""
487
488		if len(c.outputDirectory) > 0 && !filepath.IsAbs(res.Request.outputFile) {
489			file += c.outputDirectory + string(filepath.Separator)
490		}
491
492		file = filepath.Clean(file + res.Request.outputFile)
493		if err := createDirectory(filepath.Dir(file)); err != nil {
494			return err
495		}
496
497		outFile, err := os.Create(file)
498		if err != nil {
499			return err
500		}
501		defer closeq(outFile)
502
503		// io.Copy reads maximum 32kb size, it is perfect for large file download too
504		defer closeq(res.RawResponse.Body)
505
506		written, err := io.Copy(outFile, res.RawResponse.Body)
507		if err != nil {
508			return err
509		}
510
511		res.size = written
512	}
513
514	return nil
515}
516