1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonrpc2
6
7import (
8	"bufio"
9	"context"
10	"encoding/json"
11	"fmt"
12	"io"
13	"net"
14	"strconv"
15	"strings"
16)
17
18// Stream abstracts the transport mechanics from the JSON RPC protocol.
19// A Conn reads and writes messages using the stream it was provided on
20// construction, and assumes that each call to Read or Write fully transfers
21// a single message, or returns an error.
22// A stream is not safe for concurrent use, it is expected it will be used by
23// a single Conn in a safe manner.
24type Stream interface {
25	// Read gets the next message from the stream.
26	Read(context.Context) (Message, int64, error)
27	// Write sends a message to the stream.
28	Write(context.Context, Message) (int64, error)
29	// Close closes the connection.
30	// Any blocked Read or Write operations will be unblocked and return errors.
31	Close() error
32}
33
34// Framer wraps a network connection up into a Stream.
35// It is responsible for the framing and encoding of messages into wire form.
36// NewRawStream and NewHeaderStream are implementations of a Framer.
37type Framer func(conn net.Conn) Stream
38
39// NewRawStream returns a Stream built on top of a net.Conn.
40// The messages are sent with no wrapping, and rely on json decode consistency
41// to determine message boundaries.
42func NewRawStream(conn net.Conn) Stream {
43	return &rawStream{
44		conn: conn,
45		in:   json.NewDecoder(conn),
46	}
47}
48
49type rawStream struct {
50	conn net.Conn
51	in   *json.Decoder
52}
53
54func (s *rawStream) Read(ctx context.Context) (Message, int64, error) {
55	select {
56	case <-ctx.Done():
57		return nil, 0, ctx.Err()
58	default:
59	}
60	var raw json.RawMessage
61	if err := s.in.Decode(&raw); err != nil {
62		return nil, 0, err
63	}
64	msg, err := DecodeMessage(raw)
65	return msg, int64(len(raw)), err
66}
67
68func (s *rawStream) Write(ctx context.Context, msg Message) (int64, error) {
69	select {
70	case <-ctx.Done():
71		return 0, ctx.Err()
72	default:
73	}
74	data, err := json.Marshal(msg)
75	if err != nil {
76		return 0, fmt.Errorf("marshaling message: %v", err)
77	}
78	n, err := s.conn.Write(data)
79	return int64(n), err
80}
81
82func (s *rawStream) Close() error {
83	return s.conn.Close()
84}
85
86// NewHeaderStream returns a Stream built on top of a net.Conn.
87// The messages are sent with HTTP content length and MIME type headers.
88// This is the format used by LSP and others.
89func NewHeaderStream(conn net.Conn) Stream {
90	return &headerStream{
91		conn: conn,
92		in:   bufio.NewReader(conn),
93	}
94}
95
96type headerStream struct {
97	conn net.Conn
98	in   *bufio.Reader
99}
100
101func (s *headerStream) Read(ctx context.Context) (Message, int64, error) {
102	select {
103	case <-ctx.Done():
104		return nil, 0, ctx.Err()
105	default:
106	}
107	var total, length int64
108	// read the header, stop on the first empty line
109	for {
110		line, err := s.in.ReadString('\n')
111		total += int64(len(line))
112		if err != nil {
113			return nil, total, fmt.Errorf("failed reading header line: %w", err)
114		}
115		line = strings.TrimSpace(line)
116		// check we have a header line
117		if line == "" {
118			break
119		}
120		colon := strings.IndexRune(line, ':')
121		if colon < 0 {
122			return nil, total, fmt.Errorf("invalid header line %q", line)
123		}
124		name, value := line[:colon], strings.TrimSpace(line[colon+1:])
125		switch name {
126		case "Content-Length":
127			if length, err = strconv.ParseInt(value, 10, 32); err != nil {
128				return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value)
129			}
130			if length <= 0 {
131				return nil, total, fmt.Errorf("invalid Content-Length: %v", length)
132			}
133		default:
134			// ignoring unknown headers
135		}
136	}
137	if length == 0 {
138		return nil, total, fmt.Errorf("missing Content-Length header")
139	}
140	data := make([]byte, length)
141	if _, err := io.ReadFull(s.in, data); err != nil {
142		return nil, total, err
143	}
144	total += length
145	msg, err := DecodeMessage(data)
146	return msg, total, err
147}
148
149func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) {
150	select {
151	case <-ctx.Done():
152		return 0, ctx.Err()
153	default:
154	}
155	data, err := json.Marshal(msg)
156	if err != nil {
157		return 0, fmt.Errorf("marshaling message: %v", err)
158	}
159	n, err := fmt.Fprintf(s.conn, "Content-Length: %v\r\n\r\n", len(data))
160	total := int64(n)
161	if err == nil {
162		n, err = s.conn.Write(data)
163		total += int64(n)
164	}
165	return total, err
166}
167
168func (s *headerStream) Close() error {
169	return s.conn.Close()
170}
171