1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec.
6// https://www.jsonrpc.org/specification
7// It is intended to be compatible with other implementations at the wire level.
8package jsonrpc2
9
10import (
11	"context"
12	"encoding/json"
13	"fmt"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"golang.org/x/tools/internal/lsp/telemetry"
19	"golang.org/x/tools/internal/lsp/telemetry/tag"
20	"golang.org/x/tools/internal/lsp/telemetry/trace"
21)
22
23// Conn is a JSON RPC 2 client server connection.
24// Conn is bidirectional; it does not have a designated server or client end.
25type Conn struct {
26	seq                int64 // must only be accessed using atomic operations
27	Handler            Handler
28	Canceler           Canceler
29	Logger             Logger
30	Capacity           int
31	RejectIfOverloaded bool
32	stream             Stream
33	err                error
34	pendingMu          sync.Mutex // protects the pending map
35	pending            map[ID]chan *wireResponse
36	handlingMu         sync.Mutex // protects the handling map
37	handling           map[ID]*Request
38}
39
40type requestState int
41
42const (
43	requestWaiting = requestState(iota)
44	requestSerial
45	requestParallel
46	requestReplied
47	requestDone
48)
49
50// Request is sent to a server to represent a Call or Notify operaton.
51type Request struct {
52	conn        *Conn
53	cancel      context.CancelFunc
54	start       time.Time
55	state       requestState
56	nextRequest chan struct{}
57
58	// Method is a string containing the method name to invoke.
59	Method string
60	// Params is either a struct or an array with the parameters of the method.
61	Params *json.RawMessage
62	// The id of this request, used to tie the response back to the request.
63	// Will be either a string or a number. If not set, the request is a notify,
64	// and no response is possible.
65	ID *ID
66}
67
68// Handler is an option you can pass to NewConn to handle incoming requests.
69// If the request returns false from IsNotify then the Handler must eventually
70// call Reply on the Conn with the supplied request.
71// Handlers are called synchronously, they should pass the work off to a go
72// routine if they are going to take a long time.
73type Handler func(context.Context, *Request)
74
75// Canceler is an option you can pass to NewConn which is invoked for
76// cancelled outgoing requests.
77// It is okay to use the connection to send notifications, but the context will
78// be in the cancelled state, so you must do it with the background context
79// instead.
80type Canceler func(context.Context, *Conn, ID)
81
82type rpcStats struct {
83	server bool
84	method string
85	span   trace.Span
86	start  time.Time
87}
88
89func start(ctx context.Context, server bool, method string, id *ID) (context.Context, *rpcStats) {
90	if method == "" {
91		panic("no method in rpc stats")
92	}
93	s := &rpcStats{
94		server: server,
95		method: method,
96		start:  time.Now(),
97	}
98	mode := telemetry.Outbound
99	if server {
100		mode = telemetry.Inbound
101	}
102	ctx, s.span = trace.StartSpan(ctx, method,
103		tag.Tag{Key: telemetry.Method, Value: method},
104		tag.Tag{Key: telemetry.RPCDirection, Value: mode},
105		tag.Tag{Key: telemetry.RPCID, Value: id},
106	)
107	telemetry.Started.Record(ctx, 1)
108	return ctx, s
109}
110
111func (s *rpcStats) end(ctx context.Context, err *error) {
112	if err != nil && *err != nil {
113		ctx = telemetry.StatusCode.With(ctx, "ERROR")
114	} else {
115		ctx = telemetry.StatusCode.With(ctx, "OK")
116	}
117	elapsedTime := time.Since(s.start)
118	latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
119	telemetry.Latency.Record(ctx, latencyMillis)
120	s.span.End()
121}
122
123// NewErrorf builds a Error struct for the suppied message and code.
124// If args is not empty, message and args will be passed to Sprintf.
125func NewErrorf(code int64, format string, args ...interface{}) *Error {
126	return &Error{
127		Code:    code,
128		Message: fmt.Sprintf(format, args...),
129	}
130}
131
132// NewConn creates a new connection object around the supplied stream.
133// You must call Run for the connection to be active.
134func NewConn(s Stream) *Conn {
135	conn := &Conn{
136		stream:   s,
137		pending:  make(map[ID]chan *wireResponse),
138		handling: make(map[ID]*Request),
139	}
140	// the default handler reports a method error
141	conn.Handler = func(ctx context.Context, r *Request) {
142		if !r.IsNotify() {
143			r.Reply(ctx, nil, NewErrorf(CodeMethodNotFound, "method %q not found", r.Method))
144		}
145	}
146	// the default canceler does nothing
147	conn.Canceler = func(context.Context, *Conn, ID) {}
148	// the default logger does nothing
149	conn.Logger = func(Direction, *ID, time.Duration, string, *json.RawMessage, *Error) {}
150	return conn
151}
152
153// Cancel cancels a pending Call on the server side.
154// The call is identified by its id.
155// JSON RPC 2 does not specify a cancel message, so cancellation support is not
156// directly wired in. This method allows a higher level protocol to choose how
157// to propagate the cancel.
158func (c *Conn) Cancel(id ID) {
159	c.handlingMu.Lock()
160	handling, found := c.handling[id]
161	c.handlingMu.Unlock()
162	if found {
163		handling.cancel()
164	}
165}
166
167// Notify is called to send a notification request over the connection.
168// It will return as soon as the notification has been sent, as no response is
169// possible.
170func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) {
171	ctx, rpcStats := start(ctx, false, method, nil)
172	defer rpcStats.end(ctx, &err)
173
174	jsonParams, err := marshalToRaw(params)
175	if err != nil {
176		return fmt.Errorf("marshalling notify parameters: %v", err)
177	}
178	request := &wireRequest{
179		Method: method,
180		Params: jsonParams,
181	}
182	data, err := json.Marshal(request)
183	if err != nil {
184		return fmt.Errorf("marshalling notify request: %v", err)
185	}
186	c.Logger(Send, nil, -1, request.Method, request.Params, nil)
187	n, err := c.stream.Write(ctx, data)
188	telemetry.SentBytes.Record(ctx, n)
189	return err
190}
191
192// Call sends a request over the connection and then waits for a response.
193// If the response is not an error, it will be decoded into result.
194// result must be of a type you an pass to json.Unmarshal.
195func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) {
196	// generate a new request identifier
197	id := ID{Number: atomic.AddInt64(&c.seq, 1)}
198	ctx, rpcStats := start(ctx, false, method, &id)
199	defer rpcStats.end(ctx, &err)
200	jsonParams, err := marshalToRaw(params)
201	if err != nil {
202		return fmt.Errorf("marshalling call parameters: %v", err)
203	}
204	request := &wireRequest{
205		ID:     &id,
206		Method: method,
207		Params: jsonParams,
208	}
209	// marshal the request now it is complete
210	data, err := json.Marshal(request)
211	if err != nil {
212		return fmt.Errorf("marshalling call request: %v", err)
213	}
214	// we have to add ourselves to the pending map before we send, otherwise we
215	// are racing the response
216	rchan := make(chan *wireResponse)
217	c.pendingMu.Lock()
218	c.pending[id] = rchan
219	c.pendingMu.Unlock()
220	defer func() {
221		// clean up the pending response handler on the way out
222		c.pendingMu.Lock()
223		delete(c.pending, id)
224		c.pendingMu.Unlock()
225	}()
226	// now we are ready to send
227	before := time.Now()
228	c.Logger(Send, request.ID, -1, request.Method, request.Params, nil)
229	n, err := c.stream.Write(ctx, data)
230	telemetry.SentBytes.Record(ctx, n)
231	if err != nil {
232		// sending failed, we will never get a response, so don't leave it pending
233		return err
234	}
235	// now wait for the response
236	select {
237	case response := <-rchan:
238		elapsed := time.Since(before)
239		c.Logger(Receive, response.ID, elapsed, request.Method, response.Result, response.Error)
240		// is it an error response?
241		if response.Error != nil {
242			return response.Error
243		}
244		if result == nil || response.Result == nil {
245			return nil
246		}
247		if err := json.Unmarshal(*response.Result, result); err != nil {
248			return fmt.Errorf("unmarshalling result: %v", err)
249		}
250		return nil
251	case <-ctx.Done():
252		// allow the handler to propagate the cancel
253		c.Canceler(ctx, c, id)
254		return ctx.Err()
255	}
256}
257
258// Conn returns the connection that created this request.
259func (r *Request) Conn() *Conn { return r.conn }
260
261// IsNotify returns true if this request is a notification.
262func (r *Request) IsNotify() bool {
263	return r.ID == nil
264}
265
266// Parallel indicates that the system is now allowed to process other requests
267// in parallel with this one.
268// It is safe to call any number of times, but must only be called from the
269// request handling go routine.
270// It is implied by both reply and by the handler returning.
271func (r *Request) Parallel() {
272	if r.state >= requestParallel {
273		return
274	}
275	r.state = requestParallel
276	close(r.nextRequest)
277}
278
279// Reply sends a reply to the given request.
280// It is an error to call this if request was not a call.
281// You must call this exactly once for any given request.
282// It should only be called from the handler go routine.
283// If err is set then result will be ignored.
284// If the request has not yet dropped into parallel mode
285// it will be before this function returns.
286func (r *Request) Reply(ctx context.Context, result interface{}, err error) error {
287	if r.state >= requestReplied {
288		return fmt.Errorf("reply invoked more than once")
289	}
290	if r.IsNotify() {
291		return fmt.Errorf("reply not invoked with a valid call")
292	}
293	ctx, st := trace.StartSpan(ctx, r.Method+":reply")
294	defer st.End()
295
296	// reply ends the handling phase of a call, so if we are not yet
297	// parallel we should be now. The go routine is allowed to continue
298	// to do work after replying, which is why it is important to unlock
299	// the rpc system at this point.
300	r.Parallel()
301	r.state = requestReplied
302
303	elapsed := time.Since(r.start)
304	var raw *json.RawMessage
305	if err == nil {
306		raw, err = marshalToRaw(result)
307	}
308	response := &wireResponse{
309		Result: raw,
310		ID:     r.ID,
311	}
312	if err != nil {
313		if callErr, ok := err.(*Error); ok {
314			response.Error = callErr
315		} else {
316			response.Error = NewErrorf(0, "%s", err)
317		}
318	}
319	data, err := json.Marshal(response)
320	if err != nil {
321		return err
322	}
323	r.conn.Logger(Send, response.ID, elapsed, r.Method, response.Result, response.Error)
324	n, err := r.conn.stream.Write(ctx, data)
325	telemetry.SentBytes.Record(ctx, n)
326
327	if err != nil {
328		// TODO(iancottrell): if a stream write fails, we really need to shut down
329		// the whole stream
330		return err
331	}
332	return nil
333}
334
335func (c *Conn) setHandling(r *Request, active bool) {
336	if r.ID == nil {
337		return
338	}
339	r.conn.handlingMu.Lock()
340	defer r.conn.handlingMu.Unlock()
341	if active {
342		r.conn.handling[*r.ID] = r
343	} else {
344		delete(r.conn.handling, *r.ID)
345	}
346}
347
348// combined has all the fields of both Request and Response.
349// We can decode this and then work out which it is.
350type combined struct {
351	VersionTag VersionTag       `json:"jsonrpc"`
352	ID         *ID              `json:"id,omitempty"`
353	Method     string           `json:"method"`
354	Params     *json.RawMessage `json:"params,omitempty"`
355	Result     *json.RawMessage `json:"result,omitempty"`
356	Error      *Error           `json:"error,omitempty"`
357}
358
359// Run blocks until the connection is terminated, and returns any error that
360// caused the termination.
361// It must be called exactly once for each Conn.
362// It returns only when the reader is closed or there is an error in the stream.
363func (c *Conn) Run(ctx context.Context) error {
364	// we need to make the next request "lock" in an unlocked state to allow
365	// the first incoming request to proceed. All later requests are unlocked
366	// by the preceding request going to parallel mode.
367	nextRequest := make(chan struct{})
368	close(nextRequest)
369	for {
370		// get the data for a message
371		data, n, err := c.stream.Read(ctx)
372		if err != nil {
373			// the stream failed, we cannot continue
374			return err
375		}
376		// read a combined message
377		msg := &combined{}
378		if err := json.Unmarshal(data, msg); err != nil {
379			// a badly formed message arrived, log it and continue
380			// we trust the stream to have isolated the error to just this message
381			c.Logger(Receive, nil, -1, "", nil, NewErrorf(0, "unmarshal failed: %v", err))
382			continue
383		}
384		// work out which kind of message we have
385		switch {
386		case msg.Method != "":
387			// if method is set it must be a request
388			reqCtx, cancelReq := context.WithCancel(ctx)
389			reqCtx, rpcStats := start(reqCtx, true, msg.Method, msg.ID)
390			telemetry.ReceivedBytes.Record(ctx, n)
391			thisRequest := nextRequest
392			nextRequest = make(chan struct{})
393			req := &Request{
394				conn:        c,
395				cancel:      cancelReq,
396				nextRequest: nextRequest,
397				start:       time.Now(),
398				Method:      msg.Method,
399				Params:      msg.Params,
400				ID:          msg.ID,
401			}
402			c.setHandling(req, true)
403			go func() {
404				<-thisRequest
405				req.state = requestSerial
406				defer func() {
407					c.setHandling(req, false)
408					if !req.IsNotify() && req.state < requestReplied {
409						req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method))
410					}
411					req.Parallel()
412					rpcStats.end(reqCtx, nil)
413					cancelReq()
414				}()
415				c.Logger(Receive, req.ID, -1, req.Method, req.Params, nil)
416				c.Handler(reqCtx, req)
417			}()
418		case msg.ID != nil:
419			// we have a response, get the pending entry from the map
420			c.pendingMu.Lock()
421			rchan := c.pending[*msg.ID]
422			if rchan != nil {
423				delete(c.pending, *msg.ID)
424			}
425			c.pendingMu.Unlock()
426			// and send the reply to the channel
427			response := &wireResponse{
428				Result: msg.Result,
429				Error:  msg.Error,
430				ID:     msg.ID,
431			}
432			rchan <- response
433			close(rchan)
434		default:
435			c.Logger(Receive, nil, -1, "", nil, NewErrorf(0, "message not a call, notify or response, ignoring"))
436		}
437	}
438}
439
440func marshalToRaw(obj interface{}) (*json.RawMessage, error) {
441	data, err := json.Marshal(obj)
442	if err != nil {
443		return nil, err
444	}
445	raw := json.RawMessage(data)
446	return &raw, nil
447}
448