1// Package rest implements a simple REST wrapper
2//
3// All methods are safe for concurrent calling.
4package rest
5
6import (
7	"bytes"
8	"context"
9	"encoding/json"
10	"encoding/xml"
11	"io"
12	"io/ioutil"
13	"mime/multipart"
14	"net/http"
15	"net/url"
16	"sync"
17
18	"github.com/pkg/errors"
19	"github.com/rclone/rclone/fs"
20	"github.com/rclone/rclone/lib/readers"
21)
22
23// Client contains the info to sustain the API
24type Client struct {
25	mu           sync.RWMutex
26	c            *http.Client
27	rootURL      string
28	errorHandler func(resp *http.Response) error
29	headers      map[string]string
30	signer       SignerFn
31}
32
33// NewClient takes an oauth http.Client and makes a new api instance
34func NewClient(c *http.Client) *Client {
35	api := &Client{
36		c:            c,
37		errorHandler: defaultErrorHandler,
38		headers:      make(map[string]string),
39	}
40	return api
41}
42
43// ReadBody reads resp.Body into result, closing the body
44func ReadBody(resp *http.Response) (result []byte, err error) {
45	defer fs.CheckClose(resp.Body, &err)
46	return ioutil.ReadAll(resp.Body)
47}
48
49// defaultErrorHandler doesn't attempt to parse the http body, just
50// returns it in the error message closing resp.Body
51func defaultErrorHandler(resp *http.Response) (err error) {
52	body, err := ReadBody(resp)
53	if err != nil {
54		return errors.Wrap(err, "error reading error out of body")
55	}
56	return errors.Errorf("HTTP error %v (%v) returned body: %q", resp.StatusCode, resp.Status, body)
57}
58
59// SetErrorHandler sets the handler to decode an error response when
60// the HTTP status code is not 2xx.  The handler should close resp.Body.
61func (api *Client) SetErrorHandler(fn func(resp *http.Response) error) *Client {
62	api.mu.Lock()
63	defer api.mu.Unlock()
64	api.errorHandler = fn
65	return api
66}
67
68// SetRoot sets the default RootURL.  You can override this on a per
69// call basis using the RootURL field in Opts.
70func (api *Client) SetRoot(RootURL string) *Client {
71	api.mu.Lock()
72	defer api.mu.Unlock()
73	api.rootURL = RootURL
74	return api
75}
76
77// SetHeader sets a header for all requests
78// Start the key with "*" for don't canonicalise
79func (api *Client) SetHeader(key, value string) *Client {
80	api.mu.Lock()
81	defer api.mu.Unlock()
82	api.headers[key] = value
83	return api
84}
85
86// RemoveHeader unsets a header for all requests
87func (api *Client) RemoveHeader(key string) *Client {
88	api.mu.Lock()
89	defer api.mu.Unlock()
90	delete(api.headers, key)
91	return api
92}
93
94// SignerFn is used to sign an outgoing request
95type SignerFn func(*http.Request) error
96
97// SetSigner sets a signer for all requests
98func (api *Client) SetSigner(signer SignerFn) *Client {
99	api.mu.Lock()
100	defer api.mu.Unlock()
101	api.signer = signer
102	return api
103}
104
105// SetUserPass creates an Authorization header for all requests with
106// the UserName and Password passed in
107func (api *Client) SetUserPass(UserName, Password string) *Client {
108	req, _ := http.NewRequest("GET", "http://example.com", nil)
109	req.SetBasicAuth(UserName, Password)
110	api.SetHeader("Authorization", req.Header.Get("Authorization"))
111	return api
112}
113
114// SetCookie creates a Cookies Header for all requests with the supplied
115// cookies passed in.
116// All cookies have to be supplied at once, all cookies will be overwritten
117// on a new call to the method
118func (api *Client) SetCookie(cks ...*http.Cookie) *Client {
119	req, _ := http.NewRequest("GET", "http://example.com", nil)
120	for _, ck := range cks {
121		req.AddCookie(ck)
122	}
123	api.SetHeader("Cookie", req.Header.Get("Cookie"))
124	return api
125}
126
127// Opts contains parameters for Call, CallJSON, etc.
128type Opts struct {
129	Method                string // GET, POST, etc.
130	Path                  string // relative to RootURL
131	RootURL               string // override RootURL passed into SetRoot()
132	Body                  io.Reader
133	NoResponse            bool // set to close Body
134	ContentType           string
135	ContentLength         *int64
136	ContentRange          string
137	ExtraHeaders          map[string]string // extra headers, start them with "*" for don't canonicalise
138	UserName              string            // username for Basic Auth
139	Password              string            // password for Basic Auth
140	Options               []fs.OpenOption
141	IgnoreStatus          bool       // if set then we don't check error status or parse error body
142	MultipartParams       url.Values // if set do multipart form upload with attached file
143	MultipartMetadataName string     // ..this is used for the name of the metadata form part if set
144	MultipartContentName  string     // ..name of the parameter which is the attached file
145	MultipartFileName     string     // ..name of the file for the attached file
146	Parameters            url.Values // any parameters for the final URL
147	TransferEncoding      []string   // transfer encoding, set to "identity" to disable chunked encoding
148	Close                 bool       // set to close the connection after this transaction
149	NoRedirect            bool       // if this is set then the client won't follow redirects
150}
151
152// Copy creates a copy of the options
153func (o *Opts) Copy() *Opts {
154	newOpts := *o
155	return &newOpts
156}
157
158// DecodeJSON decodes resp.Body into result
159func DecodeJSON(resp *http.Response, result interface{}) (err error) {
160	defer fs.CheckClose(resp.Body, &err)
161	decoder := json.NewDecoder(resp.Body)
162	return decoder.Decode(result)
163}
164
165// DecodeXML decodes resp.Body into result
166func DecodeXML(resp *http.Response, result interface{}) (err error) {
167	defer fs.CheckClose(resp.Body, &err)
168	decoder := xml.NewDecoder(resp.Body)
169	return decoder.Decode(result)
170}
171
172// ClientWithNoRedirects makes a new http client which won't follow redirects
173func ClientWithNoRedirects(c *http.Client) *http.Client {
174	clientCopy := *c
175	clientCopy.CheckRedirect = func(req *http.Request, via []*http.Request) error {
176		return http.ErrUseLastResponse
177	}
178	return &clientCopy
179}
180
181// Call makes the call and returns the http.Response
182//
183// if err == nil then resp.Body will need to be closed unless
184// opt.NoResponse is set
185//
186// if err != nil then resp.Body will have been closed
187//
188// it will return resp if at all possible, even if err is set
189func (api *Client) Call(ctx context.Context, opts *Opts) (resp *http.Response, err error) {
190	api.mu.RLock()
191	defer api.mu.RUnlock()
192	if opts == nil {
193		return nil, errors.New("call() called with nil opts")
194	}
195	url := api.rootURL
196	if opts.RootURL != "" {
197		url = opts.RootURL
198	}
199	if url == "" {
200		return nil, errors.New("RootURL not set")
201	}
202	url += opts.Path
203	if opts.Parameters != nil && len(opts.Parameters) > 0 {
204		url += "?" + opts.Parameters.Encode()
205	}
206	body := readers.NoCloser(opts.Body)
207	// If length is set and zero then nil out the body to stop use
208	// use of chunked encoding and insert a "Content-Length: 0"
209	// header.
210	//
211	// If we don't do this we get "Content-Length" headers for all
212	// files except 0 length files.
213	if opts.ContentLength != nil && *opts.ContentLength == 0 {
214		body = nil
215	}
216	req, err := http.NewRequestWithContext(ctx, opts.Method, url, body)
217	if err != nil {
218		return
219	}
220	headers := make(map[string]string)
221	// Set default headers
222	for k, v := range api.headers {
223		headers[k] = v
224	}
225	if opts.ContentType != "" {
226		headers["Content-Type"] = opts.ContentType
227	}
228	if opts.ContentLength != nil {
229		req.ContentLength = *opts.ContentLength
230	}
231	if opts.ContentRange != "" {
232		headers["Content-Range"] = opts.ContentRange
233	}
234	if len(opts.TransferEncoding) != 0 {
235		req.TransferEncoding = opts.TransferEncoding
236	}
237	if opts.Close {
238		req.Close = true
239	}
240	// Set any extra headers
241	if opts.ExtraHeaders != nil {
242		for k, v := range opts.ExtraHeaders {
243			headers[k] = v
244		}
245	}
246	// add any options to the headers
247	fs.OpenOptionAddHeaders(opts.Options, headers)
248	// Now set the headers
249	for k, v := range headers {
250		if k != "" && v != "" {
251			if k[0] == '*' {
252				// Add non-canonical version if header starts with *
253				k = k[1:]
254				req.Header[k] = append(req.Header[k], v)
255			} else {
256				req.Header.Add(k, v)
257			}
258		}
259	}
260
261	if opts.UserName != "" || opts.Password != "" {
262		req.SetBasicAuth(opts.UserName, opts.Password)
263	}
264	var c *http.Client
265	if opts.NoRedirect {
266		c = ClientWithNoRedirects(api.c)
267	} else {
268		c = api.c
269	}
270	if api.signer != nil {
271		api.mu.RUnlock()
272		err = api.signer(req)
273		api.mu.RLock()
274		if err != nil {
275			return nil, errors.Wrap(err, "signer failed")
276		}
277	}
278	api.mu.RUnlock()
279	resp, err = c.Do(req)
280	api.mu.RLock()
281	if err != nil {
282		return nil, err
283	}
284	if !opts.IgnoreStatus {
285		if resp.StatusCode < 200 || resp.StatusCode > 299 {
286			err = api.errorHandler(resp)
287			if err.Error() == "" {
288				// replace empty errors with something
289				err = errors.Errorf("http error %d: %v", resp.StatusCode, resp.Status)
290			}
291			return resp, err
292		}
293	}
294	if opts.NoResponse {
295		return resp, resp.Body.Close()
296	}
297	return resp, nil
298}
299
300// MultipartUpload creates an io.Reader which produces an encoded a
301// multipart form upload from the params passed in and the  passed in
302//
303// in - the body of the file (may be nil)
304// params - the form parameters
305// fileName - is the name of the attached file
306// contentName - the name of the parameter for the file
307//
308// the int64 returned is the overhead in addition to the file contents, in case Content-Length is required
309//
310// NB This doesn't allow setting the content type of the attachment
311func MultipartUpload(ctx context.Context, in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) {
312	bodyReader, bodyWriter := io.Pipe()
313	writer := multipart.NewWriter(bodyWriter)
314	contentType := writer.FormDataContentType()
315
316	// Create a Multipart Writer as base for calculating the Content-Length
317	buf := &bytes.Buffer{}
318	dummyMultipartWriter := multipart.NewWriter(buf)
319	err := dummyMultipartWriter.SetBoundary(writer.Boundary())
320	if err != nil {
321		return nil, "", 0, err
322	}
323
324	for key, vals := range params {
325		for _, val := range vals {
326			err := dummyMultipartWriter.WriteField(key, val)
327			if err != nil {
328				return nil, "", 0, err
329			}
330		}
331	}
332	if in != nil {
333		_, err = dummyMultipartWriter.CreateFormFile(contentName, fileName)
334		if err != nil {
335			return nil, "", 0, err
336		}
337	}
338
339	err = dummyMultipartWriter.Close()
340	if err != nil {
341		return nil, "", 0, err
342	}
343
344	multipartLength := int64(buf.Len())
345
346	// Make sure we close the pipe writer to release the reader on context cancel
347	quit := make(chan struct{})
348	go func() {
349		select {
350		case <-quit:
351			break
352		case <-ctx.Done():
353			_ = bodyWriter.CloseWithError(ctx.Err())
354		}
355	}()
356
357	// Pump the data in the background
358	go func() {
359		defer close(quit)
360
361		var err error
362
363		for key, vals := range params {
364			for _, val := range vals {
365				err = writer.WriteField(key, val)
366				if err != nil {
367					_ = bodyWriter.CloseWithError(errors.Wrap(err, "create metadata part"))
368					return
369				}
370			}
371		}
372
373		if in != nil {
374			part, err := writer.CreateFormFile(contentName, fileName)
375			if err != nil {
376				_ = bodyWriter.CloseWithError(errors.Wrap(err, "failed to create form file"))
377				return
378			}
379
380			_, err = io.Copy(part, in)
381			if err != nil {
382				_ = bodyWriter.CloseWithError(errors.Wrap(err, "failed to copy data"))
383				return
384			}
385		}
386
387		err = writer.Close()
388		if err != nil {
389			_ = bodyWriter.CloseWithError(errors.Wrap(err, "failed to close form"))
390			return
391		}
392
393		_ = bodyWriter.Close()
394	}()
395
396	return bodyReader, contentType, multipartLength, nil
397}
398
399// CallJSON runs Call and decodes the body as a JSON object into response (if not nil)
400//
401// If request is not nil then it will be JSON encoded as the body of the request
402//
403// If response is not nil then the response will be JSON decoded into
404// it and resp.Body will be closed.
405//
406// If response is nil then the resp.Body will be closed only if
407// opts.NoResponse is set.
408//
409// If (opts.MultipartParams or opts.MultipartContentName) and
410// opts.Body are set then CallJSON will do a multipart upload with a
411// file attached.  opts.MultipartContentName is the name of the
412// parameter and opts.MultipartFileName is the name of the file.  If
413// MultpartContentName is set, and request != nil is supplied, then
414// the request will be marshalled into JSON and added to the form with
415// parameter name MultipartMetadataName.
416//
417// It will return resp if at all possible, even if err is set
418func (api *Client) CallJSON(ctx context.Context, opts *Opts, request interface{}, response interface{}) (resp *http.Response, err error) {
419	return api.callCodec(ctx, opts, request, response, json.Marshal, DecodeJSON, "application/json")
420}
421
422// CallXML runs Call and decodes the body as an XML object into response (if not nil)
423//
424// If request is not nil then it will be XML encoded as the body of the request
425//
426// If response is not nil then the response will be XML decoded into
427// it and resp.Body will be closed.
428//
429// If response is nil then the resp.Body will be closed only if
430// opts.NoResponse is set.
431//
432// See CallJSON for a description of MultipartParams and related opts
433//
434// It will return resp if at all possible, even if err is set
435func (api *Client) CallXML(ctx context.Context, opts *Opts, request interface{}, response interface{}) (resp *http.Response, err error) {
436	return api.callCodec(ctx, opts, request, response, xml.Marshal, DecodeXML, "application/xml")
437}
438
439type marshalFn func(v interface{}) ([]byte, error)
440type decodeFn func(resp *http.Response, result interface{}) (err error)
441
442func (api *Client) callCodec(ctx context.Context, opts *Opts, request interface{}, response interface{}, marshal marshalFn, decode decodeFn, contentType string) (resp *http.Response, err error) {
443	var requestBody []byte
444	// Marshal the request if given
445	if request != nil {
446		requestBody, err = marshal(request)
447		if err != nil {
448			return nil, err
449		}
450		// Set the body up as a marshalled object if no body passed in
451		if opts.Body == nil {
452			opts = opts.Copy()
453			opts.ContentType = contentType
454			opts.Body = bytes.NewBuffer(requestBody)
455		}
456	}
457	if opts.MultipartParams != nil || opts.MultipartContentName != "" {
458		params := opts.MultipartParams
459		if params == nil {
460			params = url.Values{}
461		}
462		if opts.MultipartMetadataName != "" {
463			params.Add(opts.MultipartMetadataName, string(requestBody))
464		}
465		opts = opts.Copy()
466
467		var overhead int64
468		opts.Body, opts.ContentType, overhead, err = MultipartUpload(ctx, opts.Body, params, opts.MultipartContentName, opts.MultipartFileName)
469		if err != nil {
470			return nil, err
471		}
472		if opts.ContentLength != nil {
473			*opts.ContentLength += overhead
474		}
475	}
476	resp, err = api.Call(ctx, opts)
477	if err != nil {
478		return resp, err
479	}
480	// if opts.NoResponse is set, resp.Body will have been closed by Call()
481	if response == nil || opts.NoResponse {
482		return resp, nil
483	}
484	err = decode(resp, response)
485	return resp, err
486}
487