1package jsonrpc
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"io/ioutil"
8	"net/http"
9	"net/url"
10	"sync/atomic"
11
12	"github.com/go-kit/kit/endpoint"
13	httptransport "github.com/go-kit/kit/transport/http"
14)
15
16// Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint.
17type Client struct {
18	client httptransport.HTTPClient
19
20	// JSON RPC endpoint URL
21	tgt *url.URL
22
23	// JSON RPC method name.
24	method string
25
26	enc            EncodeRequestFunc
27	dec            DecodeResponseFunc
28	before         []httptransport.RequestFunc
29	after          []httptransport.ClientResponseFunc
30	finalizer      httptransport.ClientFinalizerFunc
31	requestID      RequestIDGenerator
32	bufferedStream bool
33}
34
35type clientRequest struct {
36	JSONRPC string          `json:"jsonrpc"`
37	Method  string          `json:"method"`
38	Params  json.RawMessage `json:"params"`
39	ID      interface{}     `json:"id"`
40}
41
42// NewClient constructs a usable Client for a single remote method.
43func NewClient(
44	tgt *url.URL,
45	method string,
46	options ...ClientOption,
47) *Client {
48	c := &Client{
49		client:         http.DefaultClient,
50		method:         method,
51		tgt:            tgt,
52		enc:            DefaultRequestEncoder,
53		dec:            DefaultResponseDecoder,
54		before:         []httptransport.RequestFunc{},
55		after:          []httptransport.ClientResponseFunc{},
56		requestID:      NewAutoIncrementID(0),
57		bufferedStream: false,
58	}
59	for _, option := range options {
60		option(c)
61	}
62	return c
63}
64
65// DefaultRequestEncoder marshals the given request to JSON.
66func DefaultRequestEncoder(_ context.Context, req interface{}) (json.RawMessage, error) {
67	return json.Marshal(req)
68}
69
70// DefaultResponseDecoder unmarshals the result to interface{}, or returns an
71// error, if found.
72func DefaultResponseDecoder(_ context.Context, res Response) (interface{}, error) {
73	if res.Error != nil {
74		return nil, *res.Error
75	}
76	var result interface{}
77	err := json.Unmarshal(res.Result, &result)
78	if err != nil {
79		return nil, err
80	}
81	return result, nil
82}
83
84// ClientOption sets an optional parameter for clients.
85type ClientOption func(*Client)
86
87// SetClient sets the underlying HTTP client used for requests.
88// By default, http.DefaultClient is used.
89func SetClient(client httptransport.HTTPClient) ClientOption {
90	return func(c *Client) { c.client = client }
91}
92
93// ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP
94// request before it's invoked.
95func ClientBefore(before ...httptransport.RequestFunc) ClientOption {
96	return func(c *Client) { c.before = append(c.before, before...) }
97}
98
99// ClientAfter sets the ClientResponseFuncs applied to the server's HTTP
100// response prior to it being decoded. This is useful for obtaining anything
101// from the response and adding onto the context prior to decoding.
102func ClientAfter(after ...httptransport.ClientResponseFunc) ClientOption {
103	return func(c *Client) { c.after = append(c.after, after...) }
104}
105
106// ClientFinalizer is executed at the end of every HTTP request.
107// By default, no finalizer is registered.
108func ClientFinalizer(f httptransport.ClientFinalizerFunc) ClientOption {
109	return func(c *Client) { c.finalizer = f }
110}
111
112// ClientRequestEncoder sets the func used to encode the request params to JSON.
113// If not set, DefaultRequestEncoder is used.
114func ClientRequestEncoder(enc EncodeRequestFunc) ClientOption {
115	return func(c *Client) { c.enc = enc }
116}
117
118// ClientResponseDecoder sets the func used to decode the response params from
119// JSON. If not set, DefaultResponseDecoder is used.
120func ClientResponseDecoder(dec DecodeResponseFunc) ClientOption {
121	return func(c *Client) { c.dec = dec }
122}
123
124// RequestIDGenerator returns an ID for the request.
125type RequestIDGenerator interface {
126	Generate() interface{}
127}
128
129// ClientRequestIDGenerator is executed before each request to generate an ID
130// for the request.
131// By default, AutoIncrementRequestID is used.
132func ClientRequestIDGenerator(g RequestIDGenerator) ClientOption {
133	return func(c *Client) { c.requestID = g }
134}
135
136// BufferedStream sets whether the Response.Body is left open, allowing it
137// to be read from later. Useful for transporting a file as a buffered stream.
138func BufferedStream(buffered bool) ClientOption {
139	return func(c *Client) { c.bufferedStream = buffered }
140}
141
142// Endpoint returns a usable endpoint that invokes the remote endpoint.
143func (c Client) Endpoint() endpoint.Endpoint {
144	return func(ctx context.Context, request interface{}) (interface{}, error) {
145		ctx, cancel := context.WithCancel(ctx)
146		defer cancel()
147
148		var (
149			resp *http.Response
150			err  error
151		)
152		if c.finalizer != nil {
153			defer func() {
154				if resp != nil {
155					ctx = context.WithValue(ctx, httptransport.ContextKeyResponseHeaders, resp.Header)
156					ctx = context.WithValue(ctx, httptransport.ContextKeyResponseSize, resp.ContentLength)
157				}
158				c.finalizer(ctx, err)
159			}()
160		}
161
162		var params json.RawMessage
163		if params, err = c.enc(ctx, request); err != nil {
164			return nil, err
165		}
166		rpcReq := clientRequest{
167			JSONRPC: "",
168			Method:  c.method,
169			Params:  params,
170			ID:      c.requestID.Generate(),
171		}
172
173		req, err := http.NewRequest("POST", c.tgt.String(), nil)
174		if err != nil {
175			return nil, err
176		}
177
178		req.Header.Set("Content-Type", "application/json; charset=utf-8")
179		var b bytes.Buffer
180		req.Body = ioutil.NopCloser(&b)
181		err = json.NewEncoder(&b).Encode(rpcReq)
182		if err != nil {
183			return nil, err
184		}
185
186		for _, f := range c.before {
187			ctx = f(ctx, req)
188		}
189
190		resp, err = c.client.Do(req.WithContext(ctx))
191		if err != nil {
192			return nil, err
193		}
194
195		if !c.bufferedStream {
196			defer resp.Body.Close()
197		}
198
199		// Decode the body into an object
200		var rpcRes Response
201		err = json.NewDecoder(resp.Body).Decode(&rpcRes)
202		if err != nil {
203			return nil, err
204		}
205
206		for _, f := range c.after {
207			ctx = f(ctx, resp)
208		}
209
210		return c.dec(ctx, rpcRes)
211	}
212}
213
214// ClientFinalizerFunc can be used to perform work at the end of a client HTTP
215// request, after the response is returned. The principal
216// intended use is for error logging. Additional response parameters are
217// provided in the context under keys with the ContextKeyResponse prefix.
218// Note: err may be nil. There maybe also no additional response parameters
219// depending on when an error occurs.
220type ClientFinalizerFunc func(ctx context.Context, err error)
221
222// autoIncrementID is a RequestIDGenerator that generates
223// auto-incrementing integer IDs.
224type autoIncrementID struct {
225	v *uint64
226}
227
228// NewAutoIncrementID returns an auto-incrementing request ID generator,
229// initialised with the given value.
230func NewAutoIncrementID(init uint64) RequestIDGenerator {
231	// Offset by one so that the first generated value = init.
232	v := init - 1
233	return &autoIncrementID{v: &v}
234}
235
236// Generate satisfies RequestIDGenerator
237func (i *autoIncrementID) Generate() interface{} {
238	id := atomic.AddUint64(i.v, 1)
239	return id
240}
241