1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package jsonrpc2 6 7import ( 8 "bufio" 9 "context" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net" 14 "strconv" 15 "strings" 16) 17 18// Stream abstracts the transport mechanics from the JSON RPC protocol. 19// A Conn reads and writes messages using the stream it was provided on 20// construction, and assumes that each call to Read or Write fully transfers 21// a single message, or returns an error. 22// A stream is not safe for concurrent use, it is expected it will be used by 23// a single Conn in a safe manner. 24type Stream interface { 25 // Read gets the next message from the stream. 26 Read(context.Context) (Message, int64, error) 27 // Write sends a message to the stream. 28 Write(context.Context, Message) (int64, error) 29 // Close closes the connection. 30 // Any blocked Read or Write operations will be unblocked and return errors. 31 Close() error 32} 33 34// Framer wraps a network connection up into a Stream. 35// It is responsible for the framing and encoding of messages into wire form. 36// NewRawStream and NewHeaderStream are implementations of a Framer. 37type Framer func(conn net.Conn) Stream 38 39// NewRawStream returns a Stream built on top of a net.Conn. 40// The messages are sent with no wrapping, and rely on json decode consistency 41// to determine message boundaries. 42func NewRawStream(conn net.Conn) Stream { 43 return &rawStream{ 44 conn: conn, 45 in: json.NewDecoder(conn), 46 } 47} 48 49type rawStream struct { 50 conn net.Conn 51 in *json.Decoder 52} 53 54func (s *rawStream) Read(ctx context.Context) (Message, int64, error) { 55 select { 56 case <-ctx.Done(): 57 return nil, 0, ctx.Err() 58 default: 59 } 60 var raw json.RawMessage 61 if err := s.in.Decode(&raw); err != nil { 62 return nil, 0, err 63 } 64 msg, err := DecodeMessage(raw) 65 return msg, int64(len(raw)), err 66} 67 68func (s *rawStream) Write(ctx context.Context, msg Message) (int64, error) { 69 select { 70 case <-ctx.Done(): 71 return 0, ctx.Err() 72 default: 73 } 74 data, err := json.Marshal(msg) 75 if err != nil { 76 return 0, fmt.Errorf("marshaling message: %v", err) 77 } 78 n, err := s.conn.Write(data) 79 return int64(n), err 80} 81 82func (s *rawStream) Close() error { 83 return s.conn.Close() 84} 85 86// NewHeaderStream returns a Stream built on top of a net.Conn. 87// The messages are sent with HTTP content length and MIME type headers. 88// This is the format used by LSP and others. 89func NewHeaderStream(conn net.Conn) Stream { 90 return &headerStream{ 91 conn: conn, 92 in: bufio.NewReader(conn), 93 } 94} 95 96type headerStream struct { 97 conn net.Conn 98 in *bufio.Reader 99} 100 101func (s *headerStream) Read(ctx context.Context) (Message, int64, error) { 102 select { 103 case <-ctx.Done(): 104 return nil, 0, ctx.Err() 105 default: 106 } 107 var total, length int64 108 // read the header, stop on the first empty line 109 for { 110 line, err := s.in.ReadString('\n') 111 total += int64(len(line)) 112 if err != nil { 113 return nil, total, fmt.Errorf("failed reading header line: %w", err) 114 } 115 line = strings.TrimSpace(line) 116 // check we have a header line 117 if line == "" { 118 break 119 } 120 colon := strings.IndexRune(line, ':') 121 if colon < 0 { 122 return nil, total, fmt.Errorf("invalid header line %q", line) 123 } 124 name, value := line[:colon], strings.TrimSpace(line[colon+1:]) 125 switch name { 126 case "Content-Length": 127 if length, err = strconv.ParseInt(value, 10, 32); err != nil { 128 return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value) 129 } 130 if length <= 0 { 131 return nil, total, fmt.Errorf("invalid Content-Length: %v", length) 132 } 133 default: 134 // ignoring unknown headers 135 } 136 } 137 if length == 0 { 138 return nil, total, fmt.Errorf("missing Content-Length header") 139 } 140 data := make([]byte, length) 141 if _, err := io.ReadFull(s.in, data); err != nil { 142 return nil, total, err 143 } 144 total += length 145 msg, err := DecodeMessage(data) 146 return msg, total, err 147} 148 149func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) { 150 select { 151 case <-ctx.Done(): 152 return 0, ctx.Err() 153 default: 154 } 155 data, err := json.Marshal(msg) 156 if err != nil { 157 return 0, fmt.Errorf("marshaling message: %v", err) 158 } 159 n, err := fmt.Fprintf(s.conn, "Content-Length: %v\r\n\r\n", len(data)) 160 total := int64(n) 161 if err == nil { 162 n, err = s.conn.Write(data) 163 total += int64(n) 164 } 165 return total, err 166} 167 168func (s *headerStream) Close() error { 169 return s.conn.Close() 170} 171