1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec.
6// https://www.jsonrpc.org/specification
7// It is intended to be compatible with other implementations at the wire level.
8package jsonrpc2
9
10import (
11	"context"
12	"encoding/json"
13	"fmt"
14	"sync"
15	"sync/atomic"
16)
17
18// Conn is a JSON RPC 2 client server connection.
19// Conn is bidirectional; it does not have a designated server or client end.
20type Conn struct {
21	seq        int64 // must only be accessed using atomic operations
22	handlers   []Handler
23	stream     Stream
24	err        error
25	pendingMu  sync.Mutex // protects the pending map
26	pending    map[ID]chan *WireResponse
27	handlingMu sync.Mutex // protects the handling map
28	handling   map[ID]*Request
29}
30
31type requestState int
32
33const (
34	requestWaiting = requestState(iota)
35	requestSerial
36	requestParallel
37	requestReplied
38	requestDone
39)
40
41// Request is sent to a server to represent a Call or Notify operaton.
42type Request struct {
43	conn        *Conn
44	cancel      context.CancelFunc
45	state       requestState
46	nextRequest chan struct{}
47
48	// The Wire values of the request.
49	WireRequest
50}
51
52// NewErrorf builds a Error struct for the supplied message and code.
53// If args is not empty, message and args will be passed to Sprintf.
54func NewErrorf(code int64, format string, args ...interface{}) *Error {
55	return &Error{
56		Code:    code,
57		Message: fmt.Sprintf(format, args...),
58	}
59}
60
61// NewConn creates a new connection object around the supplied stream.
62// You must call Run for the connection to be active.
63func NewConn(s Stream) *Conn {
64	conn := &Conn{
65		handlers: []Handler{defaultHandler{}},
66		stream:   s,
67		pending:  make(map[ID]chan *WireResponse),
68		handling: make(map[ID]*Request),
69	}
70	return conn
71}
72
73// AddHandler adds a new handler to the set the connection will invoke.
74// Handlers are invoked in the reverse order of how they were added, this
75// allows the most recent addition to be the first one to attempt to handle a
76// message.
77func (c *Conn) AddHandler(handler Handler) {
78	// prepend the new handlers so we use them first
79	c.handlers = append([]Handler{handler}, c.handlers...)
80}
81
82// Cancel cancels a pending Call on the server side.
83// The call is identified by its id.
84// JSON RPC 2 does not specify a cancel message, so cancellation support is not
85// directly wired in. This method allows a higher level protocol to choose how
86// to propagate the cancel.
87func (c *Conn) Cancel(id ID) {
88	c.handlingMu.Lock()
89	handling, found := c.handling[id]
90	c.handlingMu.Unlock()
91	if found {
92		handling.cancel()
93	}
94}
95
96// Notify is called to send a notification request over the connection.
97// It will return as soon as the notification has been sent, as no response is
98// possible.
99func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) {
100	jsonParams, err := marshalToRaw(params)
101	if err != nil {
102		return fmt.Errorf("marshalling notify parameters: %v", err)
103	}
104	request := &WireRequest{
105		Method: method,
106		Params: jsonParams,
107	}
108	data, err := json.Marshal(request)
109	if err != nil {
110		return fmt.Errorf("marshalling notify request: %v", err)
111	}
112	for _, h := range c.handlers {
113		ctx = h.Request(ctx, c, Send, request)
114	}
115	defer func() {
116		for _, h := range c.handlers {
117			h.Done(ctx, err)
118		}
119	}()
120	n, err := c.stream.Write(ctx, data)
121	for _, h := range c.handlers {
122		ctx = h.Wrote(ctx, n)
123	}
124	return err
125}
126
127// Call sends a request over the connection and then waits for a response.
128// If the response is not an error, it will be decoded into result.
129// result must be of a type you an pass to json.Unmarshal.
130func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) {
131	// generate a new request identifier
132	id := ID{Number: atomic.AddInt64(&c.seq, 1)}
133	jsonParams, err := marshalToRaw(params)
134	if err != nil {
135		return fmt.Errorf("marshalling call parameters: %v", err)
136	}
137	request := &WireRequest{
138		ID:     &id,
139		Method: method,
140		Params: jsonParams,
141	}
142	// marshal the request now it is complete
143	data, err := json.Marshal(request)
144	if err != nil {
145		return fmt.Errorf("marshalling call request: %v", err)
146	}
147	for _, h := range c.handlers {
148		ctx = h.Request(ctx, c, Send, request)
149	}
150	// we have to add ourselves to the pending map before we send, otherwise we
151	// are racing the response
152	rchan := make(chan *WireResponse)
153	c.pendingMu.Lock()
154	c.pending[id] = rchan
155	c.pendingMu.Unlock()
156	defer func() {
157		// clean up the pending response handler on the way out
158		c.pendingMu.Lock()
159		delete(c.pending, id)
160		c.pendingMu.Unlock()
161		for _, h := range c.handlers {
162			h.Done(ctx, err)
163		}
164	}()
165	// now we are ready to send
166	n, err := c.stream.Write(ctx, data)
167	for _, h := range c.handlers {
168		ctx = h.Wrote(ctx, n)
169	}
170	if err != nil {
171		// sending failed, we will never get a response, so don't leave it pending
172		return err
173	}
174	// now wait for the response
175	select {
176	case response := <-rchan:
177		for _, h := range c.handlers {
178			ctx = h.Response(ctx, c, Receive, response)
179		}
180		// is it an error response?
181		if response.Error != nil {
182			return response.Error
183		}
184		if result == nil || response.Result == nil {
185			return nil
186		}
187		if err := json.Unmarshal(*response.Result, result); err != nil {
188			return fmt.Errorf("unmarshalling result: %v", err)
189		}
190		return nil
191	case <-ctx.Done():
192		// allow the handler to propagate the cancel
193		cancelled := false
194		for _, h := range c.handlers {
195			if h.Cancel(ctx, c, id, cancelled) {
196				cancelled = true
197			}
198		}
199		return ctx.Err()
200	}
201}
202
203// Conn returns the connection that created this request.
204func (r *Request) Conn() *Conn { return r.conn }
205
206// IsNotify returns true if this request is a notification.
207func (r *Request) IsNotify() bool {
208	return r.ID == nil
209}
210
211// Parallel indicates that the system is now allowed to process other requests
212// in parallel with this one.
213// It is safe to call any number of times, but must only be called from the
214// request handling go routine.
215// It is implied by both reply and by the handler returning.
216func (r *Request) Parallel() {
217	if r.state >= requestParallel {
218		return
219	}
220	r.state = requestParallel
221	close(r.nextRequest)
222}
223
224// Reply sends a reply to the given request.
225// It is an error to call this if request was not a call.
226// You must call this exactly once for any given request.
227// It should only be called from the handler go routine.
228// If err is set then result will be ignored.
229// If the request has not yet dropped into parallel mode
230// it will be before this function returns.
231func (r *Request) Reply(ctx context.Context, result interface{}, err error) error {
232	if r.state >= requestReplied {
233		return fmt.Errorf("reply invoked more than once")
234	}
235	if r.IsNotify() {
236		return fmt.Errorf("reply not invoked with a valid call")
237	}
238	// reply ends the handling phase of a call, so if we are not yet
239	// parallel we should be now. The go routine is allowed to continue
240	// to do work after replying, which is why it is important to unlock
241	// the rpc system at this point.
242	r.Parallel()
243	r.state = requestReplied
244
245	var raw *json.RawMessage
246	if err == nil {
247		raw, err = marshalToRaw(result)
248	}
249	response := &WireResponse{
250		Result: raw,
251		ID:     r.ID,
252	}
253	if err != nil {
254		if callErr, ok := err.(*Error); ok {
255			response.Error = callErr
256		} else {
257			response.Error = NewErrorf(0, "%s", err)
258		}
259	}
260	data, err := json.Marshal(response)
261	if err != nil {
262		return err
263	}
264	for _, h := range r.conn.handlers {
265		ctx = h.Response(ctx, r.conn, Send, response)
266	}
267	n, err := r.conn.stream.Write(ctx, data)
268	for _, h := range r.conn.handlers {
269		ctx = h.Wrote(ctx, n)
270	}
271
272	if err != nil {
273		// TODO(iancottrell): if a stream write fails, we really need to shut down
274		// the whole stream
275		return err
276	}
277	return nil
278}
279
280func (c *Conn) setHandling(r *Request, active bool) {
281	if r.ID == nil {
282		return
283	}
284	r.conn.handlingMu.Lock()
285	defer r.conn.handlingMu.Unlock()
286	if active {
287		r.conn.handling[*r.ID] = r
288	} else {
289		delete(r.conn.handling, *r.ID)
290	}
291}
292
293// combined has all the fields of both Request and Response.
294// We can decode this and then work out which it is.
295type combined struct {
296	VersionTag VersionTag       `json:"jsonrpc"`
297	ID         *ID              `json:"id,omitempty"`
298	Method     string           `json:"method"`
299	Params     *json.RawMessage `json:"params,omitempty"`
300	Result     *json.RawMessage `json:"result,omitempty"`
301	Error      *Error           `json:"error,omitempty"`
302}
303
304// Run blocks until the connection is terminated, and returns any error that
305// caused the termination.
306// It must be called exactly once for each Conn.
307// It returns only when the reader is closed or there is an error in the stream.
308func (c *Conn) Run(runCtx context.Context) error {
309	// we need to make the next request "lock" in an unlocked state to allow
310	// the first incoming request to proceed. All later requests are unlocked
311	// by the preceding request going to parallel mode.
312	nextRequest := make(chan struct{})
313	close(nextRequest)
314	for {
315		// get the data for a message
316		data, n, err := c.stream.Read(runCtx)
317		if err != nil {
318			// the stream failed, we cannot continue
319			return err
320		}
321		// read a combined message
322		msg := &combined{}
323		if err := json.Unmarshal(data, msg); err != nil {
324			// a badly formed message arrived, log it and continue
325			// we trust the stream to have isolated the error to just this message
326			for _, h := range c.handlers {
327				h.Error(runCtx, fmt.Errorf("unmarshal failed: %v", err))
328			}
329			continue
330		}
331		// work out which kind of message we have
332		switch {
333		case msg.Method != "":
334			// if method is set it must be a request
335			reqCtx, cancelReq := context.WithCancel(runCtx)
336			thisRequest := nextRequest
337			nextRequest = make(chan struct{})
338			req := &Request{
339				conn:        c,
340				cancel:      cancelReq,
341				nextRequest: nextRequest,
342				WireRequest: WireRequest{
343					VersionTag: msg.VersionTag,
344					Method:     msg.Method,
345					Params:     msg.Params,
346					ID:         msg.ID,
347				},
348			}
349			for _, h := range c.handlers {
350				reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest)
351				reqCtx = h.Read(reqCtx, n)
352			}
353			c.setHandling(req, true)
354			go func() {
355				<-thisRequest
356				req.state = requestSerial
357				defer func() {
358					c.setHandling(req, false)
359					if !req.IsNotify() && req.state < requestReplied {
360						req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
361					}
362					req.Parallel()
363					for _, h := range c.handlers {
364						h.Done(reqCtx, err)
365					}
366					cancelReq()
367				}()
368				delivered := false
369				for _, h := range c.handlers {
370					if h.Deliver(reqCtx, req, delivered) {
371						delivered = true
372					}
373				}
374			}()
375		case msg.ID != nil:
376			// we have a response, get the pending entry from the map
377			c.pendingMu.Lock()
378			rchan := c.pending[*msg.ID]
379			if rchan != nil {
380				delete(c.pending, *msg.ID)
381			}
382			c.pendingMu.Unlock()
383			// and send the reply to the channel
384			response := &WireResponse{
385				Result: msg.Result,
386				Error:  msg.Error,
387				ID:     msg.ID,
388			}
389			rchan <- response
390			close(rchan)
391		default:
392			for _, h := range c.handlers {
393				h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring"))
394			}
395		}
396	}
397}
398
399func marshalToRaw(obj interface{}) (*json.RawMessage, error) {
400	data, err := json.Marshal(obj)
401	if err != nil {
402		return nil, err
403	}
404	raw := json.RawMessage(data)
405	return &raw, nil
406}
407