1// Copyright 2020 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package lsprpc implements a jsonrpc2.StreamServer that may be used to 6// serve the LSP on a jsonrpc2 channel. 7package lsprpc 8 9import ( 10 "context" 11 "encoding/json" 12 "fmt" 13 stdlog "log" 14 "net" 15 "os" 16 "os/exec" 17 "strconv" 18 "sync/atomic" 19 "time" 20 21 "golang.org/x/sync/errgroup" 22 "golang.org/x/tools/internal/jsonrpc2" 23 "golang.org/x/tools/internal/lsp" 24 "golang.org/x/tools/internal/lsp/cache" 25 "golang.org/x/tools/internal/lsp/debug" 26 "golang.org/x/tools/internal/lsp/protocol" 27 "golang.org/x/tools/internal/telemetry/log" 28) 29 30// AutoNetwork is the pseudo network type used to signal that gopls should use 31// automatic discovery to resolve a remote address. 32const AutoNetwork = "auto" 33 34// The StreamServer type is a jsonrpc2.StreamServer that handles incoming 35// streams as a new LSP session, using a shared cache. 36type StreamServer struct { 37 withTelemetry bool 38 debug *debug.Instance 39 cache *cache.Cache 40 41 // serverForTest may be set to a test fake for testing. 42 serverForTest protocol.Server 43} 44 45var clientIndex, serverIndex int64 46 47// NewStreamServer creates a StreamServer using the shared cache. If 48// withTelemetry is true, each session is instrumented with telemetry that 49// records RPC statistics. 50func NewStreamServer(cache *cache.Cache, withTelemetry bool, debugInstance *debug.Instance) *StreamServer { 51 s := &StreamServer{ 52 withTelemetry: withTelemetry, 53 debug: debugInstance, 54 cache: cache, 55 } 56 return s 57} 58 59// debugInstance is the common functionality shared between client and server 60// gopls instances. 61type debugInstance struct { 62 id string 63 debugAddress string 64 logfile string 65 goplsPath string 66} 67 68func (d debugInstance) ID() string { 69 return d.id 70} 71 72func (d debugInstance) DebugAddress() string { 73 return d.debugAddress 74} 75 76func (d debugInstance) Logfile() string { 77 return d.logfile 78} 79 80func (d debugInstance) GoplsPath() string { 81 return d.goplsPath 82} 83 84// A debugServer is held by the client to identity the remove server to which 85// it is connected. 86type debugServer struct { 87 debugInstance 88 // clientID is the id of this client on the server. 89 clientID string 90} 91 92func (s debugServer) ClientID() string { 93 return s.clientID 94} 95 96// A debugClient is held by the server to identify an incoming client 97// connection. 98type debugClient struct { 99 debugInstance 100 // session is the session serving this client. 101 session *cache.Session 102 // serverID is this id of this server on the client. 103 serverID string 104} 105 106func (c debugClient) Session() debug.Session { 107 return cache.DebugSession{Session: c.session} 108} 109 110func (c debugClient) ServerID() string { 111 return c.serverID 112} 113 114// ServeStream implements the jsonrpc2.StreamServer interface, by handling 115// incoming streams using a new lsp server. 116func (s *StreamServer) ServeStream(ctx context.Context, stream jsonrpc2.Stream) error { 117 index := atomic.AddInt64(&clientIndex, 1) 118 119 conn := jsonrpc2.NewConn(stream) 120 client := protocol.ClientDispatcher(conn) 121 session := s.cache.NewSession() 122 dc := &debugClient{ 123 debugInstance: debugInstance{ 124 id: strconv.FormatInt(index, 10), 125 }, 126 session: session, 127 } 128 s.debug.State.AddClient(dc) 129 defer s.debug.State.DropClient(dc) 130 131 server := s.serverForTest 132 if server == nil { 133 server = lsp.NewServer(session, client) 134 } 135 // Clients may or may not send a shutdown message. Make sure the server is 136 // shut down. 137 // TODO(rFindley): this shutdown should perhaps be on a disconnected context. 138 defer server.Shutdown(ctx) 139 conn.AddHandler(protocol.ServerHandler(server)) 140 conn.AddHandler(protocol.Canceller{}) 141 if s.withTelemetry { 142 conn.AddHandler(telemetryHandler{}) 143 } 144 executable, err := os.Executable() 145 if err != nil { 146 stdlog.Printf("error getting gopls path: %v", err) 147 executable = "" 148 } 149 conn.AddHandler(&handshaker{ 150 client: dc, 151 debug: s.debug, 152 goplsPath: executable, 153 }) 154 return conn.Run(protocol.WithClient(ctx, client)) 155} 156 157// A Forwarder is a jsonrpc2.StreamServer that handles an LSP stream by 158// forwarding it to a remote. This is used when the gopls process started by 159// the editor is in the `-remote` mode, which means it finds and connects to a 160// separate gopls daemon. In these cases, we still want the forwarder gopls to 161// be instrumented with telemetry, and want to be able to in some cases hijack 162// the jsonrpc2 connection with the daemon. 163type Forwarder struct { 164 network, addr string 165 166 // Configuration. Right now, not all of this may be customizable, but in the 167 // future it probably will be. 168 withTelemetry bool 169 dialTimeout time.Duration 170 retries int 171 debug *debug.Instance 172 goplsPath string 173} 174 175// NewForwarder creates a new Forwarder, ready to forward connections to the 176// remote server specified by network and addr. 177func NewForwarder(network, addr string, withTelemetry bool, debugInstance *debug.Instance) *Forwarder { 178 gp, err := os.Executable() 179 if err != nil { 180 stdlog.Printf("error getting gopls path for forwarder: %v", err) 181 gp = "" 182 } 183 184 return &Forwarder{ 185 network: network, 186 addr: addr, 187 withTelemetry: withTelemetry, 188 dialTimeout: 1 * time.Second, 189 retries: 5, 190 debug: debugInstance, 191 goplsPath: gp, 192 } 193} 194 195// ServeStream dials the forwarder remote and binds the remote to serve the LSP 196// on the incoming stream. 197func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) error { 198 clientConn := jsonrpc2.NewConn(stream) 199 client := protocol.ClientDispatcher(clientConn) 200 201 netConn, err := f.connectToRemote(ctx) 202 if err != nil { 203 return fmt.Errorf("forwarder: connecting to remote: %v", err) 204 } 205 serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn, netConn)) 206 server := protocol.ServerDispatcher(serverConn) 207 208 // Forward between connections. 209 serverConn.AddHandler(protocol.ClientHandler(client)) 210 serverConn.AddHandler(protocol.Canceller{}) 211 clientConn.AddHandler(protocol.ServerHandler(server)) 212 clientConn.AddHandler(protocol.Canceller{}) 213 clientConn.AddHandler(forwarderHandler{}) 214 if f.withTelemetry { 215 clientConn.AddHandler(telemetryHandler{}) 216 } 217 g, ctx := errgroup.WithContext(ctx) 218 g.Go(func() error { 219 return serverConn.Run(ctx) 220 }) 221 // Don't run the clientConn yet, so that we can complete the handshake before 222 // processing any client messages. 223 224 // Do a handshake with the server instance to exchange debug information. 225 index := atomic.AddInt64(&serverIndex, 1) 226 serverID := strconv.FormatInt(index, 10) 227 var ( 228 hreq = handshakeRequest{ 229 ServerID: serverID, 230 Logfile: f.debug.Logfile, 231 DebugAddr: f.debug.ListenedDebugAddress, 232 GoplsPath: f.goplsPath, 233 } 234 hresp handshakeResponse 235 ) 236 if err := serverConn.Call(ctx, handshakeMethod, hreq, &hresp); err != nil { 237 log.Error(ctx, "forwarder: gopls handshake failed", err) 238 } 239 if hresp.GoplsPath != f.goplsPath { 240 log.Error(ctx, "", fmt.Errorf("forwarder: gopls path mismatch: forwarder is %q, remote is %q", f.goplsPath, hresp.GoplsPath)) 241 } 242 f.debug.State.AddServer(debugServer{ 243 debugInstance: debugInstance{ 244 id: serverID, 245 logfile: hresp.Logfile, 246 debugAddress: hresp.DebugAddr, 247 goplsPath: hresp.GoplsPath, 248 }, 249 clientID: hresp.ClientID, 250 }) 251 g.Go(func() error { 252 return clientConn.Run(ctx) 253 }) 254 255 return g.Wait() 256} 257 258func (f *Forwarder) connectToRemote(ctx context.Context) (net.Conn, error) { 259 var ( 260 netConn net.Conn 261 err error 262 network, address = f.network, f.addr 263 ) 264 if f.network == AutoNetwork { 265 // f.network is overloaded to support a concept of 'automatic' addresses, 266 // which signals that the gopls remote address should be automatically 267 // derived. 268 // So we need to resolve a real network and address here. 269 network, address = autoNetworkAddress(f.goplsPath, f.addr) 270 } 271 // Try dialing our remote once, in case it is already running. 272 netConn, err = net.DialTimeout(network, address, f.dialTimeout) 273 if err == nil { 274 return netConn, nil 275 } 276 // If our remote is on the 'auto' network, start it if it doesn't exist. 277 if f.network == AutoNetwork { 278 if f.goplsPath == "" { 279 return nil, fmt.Errorf("cannot auto-start remote: gopls path is unknown") 280 } 281 if network == "unix" { 282 // Sometimes the socketfile isn't properly cleaned up when gopls shuts 283 // down. Since we have already tried and failed to dial this address, it 284 // should *usually* be safe to remove the socket before binding to the 285 // address. 286 // TODO(rfindley): there is probably a race here if multiple gopls 287 // instances are simultaneously starting up. 288 if _, err := os.Stat(address); err == nil { 289 if err := os.Remove(address); err != nil { 290 return nil, fmt.Errorf("removing remote socket file: %v", err) 291 } 292 } 293 } 294 if err := startRemote(f.goplsPath, network, address); err != nil { 295 return nil, fmt.Errorf("startRemote(%q, %q): %v", network, address, err) 296 } 297 } 298 299 // It can take some time for the newly started server to bind to our address, 300 // so we retry for a bit. 301 for retry := 0; retry < f.retries; retry++ { 302 startDial := time.Now() 303 netConn, err = net.DialTimeout(network, address, f.dialTimeout) 304 if err == nil { 305 return netConn, nil 306 } 307 log.Print(ctx, fmt.Sprintf("failed attempt #%d to connect to remote: %v\n", retry+2, err)) 308 // In case our failure was a fast-failure, ensure we wait at least 309 // f.dialTimeout before trying again. 310 if retry != f.retries-1 { 311 time.Sleep(f.dialTimeout - time.Since(startDial)) 312 } 313 } 314 return nil, fmt.Errorf("dialing remote: %v", err) 315} 316 317func startRemote(goplsPath, network, address string) error { 318 args := []string{"serve", 319 "-listen", fmt.Sprintf(`%s;%s`, network, address), 320 "-listen.timeout", "1m", 321 "-debug", "localhost:0", 322 "-logfile", "auto", 323 } 324 cmd := exec.Command(goplsPath, args...) 325 if err := cmd.Start(); err != nil { 326 return fmt.Errorf("starting remote gopls: %v", err) 327 } 328 return nil 329} 330 331// ForwarderExitFunc is used to exit the forwarder process. It is mutable for 332// testing purposes. 333var ForwarderExitFunc = os.Exit 334 335// OverrideExitFuncsForTest can be used from test code to prevent the test 336// process from exiting on server shutdown. The returned func reverts the exit 337// funcs to their previous state. 338func OverrideExitFuncsForTest() func() { 339 // Override functions that would shut down the test process 340 cleanup := func(lspExit, forwarderExit func(code int)) func() { 341 return func() { 342 lsp.ServerExitFunc = lspExit 343 ForwarderExitFunc = forwarderExit 344 } 345 }(lsp.ServerExitFunc, ForwarderExitFunc) 346 // It is an error for a test to shutdown a server process. 347 lsp.ServerExitFunc = func(code int) { 348 panic(fmt.Sprintf("LSP server exited with code %d", code)) 349 } 350 // We don't want our forwarders to exit, but it's OK if they would have. 351 ForwarderExitFunc = func(code int) {} 352 return cleanup 353} 354 355// forwarderHandler intercepts 'exit' messages to prevent the shared gopls 356// instance from exiting. In the future it may also intercept 'shutdown' to 357// provide more graceful shutdown of the client connection. 358type forwarderHandler struct { 359 jsonrpc2.EmptyHandler 360} 361 362func (forwarderHandler) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool { 363 // TODO(golang.org/issues/34111): we should more gracefully disconnect here, 364 // once that process exists. 365 if r.Method == "exit" { 366 ForwarderExitFunc(0) 367 // Still return true here to prevent the message from being delivered: in 368 // tests, ForwarderExitFunc may be overridden to something that doesn't 369 // exit the process. 370 return true 371 } 372 return false 373} 374 375type handshaker struct { 376 jsonrpc2.EmptyHandler 377 client *debugClient 378 debug *debug.Instance 379 goplsPath string 380} 381 382type handshakeRequest struct { 383 ServerID string `json:"serverID"` 384 Logfile string `json:"logfile"` 385 DebugAddr string `json:"debugAddr"` 386 GoplsPath string `json:"goplsPath"` 387} 388 389type handshakeResponse struct { 390 ClientID string `json:"clientID"` 391 SessionID string `json:"sessionID"` 392 Logfile string `json:"logfile"` 393 DebugAddr string `json:"debugAddr"` 394 GoplsPath string `json:"goplsPath"` 395} 396 397const handshakeMethod = "gopls/handshake" 398 399func (h *handshaker) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool { 400 if r.Method == handshakeMethod { 401 var req handshakeRequest 402 if err := json.Unmarshal(*r.Params, &req); err != nil { 403 sendError(ctx, r, err) 404 return true 405 } 406 h.client.debugAddress = req.DebugAddr 407 h.client.logfile = req.Logfile 408 h.client.serverID = req.ServerID 409 h.client.goplsPath = req.GoplsPath 410 resp := handshakeResponse{ 411 ClientID: h.client.id, 412 SessionID: cache.DebugSession{Session: h.client.session}.ID(), 413 Logfile: h.debug.Logfile, 414 DebugAddr: h.debug.ListenedDebugAddress, 415 GoplsPath: h.goplsPath, 416 } 417 if err := r.Reply(ctx, resp, nil); err != nil { 418 log.Error(ctx, "replying to handshake", err) 419 } 420 return true 421 } 422 return false 423} 424 425func sendError(ctx context.Context, req *jsonrpc2.Request, err error) { 426 if _, ok := err.(*jsonrpc2.Error); !ok { 427 err = jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err) 428 } 429 if err := req.Reply(ctx, nil, err); err != nil { 430 log.Error(ctx, "", err) 431 } 432} 433