1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Identifier.h"
12 #include "mlir/IR/Location.h"
13 #include "mlir/IR/MLIRContext.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/Types.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/StringMap.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // DiagnosticArgument
31 //===----------------------------------------------------------------------===//
32 
33 /// Construct from an Attribute.
DiagnosticArgument(Attribute attr)34 DiagnosticArgument::DiagnosticArgument(Attribute attr)
35     : kind(DiagnosticArgumentKind::Attribute),
36       opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
37 
38 /// Construct from a Type.
DiagnosticArgument(Type val)39 DiagnosticArgument::DiagnosticArgument(Type val)
40     : kind(DiagnosticArgumentKind::Type),
41       opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
42 
43 /// Returns this argument as an Attribute.
getAsAttribute() const44 Attribute DiagnosticArgument::getAsAttribute() const {
45   assert(getKind() == DiagnosticArgumentKind::Attribute);
46   return Attribute::getFromOpaquePointer(
47       reinterpret_cast<const void *>(opaqueVal));
48 }
49 
50 /// Returns this argument as a Type.
getAsType() const51 Type DiagnosticArgument::getAsType() const {
52   assert(getKind() == DiagnosticArgumentKind::Type);
53   return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
54 }
55 
56 /// Outputs this argument to a stream.
print(raw_ostream & os) const57 void DiagnosticArgument::print(raw_ostream &os) const {
58   switch (kind) {
59   case DiagnosticArgumentKind::Attribute:
60     os << getAsAttribute();
61     break;
62   case DiagnosticArgumentKind::Double:
63     os << getAsDouble();
64     break;
65   case DiagnosticArgumentKind::Integer:
66     os << getAsInteger();
67     break;
68   case DiagnosticArgumentKind::String:
69     os << getAsString();
70     break;
71   case DiagnosticArgumentKind::Type:
72     os << '\'' << getAsType() << '\'';
73     break;
74   case DiagnosticArgumentKind::Unsigned:
75     os << getAsUnsigned();
76     break;
77   }
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Diagnostic
82 //===----------------------------------------------------------------------===//
83 
84 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
85 /// stored in 'strings'.
twineToStrRef(const Twine & val,std::vector<std::unique_ptr<char[]>> & strings)86 static StringRef twineToStrRef(const Twine &val,
87                                std::vector<std::unique_ptr<char[]>> &strings) {
88   // Allocate memory to hold this string.
89   SmallString<64> data;
90   auto strRef = val.toStringRef(data);
91   strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
92   memcpy(&strings.back()[0], strRef.data(), strRef.size());
93 
94   // Return a reference to the new string.
95   return StringRef(&strings.back()[0], strRef.size());
96 }
97 
98 /// Stream in a Twine argument.
operator <<(char val)99 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
operator <<(const Twine & val)100 Diagnostic &Diagnostic::operator<<(const Twine &val) {
101   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
102   return *this;
103 }
operator <<(Twine && val)104 Diagnostic &Diagnostic::operator<<(Twine &&val) {
105   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
106   return *this;
107 }
108 
109 /// Stream in an Identifier.
operator <<(Identifier val)110 Diagnostic &Diagnostic::operator<<(Identifier val) {
111   // An identifier is stored in the context, so we don't need to worry about the
112   // lifetime of its data.
113   arguments.push_back(DiagnosticArgument(val.strref()));
114   return *this;
115 }
116 
117 /// Stream in an OperationName.
operator <<(OperationName val)118 Diagnostic &Diagnostic::operator<<(OperationName val) {
119   // An OperationName is stored in the context, so we don't need to worry about
120   // the lifetime of its data.
121   arguments.push_back(DiagnosticArgument(val.getStringRef()));
122   return *this;
123 }
124 
125 /// Stream in an Operation.
operator <<(Operation & val)126 Diagnostic &Diagnostic::operator<<(Operation &val) {
127   std::string str;
128   llvm::raw_string_ostream os(str);
129   os << val;
130   return *this << os.str();
131 }
132 
133 /// Outputs this diagnostic to a stream.
print(raw_ostream & os) const134 void Diagnostic::print(raw_ostream &os) const {
135   for (auto &arg : getArguments())
136     arg.print(os);
137 }
138 
139 /// Convert the diagnostic to a string.
str() const140 std::string Diagnostic::str() const {
141   std::string str;
142   llvm::raw_string_ostream os(str);
143   print(os);
144   return os.str();
145 }
146 
147 /// Attaches a note to this diagnostic. A new location may be optionally
148 /// provided, if not, then the location defaults to the one specified for this
149 /// diagnostic. Notes may not be attached to other notes.
attachNote(Optional<Location> noteLoc)150 Diagnostic &Diagnostic::attachNote(Optional<Location> noteLoc) {
151   // We don't allow attaching notes to notes.
152   assert(severity != DiagnosticSeverity::Note &&
153          "cannot attach a note to a note");
154 
155   // If a location wasn't provided then reuse our location.
156   if (!noteLoc)
157     noteLoc = loc;
158 
159   /// Append and return a new note.
160   notes.push_back(
161       std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
162   return *notes.back();
163 }
164 
165 /// Allow a diagnostic to be converted to 'failure'.
operator LogicalResult() const166 Diagnostic::operator LogicalResult() const { return failure(); }
167 
168 //===----------------------------------------------------------------------===//
169 // InFlightDiagnostic
170 //===----------------------------------------------------------------------===//
171 
172 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
173 /// 'success' if this is an empty diagnostic.
operator LogicalResult() const174 InFlightDiagnostic::operator LogicalResult() const {
175   return failure(isActive());
176 }
177 
178 /// Reports the diagnostic to the engine.
report()179 void InFlightDiagnostic::report() {
180   // If this diagnostic is still inflight and it hasn't been abandoned, then
181   // report it.
182   if (isInFlight()) {
183     owner->emit(std::move(*impl));
184     owner = nullptr;
185   }
186   impl.reset();
187 }
188 
189 /// Abandons this diagnostic.
abandon()190 void InFlightDiagnostic::abandon() { owner = nullptr; }
191 
192 //===----------------------------------------------------------------------===//
193 // DiagnosticEngineImpl
194 //===----------------------------------------------------------------------===//
195 
196 namespace mlir {
197 namespace detail {
198 struct DiagnosticEngineImpl {
199   /// Emit a diagnostic using the registered issue handle if present, or with
200   /// the default behavior if not.
201   void emit(Diagnostic diag);
202 
203   /// A mutex to ensure that diagnostics emission is thread-safe.
204   llvm::sys::SmartMutex<true> mutex;
205 
206   /// These are the handlers used to report diagnostics.
207   llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
208                        2>
209       handlers;
210 
211   /// This is a unique identifier counter for diagnostic handlers in the
212   /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
213   DiagnosticEngine::HandlerID uniqueHandlerId = 1;
214 };
215 } // namespace detail
216 } // namespace mlir
217 
218 /// Emit a diagnostic using the registered issue handle if present, or with
219 /// the default behavior if not.
emit(Diagnostic diag)220 void DiagnosticEngineImpl::emit(Diagnostic diag) {
221   llvm::sys::SmartScopedLock<true> lock(mutex);
222 
223   // Try to process the given diagnostic on one of the registered handlers.
224   // Handlers are walked in reverse order, so that the most recent handler is
225   // processed first.
226   for (auto &handlerIt : llvm::reverse(handlers))
227     if (succeeded(handlerIt.second(diag)))
228       return;
229 
230   // Otherwise, if this is an error we emit it to stderr.
231   if (diag.getSeverity() != DiagnosticSeverity::Error)
232     return;
233 
234   auto &os = llvm::errs();
235   if (!diag.getLocation().isa<UnknownLoc>())
236     os << diag.getLocation() << ": ";
237   os << "error: ";
238 
239   // The default behavior for errors is to emit them to stderr.
240   os << diag << '\n';
241   os.flush();
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // DiagnosticEngine
246 //===----------------------------------------------------------------------===//
247 
DiagnosticEngine()248 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
~DiagnosticEngine()249 DiagnosticEngine::~DiagnosticEngine() {}
250 
251 /// Register a new handler for diagnostics to the engine. This function returns
252 /// a unique identifier for the registered handler, which can be used to
253 /// unregister this handler at a later time.
registerHandler(const HandlerTy & handler)254 auto DiagnosticEngine::registerHandler(const HandlerTy &handler) -> HandlerID {
255   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
256   auto uniqueID = impl->uniqueHandlerId++;
257   impl->handlers.insert({uniqueID, handler});
258   return uniqueID;
259 }
260 
261 /// Erase the registered diagnostic handler with the given identifier.
eraseHandler(HandlerID handlerID)262 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
263   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
264   impl->handlers.erase(handlerID);
265 }
266 
267 /// Emit a diagnostic using the registered issue handler if present, or with
268 /// the default behavior if not.
emit(Diagnostic diag)269 void DiagnosticEngine::emit(Diagnostic diag) {
270   assert(diag.getSeverity() != DiagnosticSeverity::Note &&
271          "notes should not be emitted directly");
272   impl->emit(std::move(diag));
273 }
274 
275 /// Helper function used to emit a diagnostic with an optionally empty twine
276 /// message. If the message is empty, then it is not inserted into the
277 /// diagnostic.
278 static InFlightDiagnostic
emitDiag(Location location,DiagnosticSeverity severity,const Twine & message)279 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
280   MLIRContext *ctx = location->getContext();
281   auto &diagEngine = ctx->getDiagEngine();
282   auto diag = diagEngine.emit(location, severity);
283   if (!message.isTriviallyEmpty())
284     diag << message;
285 
286   // Add the stack trace as a note if necessary.
287   if (ctx->shouldPrintStackTraceOnDiagnostic()) {
288     std::string bt;
289     {
290       llvm::raw_string_ostream stream(bt);
291       llvm::sys::PrintStackTrace(stream);
292     }
293     if (!bt.empty())
294       diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
295   }
296 
297   return diag;
298 }
299 
300 /// Emit an error message using this location.
emitError(Location loc)301 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
emitError(Location loc,const Twine & message)302 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
303   return emitDiag(loc, DiagnosticSeverity::Error, message);
304 }
305 
306 /// Emit a warning message using this location.
emitWarning(Location loc)307 InFlightDiagnostic mlir::emitWarning(Location loc) {
308   return emitWarning(loc, {});
309 }
emitWarning(Location loc,const Twine & message)310 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
311   return emitDiag(loc, DiagnosticSeverity::Warning, message);
312 }
313 
314 /// Emit a remark message using this location.
emitRemark(Location loc)315 InFlightDiagnostic mlir::emitRemark(Location loc) {
316   return emitRemark(loc, {});
317 }
emitRemark(Location loc,const Twine & message)318 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
319   return emitDiag(loc, DiagnosticSeverity::Remark, message);
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // ScopedDiagnosticHandler
324 //===----------------------------------------------------------------------===//
325 
~ScopedDiagnosticHandler()326 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
327   if (handlerID)
328     ctx->getDiagEngine().eraseHandler(handlerID);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // SourceMgrDiagnosticHandler
333 //===----------------------------------------------------------------------===//
334 namespace mlir {
335 namespace detail {
336 struct SourceMgrDiagnosticHandlerImpl {
337   /// Return the SrcManager buffer id for the specified file, or zero if none
338   /// can be found.
getSourceMgrBufferIDForFilemlir::detail::SourceMgrDiagnosticHandlerImpl339   unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
340                                        StringRef filename) {
341     // Check for an existing mapping to the buffer id for this file.
342     auto bufferIt = filenameToBufId.find(filename);
343     if (bufferIt != filenameToBufId.end())
344       return bufferIt->second;
345 
346     // Look for a buffer in the manager that has this filename.
347     for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
348       auto *buf = mgr.getMemoryBuffer(i);
349       if (buf->getBufferIdentifier() == filename)
350         return filenameToBufId[filename] = i;
351     }
352 
353     // Otherwise, try to load the source file.
354     std::string ignored;
355     unsigned id =
356         mgr.AddIncludeFile(std::string(filename), llvm::SMLoc(), ignored);
357     filenameToBufId[filename] = id;
358     return id;
359   }
360 
361   /// Mapping between file name and buffer ID's.
362   llvm::StringMap<unsigned> filenameToBufId;
363 };
364 } // end namespace detail
365 } // end namespace mlir
366 
367 /// Return a processable FileLineColLoc from the given location.
getFileLineColLoc(Location loc)368 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
369   if (auto nameLoc = loc.dyn_cast<NameLoc>())
370     return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
371   if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
372     return fileLoc;
373   if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
374     return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
375   if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
376     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
377       if (auto callLoc = getFileLineColLoc(subLoc)) {
378         return callLoc;
379       }
380     }
381     return llvm::None;
382   }
383   return llvm::None;
384 }
385 
386 /// Return a processable CallSiteLoc from the given location.
getCallSiteLoc(Location loc)387 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
388   if (auto nameLoc = loc.dyn_cast<NameLoc>())
389     return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
390   if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
391     return callLoc;
392   if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
393     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
394       if (auto callLoc = getCallSiteLoc(subLoc)) {
395         return callLoc;
396       }
397     }
398     return llvm::None;
399   }
400   return llvm::None;
401 }
402 
403 /// Given a diagnostic kind, returns the LLVM DiagKind.
getDiagKind(DiagnosticSeverity kind)404 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
405   switch (kind) {
406   case DiagnosticSeverity::Note:
407     return llvm::SourceMgr::DK_Note;
408   case DiagnosticSeverity::Warning:
409     return llvm::SourceMgr::DK_Warning;
410   case DiagnosticSeverity::Error:
411     return llvm::SourceMgr::DK_Error;
412   case DiagnosticSeverity::Remark:
413     return llvm::SourceMgr::DK_Remark;
414   }
415   llvm_unreachable("Unknown DiagnosticSeverity");
416 }
417 
SourceMgrDiagnosticHandler(llvm::SourceMgr & mgr,MLIRContext * ctx,raw_ostream & os)418 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
419                                                        MLIRContext *ctx,
420                                                        raw_ostream &os)
421     : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
422       impl(new SourceMgrDiagnosticHandlerImpl()) {
423   setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
424 }
425 
SourceMgrDiagnosticHandler(llvm::SourceMgr & mgr,MLIRContext * ctx)426 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
427                                                        MLIRContext *ctx)
428     : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs()) {}
429 
~SourceMgrDiagnosticHandler()430 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() {}
431 
emitDiagnostic(Location loc,Twine message,DiagnosticSeverity kind,bool displaySourceLine)432 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
433                                                 DiagnosticSeverity kind,
434                                                 bool displaySourceLine) {
435   // Extract a file location from this loc.
436   auto fileLoc = getFileLineColLoc(loc);
437 
438   // If one doesn't exist, then print the raw message without a source location.
439   if (!fileLoc) {
440     std::string str;
441     llvm::raw_string_ostream strOS(str);
442     if (!loc.isa<UnknownLoc>())
443       strOS << loc << ": ";
444     strOS << message;
445     return mgr.PrintMessage(os, llvm::SMLoc(), getDiagKind(kind), strOS.str());
446   }
447 
448   // Otherwise if we are displaying the source line, try to convert the file
449   // location to an SMLoc.
450   if (displaySourceLine) {
451     auto smloc = convertLocToSMLoc(*fileLoc);
452     if (smloc.isValid())
453       return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
454   }
455 
456   // If the conversion was unsuccessful, create a diagnostic with the file
457   // information. We manually combine the line and column to avoid asserts in
458   // the constructor of SMDiagnostic that takes a location.
459   std::string locStr;
460   llvm::raw_string_ostream locOS(locStr);
461   locOS << fileLoc->getFilename() << ":" << fileLoc->getLine() << ":"
462         << fileLoc->getColumn();
463   llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
464   diag.print(nullptr, os);
465 }
466 
467 /// Emit the given diagnostic with the held source manager.
emitDiagnostic(Diagnostic & diag)468 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
469   // Emit the diagnostic.
470   Location loc = diag.getLocation();
471   emitDiagnostic(loc, diag.str(), diag.getSeverity());
472 
473   // If the diagnostic location was a call site location, then print the call
474   // stack as well.
475   if (auto callLoc = getCallSiteLoc(loc)) {
476     // Print the call stack while valid, or until the limit is reached.
477     loc = callLoc->getCaller();
478     for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
479       emitDiagnostic(loc, "called from", DiagnosticSeverity::Note);
480       if ((callLoc = getCallSiteLoc(loc)))
481         loc = callLoc->getCaller();
482       else
483         break;
484     }
485   }
486 
487   // Emit each of the notes. Only display the source code if the location is
488   // different from the previous location.
489   for (auto &note : diag.getNotes()) {
490     emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
491                    /*displaySourceLine=*/loc != note.getLocation());
492     loc = note.getLocation();
493   }
494 }
495 
496 /// Get a memory buffer for the given file, or nullptr if one is not found.
497 const llvm::MemoryBuffer *
getBufferForFile(StringRef filename)498 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
499   if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
500     return mgr.getMemoryBuffer(id);
501   return nullptr;
502 }
503 
504 /// Get a memory buffer for the given file, or the main file of the source
505 /// manager if one doesn't exist. This always returns non-null.
convertLocToSMLoc(FileLineColLoc loc)506 llvm::SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
507   // The column and line may be zero to represent unknown column and/or unknown
508   /// line/column information.
509   if (loc.getLine() == 0 || loc.getColumn() == 0)
510     return llvm::SMLoc();
511 
512   unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
513   if (!bufferId)
514     return llvm::SMLoc();
515   return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // SourceMgrDiagnosticVerifierHandler
520 //===----------------------------------------------------------------------===//
521 
522 namespace mlir {
523 namespace detail {
524 // Record the expected diagnostic's position, substring and whether it was
525 // seen.
526 struct ExpectedDiag {
527   DiagnosticSeverity kind;
528   unsigned lineNo;
529   StringRef substring;
530   llvm::SMLoc fileLoc;
531   bool matched;
532 };
533 
534 struct SourceMgrDiagnosticVerifierHandlerImpl {
SourceMgrDiagnosticVerifierHandlerImplmlir::detail::SourceMgrDiagnosticVerifierHandlerImpl535   SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
536 
537   /// Returns the expected diagnostics for the given source file.
538   Optional<MutableArrayRef<ExpectedDiag>> getExpectedDiags(StringRef bufName);
539 
540   /// Computes the expected diagnostics for the given source buffer.
541   MutableArrayRef<ExpectedDiag>
542   computeExpectedDiags(const llvm::MemoryBuffer *buf);
543 
544   /// The current status of the verifier.
545   LogicalResult status;
546 
547   /// A list of expected diagnostics for each buffer of the source manager.
548   llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
549 
550   /// Regex to match the expected diagnostics format.
551   llvm::Regex expected = llvm::Regex("expected-(error|note|remark|warning) "
552                                      "*(@([+-][0-9]+|above|below))? *{{(.*)}}");
553 };
554 } // end namespace detail
555 } // end namespace mlir
556 
557 /// Given a diagnostic kind, return a human readable string for it.
getDiagKindStr(DiagnosticSeverity kind)558 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
559   switch (kind) {
560   case DiagnosticSeverity::Note:
561     return "note";
562   case DiagnosticSeverity::Warning:
563     return "warning";
564   case DiagnosticSeverity::Error:
565     return "error";
566   case DiagnosticSeverity::Remark:
567     return "remark";
568   }
569   llvm_unreachable("Unknown DiagnosticSeverity");
570 }
571 
572 /// Returns the expected diagnostics for the given source file.
573 Optional<MutableArrayRef<ExpectedDiag>>
getExpectedDiags(StringRef bufName)574 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
575   auto expectedDiags = expectedDiagsPerFile.find(bufName);
576   if (expectedDiags != expectedDiagsPerFile.end())
577     return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
578   return llvm::None;
579 }
580 
581 /// Computes the expected diagnostics for the given source buffer.
582 MutableArrayRef<ExpectedDiag>
computeExpectedDiags(const llvm::MemoryBuffer * buf)583 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
584     const llvm::MemoryBuffer *buf) {
585   // If the buffer is invalid, return an empty list.
586   if (!buf)
587     return llvm::None;
588   auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
589 
590   // The number of the last line that did not correlate to a designator.
591   unsigned lastNonDesignatorLine = 0;
592 
593   // The indices of designators that apply to the next non designator line.
594   SmallVector<unsigned, 1> designatorsForNextLine;
595 
596   // Scan the file for expected-* designators.
597   SmallVector<StringRef, 100> lines;
598   buf->getBuffer().split(lines, '\n');
599   for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
600     SmallVector<StringRef, 4> matches;
601     if (!expected.match(lines[lineNo], &matches)) {
602       // Check for designators that apply to this line.
603       if (!designatorsForNextLine.empty()) {
604         for (unsigned diagIndex : designatorsForNextLine)
605           expectedDiags[diagIndex].lineNo = lineNo + 1;
606         designatorsForNextLine.clear();
607       }
608       lastNonDesignatorLine = lineNo;
609       continue;
610     }
611 
612     // Point to the start of expected-*.
613     auto expectedStart = llvm::SMLoc::getFromPointer(matches[0].data());
614 
615     DiagnosticSeverity kind;
616     if (matches[1] == "error")
617       kind = DiagnosticSeverity::Error;
618     else if (matches[1] == "warning")
619       kind = DiagnosticSeverity::Warning;
620     else if (matches[1] == "remark")
621       kind = DiagnosticSeverity::Remark;
622     else {
623       assert(matches[1] == "note");
624       kind = DiagnosticSeverity::Note;
625     }
626 
627     ExpectedDiag record{kind, lineNo + 1, matches[4], expectedStart, false};
628     auto offsetMatch = matches[2];
629     if (!offsetMatch.empty()) {
630       offsetMatch = offsetMatch.drop_front(1);
631 
632       // Get the integer value without the @ and +/- prefix.
633       if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
634         int offset;
635         offsetMatch.drop_front().getAsInteger(0, offset);
636 
637         if (offsetMatch.front() == '+')
638           record.lineNo += offset;
639         else
640           record.lineNo -= offset;
641       } else if (offsetMatch.consume_front("above")) {
642         // If the designator applies 'above' we add it to the last non
643         // designator line.
644         record.lineNo = lastNonDesignatorLine + 1;
645       } else {
646         // Otherwise, this is a 'below' designator and applies to the next
647         // non-designator line.
648         assert(offsetMatch.consume_front("below"));
649         designatorsForNextLine.push_back(expectedDiags.size());
650 
651         // Set the line number to the last in the case that this designator ends
652         // up dangling.
653         record.lineNo = e;
654       }
655     }
656     expectedDiags.push_back(record);
657   }
658   return expectedDiags;
659 }
660 
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr & srcMgr,MLIRContext * ctx,raw_ostream & out)661 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
662     llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
663     : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
664       impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
665   // Compute the expected diagnostics for each of the current files in the
666   // source manager.
667   for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
668     (void)impl->computeExpectedDiags(mgr.getMemoryBuffer(i + 1));
669 
670   // Register a handler to verify the diagnostics.
671   setHandler([&](Diagnostic &diag) {
672     // Process the main diagnostics.
673     process(diag);
674 
675     // Process each of the notes.
676     for (auto &note : diag.getNotes())
677       process(note);
678   });
679 }
680 
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr & srcMgr,MLIRContext * ctx)681 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
682     llvm::SourceMgr &srcMgr, MLIRContext *ctx)
683     : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
684 
~SourceMgrDiagnosticVerifierHandler()685 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
686   // Ensure that all expected diagnostics were handled.
687   (void)verify();
688 }
689 
690 /// Returns the status of the verifier and verifies that all expected
691 /// diagnostics were emitted. This return success if all diagnostics were
692 /// verified correctly, failure otherwise.
verify()693 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
694   // Verify that all expected errors were seen.
695   for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
696     for (auto &err : expectedDiagsPair.second) {
697       if (err.matched)
698         continue;
699       llvm::SMRange range(err.fileLoc,
700                           llvm::SMLoc::getFromPointer(err.fileLoc.getPointer() +
701                                                       err.substring.size()));
702       mgr.PrintMessage(os, err.fileLoc, llvm::SourceMgr::DK_Error,
703                        "expected " + getDiagKindStr(err.kind) + " \"" +
704                            err.substring + "\" was not produced",
705                        range);
706       impl->status = failure();
707     }
708   }
709   impl->expectedDiagsPerFile.clear();
710   return impl->status;
711 }
712 
713 /// Process a single diagnostic.
process(Diagnostic & diag)714 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
715   auto kind = diag.getSeverity();
716 
717   // Process a FileLineColLoc.
718   if (auto fileLoc = getFileLineColLoc(diag.getLocation()))
719     return process(*fileLoc, diag.str(), kind);
720 
721   emitDiagnostic(diag.getLocation(),
722                  "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
723                  DiagnosticSeverity::Error);
724   impl->status = failure();
725 }
726 
727 /// Process a FileLineColLoc diagnostic.
process(FileLineColLoc loc,StringRef msg,DiagnosticSeverity kind)728 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
729                                                  StringRef msg,
730                                                  DiagnosticSeverity kind) {
731   // Get the expected diagnostics for this file.
732   auto diags = impl->getExpectedDiags(loc.getFilename());
733   if (!diags)
734     diags = impl->computeExpectedDiags(getBufferForFile(loc.getFilename()));
735 
736   // Search for a matching expected diagnostic.
737   // If we find something that is close then emit a more specific error.
738   ExpectedDiag *nearMiss = nullptr;
739 
740   // If this was an expected error, remember that we saw it and return.
741   unsigned line = loc.getLine();
742   for (auto &e : *diags) {
743     if (line == e.lineNo && msg.contains(e.substring)) {
744       if (e.kind == kind) {
745         e.matched = true;
746         return;
747       }
748 
749       // If this only differs based on the diagnostic kind, then consider it
750       // to be a near miss.
751       nearMiss = &e;
752     }
753   }
754 
755   // Otherwise, emit an error for the near miss.
756   if (nearMiss)
757     mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
758                      "'" + getDiagKindStr(kind) +
759                          "' diagnostic emitted when expecting a '" +
760                          getDiagKindStr(nearMiss->kind) + "'");
761   else
762     emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
763                    DiagnosticSeverity::Error);
764   impl->status = failure();
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // ParallelDiagnosticHandler
769 //===----------------------------------------------------------------------===//
770 
771 namespace mlir {
772 namespace detail {
773 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
774   struct ThreadDiagnostic {
ThreadDiagnosticmlir::detail::ParallelDiagnosticHandlerImpl::ThreadDiagnostic775     ThreadDiagnostic(size_t id, Diagnostic diag)
776         : id(id), diag(std::move(diag)) {}
operator <mlir::detail::ParallelDiagnosticHandlerImpl::ThreadDiagnostic777     bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
778 
779     /// The id for this diagnostic, this is used for ordering.
780     /// Note: This id corresponds to the ordered position of the current element
781     ///       being processed by a given thread.
782     size_t id;
783 
784     /// The diagnostic.
785     Diagnostic diag;
786   };
787 
ParallelDiagnosticHandlerImplmlir::detail::ParallelDiagnosticHandlerImpl788   ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : handlerID(0), context(ctx) {
789     handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
790       uint64_t tid = llvm::get_threadid();
791       llvm::sys::SmartScopedLock<true> lock(mutex);
792 
793       // If this thread is not tracked, then return failure to let another
794       // handler process this diagnostic.
795       if (!threadToOrderID.count(tid))
796         return failure();
797 
798       // Append a new diagnostic.
799       diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
800       return success();
801     });
802   }
803 
~ParallelDiagnosticHandlerImplmlir::detail::ParallelDiagnosticHandlerImpl804   ~ParallelDiagnosticHandlerImpl() override {
805     // Erase this handler from the context.
806     context->getDiagEngine().eraseHandler(handlerID);
807 
808     // Early exit if there are no diagnostics, this is the common case.
809     if (diagnostics.empty())
810       return;
811 
812     // Emit the diagnostics back to the context.
813     emitDiagnostics([&](Diagnostic diag) {
814       return context->getDiagEngine().emit(std::move(diag));
815     });
816   }
817 
818   /// Utility method to emit any held diagnostics.
emitDiagnosticsmlir::detail::ParallelDiagnosticHandlerImpl819   void emitDiagnostics(std::function<void(Diagnostic)> emitFn) const {
820     // Stable sort all of the diagnostics that were emitted. This creates a
821     // deterministic ordering for the diagnostics based upon which order id they
822     // were emitted for.
823     std::stable_sort(diagnostics.begin(), diagnostics.end());
824 
825     // Emit each diagnostic to the context again.
826     for (ThreadDiagnostic &diag : diagnostics)
827       emitFn(std::move(diag.diag));
828   }
829 
830   /// Set the order id for the current thread.
setOrderIDForThreadmlir::detail::ParallelDiagnosticHandlerImpl831   void setOrderIDForThread(size_t orderID) {
832     uint64_t tid = llvm::get_threadid();
833     llvm::sys::SmartScopedLock<true> lock(mutex);
834     threadToOrderID[tid] = orderID;
835   }
836 
837   /// Remove the order id for the current thread.
eraseOrderIDForThreadmlir::detail::ParallelDiagnosticHandlerImpl838   void eraseOrderIDForThread() {
839     uint64_t tid = llvm::get_threadid();
840     llvm::sys::SmartScopedLock<true> lock(mutex);
841     threadToOrderID.erase(tid);
842   }
843 
844   /// Dump the current diagnostics that were inflight.
printmlir::detail::ParallelDiagnosticHandlerImpl845   void print(raw_ostream &os) const override {
846     // Early exit if there are no diagnostics, this is the common case.
847     if (diagnostics.empty())
848       return;
849 
850     os << "In-Flight Diagnostics:\n";
851     emitDiagnostics([&](Diagnostic diag) {
852       os.indent(4);
853 
854       // Print each diagnostic with the format:
855       //   "<location>: <kind>: <msg>"
856       if (!diag.getLocation().isa<UnknownLoc>())
857         os << diag.getLocation() << ": ";
858       switch (diag.getSeverity()) {
859       case DiagnosticSeverity::Error:
860         os << "error: ";
861         break;
862       case DiagnosticSeverity::Warning:
863         os << "warning: ";
864         break;
865       case DiagnosticSeverity::Note:
866         os << "note: ";
867         break;
868       case DiagnosticSeverity::Remark:
869         os << "remark: ";
870         break;
871       }
872       os << diag << '\n';
873     });
874   }
875 
876   /// A smart mutex to lock access to the internal state.
877   llvm::sys::SmartMutex<true> mutex;
878 
879   /// A mapping between the thread id and the current order id.
880   DenseMap<uint64_t, size_t> threadToOrderID;
881 
882   /// An unordered list of diagnostics that were emitted.
883   mutable std::vector<ThreadDiagnostic> diagnostics;
884 
885   /// The unique id for the parallel handler.
886   DiagnosticEngine::HandlerID handlerID;
887 
888   /// The context to emit the diagnostics to.
889   MLIRContext *context;
890 };
891 } // end namespace detail
892 } // end namespace mlir
893 
ParallelDiagnosticHandler(MLIRContext * ctx)894 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
895     : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
~ParallelDiagnosticHandler()896 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() {}
897 
898 /// Set the order id for the current thread.
setOrderIDForThread(size_t orderID)899 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
900   impl->setOrderIDForThread(orderID);
901 }
902 
903 /// Remove the order id for the current thread. This removes the thread from
904 /// diagnostics tracking.
eraseOrderIDForThread()905 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
906   impl->eraseOrderIDForThread();
907 }
908