1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LSPServer.h"
10 #include "MLIRServer.h"
11 #include "lsp/Logging.h"
12 #include "lsp/Protocol.h"
13 #include "lsp/Transport.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 #define DEBUG_TYPE "mlir-lsp-server"
18 
19 using namespace mlir;
20 using namespace mlir::lsp;
21 
22 //===----------------------------------------------------------------------===//
23 // LSPServer::Impl
24 //===----------------------------------------------------------------------===//
25 
26 struct LSPServer::Impl {
ImplLSPServer::Impl27   Impl(MLIRServer &server, JSONTransport &transport)
28       : server(server), transport(transport) {}
29 
30   //===--------------------------------------------------------------------===//
31   // Initialization
32 
33   void onInitialize(const InitializeParams &params,
34                     Callback<llvm::json::Value> reply);
35   void onInitialized(const InitializedParams &params);
36   void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
37 
38   //===--------------------------------------------------------------------===//
39   // Document Change
40 
41   void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
42   void onDocumentDidClose(const DidCloseTextDocumentParams &params);
43   void onDocumentDidChange(const DidChangeTextDocumentParams &params);
44 
45   //===--------------------------------------------------------------------===//
46   // Definitions and References
47 
48   void onGoToDefinition(const TextDocumentPositionParams &params,
49                         Callback<std::vector<Location>> reply);
50   void onReference(const ReferenceParams &params,
51                    Callback<std::vector<Location>> reply);
52 
53   //===--------------------------------------------------------------------===//
54   // Hover
55 
56   void onHover(const TextDocumentPositionParams &params,
57                Callback<Optional<Hover>> reply);
58 
59   //===--------------------------------------------------------------------===//
60   // Document Symbols
61 
62   void onDocumentSymbol(const DocumentSymbolParams &params,
63                         Callback<std::vector<DocumentSymbol>> reply);
64 
65   //===--------------------------------------------------------------------===//
66   // Fields
67   //===--------------------------------------------------------------------===//
68 
69   MLIRServer &server;
70   JSONTransport &transport;
71 
72   /// An outgoing notification used to send diagnostics to the client when they
73   /// are ready to be processed.
74   OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
75 
76   /// Used to indicate that the 'shutdown' request was received from the
77   /// Language Server client.
78   bool shutdownRequestReceived = false;
79 };
80 
81 //===----------------------------------------------------------------------===//
82 // Initialization
83 
onInitialize(const InitializeParams & params,Callback<llvm::json::Value> reply)84 void LSPServer::Impl::onInitialize(const InitializeParams &params,
85                                    Callback<llvm::json::Value> reply) {
86   // Send a response with the capabilities of this server.
87   llvm::json::Object serverCaps{
88       {"textDocumentSync",
89        llvm::json::Object{
90            {"openClose", true},
91            {"change", (int)TextDocumentSyncKind::Full},
92            {"save", true},
93        }},
94       {"definitionProvider", true},
95       {"referencesProvider", true},
96       {"hoverProvider", true},
97 
98       // For now we only support documenting symbols when the client supports
99       // hierarchical symbols.
100       {"documentSymbolProvider",
101        params.capabilities.hierarchicalDocumentSymbol},
102   };
103 
104   llvm::json::Object result{
105       {{"serverInfo",
106         llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
107        {"capabilities", std::move(serverCaps)}}};
108   reply(std::move(result));
109 }
onInitialized(const InitializedParams &)110 void LSPServer::Impl::onInitialized(const InitializedParams &) {}
onShutdown(const NoParams &,Callback<std::nullptr_t> reply)111 void LSPServer::Impl::onShutdown(const NoParams &,
112                                  Callback<std::nullptr_t> reply) {
113   shutdownRequestReceived = true;
114   reply(nullptr);
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // Document Change
119 
onDocumentDidOpen(const DidOpenTextDocumentParams & params)120 void LSPServer::Impl::onDocumentDidOpen(
121     const DidOpenTextDocumentParams &params) {
122   PublishDiagnosticsParams diagParams(params.textDocument.uri,
123                                       params.textDocument.version);
124   server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
125                              params.textDocument.version,
126                              diagParams.diagnostics);
127 
128   // Publish any recorded diagnostics.
129   publishDiagnostics(diagParams);
130 }
onDocumentDidClose(const DidCloseTextDocumentParams & params)131 void LSPServer::Impl::onDocumentDidClose(
132     const DidCloseTextDocumentParams &params) {
133   Optional<int64_t> version = server.removeDocument(params.textDocument.uri);
134   if (!version)
135     return;
136 
137   // Empty out the diagnostics shown for this document. This will clear out
138   // anything currently displayed by the client for this document (e.g. in the
139   // "Problems" pane of VSCode).
140   publishDiagnostics(
141       PublishDiagnosticsParams(params.textDocument.uri, *version));
142 }
onDocumentDidChange(const DidChangeTextDocumentParams & params)143 void LSPServer::Impl::onDocumentDidChange(
144     const DidChangeTextDocumentParams &params) {
145   // TODO: We currently only support full document updates, we should refactor
146   // to avoid this.
147   if (params.contentChanges.size() != 1)
148     return;
149   PublishDiagnosticsParams diagParams(params.textDocument.uri,
150                                       params.textDocument.version);
151   server.addOrUpdateDocument(
152       params.textDocument.uri, params.contentChanges.front().text,
153       params.textDocument.version, diagParams.diagnostics);
154 
155   // Publish any recorded diagnostics.
156   publishDiagnostics(diagParams);
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Definitions and References
161 
onGoToDefinition(const TextDocumentPositionParams & params,Callback<std::vector<Location>> reply)162 void LSPServer::Impl::onGoToDefinition(const TextDocumentPositionParams &params,
163                                        Callback<std::vector<Location>> reply) {
164   std::vector<Location> locations;
165   server.getLocationsOf(params.textDocument.uri, params.position, locations);
166   reply(std::move(locations));
167 }
168 
onReference(const ReferenceParams & params,Callback<std::vector<Location>> reply)169 void LSPServer::Impl::onReference(const ReferenceParams &params,
170                                   Callback<std::vector<Location>> reply) {
171   std::vector<Location> locations;
172   server.findReferencesOf(params.textDocument.uri, params.position, locations);
173   reply(std::move(locations));
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Hover
178 
onHover(const TextDocumentPositionParams & params,Callback<Optional<Hover>> reply)179 void LSPServer::Impl::onHover(const TextDocumentPositionParams &params,
180                               Callback<Optional<Hover>> reply) {
181   reply(server.findHover(params.textDocument.uri, params.position));
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // Document Symbols
186 
onDocumentSymbol(const DocumentSymbolParams & params,Callback<std::vector<DocumentSymbol>> reply)187 void LSPServer::Impl::onDocumentSymbol(
188     const DocumentSymbolParams &params,
189     Callback<std::vector<DocumentSymbol>> reply) {
190   std::vector<DocumentSymbol> symbols;
191   server.findDocumentSymbols(params.textDocument.uri, symbols);
192   reply(std::move(symbols));
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // LSPServer
197 //===----------------------------------------------------------------------===//
198 
LSPServer(MLIRServer & server,JSONTransport & transport)199 LSPServer::LSPServer(MLIRServer &server, JSONTransport &transport)
200     : impl(std::make_unique<Impl>(server, transport)) {}
~LSPServer()201 LSPServer::~LSPServer() {}
202 
run()203 LogicalResult LSPServer::run() {
204   MessageHandler messageHandler(impl->transport);
205 
206   // Initialization
207   messageHandler.method("initialize", impl.get(), &Impl::onInitialize);
208   messageHandler.notification("initialized", impl.get(), &Impl::onInitialized);
209   messageHandler.method("shutdown", impl.get(), &Impl::onShutdown);
210 
211   // Document Changes
212   messageHandler.notification("textDocument/didOpen", impl.get(),
213                               &Impl::onDocumentDidOpen);
214   messageHandler.notification("textDocument/didClose", impl.get(),
215                               &Impl::onDocumentDidClose);
216   messageHandler.notification("textDocument/didChange", impl.get(),
217                               &Impl::onDocumentDidChange);
218 
219   // Definitions and References
220   messageHandler.method("textDocument/definition", impl.get(),
221                         &Impl::onGoToDefinition);
222   messageHandler.method("textDocument/references", impl.get(),
223                         &Impl::onReference);
224 
225   // Hover
226   messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover);
227 
228   // Document Symbols
229   messageHandler.method("textDocument/documentSymbol", impl.get(),
230                         &Impl::onDocumentSymbol);
231 
232   // Diagnostics
233   impl->publishDiagnostics =
234       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
235           "textDocument/publishDiagnostics");
236 
237   // Run the main loop of the transport.
238   LogicalResult result = success();
239   if (llvm::Error error = impl->transport.run(messageHandler)) {
240     Logger::error("Transport error: {0}", error);
241     llvm::consumeError(std::move(error));
242     result = failure();
243   } else {
244     result = success(impl->shutdownRequestReceived);
245   }
246   return result;
247 }
248