1// Copyright (c) 2015-2019 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
2// resty source code and usage is governed by a MIT style
3// license that can be found in the LICENSE file.
4
5package resty
6
7import (
8	"bytes"
9	"encoding/xml"
10	"errors"
11	"fmt"
12	"io"
13	"mime/multipart"
14	"net/http"
15	"net/url"
16	"os"
17	"path/filepath"
18	"reflect"
19	"strings"
20	"time"
21)
22
23//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
24// Request Middleware(s)
25//___________________________________
26
27func parseRequestURL(c *Client, r *Request) error {
28	// GitHub #103 Path Params
29	if len(r.pathParams) > 0 {
30		for p, v := range r.pathParams {
31			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
32		}
33	}
34	if len(c.pathParams) > 0 {
35		for p, v := range c.pathParams {
36			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
37		}
38	}
39
40	// Parsing request URL
41	reqURL, err := url.Parse(r.URL)
42	if err != nil {
43		return err
44	}
45
46	// If Request.URL is relative path then added c.HostURL into
47	// the request URL otherwise Request.URL will be used as-is
48	if !reqURL.IsAbs() {
49		r.URL = reqURL.String()
50		if len(r.URL) > 0 && r.URL[0] != '/' {
51			r.URL = "/" + r.URL
52		}
53
54		reqURL, err = url.Parse(c.HostURL + r.URL)
55		if err != nil {
56			return err
57		}
58	}
59
60	// Adding Query Param
61	query := make(url.Values)
62	for k, v := range c.QueryParam {
63		for _, iv := range v {
64			query.Add(k, iv)
65		}
66	}
67
68	for k, v := range r.QueryParam {
69		// remove query param from client level by key
70		// since overrides happens for that key in the request
71		query.Del(k)
72
73		for _, iv := range v {
74			query.Add(k, iv)
75		}
76	}
77
78	// GitHub #123 Preserve query string order partially.
79	// Since not feasible in `SetQuery*` resty methods, because
80	// standard package `url.Encode(...)` sorts the query params
81	// alphabetically
82	if len(query) > 0 {
83		if IsStringEmpty(reqURL.RawQuery) {
84			reqURL.RawQuery = query.Encode()
85		} else {
86			reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode()
87		}
88	}
89
90	r.URL = reqURL.String()
91
92	return nil
93}
94
95func parseRequestHeader(c *Client, r *Request) error {
96	hdr := make(http.Header)
97	for k := range c.Header {
98		hdr[k] = append(hdr[k], c.Header[k]...)
99	}
100
101	for k := range r.Header {
102		hdr.Del(k)
103		hdr[k] = append(hdr[k], r.Header[k]...)
104	}
105
106	if IsStringEmpty(hdr.Get(hdrUserAgentKey)) {
107		hdr.Set(hdrUserAgentKey, fmt.Sprintf(hdrUserAgentValue, Version))
108	}
109
110	ct := hdr.Get(hdrContentTypeKey)
111	if IsStringEmpty(hdr.Get(hdrAcceptKey)) && !IsStringEmpty(ct) &&
112		(IsJSONType(ct) || IsXMLType(ct)) {
113		hdr.Set(hdrAcceptKey, hdr.Get(hdrContentTypeKey))
114	}
115
116	r.Header = hdr
117
118	return nil
119}
120
121func parseRequestBody(c *Client, r *Request) (err error) {
122	if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
123		// Handling Multipart
124		if r.isMultiPart && !(r.Method == MethodPatch) {
125			if err = handleMultipart(c, r); err != nil {
126				return
127			}
128
129			goto CL
130		}
131
132		// Handling Form Data
133		if len(c.FormData) > 0 || len(r.FormData) > 0 {
134			handleFormData(c, r)
135
136			goto CL
137		}
138
139		// Handling Request body
140		if r.Body != nil {
141			handleContentType(c, r)
142
143			if err = handleRequestBody(c, r); err != nil {
144				return
145			}
146		}
147	}
148
149CL:
150	// by default resty won't set content length, you can if you want to :)
151	if (c.setContentLength || r.setContentLength) && r.bodyBuf != nil {
152		r.Header.Set(hdrContentLengthKey, fmt.Sprintf("%d", r.bodyBuf.Len()))
153	}
154
155	return
156}
157
158func createHTTPRequest(c *Client, r *Request) (err error) {
159	if r.bodyBuf == nil {
160		if reader, ok := r.Body.(io.Reader); ok {
161			r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
162		} else {
163			r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
164		}
165	} else {
166		r.RawRequest, err = http.NewRequest(r.Method, r.URL, r.bodyBuf)
167	}
168
169	if err != nil {
170		return
171	}
172
173	// Assign close connection option
174	r.RawRequest.Close = c.closeConnection
175
176	// Add headers into http request
177	r.RawRequest.Header = r.Header
178
179	// Add cookies into http request
180	for _, cookie := range c.Cookies {
181		r.RawRequest.AddCookie(cookie)
182	}
183
184	// it's for non-http scheme option
185	if r.RawRequest.URL != nil && r.RawRequest.URL.Scheme == "" {
186		r.RawRequest.URL.Scheme = c.scheme
187		r.RawRequest.URL.Host = r.URL
188	}
189
190	// Use context if it was specified
191	r.addContextIfAvailable()
192
193	return
194}
195
196func addCredentials(c *Client, r *Request) error {
197	var isBasicAuth bool
198	// Basic Auth
199	if r.UserInfo != nil { // takes precedence
200		r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
201		isBasicAuth = true
202	} else if c.UserInfo != nil {
203		r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
204		isBasicAuth = true
205	}
206
207	if !c.DisableWarn {
208		if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
209			c.Log.Println("WARNING - Using Basic Auth in HTTP mode is not secure.")
210		}
211	}
212
213	// Token Auth
214	if !IsStringEmpty(r.Token) { // takes precedence
215		r.RawRequest.Header.Set(hdrAuthorizationKey, "Bearer "+r.Token)
216	} else if !IsStringEmpty(c.Token) {
217		r.RawRequest.Header.Set(hdrAuthorizationKey, "Bearer "+c.Token)
218	}
219
220	return nil
221}
222
223func requestLogger(c *Client, r *Request) error {
224	if c.Debug {
225		rr := r.RawRequest
226		rl := &RequestLog{Header: copyHeaders(rr.Header), Body: r.fmtBodyString()}
227		if c.requestLog != nil {
228			if err := c.requestLog(rl); err != nil {
229				return err
230			}
231		}
232
233		reqLog := "\n---------------------- REQUEST LOG -----------------------\n" +
234			fmt.Sprintf("%s  %s  %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) +
235			fmt.Sprintf("HOST   : %s\n", rr.URL.Host) +
236			fmt.Sprintf("HEADERS:\n") +
237			composeHeaders(rl.Header) + "\n" +
238			fmt.Sprintf("BODY   :\n%v\n", rl.Body) +
239			"----------------------------------------------------------\n"
240
241		c.Log.Print(reqLog)
242	}
243
244	return nil
245}
246
247//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
248// Response Middleware(s)
249//___________________________________
250
251func responseLogger(c *Client, res *Response) error {
252	if c.Debug {
253		rl := &ResponseLog{Header: copyHeaders(res.Header()), Body: res.fmtBodyString(c.debugBodySizeLimit)}
254		if c.responseLog != nil {
255			if err := c.responseLog(rl); err != nil {
256				return err
257			}
258		}
259
260		resLog := "\n---------------------- RESPONSE LOG -----------------------\n" +
261			fmt.Sprintf("STATUS 		: %s\n", res.Status()) +
262			fmt.Sprintf("RECEIVED AT	: %v\n", res.ReceivedAt().Format(time.RFC3339Nano)) +
263			fmt.Sprintf("RESPONSE TIME	: %v\n", res.Time()) +
264			"HEADERS:\n" +
265			composeHeaders(rl.Header) + "\n"
266		if res.Request.isSaveResponse {
267			resLog += fmt.Sprintf("BODY   :\n***** RESPONSE WRITTEN INTO FILE *****\n")
268		} else {
269			resLog += fmt.Sprintf("BODY   :\n%v\n", rl.Body)
270		}
271		resLog += "----------------------------------------------------------\n"
272
273		c.Log.Print(resLog)
274	}
275
276	return nil
277}
278
279func parseResponseBody(c *Client, res *Response) (err error) {
280	if res.StatusCode() == http.StatusNoContent {
281		return
282	}
283	// Handles only JSON or XML content type
284	ct := firstNonEmpty(res.Header().Get(hdrContentTypeKey), res.Request.fallbackContentType)
285	if IsJSONType(ct) || IsXMLType(ct) {
286		// HTTP status code > 199 and < 300, considered as Result
287		if res.IsSuccess() {
288			if res.Request.Result != nil {
289				err = Unmarshalc(c, ct, res.body, res.Request.Result)
290				return
291			}
292		}
293
294		// HTTP status code > 399, considered as Error
295		if res.IsError() {
296			// global error interface
297			if res.Request.Error == nil && c.Error != nil {
298				res.Request.Error = reflect.New(c.Error).Interface()
299			}
300
301			if res.Request.Error != nil {
302				err = Unmarshalc(c, ct, res.body, res.Request.Error)
303			}
304		}
305	}
306
307	return
308}
309
310func handleMultipart(c *Client, r *Request) (err error) {
311	r.bodyBuf = acquireBuffer()
312	w := multipart.NewWriter(r.bodyBuf)
313
314	for k, v := range c.FormData {
315		for _, iv := range v {
316			if err = w.WriteField(k, iv); err != nil {
317				return err
318			}
319		}
320	}
321
322	for k, v := range r.FormData {
323		for _, iv := range v {
324			if strings.HasPrefix(k, "@") { // file
325				err = addFile(w, k[1:], iv)
326				if err != nil {
327					return
328				}
329			} else { // form value
330				if err = w.WriteField(k, iv); err != nil {
331					return err
332				}
333			}
334		}
335	}
336
337	// #21 - adding io.Reader support
338	if len(r.multipartFiles) > 0 {
339		for _, f := range r.multipartFiles {
340			err = addFileReader(w, f)
341			if err != nil {
342				return
343			}
344		}
345	}
346
347	// GitHub #130 adding multipart field support with content type
348	if len(r.multipartFields) > 0 {
349		for _, mf := range r.multipartFields {
350			if err = addMultipartFormField(w, mf); err != nil {
351				return
352			}
353		}
354	}
355
356	r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
357	err = w.Close()
358
359	return
360}
361
362func handleFormData(c *Client, r *Request) {
363	formData := url.Values{}
364
365	for k, v := range c.FormData {
366		for _, iv := range v {
367			formData.Add(k, iv)
368		}
369	}
370
371	for k, v := range r.FormData {
372		// remove form data field from client level by key
373		// since overrides happens for that key in the request
374		formData.Del(k)
375
376		for _, iv := range v {
377			formData.Add(k, iv)
378		}
379	}
380
381	r.bodyBuf = bytes.NewBuffer([]byte(formData.Encode()))
382	r.Header.Set(hdrContentTypeKey, formContentType)
383	r.isFormData = true
384}
385
386func handleContentType(c *Client, r *Request) {
387	contentType := r.Header.Get(hdrContentTypeKey)
388	if IsStringEmpty(contentType) {
389		contentType = DetectContentType(r.Body)
390		r.Header.Set(hdrContentTypeKey, contentType)
391	}
392}
393
394func handleRequestBody(c *Client, r *Request) (err error) {
395	var bodyBytes []byte
396	contentType := r.Header.Get(hdrContentTypeKey)
397	kind := kindOf(r.Body)
398	r.bodyBuf = nil
399
400	if reader, ok := r.Body.(io.Reader); ok {
401		if c.setContentLength || r.setContentLength { // keep backward compability
402			r.bodyBuf = acquireBuffer()
403			_, err = r.bodyBuf.ReadFrom(reader)
404			r.Body = nil
405		} else {
406			// Otherwise buffer less processing for `io.Reader`, sounds good.
407			return
408		}
409	} else if b, ok := r.Body.([]byte); ok {
410		bodyBytes = b
411	} else if s, ok := r.Body.(string); ok {
412		bodyBytes = []byte(s)
413	} else if IsJSONType(contentType) &&
414		(kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
415		bodyBytes, err = jsonMarshal(c, r, r.Body)
416	} else if IsXMLType(contentType) && (kind == reflect.Struct) {
417		bodyBytes, err = xml.Marshal(r.Body)
418	}
419
420	if bodyBytes == nil && r.bodyBuf == nil {
421		err = errors.New("unsupported 'Body' type/value")
422	}
423
424	// if any errors during body bytes handling, return it
425	if err != nil {
426		return
427	}
428
429	// []byte into Buffer
430	if bodyBytes != nil && r.bodyBuf == nil {
431		r.bodyBuf = acquireBuffer()
432		_, _ = r.bodyBuf.Write(bodyBytes)
433	}
434
435	return
436}
437
438func saveResponseIntoFile(c *Client, res *Response) error {
439	if res.Request.isSaveResponse {
440		file := ""
441
442		if len(c.outputDirectory) > 0 && !filepath.IsAbs(res.Request.outputFile) {
443			file += c.outputDirectory + string(filepath.Separator)
444		}
445
446		file = filepath.Clean(file + res.Request.outputFile)
447		if err := createDirectory(filepath.Dir(file)); err != nil {
448			return err
449		}
450
451		outFile, err := os.Create(file)
452		if err != nil {
453			return err
454		}
455		defer closeq(outFile)
456
457		// io.Copy reads maximum 32kb size, it is perfect for large file download too
458		defer closeq(res.RawResponse.Body)
459
460		written, err := io.Copy(outFile, res.RawResponse.Body)
461		if err != nil {
462			return err
463		}
464
465		res.size = written
466	}
467
468	return nil
469}
470