1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package lsprpc implements a jsonrpc2.StreamServer that may be used to
6// serve the LSP on a jsonrpc2 channel.
7package lsprpc
8
9import (
10	"context"
11	"encoding/json"
12	"fmt"
13	stdlog "log"
14	"net"
15	"os"
16	"os/exec"
17	"strconv"
18	"sync/atomic"
19	"time"
20
21	"golang.org/x/sync/errgroup"
22	"golang.org/x/tools/internal/jsonrpc2"
23	"golang.org/x/tools/internal/lsp"
24	"golang.org/x/tools/internal/lsp/cache"
25	"golang.org/x/tools/internal/lsp/debug"
26	"golang.org/x/tools/internal/lsp/protocol"
27	"golang.org/x/tools/internal/telemetry/log"
28)
29
30// AutoNetwork is the pseudo network type used to signal that gopls should use
31// automatic discovery to resolve a remote address.
32const AutoNetwork = "auto"
33
34// The StreamServer type is a jsonrpc2.StreamServer that handles incoming
35// streams as a new LSP session, using a shared cache.
36type StreamServer struct {
37	withTelemetry bool
38	debug         *debug.Instance
39	cache         *cache.Cache
40
41	// serverForTest may be set to a test fake for testing.
42	serverForTest protocol.Server
43}
44
45var clientIndex, serverIndex int64
46
47// NewStreamServer creates a StreamServer using the shared cache. If
48// withTelemetry is true, each session is instrumented with telemetry that
49// records RPC statistics.
50func NewStreamServer(cache *cache.Cache, withTelemetry bool, debugInstance *debug.Instance) *StreamServer {
51	s := &StreamServer{
52		withTelemetry: withTelemetry,
53		debug:         debugInstance,
54		cache:         cache,
55	}
56	return s
57}
58
59// debugInstance is the common functionality shared between client and server
60// gopls instances.
61type debugInstance struct {
62	id           string
63	debugAddress string
64	logfile      string
65	goplsPath    string
66}
67
68func (d debugInstance) ID() string {
69	return d.id
70}
71
72func (d debugInstance) DebugAddress() string {
73	return d.debugAddress
74}
75
76func (d debugInstance) Logfile() string {
77	return d.logfile
78}
79
80func (d debugInstance) GoplsPath() string {
81	return d.goplsPath
82}
83
84// A debugServer is held by the client to identity the remove server to which
85// it is connected.
86type debugServer struct {
87	debugInstance
88	// clientID is the id of this client on the server.
89	clientID string
90}
91
92func (s debugServer) ClientID() string {
93	return s.clientID
94}
95
96// A debugClient is held by the server to identify an incoming client
97// connection.
98type debugClient struct {
99	debugInstance
100	// session is the session serving this client.
101	session *cache.Session
102	// serverID is this id of this server on the client.
103	serverID string
104}
105
106func (c debugClient) Session() debug.Session {
107	return cache.DebugSession{Session: c.session}
108}
109
110func (c debugClient) ServerID() string {
111	return c.serverID
112}
113
114// ServeStream implements the jsonrpc2.StreamServer interface, by handling
115// incoming streams using a new lsp server.
116func (s *StreamServer) ServeStream(ctx context.Context, stream jsonrpc2.Stream) error {
117	index := atomic.AddInt64(&clientIndex, 1)
118
119	conn := jsonrpc2.NewConn(stream)
120	client := protocol.ClientDispatcher(conn)
121	session := s.cache.NewSession()
122	dc := &debugClient{
123		debugInstance: debugInstance{
124			id: strconv.FormatInt(index, 10),
125		},
126		session: session,
127	}
128	s.debug.State.AddClient(dc)
129	defer s.debug.State.DropClient(dc)
130
131	server := s.serverForTest
132	if server == nil {
133		server = lsp.NewServer(session, client)
134	}
135	// Clients may or may not send a shutdown message. Make sure the server is
136	// shut down.
137	// TODO(rFindley): this shutdown should perhaps be on a disconnected context.
138	defer server.Shutdown(ctx)
139	conn.AddHandler(protocol.ServerHandler(server))
140	conn.AddHandler(protocol.Canceller{})
141	if s.withTelemetry {
142		conn.AddHandler(telemetryHandler{})
143	}
144	executable, err := os.Executable()
145	if err != nil {
146		stdlog.Printf("error getting gopls path: %v", err)
147		executable = ""
148	}
149	conn.AddHandler(&handshaker{
150		client:    dc,
151		debug:     s.debug,
152		goplsPath: executable,
153	})
154	return conn.Run(protocol.WithClient(ctx, client))
155}
156
157// A Forwarder is a jsonrpc2.StreamServer that handles an LSP stream by
158// forwarding it to a remote. This is used when the gopls process started by
159// the editor is in the `-remote` mode, which means it finds and connects to a
160// separate gopls daemon. In these cases, we still want the forwarder gopls to
161// be instrumented with telemetry, and want to be able to in some cases hijack
162// the jsonrpc2 connection with the daemon.
163type Forwarder struct {
164	network, addr string
165
166	// Configuration. Right now, not all of this may be customizable, but in the
167	// future it probably will be.
168	withTelemetry bool
169	dialTimeout   time.Duration
170	retries       int
171	debug         *debug.Instance
172	goplsPath     string
173}
174
175// NewForwarder creates a new Forwarder, ready to forward connections to the
176// remote server specified by network and addr.
177func NewForwarder(network, addr string, withTelemetry bool, debugInstance *debug.Instance) *Forwarder {
178	gp, err := os.Executable()
179	if err != nil {
180		stdlog.Printf("error getting gopls path for forwarder: %v", err)
181		gp = ""
182	}
183
184	return &Forwarder{
185		network:       network,
186		addr:          addr,
187		withTelemetry: withTelemetry,
188		dialTimeout:   1 * time.Second,
189		retries:       5,
190		debug:         debugInstance,
191		goplsPath:     gp,
192	}
193}
194
195// ServeStream dials the forwarder remote and binds the remote to serve the LSP
196// on the incoming stream.
197func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) error {
198	clientConn := jsonrpc2.NewConn(stream)
199	client := protocol.ClientDispatcher(clientConn)
200
201	netConn, err := f.connectToRemote(ctx)
202	if err != nil {
203		return fmt.Errorf("forwarder: connecting to remote: %v", err)
204	}
205	serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn, netConn))
206	server := protocol.ServerDispatcher(serverConn)
207
208	// Forward between connections.
209	serverConn.AddHandler(protocol.ClientHandler(client))
210	serverConn.AddHandler(protocol.Canceller{})
211	clientConn.AddHandler(protocol.ServerHandler(server))
212	clientConn.AddHandler(protocol.Canceller{})
213	clientConn.AddHandler(forwarderHandler{})
214	if f.withTelemetry {
215		clientConn.AddHandler(telemetryHandler{})
216	}
217	g, ctx := errgroup.WithContext(ctx)
218	g.Go(func() error {
219		return serverConn.Run(ctx)
220	})
221	// Don't run the clientConn yet, so that we can complete the handshake before
222	// processing any client messages.
223
224	// Do a handshake with the server instance to exchange debug information.
225	index := atomic.AddInt64(&serverIndex, 1)
226	serverID := strconv.FormatInt(index, 10)
227	var (
228		hreq = handshakeRequest{
229			ServerID:  serverID,
230			Logfile:   f.debug.Logfile,
231			DebugAddr: f.debug.ListenedDebugAddress,
232			GoplsPath: f.goplsPath,
233		}
234		hresp handshakeResponse
235	)
236	if err := serverConn.Call(ctx, handshakeMethod, hreq, &hresp); err != nil {
237		log.Error(ctx, "forwarder: gopls handshake failed", err)
238	}
239	if hresp.GoplsPath != f.goplsPath {
240		log.Error(ctx, "", fmt.Errorf("forwarder: gopls path mismatch: forwarder is %q, remote is %q", f.goplsPath, hresp.GoplsPath))
241	}
242	f.debug.State.AddServer(debugServer{
243		debugInstance: debugInstance{
244			id:           serverID,
245			logfile:      hresp.Logfile,
246			debugAddress: hresp.DebugAddr,
247			goplsPath:    hresp.GoplsPath,
248		},
249		clientID: hresp.ClientID,
250	})
251	g.Go(func() error {
252		return clientConn.Run(ctx)
253	})
254
255	return g.Wait()
256}
257
258func (f *Forwarder) connectToRemote(ctx context.Context) (net.Conn, error) {
259	var (
260		netConn          net.Conn
261		err              error
262		network, address = f.network, f.addr
263	)
264	if f.network == AutoNetwork {
265		// f.network is overloaded to support a concept of 'automatic' addresses,
266		// which signals that the gopls remote address should be automatically
267		// derived.
268		// So we need to resolve a real network and address here.
269		network, address = autoNetworkAddress(f.goplsPath, f.addr)
270	}
271	// Try dialing our remote once, in case it is already running.
272	netConn, err = net.DialTimeout(network, address, f.dialTimeout)
273	if err == nil {
274		return netConn, nil
275	}
276	// If our remote is on the 'auto' network, start it if it doesn't exist.
277	if f.network == AutoNetwork {
278		if f.goplsPath == "" {
279			return nil, fmt.Errorf("cannot auto-start remote: gopls path is unknown")
280		}
281		if network == "unix" {
282			// Sometimes the socketfile isn't properly cleaned up when gopls shuts
283			// down. Since we have already tried and failed to dial this address, it
284			// should *usually* be safe to remove the socket before binding to the
285			// address.
286			// TODO(rfindley): there is probably a race here if multiple gopls
287			// instances are simultaneously starting up.
288			if _, err := os.Stat(address); err == nil {
289				if err := os.Remove(address); err != nil {
290					return nil, fmt.Errorf("removing remote socket file: %v", err)
291				}
292			}
293		}
294		if err := startRemote(f.goplsPath, network, address); err != nil {
295			return nil, fmt.Errorf("startRemote(%q, %q): %v", network, address, err)
296		}
297	}
298
299	// It can take some time for the newly started server to bind to our address,
300	// so we retry for a bit.
301	for retry := 0; retry < f.retries; retry++ {
302		startDial := time.Now()
303		netConn, err = net.DialTimeout(network, address, f.dialTimeout)
304		if err == nil {
305			return netConn, nil
306		}
307		log.Print(ctx, fmt.Sprintf("failed attempt #%d to connect to remote: %v\n", retry+2, err))
308		// In case our failure was a fast-failure, ensure we wait at least
309		// f.dialTimeout before trying again.
310		if retry != f.retries-1 {
311			time.Sleep(f.dialTimeout - time.Since(startDial))
312		}
313	}
314	return nil, fmt.Errorf("dialing remote: %v", err)
315}
316
317func startRemote(goplsPath, network, address string) error {
318	args := []string{"serve",
319		"-listen", fmt.Sprintf(`%s;%s`, network, address),
320		"-listen.timeout", "1m",
321		"-debug", "localhost:0",
322		"-logfile", "auto",
323	}
324	cmd := exec.Command(goplsPath, args...)
325	if err := cmd.Start(); err != nil {
326		return fmt.Errorf("starting remote gopls: %v", err)
327	}
328	return nil
329}
330
331// ForwarderExitFunc is used to exit the forwarder process. It is mutable for
332// testing purposes.
333var ForwarderExitFunc = os.Exit
334
335// OverrideExitFuncsForTest can be used from test code to prevent the test
336// process from exiting on server shutdown. The returned func reverts the exit
337// funcs to their previous state.
338func OverrideExitFuncsForTest() func() {
339	// Override functions that would shut down the test process
340	cleanup := func(lspExit, forwarderExit func(code int)) func() {
341		return func() {
342			lsp.ServerExitFunc = lspExit
343			ForwarderExitFunc = forwarderExit
344		}
345	}(lsp.ServerExitFunc, ForwarderExitFunc)
346	// It is an error for a test to shutdown a server process.
347	lsp.ServerExitFunc = func(code int) {
348		panic(fmt.Sprintf("LSP server exited with code %d", code))
349	}
350	// We don't want our forwarders to exit, but it's OK if they would have.
351	ForwarderExitFunc = func(code int) {}
352	return cleanup
353}
354
355// forwarderHandler intercepts 'exit' messages to prevent the shared gopls
356// instance from exiting. In the future it may also intercept 'shutdown' to
357// provide more graceful shutdown of the client connection.
358type forwarderHandler struct {
359	jsonrpc2.EmptyHandler
360}
361
362func (forwarderHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
363	// TODO(golang.org/issues/34111): we should more gracefully disconnect here,
364	// once that process exists.
365	if r.Method == "exit" {
366		ForwarderExitFunc(0)
367		// Still return true here to prevent the message from being delivered: in
368		// tests, ForwarderExitFunc may be overridden to something that doesn't
369		// exit the process.
370		return true
371	}
372	return false
373}
374
375type handshaker struct {
376	jsonrpc2.EmptyHandler
377	client    *debugClient
378	debug     *debug.Instance
379	goplsPath string
380}
381
382type handshakeRequest struct {
383	ServerID  string `json:"serverID"`
384	Logfile   string `json:"logfile"`
385	DebugAddr string `json:"debugAddr"`
386	GoplsPath string `json:"goplsPath"`
387}
388
389type handshakeResponse struct {
390	ClientID  string `json:"clientID"`
391	SessionID string `json:"sessionID"`
392	Logfile   string `json:"logfile"`
393	DebugAddr string `json:"debugAddr"`
394	GoplsPath string `json:"goplsPath"`
395}
396
397const handshakeMethod = "gopls/handshake"
398
399func (h *handshaker) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool {
400	if r.Method == handshakeMethod {
401		var req handshakeRequest
402		if err := json.Unmarshal(*r.Params, &req); err != nil {
403			sendError(ctx, r, err)
404			return true
405		}
406		h.client.debugAddress = req.DebugAddr
407		h.client.logfile = req.Logfile
408		h.client.serverID = req.ServerID
409		h.client.goplsPath = req.GoplsPath
410		resp := handshakeResponse{
411			ClientID:  h.client.id,
412			SessionID: cache.DebugSession{Session: h.client.session}.ID(),
413			Logfile:   h.debug.Logfile,
414			DebugAddr: h.debug.ListenedDebugAddress,
415			GoplsPath: h.goplsPath,
416		}
417		if err := r.Reply(ctx, resp, nil); err != nil {
418			log.Error(ctx, "replying to handshake", err)
419		}
420		return true
421	}
422	return false
423}
424
425func sendError(ctx context.Context, req *jsonrpc2.Request, err error) {
426	if _, ok := err.(*jsonrpc2.Error); !ok {
427		err = jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err)
428	}
429	if err := req.Reply(ctx, nil, err); err != nil {
430		log.Error(ctx, "", err)
431	}
432}
433