1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec.
6// https://www.jsonrpc.org/specification
7// It is intended to be compatible with other implementations at the wire level.
8package jsonrpc2
9
10import (
11	"context"
12	"encoding/json"
13	"fmt"
14	"sync"
15	"sync/atomic"
16)
17
18// Conn is a JSON RPC 2 client server connection.
19// Conn is bidirectional; it does not have a designated server or client end.
20type Conn struct {
21	seq        int64 // must only be accessed using atomic operations
22	handlers   []Handler
23	stream     Stream
24	err        error
25	pendingMu  sync.Mutex // protects the pending map
26	pending    map[ID]chan *WireResponse
27	handlingMu sync.Mutex // protects the handling map
28	handling   map[ID]*Request
29}
30
31type requestState int
32
33const (
34	requestWaiting = requestState(iota)
35	requestSerial
36	requestParallel
37	requestReplied
38	requestDone
39)
40
41// Request is sent to a server to represent a Call or Notify operaton.
42type Request struct {
43	conn        *Conn
44	cancel      context.CancelFunc
45	state       requestState
46	nextRequest chan struct{}
47
48	// The Wire values of the request.
49	WireRequest
50}
51
52// NewErrorf builds a Error struct for the supplied message and code.
53// If args is not empty, message and args will be passed to Sprintf.
54func NewErrorf(code int64, format string, args ...interface{}) *Error {
55	return &Error{
56		Code:    code,
57		Message: fmt.Sprintf(format, args...),
58	}
59}
60
61// NewConn creates a new connection object around the supplied stream.
62// You must call Run for the connection to be active.
63func NewConn(s Stream) *Conn {
64	conn := &Conn{
65		handlers: []Handler{defaultHandler{}},
66		stream:   s,
67		pending:  make(map[ID]chan *WireResponse),
68		handling: make(map[ID]*Request),
69	}
70	return conn
71}
72
73// AddHandler adds a new handler to the set the connection will invoke.
74// Handlers are invoked in the reverse order of how they were added, this
75// allows the most recent addition to be the first one to attempt to handle a
76// message.
77func (c *Conn) AddHandler(handler Handler) {
78	// prepend the new handlers so we use them first
79	c.handlers = append([]Handler{handler}, c.handlers...)
80}
81
82// Cancel cancels a pending Call on the server side.
83// The call is identified by its id.
84// JSON RPC 2 does not specify a cancel message, so cancellation support is not
85// directly wired in. This method allows a higher level protocol to choose how
86// to propagate the cancel.
87func (c *Conn) Cancel(id ID) {
88	c.handlingMu.Lock()
89	handling, found := c.handling[id]
90	c.handlingMu.Unlock()
91	if found {
92		handling.cancel()
93	}
94}
95
96// Notify is called to send a notification request over the connection.
97// It will return as soon as the notification has been sent, as no response is
98// possible.
99func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) {
100	jsonParams, err := marshalToRaw(params)
101	if err != nil {
102		return fmt.Errorf("marshalling notify parameters: %v", err)
103	}
104	request := &WireRequest{
105		Method: method,
106		Params: jsonParams,
107	}
108	data, err := json.Marshal(request)
109	if err != nil {
110		return fmt.Errorf("marshalling notify request: %v", err)
111	}
112	for _, h := range c.handlers {
113		ctx = h.Request(ctx, c, Send, request)
114	}
115	defer func() {
116		for _, h := range c.handlers {
117			h.Done(ctx, err)
118		}
119	}()
120	n, err := c.stream.Write(ctx, data)
121	for _, h := range c.handlers {
122		ctx = h.Wrote(ctx, n)
123	}
124	return err
125}
126
127// Call sends a request over the connection and then waits for a response.
128// If the response is not an error, it will be decoded into result.
129// result must be of a type you an pass to json.Unmarshal.
130func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) {
131	// generate a new request identifier
132	id := ID{Number: atomic.AddInt64(&c.seq, 1)}
133	jsonParams, err := marshalToRaw(params)
134	if err != nil {
135		return fmt.Errorf("marshalling call parameters: %v", err)
136	}
137	request := &WireRequest{
138		ID:     &id,
139		Method: method,
140		Params: jsonParams,
141	}
142	// marshal the request now it is complete
143	data, err := json.Marshal(request)
144	if err != nil {
145		return fmt.Errorf("marshalling call request: %v", err)
146	}
147	for _, h := range c.handlers {
148		ctx = h.Request(ctx, c, Send, request)
149	}
150	// We have to add ourselves to the pending map before we send, otherwise we
151	// are racing the response. Also add a buffer to rchan, so that if we get a
152	// wire response between the time this call is cancelled and id is deleted
153	// from c.pending, the send to rchan will not block.
154	rchan := make(chan *WireResponse, 1)
155	c.pendingMu.Lock()
156	c.pending[id] = rchan
157	c.pendingMu.Unlock()
158	defer func() {
159		c.pendingMu.Lock()
160		delete(c.pending, id)
161		c.pendingMu.Unlock()
162		for _, h := range c.handlers {
163			h.Done(ctx, err)
164		}
165	}()
166	// now we are ready to send
167	n, err := c.stream.Write(ctx, data)
168	for _, h := range c.handlers {
169		ctx = h.Wrote(ctx, n)
170	}
171	if err != nil {
172		// sending failed, we will never get a response, so don't leave it pending
173		return err
174	}
175	// now wait for the response
176	select {
177	case response := <-rchan:
178		for _, h := range c.handlers {
179			ctx = h.Response(ctx, c, Receive, response)
180		}
181		// is it an error response?
182		if response.Error != nil {
183			return response.Error
184		}
185		if result == nil || response.Result == nil {
186			return nil
187		}
188		if err := json.Unmarshal(*response.Result, result); err != nil {
189			return fmt.Errorf("unmarshalling result: %v", err)
190		}
191		return nil
192	case <-ctx.Done():
193		// Allow the handler to propagate the cancel.
194		cancelled := false
195		for _, h := range c.handlers {
196			if h.Cancel(ctx, c, id, cancelled) {
197				cancelled = true
198			}
199		}
200		return ctx.Err()
201	}
202}
203
204// Conn returns the connection that created this request.
205func (r *Request) Conn() *Conn { return r.conn }
206
207// IsNotify returns true if this request is a notification.
208func (r *Request) IsNotify() bool {
209	return r.ID == nil
210}
211
212// Parallel indicates that the system is now allowed to process other requests
213// in parallel with this one.
214// It is safe to call any number of times, but must only be called from the
215// request handling go routine.
216// It is implied by both reply and by the handler returning.
217func (r *Request) Parallel() {
218	if r.state >= requestParallel {
219		return
220	}
221	r.state = requestParallel
222	close(r.nextRequest)
223}
224
225// Reply sends a reply to the given request.
226// It is an error to call this if request was not a call.
227// You must call this exactly once for any given request.
228// It should only be called from the handler go routine.
229// If err is set then result will be ignored.
230// If the request has not yet dropped into parallel mode
231// it will be before this function returns.
232func (r *Request) Reply(ctx context.Context, result interface{}, err error) error {
233	if r.state >= requestReplied {
234		return fmt.Errorf("reply invoked more than once")
235	}
236	if r.IsNotify() {
237		return fmt.Errorf("reply not invoked with a valid call: %v, %s", r.Method, r.Params)
238	}
239	// reply ends the handling phase of a call, so if we are not yet
240	// parallel we should be now. The go routine is allowed to continue
241	// to do work after replying, which is why it is important to unlock
242	// the rpc system at this point.
243	r.Parallel()
244	r.state = requestReplied
245
246	var raw *json.RawMessage
247	if err == nil {
248		raw, err = marshalToRaw(result)
249	}
250	response := &WireResponse{
251		Result: raw,
252		ID:     r.ID,
253	}
254	if err != nil {
255		if callErr, ok := err.(*Error); ok {
256			response.Error = callErr
257		} else {
258			response.Error = NewErrorf(0, "%s", err)
259		}
260	}
261	data, err := json.Marshal(response)
262	if err != nil {
263		return err
264	}
265	for _, h := range r.conn.handlers {
266		ctx = h.Response(ctx, r.conn, Send, response)
267	}
268	n, err := r.conn.stream.Write(ctx, data)
269	for _, h := range r.conn.handlers {
270		ctx = h.Wrote(ctx, n)
271	}
272
273	if err != nil {
274		// TODO(iancottrell): if a stream write fails, we really need to shut down
275		// the whole stream
276		return err
277	}
278	return nil
279}
280
281func (c *Conn) setHandling(r *Request, active bool) {
282	if r.ID == nil {
283		return
284	}
285	r.conn.handlingMu.Lock()
286	defer r.conn.handlingMu.Unlock()
287	if active {
288		r.conn.handling[*r.ID] = r
289	} else {
290		delete(r.conn.handling, *r.ID)
291	}
292}
293
294// combined has all the fields of both Request and Response.
295// We can decode this and then work out which it is.
296type combined struct {
297	VersionTag VersionTag       `json:"jsonrpc"`
298	ID         *ID              `json:"id,omitempty"`
299	Method     string           `json:"method"`
300	Params     *json.RawMessage `json:"params,omitempty"`
301	Result     *json.RawMessage `json:"result,omitempty"`
302	Error      *Error           `json:"error,omitempty"`
303}
304
305// Run blocks until the connection is terminated, and returns any error that
306// caused the termination.
307// It must be called exactly once for each Conn.
308// It returns only when the reader is closed or there is an error in the stream.
309func (c *Conn) Run(runCtx context.Context) error {
310	// we need to make the next request "lock" in an unlocked state to allow
311	// the first incoming request to proceed. All later requests are unlocked
312	// by the preceding request going to parallel mode.
313	nextRequest := make(chan struct{})
314	close(nextRequest)
315	for {
316		// get the data for a message
317		data, n, err := c.stream.Read(runCtx)
318		if err != nil {
319			// The stream failed, we cannot continue. If the client disconnected
320			// normally, we should get ErrDisconnected here.
321			return err
322		}
323		// read a combined message
324		msg := &combined{}
325		if err := json.Unmarshal(data, msg); err != nil {
326			// a badly formed message arrived, log it and continue
327			// we trust the stream to have isolated the error to just this message
328			for _, h := range c.handlers {
329				h.Error(runCtx, fmt.Errorf("unmarshal failed: %v", err))
330			}
331			continue
332		}
333		// Work out whether this is a request or response.
334		switch {
335		case msg.Method != "":
336			// If method is set it must be a request.
337			reqCtx, cancelReq := context.WithCancel(runCtx)
338			thisRequest := nextRequest
339			nextRequest = make(chan struct{})
340			req := &Request{
341				conn:        c,
342				cancel:      cancelReq,
343				nextRequest: nextRequest,
344				WireRequest: WireRequest{
345					VersionTag: msg.VersionTag,
346					Method:     msg.Method,
347					Params:     msg.Params,
348					ID:         msg.ID,
349				},
350			}
351			for _, h := range c.handlers {
352				reqCtx = h.Request(reqCtx, c, Receive, &req.WireRequest)
353				reqCtx = h.Read(reqCtx, n)
354			}
355			c.setHandling(req, true)
356			go func() {
357				<-thisRequest
358				req.state = requestSerial
359				defer func() {
360					c.setHandling(req, false)
361					if !req.IsNotify() && req.state < requestReplied {
362						req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
363					}
364					req.Parallel()
365					for _, h := range c.handlers {
366						h.Done(reqCtx, err)
367					}
368					cancelReq()
369				}()
370				delivered := false
371				for _, h := range c.handlers {
372					if h.Deliver(reqCtx, req, delivered) {
373						delivered = true
374					}
375				}
376			}()
377		case msg.ID != nil:
378			// If method is not set, this should be a response, in which case we must
379			// have an id to send the response back to the caller.
380			c.pendingMu.Lock()
381			rchan, ok := c.pending[*msg.ID]
382			c.pendingMu.Unlock()
383			if ok {
384				response := &WireResponse{
385					Result: msg.Result,
386					Error:  msg.Error,
387					ID:     msg.ID,
388				}
389				rchan <- response
390			}
391		default:
392			for _, h := range c.handlers {
393				h.Error(runCtx, fmt.Errorf("message not a call, notify or response, ignoring"))
394			}
395		}
396	}
397}
398
399func marshalToRaw(obj interface{}) (*json.RawMessage, error) {
400	data, err := json.Marshal(obj)
401	if err != nil {
402		return nil, err
403	}
404	raw := json.RawMessage(data)
405	return &raw, nil
406}
407