1package protocol
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"io"
8	"strings"
9	"sync"
10	"time"
11
12	"golang.org/x/tools/internal/jsonrpc2"
13)
14
15type loggingStream struct {
16	stream jsonrpc2.Stream
17	log    io.Writer
18}
19
20// LoggingStream returns a stream that does LSP protocol logging too
21func LoggingStream(str jsonrpc2.Stream, w io.Writer) jsonrpc2.Stream {
22	return &loggingStream{str, w}
23}
24
25func (s *loggingStream) Read(ctx context.Context) ([]byte, int64, error) {
26	data, count, err := s.stream.Read(ctx)
27	if err == nil {
28		logIn(s.log, data)
29	}
30	return data, count, err
31}
32
33func (s *loggingStream) Write(ctx context.Context, data []byte) (int64, error) {
34	logOut(s.log, data)
35	count, err := s.stream.Write(ctx, data)
36	return count, err
37}
38
39// Combined has all the fields of both Request and Response.
40// We can decode this and then work out which it is.
41type Combined struct {
42	VersionTag jsonrpc2.VersionTag `json:"jsonrpc"`
43	ID         *jsonrpc2.ID        `json:"id,omitempty"`
44	Method     string              `json:"method"`
45	Params     *json.RawMessage    `json:"params,omitempty"`
46	Result     *json.RawMessage    `json:"result,omitempty"`
47	Error      *jsonrpc2.Error     `json:"error,omitempty"`
48}
49
50type req struct {
51	method string
52	start  time.Time
53}
54
55type mapped struct {
56	mu          sync.Mutex
57	clientCalls map[string]req
58	serverCalls map[string]req
59}
60
61var maps = &mapped{
62	sync.Mutex{},
63	make(map[string]req),
64	make(map[string]req),
65}
66
67// these 4 methods are each used exactly once, but it seemed
68// better to have the encapsulation rather than ad hoc mutex
69// code in 4 places
70func (m *mapped) client(id string, del bool) req {
71	m.mu.Lock()
72	defer m.mu.Unlock()
73	v := m.clientCalls[id]
74	if del {
75		delete(m.clientCalls, id)
76	}
77	return v
78}
79
80func (m *mapped) server(id string, del bool) req {
81	m.mu.Lock()
82	defer m.mu.Unlock()
83	v := m.serverCalls[id]
84	if del {
85		delete(m.serverCalls, id)
86	}
87	return v
88}
89
90func (m *mapped) setClient(id string, r req) {
91	m.mu.Lock()
92	defer m.mu.Unlock()
93	m.clientCalls[id] = r
94}
95
96func (m *mapped) setServer(id string, r req) {
97	m.mu.Lock()
98	defer m.mu.Unlock()
99	m.serverCalls[id] = r
100}
101
102const eor = "\r\n\r\n\r\n"
103
104func strID(x *jsonrpc2.ID) string {
105	if x == nil {
106		// should never happen, but we need a number
107		return "999999999"
108	}
109	if x.Name != "" {
110		return x.Name
111	}
112	return fmt.Sprintf("%d", x.Number)
113}
114
115func logCommon(outfd io.Writer, data []byte) (*Combined, time.Time, string) {
116	if outfd == nil {
117		return nil, time.Time{}, ""
118	}
119	var v Combined
120	err := json.Unmarshal(data, &v)
121	if err != nil {
122		fmt.Fprintf(outfd, "Unmarshal %v\n", err)
123		panic(err) // do better
124	}
125	tm := time.Now()
126	tmfmt := tm.Format("15:04:05.000 PM")
127	return &v, tm, tmfmt
128}
129
130// logOut and logIn could be combined. "received"<->"Sending", serverCalls<->clientCalls
131// but it wouldn't be a lot shorter or clearer and "shutdown" is a special case
132
133// Writing a message to the client, log it
134func logOut(outfd io.Writer, data []byte) {
135	v, tm, tmfmt := logCommon(outfd, data)
136	if v == nil {
137		return
138	}
139	if v.Error != nil {
140		id := strID(v.ID)
141		fmt.Fprintf(outfd, "[Error - %s] Received #%s %s%s", tmfmt, id, v.Error, eor)
142		return
143	}
144	buf := strings.Builder{}
145	id := strID(v.ID)
146	fmt.Fprintf(&buf, "[Trace - %s] ", tmfmt) // common beginning
147	if v.ID != nil && v.Method != "" && v.Params != nil {
148		fmt.Fprintf(&buf, "Received request '%s - (%s)'.\n", v.Method, id)
149		fmt.Fprintf(&buf, "Params: %s%s", *v.Params, eor)
150		maps.setServer(id, req{method: v.Method, start: tm})
151	} else if v.ID != nil && v.Method == "" && v.Params == nil {
152		cc := maps.client(id, true)
153		elapsed := tm.Sub(cc.start)
154		fmt.Fprintf(&buf, "Received response '%s - (%s)' in %dms.\n",
155			cc.method, id, elapsed/time.Millisecond)
156		if v.Result == nil {
157			fmt.Fprintf(&buf, "Result: {}%s", eor)
158		} else {
159			fmt.Fprintf(&buf, "Result: %s%s", string(*v.Result), eor)
160		}
161	} else if v.ID == nil && v.Method != "" && v.Params != nil {
162		p := "null"
163		if v.Params != nil {
164			p = string(*v.Params)
165		}
166		fmt.Fprintf(&buf, "Received notification '%s'.\n", v.Method)
167		fmt.Fprintf(&buf, "Params: %s%s", p, eor)
168	} else { // for completeness, as it should never happen
169		buf = strings.Builder{} // undo common Trace
170		fmt.Fprintf(&buf, "[Error - %s] on write ID?%v method:%q Params:%v Result:%v Error:%v%s",
171			tmfmt, v.ID != nil, v.Method, v.Params != nil,
172			v.Result != nil, v.Error != nil, eor)
173		p := "null"
174		if v.Params != nil {
175			p = string(*v.Params)
176		}
177		r := "null"
178		if v.Result != nil {
179			r = string(*v.Result)
180		}
181		fmt.Fprintf(&buf, "%s\n%s\n%s%s", p, r, v.Error, eor)
182	}
183	outfd.Write([]byte(buf.String()))
184}
185
186// Got a message from the client, log it
187func logIn(outfd io.Writer, data []byte) {
188	v, tm, tmfmt := logCommon(outfd, data)
189	if v == nil {
190		return
191	}
192	// ID Method Params => Sending request
193	// ID !Method Result(might be null, but !Params) => Sending response (could we get an Error?)
194	// !ID Method Params => Sending notification
195	if v.Error != nil { // does this ever happen?
196		id := strID(v.ID)
197		fmt.Fprintf(outfd, "[Error - %s] Sent #%s %s%s", tmfmt, id, v.Error, eor)
198		return
199	}
200	buf := strings.Builder{}
201	id := strID(v.ID)
202	fmt.Fprintf(&buf, "[Trace - %s] ", tmfmt) // common beginning
203	if v.ID != nil && v.Method != "" && (v.Params != nil || v.Method == "shutdown") {
204		fmt.Fprintf(&buf, "Sending request '%s - (%s)'.\n", v.Method, id)
205		x := "{}"
206		if v.Params != nil {
207			x = string(*v.Params)
208		}
209		fmt.Fprintf(&buf, "Params: %s%s", x, eor)
210		maps.setClient(id, req{method: v.Method, start: tm})
211	} else if v.ID != nil && v.Method == "" && v.Params == nil {
212		sc := maps.server(id, true)
213		elapsed := tm.Sub(sc.start)
214		fmt.Fprintf(&buf, "Sending response '%s - (%s)' took %dms.\n",
215			sc.method, id, elapsed/time.Millisecond)
216		if v.Result == nil {
217			fmt.Fprintf(&buf, "Result: {}%s", eor)
218		} else {
219			fmt.Fprintf(&buf, "Result: %s%s", string(*v.Result), eor)
220		}
221	} else if v.ID == nil && v.Method != "" {
222		p := "null"
223		if v.Params != nil {
224			p = string(*v.Params)
225		}
226		fmt.Fprintf(&buf, "Sending notification '%s'.\n", v.Method)
227		fmt.Fprintf(&buf, "Params: %s%s", p, eor)
228	} else { // for completeness, as it should never happen
229		buf = strings.Builder{} // undo common Trace
230		fmt.Fprintf(&buf, "[Error - %s] on read ID?%v method:%q Params:%v Result:%v Error:%v%s",
231			tmfmt, v.ID != nil, v.Method, v.Params != nil,
232			v.Result != nil, v.Error != nil, eor)
233		p := "null"
234		if v.Params != nil {
235			p = string(*v.Params)
236		}
237		r := "null"
238		if v.Result != nil {
239			r = string(*v.Result)
240		}
241		fmt.Fprintf(&buf, "%s\n%s\n%s%s", p, r, v.Error, eor)
242	}
243	outfd.Write([]byte(buf.String()))
244}
245