1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "MLIRServer.h"
10 #include "lsp/Logging.h"
11 #include "lsp/Protocol.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Parser.h"
14 #include "mlir/Parser/AsmParserState.h"
15 #include "llvm/Support/SourceMgr.h"
16
17 using namespace mlir;
18
19 /// Returns a language server position for the given source location.
getPosFromLoc(llvm::SourceMgr & mgr,llvm::SMLoc loc)20 static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, llvm::SMLoc loc) {
21 std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
22 lsp::Position pos;
23 pos.line = lineAndCol.first - 1;
24 pos.character = lineAndCol.second - 1;
25 return pos;
26 }
27
28 /// Returns a source location from the given language server position.
getPosFromLoc(llvm::SourceMgr & mgr,lsp::Position pos)29 static llvm::SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
30 return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1,
31 pos.character);
32 }
33
34 /// Returns a language server range for the given source range.
getRangeFromLoc(llvm::SourceMgr & mgr,llvm::SMRange range)35 static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, llvm::SMRange range) {
36 return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
37 }
38
39 /// Returns a language server location from the given source range.
getLocationFromLoc(llvm::SourceMgr & mgr,llvm::SMRange range,const lsp::URIForFile & uri)40 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr,
41 llvm::SMRange range,
42 const lsp::URIForFile &uri) {
43 return lsp::Location{uri, getRangeFromLoc(mgr, range)};
44 }
45
46 /// Returns a language server location from the given MLIR file location.
getLocationFromLoc(FileLineColLoc loc)47 static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
48 llvm::Expected<lsp::URIForFile> sourceURI =
49 lsp::URIForFile::fromFile(loc.getFilename());
50 if (!sourceURI) {
51 lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
52 loc.getFilename(),
53 llvm::toString(sourceURI.takeError()));
54 return llvm::None;
55 }
56
57 lsp::Position position;
58 position.line = loc.getLine() - 1;
59 position.character = loc.getColumn();
60 return lsp::Location{*sourceURI, lsp::Range(position)};
61 }
62
63 /// Returns a language server location from the given MLIR location, or None if
64 /// one couldn't be created. `uri` is an optional additional filter that, when
65 /// present, is used to filter sub locations that do not share the same uri.
66 static Optional<lsp::Location>
getLocationFromLoc(llvm::SourceMgr & sourceMgr,Location loc,const lsp::URIForFile * uri=nullptr)67 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
68 const lsp::URIForFile *uri = nullptr) {
69 Optional<lsp::Location> location;
70 loc->walk([&](Location nestedLoc) {
71 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
72 if (!fileLoc)
73 return WalkResult::advance();
74
75 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
76 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
77 location = *sourceLoc;
78 llvm::SMLoc loc = sourceMgr.FindLocForLineAndColumn(
79 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
80
81 // Use range of potential identifier starting at location, else length 1
82 // range.
83 location->range.end.character += 1;
84 if (Optional<llvm::SMRange> range =
85 AsmParserState::convertIdLocToRange(loc)) {
86 auto lineCol = sourceMgr.getLineAndColumn(range->End);
87 location->range.end.character =
88 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
89 }
90 return WalkResult::interrupt();
91 }
92 return WalkResult::advance();
93 });
94 return location;
95 }
96
97 /// Collect all of the locations from the given MLIR location that are not
98 /// contained within the given URI.
collectLocationsFromLoc(Location loc,std::vector<lsp::Location> & locations,const lsp::URIForFile & uri)99 static void collectLocationsFromLoc(Location loc,
100 std::vector<lsp::Location> &locations,
101 const lsp::URIForFile &uri) {
102 SetVector<Location> visitedLocs;
103 loc->walk([&](Location nestedLoc) {
104 FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
105 if (!fileLoc || !visitedLocs.insert(nestedLoc))
106 return WalkResult::advance();
107
108 Optional<lsp::Location> sourceLoc = getLocationFromLoc(fileLoc);
109 if (sourceLoc && sourceLoc->uri != uri)
110 locations.push_back(*sourceLoc);
111 return WalkResult::advance();
112 });
113 }
114
115 /// Returns true if the given range contains the given source location. Note
116 /// that this has slightly different behavior than SMRange because it is
117 /// inclusive of the end location.
contains(llvm::SMRange range,llvm::SMLoc loc)118 static bool contains(llvm::SMRange range, llvm::SMLoc loc) {
119 return range.Start.getPointer() <= loc.getPointer() &&
120 loc.getPointer() <= range.End.getPointer();
121 }
122
123 /// Returns true if the given location is contained by the definition or one of
124 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
125 /// the range within `def` that the provided `loc` overlapped with.
isDefOrUse(const AsmParserState::SMDefinition & def,llvm::SMLoc loc,llvm::SMRange * overlappedRange=nullptr)126 static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc,
127 llvm::SMRange *overlappedRange = nullptr) {
128 // Check the main definition.
129 if (contains(def.loc, loc)) {
130 if (overlappedRange)
131 *overlappedRange = def.loc;
132 return true;
133 }
134
135 // Check the uses.
136 auto useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) {
137 return contains(range, loc);
138 });
139 if (useIt != def.uses.end()) {
140 if (overlappedRange)
141 *overlappedRange = *useIt;
142 return true;
143 }
144 return false;
145 }
146
147 /// Given a location pointing to a result, return the result number it refers
148 /// to or None if it refers to all of the results.
getResultNumberFromLoc(llvm::SMLoc loc)149 static Optional<unsigned> getResultNumberFromLoc(llvm::SMLoc loc) {
150 // Skip all of the identifier characters.
151 auto isIdentifierChar = [](char c) {
152 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
153 c == '-';
154 };
155 const char *curPtr = loc.getPointer();
156 while (isIdentifierChar(*curPtr))
157 ++curPtr;
158
159 // Check to see if this location indexes into the result group, via `#`. If it
160 // doesn't, we can't extract a sub result number.
161 if (*curPtr != '#')
162 return llvm::None;
163
164 // Compute the sub result number from the remaining portion of the string.
165 const char *numberStart = ++curPtr;
166 while (llvm::isDigit(*curPtr))
167 ++curPtr;
168 StringRef numberStr(numberStart, curPtr - numberStart);
169 unsigned resultNumber = 0;
170 return numberStr.consumeInteger(10, resultNumber) ? Optional<unsigned>()
171 : resultNumber;
172 }
173
174 /// Given a source location range, return the text covered by the given range.
175 /// If the range is invalid, returns None.
getTextFromRange(llvm::SMRange range)176 static Optional<StringRef> getTextFromRange(llvm::SMRange range) {
177 if (!range.isValid())
178 return None;
179 const char *startPtr = range.Start.getPointer();
180 return StringRef(startPtr, range.End.getPointer() - startPtr);
181 }
182
183 /// Given a block, return its position in its parent region.
getBlockNumber(Block * block)184 static unsigned getBlockNumber(Block *block) {
185 return std::distance(block->getParent()->begin(), block->getIterator());
186 }
187
188 /// Given a block and source location, print the source name of the block to the
189 /// given output stream.
printDefBlockName(raw_ostream & os,Block * block,llvm::SMRange loc={})190 static void printDefBlockName(raw_ostream &os, Block *block,
191 llvm::SMRange loc = {}) {
192 // Try to extract a name from the source location.
193 Optional<StringRef> text = getTextFromRange(loc);
194 if (text && text->startswith("^")) {
195 os << *text;
196 return;
197 }
198
199 // Otherwise, we don't have a name so print the block number.
200 os << "<Block #" << getBlockNumber(block) << ">";
201 }
printDefBlockName(raw_ostream & os,const AsmParserState::BlockDefinition & def)202 static void printDefBlockName(raw_ostream &os,
203 const AsmParserState::BlockDefinition &def) {
204 printDefBlockName(os, def.block, def.definition.loc);
205 }
206
207 /// Convert the given MLIR diagnostic to the LSP form.
getLspDiagnoticFromDiag(llvm::SourceMgr & sourceMgr,Diagnostic & diag,const lsp::URIForFile & uri)208 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
209 Diagnostic &diag,
210 const lsp::URIForFile &uri) {
211 lsp::Diagnostic lspDiag;
212 lspDiag.source = "mlir";
213
214 // Note: Right now all of the diagnostics are treated as parser issues, but
215 // some are parser and some are verifier.
216 lspDiag.category = "Parse Error";
217
218 // Try to grab a file location for this diagnostic.
219 // TODO: For simplicity, we just grab the first one. It may be likely that we
220 // will need a more interesting heuristic here.'
221 Optional<lsp::Location> lspLocation =
222 getLocationFromLoc(sourceMgr, diag.getLocation(), &uri);
223 if (lspLocation)
224 lspDiag.range = lspLocation->range;
225
226 // Convert the severity for the diagnostic.
227 switch (diag.getSeverity()) {
228 case DiagnosticSeverity::Note:
229 llvm_unreachable("expected notes to be handled separately");
230 case DiagnosticSeverity::Warning:
231 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
232 break;
233 case DiagnosticSeverity::Error:
234 lspDiag.severity = lsp::DiagnosticSeverity::Error;
235 break;
236 case DiagnosticSeverity::Remark:
237 lspDiag.severity = lsp::DiagnosticSeverity::Information;
238 break;
239 }
240 lspDiag.message = diag.str();
241
242 // Attach any notes to the main diagnostic as related information.
243 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
244 for (Diagnostic ¬e : diag.getNotes()) {
245 lsp::Location noteLoc;
246 if (Optional<lsp::Location> loc =
247 getLocationFromLoc(sourceMgr, note.getLocation()))
248 noteLoc = *loc;
249 else
250 noteLoc.uri = uri;
251 relatedDiags.emplace_back(noteLoc, note.str());
252 }
253 if (!relatedDiags.empty())
254 lspDiag.relatedInformation = std::move(relatedDiags);
255
256 return lspDiag;
257 }
258
259 //===----------------------------------------------------------------------===//
260 // MLIRDocument
261 //===----------------------------------------------------------------------===//
262
263 namespace {
264 /// This class represents all of the information pertaining to a specific MLIR
265 /// document.
266 struct MLIRDocument {
267 MLIRDocument(const lsp::URIForFile &uri, StringRef contents,
268 DialectRegistry ®istry,
269 std::vector<lsp::Diagnostic> &diagnostics);
270 MLIRDocument(const MLIRDocument &) = delete;
271 MLIRDocument &operator=(const MLIRDocument &) = delete;
272
273 //===--------------------------------------------------------------------===//
274 // Definitions and References
275 //===--------------------------------------------------------------------===//
276
277 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
278 std::vector<lsp::Location> &locations);
279 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
280 std::vector<lsp::Location> &references);
281
282 //===--------------------------------------------------------------------===//
283 // Hover
284 //===--------------------------------------------------------------------===//
285
286 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
287 const lsp::Position &hoverPos);
288 Optional<lsp::Hover>
289 buildHoverForOperation(llvm::SMRange hoverRange,
290 const AsmParserState::OperationDefinition &op);
291 lsp::Hover buildHoverForOperationResult(llvm::SMRange hoverRange,
292 Operation *op, unsigned resultStart,
293 unsigned resultEnd,
294 llvm::SMLoc posLoc);
295 lsp::Hover buildHoverForBlock(llvm::SMRange hoverRange,
296 const AsmParserState::BlockDefinition &block);
297 lsp::Hover
298 buildHoverForBlockArgument(llvm::SMRange hoverRange, BlockArgument arg,
299 const AsmParserState::BlockDefinition &block);
300
301 //===--------------------------------------------------------------------===//
302 // Document Symbols
303 //===--------------------------------------------------------------------===//
304
305 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
306 void findDocumentSymbols(Operation *op,
307 std::vector<lsp::DocumentSymbol> &symbols);
308
309 //===--------------------------------------------------------------------===//
310 // Fields
311 //===--------------------------------------------------------------------===//
312
313 /// The context used to hold the state contained by the parsed document.
314 MLIRContext context;
315
316 /// The high level parser state used to find definitions and references within
317 /// the source file.
318 AsmParserState asmState;
319
320 /// The container for the IR parsed from the input file.
321 Block parsedIR;
322
323 /// The source manager containing the contents of the input file.
324 llvm::SourceMgr sourceMgr;
325 };
326 } // namespace
327
MLIRDocument(const lsp::URIForFile & uri,StringRef contents,DialectRegistry & registry,std::vector<lsp::Diagnostic> & diagnostics)328 MLIRDocument::MLIRDocument(const lsp::URIForFile &uri, StringRef contents,
329 DialectRegistry ®istry,
330 std::vector<lsp::Diagnostic> &diagnostics)
331 : context(registry) {
332 context.allowUnregisteredDialects();
333 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
334 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
335 });
336
337 // Try to parsed the given IR string.
338 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
339 if (!memBuffer) {
340 lsp::Logger::error("Failed to create memory buffer for file", uri.file());
341 return;
342 }
343
344 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc());
345 if (failed(parseSourceFile(sourceMgr, &parsedIR, &context, nullptr,
346 &asmState))) {
347 // If parsing failed, clear out any of the current state.
348 parsedIR.clear();
349 asmState = AsmParserState();
350 return;
351 }
352 }
353
354 //===----------------------------------------------------------------------===//
355 // MLIRDocument: Definitions and References
356 //===----------------------------------------------------------------------===//
357
getLocationsOf(const lsp::URIForFile & uri,const lsp::Position & defPos,std::vector<lsp::Location> & locations)358 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
359 const lsp::Position &defPos,
360 std::vector<lsp::Location> &locations) {
361 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, defPos);
362
363 // Functor used to check if an SM definition contains the position.
364 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
365 if (!isDefOrUse(def, posLoc))
366 return false;
367 locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
368 return true;
369 };
370
371 // Check all definitions related to operations.
372 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
373 if (contains(op.loc, posLoc))
374 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
375 for (const auto &result : op.resultGroups)
376 if (containsPosition(result.second))
377 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
378 for (const auto &symUse : op.symbolUses) {
379 if (contains(symUse, posLoc)) {
380 locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
381 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
382 }
383 }
384 }
385
386 // Check all definitions related to blocks.
387 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
388 if (containsPosition(block.definition))
389 return;
390 for (const AsmParserState::SMDefinition &arg : block.arguments)
391 if (containsPosition(arg))
392 return;
393 }
394 }
395
findReferencesOf(const lsp::URIForFile & uri,const lsp::Position & pos,std::vector<lsp::Location> & references)396 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
397 const lsp::Position &pos,
398 std::vector<lsp::Location> &references) {
399 // Functor used to append all of the definitions/uses of the given SM
400 // definition to the reference list.
401 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
402 references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
403 for (const llvm::SMRange &use : def.uses)
404 references.push_back(getLocationFromLoc(sourceMgr, use, uri));
405 };
406
407 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, pos);
408
409 // Check all definitions related to operations.
410 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
411 if (contains(op.loc, posLoc)) {
412 for (const auto &result : op.resultGroups)
413 appendSMDef(result.second);
414 for (const auto &symUse : op.symbolUses)
415 if (contains(symUse, posLoc))
416 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
417 return;
418 }
419 for (const auto &result : op.resultGroups)
420 if (isDefOrUse(result.second, posLoc))
421 return appendSMDef(result.second);
422 for (const auto &symUse : op.symbolUses) {
423 if (!contains(symUse, posLoc))
424 continue;
425 for (const auto &symUse : op.symbolUses)
426 references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
427 return;
428 }
429 }
430
431 // Check all definitions related to blocks.
432 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
433 if (isDefOrUse(block.definition, posLoc))
434 return appendSMDef(block.definition);
435
436 for (const AsmParserState::SMDefinition &arg : block.arguments)
437 if (isDefOrUse(arg, posLoc))
438 return appendSMDef(arg);
439 }
440 }
441
442 //===----------------------------------------------------------------------===//
443 // MLIRDocument: Hover
444 //===----------------------------------------------------------------------===//
445
findHover(const lsp::URIForFile & uri,const lsp::Position & hoverPos)446 Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
447 const lsp::Position &hoverPos) {
448 llvm::SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos);
449 llvm::SMRange hoverRange;
450
451 // Check for Hovers on operations and results.
452 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
453 // Check if the position points at this operation.
454 if (contains(op.loc, posLoc))
455 return buildHoverForOperation(op.loc, op);
456
457 // Check if the position points at the symbol name.
458 for (auto &use : op.symbolUses)
459 if (contains(use, posLoc))
460 return buildHoverForOperation(use, op);
461
462 // Check if the position points at a result group.
463 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
464 const auto &result = op.resultGroups[i];
465 if (!isDefOrUse(result.second, posLoc, &hoverRange))
466 continue;
467
468 // Get the range of results covered by the over position.
469 unsigned resultStart = result.first;
470 unsigned resultEnd =
471 (i == e - 1) ? op.op->getNumResults() : op.resultGroups[i + 1].first;
472 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
473 resultEnd, posLoc);
474 }
475 }
476
477 // Check to see if the hover is over a block argument.
478 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
479 if (isDefOrUse(block.definition, posLoc, &hoverRange))
480 return buildHoverForBlock(hoverRange, block);
481
482 for (const auto &arg : llvm::enumerate(block.arguments)) {
483 if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
484 continue;
485
486 return buildHoverForBlockArgument(
487 hoverRange, block.block->getArgument(arg.index()), block);
488 }
489 }
490 return llvm::None;
491 }
492
buildHoverForOperation(llvm::SMRange hoverRange,const AsmParserState::OperationDefinition & op)493 Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
494 llvm::SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
495 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
496 llvm::raw_string_ostream os(hover.contents.value);
497
498 // Add the operation name to the hover.
499 os << "\"" << op.op->getName() << "\"";
500 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
501 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
502 os << "\n\n";
503
504 os << "Generic Form:\n\n```mlir\n";
505
506 // Temporary drop the regions of this operation so that they don't get
507 // printed in the output. This helps keeps the size of the output hover
508 // small.
509 SmallVector<std::unique_ptr<Region>> regions;
510 for (Region ®ion : op.op->getRegions()) {
511 regions.emplace_back(std::make_unique<Region>());
512 regions.back()->takeBody(region);
513 }
514
515 op.op->print(
516 os, OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
517 os << "\n```\n";
518
519 // Move the regions back to the current operation.
520 for (Region ®ion : op.op->getRegions())
521 region.takeBody(*regions.back());
522
523 return hover;
524 }
525
buildHoverForOperationResult(llvm::SMRange hoverRange,Operation * op,unsigned resultStart,unsigned resultEnd,llvm::SMLoc posLoc)526 lsp::Hover MLIRDocument::buildHoverForOperationResult(llvm::SMRange hoverRange,
527 Operation *op,
528 unsigned resultStart,
529 unsigned resultEnd,
530 llvm::SMLoc posLoc) {
531 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
532 llvm::raw_string_ostream os(hover.contents.value);
533
534 // Add the parent operation name to the hover.
535 os << "Operation: \"" << op->getName() << "\"\n\n";
536
537 // Check to see if the location points to a specific result within the
538 // group.
539 if (Optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
540 if ((resultStart + *resultNumber) < resultEnd) {
541 resultStart += *resultNumber;
542 resultEnd = resultStart + 1;
543 }
544 }
545
546 // Add the range of results and their types to the hover info.
547 if ((resultStart + 1) == resultEnd) {
548 os << "Result #" << resultStart << "\n\n"
549 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
550 } else {
551 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
552 << "Types: ";
553 llvm::interleaveComma(
554 op->getResults().slice(resultStart, resultEnd), os,
555 [&](Value result) { os << "`" << result.getType() << "`"; });
556 }
557
558 return hover;
559 }
560
561 lsp::Hover
buildHoverForBlock(llvm::SMRange hoverRange,const AsmParserState::BlockDefinition & block)562 MLIRDocument::buildHoverForBlock(llvm::SMRange hoverRange,
563 const AsmParserState::BlockDefinition &block) {
564 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
565 llvm::raw_string_ostream os(hover.contents.value);
566
567 // Print the given block to the hover output stream.
568 auto printBlockToHover = [&](Block *newBlock) {
569 if (const auto *def = asmState.getBlockDef(newBlock))
570 printDefBlockName(os, *def);
571 else
572 printDefBlockName(os, newBlock);
573 };
574
575 // Display the parent operation, block number, predecessors, and successors.
576 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
577 << "Block #" << getBlockNumber(block.block) << "\n\n";
578 if (!block.block->hasNoPredecessors()) {
579 os << "Predecessors: ";
580 llvm::interleaveComma(block.block->getPredecessors(), os,
581 printBlockToHover);
582 os << "\n\n";
583 }
584 if (!block.block->hasNoSuccessors()) {
585 os << "Successors: ";
586 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
587 os << "\n\n";
588 }
589
590 return hover;
591 }
592
buildHoverForBlockArgument(llvm::SMRange hoverRange,BlockArgument arg,const AsmParserState::BlockDefinition & block)593 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
594 llvm::SMRange hoverRange, BlockArgument arg,
595 const AsmParserState::BlockDefinition &block) {
596 lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
597 llvm::raw_string_ostream os(hover.contents.value);
598
599 // Display the parent operation, block, the argument number, and the type.
600 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
601 << "Block: ";
602 printDefBlockName(os, block);
603 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
604 << "Type: `" << arg.getType() << "`\n\n";
605
606 return hover;
607 }
608
609 //===----------------------------------------------------------------------===//
610 // MLIRDocument: Document Symbols
611 //===----------------------------------------------------------------------===//
612
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)613 void MLIRDocument::findDocumentSymbols(
614 std::vector<lsp::DocumentSymbol> &symbols) {
615 for (Operation &op : parsedIR)
616 findDocumentSymbols(&op, symbols);
617 }
618
findDocumentSymbols(Operation * op,std::vector<lsp::DocumentSymbol> & symbols)619 void MLIRDocument::findDocumentSymbols(
620 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
621 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
622
623 // Check for the source information of this operation.
624 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
625 // If this operation defines a symbol, record it.
626 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
627 symbols.emplace_back(symbol.getName(),
628 op->hasTrait<OpTrait::FunctionLike>()
629 ? lsp::SymbolKind::Function
630 : lsp::SymbolKind::Class,
631 getRangeFromLoc(sourceMgr, def->scopeLoc),
632 getRangeFromLoc(sourceMgr, def->loc));
633 childSymbols = &symbols.back().children;
634
635 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
636 // Otherwise, if this is a symbol table push an anonymous document symbol.
637 symbols.emplace_back("<" + op->getName().getStringRef() + ">",
638 lsp::SymbolKind::Namespace,
639 getRangeFromLoc(sourceMgr, def->scopeLoc),
640 getRangeFromLoc(sourceMgr, def->loc));
641 childSymbols = &symbols.back().children;
642 }
643 }
644
645 // Recurse into the regions of this operation.
646 if (!op->getNumRegions())
647 return;
648 for (Region ®ion : op->getRegions())
649 for (Operation &childOp : region.getOps())
650 findDocumentSymbols(&childOp, *childSymbols);
651 }
652
653 //===----------------------------------------------------------------------===//
654 // MLIRTextFileChunk
655 //===----------------------------------------------------------------------===//
656
657 namespace {
658 /// This class represents a single chunk of an MLIR text file.
659 struct MLIRTextFileChunk {
MLIRTextFileChunk__anon8f678da50b11::MLIRTextFileChunk660 MLIRTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
661 StringRef contents, DialectRegistry ®istry,
662 std::vector<lsp::Diagnostic> &diagnostics)
663 : lineOffset(lineOffset), document(uri, contents, registry, diagnostics) {
664 }
665
666 /// Adjust the line number of the given range to anchor at the beginning of
667 /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon8f678da50b11::MLIRTextFileChunk668 void adjustLocForChunkOffset(lsp::Range &range) {
669 adjustLocForChunkOffset(range.start);
670 adjustLocForChunkOffset(range.end);
671 }
672 /// Adjust the line number of the given position to anchor at the beginning of
673 /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon8f678da50b11::MLIRTextFileChunk674 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
675
676 /// The line offset of this chunk from the beginning of the file.
677 uint64_t lineOffset;
678 /// The document referred to by this chunk.
679 MLIRDocument document;
680 };
681 } // namespace
682
683 //===----------------------------------------------------------------------===//
684 // MLIRTextFile
685 //===----------------------------------------------------------------------===//
686
687 namespace {
688 /// This class represents a text file containing one or more MLIR documents.
689 class MLIRTextFile {
690 public:
691 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
692 int64_t version, DialectRegistry ®istry,
693 std::vector<lsp::Diagnostic> &diagnostics);
694
695 /// Return the current version of this text file.
getVersion() const696 int64_t getVersion() const { return version; }
697
698 //===--------------------------------------------------------------------===//
699 // LSP Queries
700 //===--------------------------------------------------------------------===//
701
702 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
703 std::vector<lsp::Location> &locations);
704 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
705 std::vector<lsp::Location> &references);
706 Optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
707 lsp::Position hoverPos);
708 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
709
710 private:
711 /// Find the MLIR document that contains the given position, and update the
712 /// position to be anchored at the start of the found chunk instead of the
713 /// beginning of the file.
714 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
715
716 /// The full string contents of the file.
717 std::string contents;
718
719 /// The version of this file.
720 int64_t version;
721
722 /// The number of lines in the file.
723 int64_t totalNumLines;
724
725 /// The chunks of this file. The order of these chunks is the order in which
726 /// they appear in the text file.
727 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
728 };
729 } // namespace
730
MLIRTextFile(const lsp::URIForFile & uri,StringRef fileContents,int64_t version,DialectRegistry & registry,std::vector<lsp::Diagnostic> & diagnostics)731 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
732 int64_t version, DialectRegistry ®istry,
733 std::vector<lsp::Diagnostic> &diagnostics)
734 : contents(fileContents.str()), version(version), totalNumLines(0) {
735 // Split the file into separate MLIR documents.
736 // TODO: Find a way to share the split file marker with other tools. We don't
737 // want to use `splitAndProcessBuffer` here, but we do want to make sure this
738 // marker doesn't go out of sync.
739 SmallVector<StringRef, 8> subContents;
740 StringRef(contents).split(subContents, "// -----");
741 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
742 /*lineOffset=*/0, uri, subContents.front(), registry, diagnostics));
743
744 uint64_t lineOffset = subContents.front().count('\n');
745 for (StringRef docContents : llvm::drop_begin(subContents)) {
746 unsigned currentNumDiags = diagnostics.size();
747 auto chunk = std::make_unique<MLIRTextFileChunk>(
748 lineOffset, uri, docContents, registry, diagnostics);
749 lineOffset += docContents.count('\n');
750
751 // Adjust locations used in diagnostics to account for the offset from the
752 // beginning of the file.
753 for (lsp::Diagnostic &diag :
754 llvm::drop_begin(diagnostics, currentNumDiags)) {
755 chunk->adjustLocForChunkOffset(diag.range);
756
757 if (!diag.relatedInformation)
758 continue;
759 for (auto &it : *diag.relatedInformation)
760 if (it.location.uri == uri)
761 chunk->adjustLocForChunkOffset(it.location.range);
762 }
763 chunks.emplace_back(std::move(chunk));
764 }
765 totalNumLines = lineOffset;
766 }
767
getLocationsOf(const lsp::URIForFile & uri,lsp::Position defPos,std::vector<lsp::Location> & locations)768 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
769 lsp::Position defPos,
770 std::vector<lsp::Location> &locations) {
771 MLIRTextFileChunk &chunk = getChunkFor(defPos);
772 chunk.document.getLocationsOf(uri, defPos, locations);
773
774 // Adjust any locations within this file for the offset of this chunk.
775 if (chunk.lineOffset == 0)
776 return;
777 for (lsp::Location &loc : locations)
778 if (loc.uri == uri)
779 chunk.adjustLocForChunkOffset(loc.range);
780 }
781
findReferencesOf(const lsp::URIForFile & uri,lsp::Position pos,std::vector<lsp::Location> & references)782 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
783 lsp::Position pos,
784 std::vector<lsp::Location> &references) {
785 MLIRTextFileChunk &chunk = getChunkFor(pos);
786 chunk.document.findReferencesOf(uri, pos, references);
787
788 // Adjust any locations within this file for the offset of this chunk.
789 if (chunk.lineOffset == 0)
790 return;
791 for (lsp::Location &loc : references)
792 if (loc.uri == uri)
793 chunk.adjustLocForChunkOffset(loc.range);
794 }
795
findHover(const lsp::URIForFile & uri,lsp::Position hoverPos)796 Optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
797 lsp::Position hoverPos) {
798 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
799 Optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
800
801 // Adjust any locations within this file for the offset of this chunk.
802 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
803 chunk.adjustLocForChunkOffset(*hoverInfo->range);
804 return hoverInfo;
805 }
806
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)807 void MLIRTextFile::findDocumentSymbols(
808 std::vector<lsp::DocumentSymbol> &symbols) {
809 if (chunks.size() == 1)
810 return chunks.front()->document.findDocumentSymbols(symbols);
811
812 // If there are multiple chunks in this file, we create top-level symbols for
813 // each chunk.
814 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
815 MLIRTextFileChunk &chunk = *chunks[i];
816 lsp::Position startPos(chunk.lineOffset);
817 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
818 : chunks[i + 1]->lineOffset);
819 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
820 lsp::SymbolKind::Namespace,
821 /*range=*/lsp::Range(startPos, endPos),
822 /*selectionRange=*/lsp::Range(startPos));
823 chunk.document.findDocumentSymbols(symbol.children);
824
825 // Fixup the locations of document symbols within this chunk.
826 if (i != 0) {
827 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
828 for (lsp::DocumentSymbol &childSymbol : symbol.children)
829 symbolsToFix.push_back(&childSymbol);
830
831 while (!symbolsToFix.empty()) {
832 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
833 chunk.adjustLocForChunkOffset(symbol->range);
834 chunk.adjustLocForChunkOffset(symbol->selectionRange);
835
836 for (lsp::DocumentSymbol &childSymbol : symbol->children)
837 symbolsToFix.push_back(&childSymbol);
838 }
839 }
840
841 // Push the symbol for this chunk.
842 symbols.emplace_back(std::move(symbol));
843 }
844 }
845
getChunkFor(lsp::Position & pos)846 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
847 if (chunks.size() == 1)
848 return *chunks.front();
849
850 // Search for the first chunk with a greater line offset, the previous chunk
851 // is the one that contains `pos`.
852 auto it = llvm::upper_bound(
853 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
854 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
855 });
856 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
857 pos.line -= chunk.lineOffset;
858 return chunk;
859 }
860
861 //===----------------------------------------------------------------------===//
862 // MLIRServer::Impl
863 //===----------------------------------------------------------------------===//
864
865 struct lsp::MLIRServer::Impl {
Impllsp::MLIRServer::Impl866 Impl(DialectRegistry ®istry) : registry(registry) {}
867
868 /// The registry containing dialects that can be recognized in parsed .mlir
869 /// files.
870 DialectRegistry ®istry;
871
872 /// The files held by the server, mapped by their URI file name.
873 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
874 };
875
876 //===----------------------------------------------------------------------===//
877 // MLIRServer
878 //===----------------------------------------------------------------------===//
879
MLIRServer(DialectRegistry & registry)880 lsp::MLIRServer::MLIRServer(DialectRegistry ®istry)
881 : impl(std::make_unique<Impl>(registry)) {}
~MLIRServer()882 lsp::MLIRServer::~MLIRServer() {}
883
addOrUpdateDocument(const URIForFile & uri,StringRef contents,int64_t version,std::vector<Diagnostic> & diagnostics)884 void lsp::MLIRServer::addOrUpdateDocument(
885 const URIForFile &uri, StringRef contents, int64_t version,
886 std::vector<Diagnostic> &diagnostics) {
887 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
888 uri, contents, version, impl->registry, diagnostics);
889 }
890
removeDocument(const URIForFile & uri)891 Optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
892 auto it = impl->files.find(uri.file());
893 if (it == impl->files.end())
894 return llvm::None;
895
896 int64_t version = it->second->getVersion();
897 impl->files.erase(it);
898 return version;
899 }
900
getLocationsOf(const URIForFile & uri,const Position & defPos,std::vector<Location> & locations)901 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
902 const Position &defPos,
903 std::vector<Location> &locations) {
904 auto fileIt = impl->files.find(uri.file());
905 if (fileIt != impl->files.end())
906 fileIt->second->getLocationsOf(uri, defPos, locations);
907 }
908
findReferencesOf(const URIForFile & uri,const Position & pos,std::vector<Location> & references)909 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
910 const Position &pos,
911 std::vector<Location> &references) {
912 auto fileIt = impl->files.find(uri.file());
913 if (fileIt != impl->files.end())
914 fileIt->second->findReferencesOf(uri, pos, references);
915 }
916
findHover(const URIForFile & uri,const Position & hoverPos)917 Optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
918 const Position &hoverPos) {
919 auto fileIt = impl->files.find(uri.file());
920 if (fileIt != impl->files.end())
921 return fileIt->second->findHover(uri, hoverPos);
922 return llvm::None;
923 }
924
findDocumentSymbols(const URIForFile & uri,std::vector<DocumentSymbol> & symbols)925 void lsp::MLIRServer::findDocumentSymbols(
926 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
927 auto fileIt = impl->files.find(uri.file());
928 if (fileIt != impl->files.end())
929 fileIt->second->findDocumentSymbols(symbols);
930 }
931