1package http
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"encoding/xml"
8	"io"
9	"io/ioutil"
10	"net/http"
11	"net/url"
12
13	"github.com/go-kit/kit/endpoint"
14)
15
16// HTTPClient is an interface that models *http.Client.
17type HTTPClient interface {
18	Do(req *http.Request) (*http.Response, error)
19}
20
21// Client wraps a URL and provides a method that implements endpoint.Endpoint.
22type Client struct {
23	client         HTTPClient
24	method         string
25	tgt            *url.URL
26	enc            EncodeRequestFunc
27	dec            DecodeResponseFunc
28	before         []RequestFunc
29	after          []ClientResponseFunc
30	finalizer      []ClientFinalizerFunc
31	bufferedStream bool
32}
33
34// NewClient constructs a usable Client for a single remote method.
35func NewClient(
36	method string,
37	tgt *url.URL,
38	enc EncodeRequestFunc,
39	dec DecodeResponseFunc,
40	options ...ClientOption,
41) *Client {
42	c := &Client{
43		client:         http.DefaultClient,
44		method:         method,
45		tgt:            tgt,
46		enc:            enc,
47		dec:            dec,
48		before:         []RequestFunc{},
49		after:          []ClientResponseFunc{},
50		bufferedStream: false,
51	}
52	for _, option := range options {
53		option(c)
54	}
55	return c
56}
57
58// ClientOption sets an optional parameter for clients.
59type ClientOption func(*Client)
60
61// SetClient sets the underlying HTTP client used for requests.
62// By default, http.DefaultClient is used.
63func SetClient(client HTTPClient) ClientOption {
64	return func(c *Client) { c.client = client }
65}
66
67// ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP
68// request before it's invoked.
69func ClientBefore(before ...RequestFunc) ClientOption {
70	return func(c *Client) { c.before = append(c.before, before...) }
71}
72
73// ClientAfter sets the ClientResponseFuncs applied to the incoming HTTP
74// request prior to it being decoded. This is useful for obtaining anything off
75// of the response and adding onto the context prior to decoding.
76func ClientAfter(after ...ClientResponseFunc) ClientOption {
77	return func(c *Client) { c.after = append(c.after, after...) }
78}
79
80// ClientFinalizer is executed at the end of every HTTP request.
81// By default, no finalizer is registered.
82func ClientFinalizer(f ...ClientFinalizerFunc) ClientOption {
83	return func(s *Client) { s.finalizer = append(s.finalizer, f...) }
84}
85
86// BufferedStream sets whether the Response.Body is left open, allowing it
87// to be read from later. Useful for transporting a file as a buffered stream.
88// That body has to be Closed to propery end the request.
89func BufferedStream(buffered bool) ClientOption {
90	return func(c *Client) { c.bufferedStream = buffered }
91}
92
93// Endpoint returns a usable endpoint that invokes the remote endpoint.
94func (c Client) Endpoint() endpoint.Endpoint {
95	return func(ctx context.Context, request interface{}) (interface{}, error) {
96		ctx, cancel := context.WithCancel(ctx)
97
98		var (
99			resp *http.Response
100			err  error
101		)
102		if c.finalizer != nil {
103			defer func() {
104				if resp != nil {
105					ctx = context.WithValue(ctx, ContextKeyResponseHeaders, resp.Header)
106					ctx = context.WithValue(ctx, ContextKeyResponseSize, resp.ContentLength)
107				}
108				for _, f := range c.finalizer {
109					f(ctx, err)
110				}
111			}()
112		}
113
114		req, err := http.NewRequest(c.method, c.tgt.String(), nil)
115		if err != nil {
116			cancel()
117			return nil, err
118		}
119
120		if err = c.enc(ctx, req, request); err != nil {
121			cancel()
122			return nil, err
123		}
124
125		for _, f := range c.before {
126			ctx = f(ctx, req)
127		}
128
129		resp, err = c.client.Do(req.WithContext(ctx))
130
131		if err != nil {
132			cancel()
133			return nil, err
134		}
135
136		// If we expect a buffered stream, we don't cancel the context when the endpoint returns.
137		// Instead, we should call the cancel func when closing the response body.
138		if c.bufferedStream {
139			resp.Body = bodyWithCancel{ReadCloser: resp.Body, cancel: cancel}
140		} else {
141			defer resp.Body.Close()
142			defer cancel()
143		}
144
145		for _, f := range c.after {
146			ctx = f(ctx, resp)
147		}
148
149		response, err := c.dec(ctx, resp)
150		if err != nil {
151			return nil, err
152		}
153
154		return response, nil
155	}
156}
157
158// bodyWithCancel is a wrapper for an io.ReadCloser with also a
159// cancel function which is called when the Close is used
160type bodyWithCancel struct {
161	io.ReadCloser
162
163	cancel context.CancelFunc
164}
165
166func (bwc bodyWithCancel) Close() error {
167	bwc.ReadCloser.Close()
168	bwc.cancel()
169	return nil
170}
171
172// ClientFinalizerFunc can be used to perform work at the end of a client HTTP
173// request, after the response is returned. The principal
174// intended use is for error logging. Additional response parameters are
175// provided in the context under keys with the ContextKeyResponse prefix.
176// Note: err may be nil. There maybe also no additional response parameters
177// depending on when an error occurs.
178type ClientFinalizerFunc func(ctx context.Context, err error)
179
180// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a
181// JSON object to the Request body. Many JSON-over-HTTP services can use it as
182// a sensible default. If the request implements Headerer, the provided headers
183// will be applied to the request.
184func EncodeJSONRequest(c context.Context, r *http.Request, request interface{}) error {
185	r.Header.Set("Content-Type", "application/json; charset=utf-8")
186	if headerer, ok := request.(Headerer); ok {
187		for k := range headerer.Headers() {
188			r.Header.Set(k, headerer.Headers().Get(k))
189		}
190	}
191	var b bytes.Buffer
192	r.Body = ioutil.NopCloser(&b)
193	return json.NewEncoder(&b).Encode(request)
194}
195
196// EncodeXMLRequest is an EncodeRequestFunc that serializes the request as a
197// XML object to the Request body. If the request implements Headerer,
198// the provided headers will be applied to the request.
199func EncodeXMLRequest(c context.Context, r *http.Request, request interface{}) error {
200	r.Header.Set("Content-Type", "text/xml; charset=utf-8")
201	if headerer, ok := request.(Headerer); ok {
202		for k := range headerer.Headers() {
203			r.Header.Set(k, headerer.Headers().Get(k))
204		}
205	}
206	var b bytes.Buffer
207	r.Body = ioutil.NopCloser(&b)
208	return xml.NewEncoder(&b).Encode(request)
209}
210