1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/SubElementInterfaces.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/ScopedHashTable.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/ADT/StringSet.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Endian.h"
38 #include "llvm/Support/Regex.h"
39 #include "llvm/Support/SaveAndRestore.h"
40 
41 #include <tuple>
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
print(raw_ostream & os) const46 void Identifier::print(raw_ostream &os) const { os << str(); }
47 
dump() const48 void Identifier::dump() const { print(llvm::errs()); }
49 
print(raw_ostream & os) const50 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
51 
dump() const52 void OperationName::dump() const { print(llvm::errs()); }
53 
~DialectAsmPrinter()54 DialectAsmPrinter::~DialectAsmPrinter() {}
55 
56 //===--------------------------------------------------------------------===//
57 // OpAsmPrinter
58 //===--------------------------------------------------------------------===//
59 
~OpAsmPrinter()60 OpAsmPrinter::~OpAsmPrinter() {}
61 
printFunctionalType(Operation * op)62 void OpAsmPrinter::printFunctionalType(Operation *op) {
63   auto &os = getStream();
64   os << '(';
65   llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
66     // Print the types of null values as <<NULL TYPE>>.
67     *this << (operand ? operand.getType() : Type());
68   });
69   os << ") -> ";
70 
71   // Print the result list.  We don't parenthesize single result types unless
72   // it is a function (avoiding a grammar ambiguity).
73   bool wrapped = op->getNumResults() != 1;
74   if (!wrapped && op->getResult(0).getType() &&
75       op->getResult(0).getType().isa<FunctionType>())
76     wrapped = true;
77 
78   if (wrapped)
79     os << '(';
80 
81   llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
82     // Print the types of null values as <<NULL TYPE>>.
83     *this << (result ? result.getType() : Type());
84   });
85 
86   if (wrapped)
87     os << ')';
88 }
89 
90 //===--------------------------------------------------------------------===//
91 // Operation OpAsm interface.
92 //===--------------------------------------------------------------------===//
93 
94 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
95 #include "mlir/IR/OpAsmInterface.cpp.inc"
96 
97 //===----------------------------------------------------------------------===//
98 // OpPrintingFlags
99 //===----------------------------------------------------------------------===//
100 
101 namespace {
102 /// This struct contains command line options that can be used to initialize
103 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
104 /// for global command line options.
105 struct AsmPrinterOptions {
106   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
107       "mlir-print-elementsattrs-with-hex-if-larger",
108       llvm::cl::desc(
109           "Print DenseElementsAttrs with a hex string that have "
110           "more elements than the given upper limit (use -1 to disable)")};
111 
112   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
113       "mlir-elide-elementsattrs-if-larger",
114       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
115                      "more elements than the given upper limit")};
116 
117   llvm::cl::opt<bool> printDebugInfoOpt{
118       "mlir-print-debuginfo", llvm::cl::init(false),
119       llvm::cl::desc("Print debug info in MLIR output")};
120 
121   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
122       "mlir-pretty-debuginfo", llvm::cl::init(false),
123       llvm::cl::desc("Print pretty debug info in MLIR output")};
124 
125   // Use the generic op output form in the operation printer even if the custom
126   // form is defined.
127   llvm::cl::opt<bool> printGenericOpFormOpt{
128       "mlir-print-op-generic", llvm::cl::init(false),
129       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
130 
131   llvm::cl::opt<bool> printLocalScopeOpt{
132       "mlir-print-local-scope", llvm::cl::init(false),
133       llvm::cl::desc("Print assuming in local scope by default"),
134       llvm::cl::Hidden};
135 };
136 } // end anonymous namespace
137 
138 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
139 
140 /// Register a set of useful command-line options that can be used to configure
141 /// various flags within the AsmPrinter.
registerAsmPrinterCLOptions()142 void mlir::registerAsmPrinterCLOptions() {
143   // Make sure that the options struct has been initialized.
144   *clOptions;
145 }
146 
147 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()148 OpPrintingFlags::OpPrintingFlags()
149     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
150       printGenericOpFormFlag(false), printLocalScope(false) {
151   // Initialize based upon command line options, if they are available.
152   if (!clOptions.isConstructed())
153     return;
154   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
155     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
156   printDebugInfoFlag = clOptions->printDebugInfoOpt;
157   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
158   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
159   printLocalScope = clOptions->printLocalScopeOpt;
160 }
161 
162 /// Enable the elision of large elements attributes, by printing a '...'
163 /// instead of the element data, when the number of elements is greater than
164 /// `largeElementLimit`. Note: The IR generated with this option is not
165 /// parsable.
166 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)167 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
168   elementsAttrElementLimit = largeElementLimit;
169   return *this;
170 }
171 
172 /// Enable printing of debug information. If 'prettyForm' is set to true,
173 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)174 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
175   printDebugInfoFlag = true;
176   printDebugInfoPrettyFormFlag = prettyForm;
177   return *this;
178 }
179 
180 /// Always print operations in the generic form.
printGenericOpForm()181 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
182   printGenericOpFormFlag = true;
183   return *this;
184 }
185 
186 /// Use local scope when printing the operation. This allows for using the
187 /// printer in a more localized and thread-safe setting, but may not necessarily
188 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()189 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
190   printLocalScope = true;
191   return *this;
192 }
193 
194 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const195 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
196   return elementsAttrElementLimit.hasValue() &&
197          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
198          !attr.isa<SplatElementsAttr>();
199 }
200 
201 /// Return the size limit for printing large ElementsAttr.
getLargeElementsAttrLimit() const202 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
203   return elementsAttrElementLimit;
204 }
205 
206 /// Return if debug information should be printed.
shouldPrintDebugInfo() const207 bool OpPrintingFlags::shouldPrintDebugInfo() const {
208   return printDebugInfoFlag;
209 }
210 
211 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const212 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
213   return printDebugInfoPrettyFormFlag;
214 }
215 
216 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const217 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
218   return printGenericOpFormFlag;
219 }
220 
221 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const222 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
223 
224 /// Returns true if an ElementsAttr with the given number of elements should be
225 /// printed with hex.
shouldPrintElementsAttrWithHex(int64_t numElements)226 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
227   // Check to see if a command line option was provided for the limit.
228   if (clOptions.isConstructed()) {
229     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
230       // -1 is used to disable hex printing.
231       if (clOptions->printElementsAttrWithHexIfLarger == -1)
232         return false;
233       return numElements > clOptions->printElementsAttrWithHexIfLarger;
234     }
235   }
236 
237   // Otherwise, default to printing with hex if the number of elements is >100.
238   return numElements > 100;
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // NewLineCounter
243 //===----------------------------------------------------------------------===//
244 
245 namespace {
246 /// This class is a simple formatter that emits a new line when inputted into a
247 /// stream, that enables counting the number of newlines emitted. This class
248 /// should be used whenever emitting newlines in the printer.
249 struct NewLineCounter {
250   unsigned curLine = 1;
251 };
252 } // end anonymous namespace
253 
operator <<(raw_ostream & os,NewLineCounter & newLine)254 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
255   ++newLine.curLine;
256   return os << '\n';
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // AliasInitializer
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents a specific instance of a symbol Alias.
265 class SymbolAlias {
266 public:
SymbolAlias(StringRef name,bool isDeferrable)267   SymbolAlias(StringRef name, bool isDeferrable)
268       : name(name), suffixIndex(0), hasSuffixIndex(false),
269         isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name,uint32_t suffixIndex,bool isDeferrable)270   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
271       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
272         isDeferrable(isDeferrable) {}
273 
274   /// Print this alias to the given stream.
print(raw_ostream & os) const275   void print(raw_ostream &os) const {
276     os << name;
277     if (hasSuffixIndex)
278       os << suffixIndex;
279   }
280 
281   /// Returns true if this alias supports deferred resolution when parsing.
canBeDeferred() const282   bool canBeDeferred() const { return isDeferrable; }
283 
284 private:
285   /// The main name of the alias.
286   StringRef name;
287   /// The optional suffix index of the alias, if multiple aliases had the same
288   /// name.
289   uint32_t suffixIndex : 30;
290   /// A flag indicating whether this alias has a suffix or not.
291   bool hasSuffixIndex : 1;
292   /// A flag indicating whether this alias may be deferred or not.
293   bool isDeferrable : 1;
294 };
295 
296 /// This class represents a utility that initializes the set of attribute and
297 /// type aliases, without the need to store the extra information within the
298 /// main AliasState class or pass it around via function arguments.
299 class AliasInitializer {
300 public:
AliasInitializer(DialectInterfaceCollection<OpAsmDialectInterface> & interfaces,llvm::BumpPtrAllocator & aliasAllocator)301   AliasInitializer(
302       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
303       llvm::BumpPtrAllocator &aliasAllocator)
304       : interfaces(interfaces), aliasAllocator(aliasAllocator),
305         aliasOS(aliasBuffer) {}
306 
307   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
308                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
309                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
310 
311   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
312   /// set to true if the originator of this attribute can resolve the alias
313   /// after parsing has completed (e.g. in the case of operation locations).
314   void visit(Attribute attr, bool canBeDeferred = false);
315 
316   /// Visit the given type to see if it has an alias.
317   void visit(Type type);
318 
319 private:
320   /// Try to generate an alias for the provided symbol. If an alias is
321   /// generated, the provided alias mapping and reverse mapping are updated.
322   /// Returns success if an alias was generated, failure otherwise.
323   template <typename T>
324   LogicalResult
325   generateAlias(T symbol,
326                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
327 
328   /// The set of asm interfaces within the context.
329   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
330 
331   /// Mapping between an alias and the set of symbols mapped to it.
332   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
333   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
334 
335   /// An allocator used for alias names.
336   llvm::BumpPtrAllocator &aliasAllocator;
337 
338   /// The set of visited attributes.
339   DenseSet<Attribute> visitedAttributes;
340 
341   /// The set of attributes that have aliases *and* can be deferred.
342   DenseSet<Attribute> deferrableAttributes;
343 
344   /// The set of visited types.
345   DenseSet<Type> visitedTypes;
346 
347   /// Storage and stream used when generating an alias.
348   SmallString<32> aliasBuffer;
349   llvm::raw_svector_ostream aliasOS;
350 };
351 
352 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
353 /// and merely collects the attributes and types that *would* be printed in a
354 /// normal print invocation so that we can generate proper aliases. This allows
355 /// for us to generate aliases only for the attributes and types that would be
356 /// in the output, and trims down unnecessary output.
357 class DummyAliasOperationPrinter : private OpAsmPrinter {
358 public:
DummyAliasOperationPrinter(const OpPrintingFlags & printerFlags,AliasInitializer & initializer)359   explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
360                                       AliasInitializer &initializer)
361       : printerFlags(printerFlags), initializer(initializer) {}
362 
363   /// Print the given operation.
print(Operation * op)364   void print(Operation *op) {
365     // Visit the operation location.
366     if (printerFlags.shouldPrintDebugInfo())
367       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
368 
369     // If requested, always print the generic form.
370     if (!printerFlags.shouldPrintGenericOpForm()) {
371       // Check to see if this is a known operation.  If so, use the registered
372       // custom printer hook.
373       if (auto *opInfo = op->getAbstractOperation()) {
374         opInfo->printAssembly(op, *this);
375         return;
376       }
377     }
378 
379     // Otherwise print with the generic assembly form.
380     printGenericOp(op);
381   }
382 
383 private:
384   /// Print the given operation in the generic form.
printGenericOp(Operation * op)385   void printGenericOp(Operation *op) override {
386     // Consider nested operations for aliases.
387     if (op->getNumRegions() != 0) {
388       for (Region &region : op->getRegions())
389         printRegion(region, /*printEntryBlockArgs=*/true,
390                     /*printBlockTerminators=*/true);
391     }
392 
393     // Visit all the types used in the operation.
394     for (Type type : op->getOperandTypes())
395       printType(type);
396     for (Type type : op->getResultTypes())
397       printType(type);
398 
399     // Consider the attributes of the operation for aliases.
400     for (const NamedAttribute &attr : op->getAttrs())
401       printAttribute(attr.second);
402   }
403 
404   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
405   /// block are not printed. If 'printBlockTerminator' is false, the terminator
406   /// operation of the block is not printed.
print(Block * block,bool printBlockArgs=true,bool printBlockTerminator=true)407   void print(Block *block, bool printBlockArgs = true,
408              bool printBlockTerminator = true) {
409     // Consider the types of the block arguments for aliases if 'printBlockArgs'
410     // is set to true.
411     if (printBlockArgs) {
412       for (BlockArgument arg : block->getArguments()) {
413         printType(arg.getType());
414 
415         // Visit the argument location.
416         if (printerFlags.shouldPrintDebugInfo())
417           // TODO: Allow deferring argument locations.
418           initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
419       }
420     }
421 
422     // Consider the operations within this block, ignoring the terminator if
423     // requested.
424     bool hasTerminator =
425         !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
426     auto range = llvm::make_range(
427         block->begin(),
428         std::prev(block->end(),
429                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
430     for (Operation &op : range)
431       print(&op);
432   }
433 
434   /// Print the given region.
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock=false)435   void printRegion(Region &region, bool printEntryBlockArgs,
436                    bool printBlockTerminators,
437                    bool printEmptyBlock = false) override {
438     if (region.empty())
439       return;
440 
441     auto *entryBlock = &region.front();
442     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
443     for (Block &b : llvm::drop_begin(region, 1))
444       print(&b);
445   }
446 
printRegionArgument(BlockArgument arg,ArrayRef<NamedAttribute> argAttrs,bool omitType)447   void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
448                            bool omitType) override {
449     printType(arg.getType());
450     // Visit the argument location.
451     if (printerFlags.shouldPrintDebugInfo())
452       // TODO: Allow deferring argument locations.
453       initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
454   }
455 
456   /// Consider the given type to be printed for an alias.
printType(Type type)457   void printType(Type type) override { initializer.visit(type); }
458 
459   /// Consider the given attribute to be printed for an alias.
printAttribute(Attribute attr)460   void printAttribute(Attribute attr) override { initializer.visit(attr); }
printAttributeWithoutType(Attribute attr)461   void printAttributeWithoutType(Attribute attr) override {
462     printAttribute(attr);
463   }
464 
465   /// Print the given set of attributes with names not included within
466   /// 'elidedAttrs'.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})467   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
468                              ArrayRef<StringRef> elidedAttrs = {}) override {
469     if (attrs.empty())
470       return;
471     if (elidedAttrs.empty()) {
472       for (const NamedAttribute &attr : attrs)
473         printAttribute(attr.second);
474       return;
475     }
476     llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
477                                                   elidedAttrs.end());
478     for (const NamedAttribute &attr : attrs)
479       if (!elidedAttrsSet.contains(attr.first.strref()))
480         printAttribute(attr.second);
481   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})482   void printOptionalAttrDictWithKeyword(
483       ArrayRef<NamedAttribute> attrs,
484       ArrayRef<StringRef> elidedAttrs = {}) override {
485     printOptionalAttrDict(attrs, elidedAttrs);
486   }
487 
488   /// Return a null stream as the output stream, this will ignore any data fed
489   /// to it.
getStream() const490   raw_ostream &getStream() const override { return os; }
491 
492   /// The following are hooks of `OpAsmPrinter` that are not necessary for
493   /// determining potential aliases.
printAffineMapOfSSAIds(AffineMapAttr,ValueRange)494   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
printAffineExprOfSSAIds(AffineExpr,ValueRange,ValueRange)495   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
printNewline()496   void printNewline() override {}
printOperand(Value)497   void printOperand(Value) override {}
printOperand(Value,raw_ostream & os)498   void printOperand(Value, raw_ostream &os) override {
499     // Users expect the output string to have at least the prefixed % to signal
500     // a value name. To maintain this invariant, emit a name even if it is
501     // guaranteed to go unused.
502     os << "%";
503   }
printSymbolName(StringRef)504   void printSymbolName(StringRef) override {}
printSuccessor(Block *)505   void printSuccessor(Block *) override {}
printSuccessorAndUseList(Block *,ValueRange)506   void printSuccessorAndUseList(Block *, ValueRange) override {}
shadowRegionArgs(Region &,ValueRange)507   void shadowRegionArgs(Region &, ValueRange) override {}
508 
509   /// The printer flags to use when determining potential aliases.
510   const OpPrintingFlags &printerFlags;
511 
512   /// The initializer to use when identifying aliases.
513   AliasInitializer &initializer;
514 
515   /// A dummy output stream.
516   mutable llvm::raw_null_ostream os;
517 };
518 } // end anonymous namespace
519 
520 /// Sanitize the given name such that it can be used as a valid identifier. If
521 /// the string needs to be modified in any way, the provided buffer is used to
522 /// store the new copy,
sanitizeIdentifier(StringRef name,SmallString<16> & buffer,StringRef allowedPunctChars="$._-",bool allowTrailingDigit=true)523 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
524                                     StringRef allowedPunctChars = "$._-",
525                                     bool allowTrailingDigit = true) {
526   assert(!name.empty() && "Shouldn't have an empty name here");
527 
528   auto copyNameToBuffer = [&] {
529     for (char ch : name) {
530       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
531         buffer.push_back(ch);
532       else if (ch == ' ')
533         buffer.push_back('_');
534       else
535         buffer.append(llvm::utohexstr((unsigned char)ch));
536     }
537   };
538 
539   // Check to see if this name is valid. If it starts with a digit, then it
540   // could conflict with the autogenerated numeric ID's, so add an underscore
541   // prefix to avoid problems.
542   if (isdigit(name[0])) {
543     buffer.push_back('_');
544     copyNameToBuffer();
545     return buffer;
546   }
547 
548   // If the name ends with a trailing digit, add a '_' to avoid potential
549   // conflicts with autogenerated ID's.
550   if (!allowTrailingDigit && isdigit(name.back())) {
551     copyNameToBuffer();
552     buffer.push_back('_');
553     return buffer;
554   }
555 
556   // Check to see that the name consists of only valid identifier characters.
557   for (char ch : name) {
558     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
559       copyNameToBuffer();
560       return buffer;
561     }
562   }
563 
564   // If there are no invalid characters, return the original name.
565   return name;
566 }
567 
568 /// Given a collection of aliases and symbols, initialize a mapping from a
569 /// symbol to a given alias.
570 template <typename T>
571 static void
initializeAliases(llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol,llvm::MapVector<T,SymbolAlias> & symbolToAlias,DenseSet<T> * deferrableAliases=nullptr)572 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
573                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
574                   DenseSet<T> *deferrableAliases = nullptr) {
575   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
576       aliasToSymbol.takeVector();
577   llvm::array_pod_sort(aliases.begin(), aliases.end(),
578                        [](const auto *lhs, const auto *rhs) {
579                          return lhs->first.compare(rhs->first);
580                        });
581 
582   for (auto &it : aliases) {
583     // If there is only one instance for this alias, use the name directly.
584     if (it.second.size() == 1) {
585       T symbol = it.second.front();
586       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
587       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
588       continue;
589     }
590     // Otherwise, add the index to the name.
591     for (int i = 0, e = it.second.size(); i < e; ++i) {
592       T symbol = it.second[i];
593       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
594       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
595     }
596   }
597 }
598 
initialize(Operation * op,const OpPrintingFlags & printerFlags,llvm::MapVector<Attribute,SymbolAlias> & attrToAlias,llvm::MapVector<Type,SymbolAlias> & typeToAlias)599 void AliasInitializer::initialize(
600     Operation *op, const OpPrintingFlags &printerFlags,
601     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
602     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
603   // Use a dummy printer when walking the IR so that we can collect the
604   // attributes/types that will actually be used during printing when
605   // considering aliases.
606   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
607   aliasPrinter.print(op);
608 
609   // Initialize the aliases sorted by name.
610   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
611   initializeAliases(aliasToType, typeToAlias);
612 }
613 
visit(Attribute attr,bool canBeDeferred)614 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
615   if (!visitedAttributes.insert(attr).second) {
616     // If this attribute already has an alias and this instance can't be
617     // deferred, make sure that the alias isn't deferred.
618     if (!canBeDeferred)
619       deferrableAttributes.erase(attr);
620     return;
621   }
622 
623   // Try to generate an alias for this attribute.
624   if (succeeded(generateAlias(attr, aliasToAttr))) {
625     if (canBeDeferred)
626       deferrableAttributes.insert(attr);
627     return;
628   }
629 
630   // Check for any sub elements.
631   if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
632     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
633                                         [&](Type type) { visit(type); });
634   }
635 }
636 
visit(Type type)637 void AliasInitializer::visit(Type type) {
638   if (!visitedTypes.insert(type).second)
639     return;
640 
641   // Try to generate an alias for this type.
642   if (succeeded(generateAlias(type, aliasToType)))
643     return;
644 
645   // Check for any sub elements.
646   if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
647     subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
648                                         [&](Type type) { visit(type); });
649   }
650 }
651 
652 template <typename T>
generateAlias(T symbol,llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol)653 LogicalResult AliasInitializer::generateAlias(
654     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
655   SmallString<16> tempBuffer;
656   for (const auto &interface : interfaces) {
657     if (failed(interface.getAlias(symbol, aliasOS)))
658       continue;
659     StringRef name = aliasOS.str();
660     assert(!name.empty() && "expected valid alias name");
661     name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
662                               /*allowTrailingDigit=*/false);
663     name = name.copy(aliasAllocator);
664 
665     aliasToSymbol[name].push_back(symbol);
666     aliasBuffer.clear();
667     return success();
668   }
669   return failure();
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // AliasState
674 //===----------------------------------------------------------------------===//
675 
676 namespace {
677 /// This class manages the state for type and attribute aliases.
678 class AliasState {
679 public:
680   // Initialize the internal aliases.
681   void
682   initialize(Operation *op, const OpPrintingFlags &printerFlags,
683              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
684 
685   /// Get an alias for the given attribute if it has one and print it in `os`.
686   /// Returns success if an alias was printed, failure otherwise.
687   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
688 
689   /// Get an alias for the given type if it has one and print it in `os`.
690   /// Returns success if an alias was printed, failure otherwise.
691   LogicalResult getAlias(Type ty, raw_ostream &os) const;
692 
693   /// Print all of the referenced aliases that can not be resolved in a deferred
694   /// manner.
printNonDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const695   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
696     printAliases(os, newLine, /*isDeferred=*/false);
697   }
698 
699   /// Print all of the referenced aliases that support deferred resolution.
printDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const700   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
701     printAliases(os, newLine, /*isDeferred=*/true);
702   }
703 
704 private:
705   /// Print all of the referenced aliases that support the provided resolution
706   /// behavior.
707   void printAliases(raw_ostream &os, NewLineCounter &newLine,
708                     bool isDeferred) const;
709 
710   /// Mapping between attribute and alias.
711   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
712   /// Mapping between type and alias.
713   llvm::MapVector<Type, SymbolAlias> typeToAlias;
714 
715   /// An allocator used for alias names.
716   llvm::BumpPtrAllocator aliasAllocator;
717 };
718 } // end anonymous namespace
719 
initialize(Operation * op,const OpPrintingFlags & printerFlags,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)720 void AliasState::initialize(
721     Operation *op, const OpPrintingFlags &printerFlags,
722     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
723   AliasInitializer initializer(interfaces, aliasAllocator);
724   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
725 }
726 
getAlias(Attribute attr,raw_ostream & os) const727 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
728   auto it = attrToAlias.find(attr);
729   if (it == attrToAlias.end())
730     return failure();
731   it->second.print(os << '#');
732   return success();
733 }
734 
getAlias(Type ty,raw_ostream & os) const735 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
736   auto it = typeToAlias.find(ty);
737   if (it == typeToAlias.end())
738     return failure();
739 
740   it->second.print(os << '!');
741   return success();
742 }
743 
printAliases(raw_ostream & os,NewLineCounter & newLine,bool isDeferred) const744 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
745                               bool isDeferred) const {
746   auto filterFn = [=](const auto &aliasIt) {
747     return aliasIt.second.canBeDeferred() == isDeferred;
748   };
749   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
750     it.second.print(os << '#');
751     os << " = " << it.first << newLine;
752   }
753   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
754     it.second.print(os << '!');
755     os << " = type " << it.first << newLine;
756   }
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // SSANameState
761 //===----------------------------------------------------------------------===//
762 
763 namespace {
764 /// This class manages the state of SSA value names.
765 class SSANameState {
766 public:
767   /// A sentinel value used for values with names set.
768   enum : unsigned { NameSentinel = ~0U };
769 
770   SSANameState(Operation *op, const OpPrintingFlags &printerFlags,
771                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
772 
773   /// Print the SSA identifier for the given value to 'stream'. If
774   /// 'printResultNo' is true, it also presents the result number ('#' number)
775   /// of this value.
776   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
777 
778   /// Return the result indices for each of the result groups registered by this
779   /// operation, or empty if none exist.
780   ArrayRef<int> getOpResultGroups(Operation *op);
781 
782   /// Get the ID for the given block.
783   unsigned getBlockID(Block *block);
784 
785   /// Renumber the arguments for the specified region to the same names as the
786   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
787   /// details.
788   void shadowRegionArgs(Region &region, ValueRange namesToUse);
789 
790 private:
791   /// Number the SSA values within the given IR unit.
792   void numberValuesInRegion(Region &region);
793   void numberValuesInBlock(Block &block);
794   void numberValuesInOp(Operation &op);
795 
796   /// Given a result of an operation 'result', find the result group head
797   /// 'lookupValue' and the result of 'result' within that group in
798   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
799   /// has more than 1 result.
800   void getResultIDAndNumber(OpResult result, Value &lookupValue,
801                             Optional<int> &lookupResultNo) const;
802 
803   /// Set a special value name for the given value.
804   void setValueName(Value value, StringRef name);
805 
806   /// Uniques the given value name within the printer. If the given name
807   /// conflicts, it is automatically renamed.
808   StringRef uniqueValueName(StringRef name);
809 
810   /// This is the value ID for each SSA value. If this returns NameSentinel,
811   /// then the valueID has an entry in valueNames.
812   DenseMap<Value, unsigned> valueIDs;
813   DenseMap<Value, StringRef> valueNames;
814 
815   /// This is a map of operations that contain multiple named result groups,
816   /// i.e. there may be multiple names for the results of the operation. The
817   /// value of this map are the result numbers that start a result group.
818   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
819 
820   /// This is the block ID for each block in the current.
821   DenseMap<Block *, unsigned> blockIDs;
822 
823   /// This keeps track of all of the non-numeric names that are in flight,
824   /// allowing us to check for duplicates.
825   /// Note: the value of the map is unused.
826   llvm::ScopedHashTable<StringRef, char> usedNames;
827   llvm::BumpPtrAllocator usedNameAllocator;
828 
829   /// This is the next value ID to assign in numbering.
830   unsigned nextValueID = 0;
831   /// This is the next ID to assign to a region entry block argument.
832   unsigned nextArgumentID = 0;
833   /// This is the next ID to assign when a name conflict is detected.
834   unsigned nextConflictID = 0;
835 
836   /// These are the printing flags.  They control, eg., whether to print in
837   /// generic form.
838   OpPrintingFlags printerFlags;
839 
840   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
841 };
842 } // end anonymous namespace
843 
SSANameState(Operation * op,const OpPrintingFlags & printerFlags,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)844 SSANameState::SSANameState(
845     Operation *op, const OpPrintingFlags &printerFlags,
846     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces)
847     : printerFlags(printerFlags), interfaces(interfaces) {
848   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
849   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
850   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
851 
852   // The naming context includes `nextValueID`, `nextArgumentID`,
853   // `nextConflictID` and `usedNames` scoped HashTable. This information is
854   // carried from the parent region.
855   using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
856   using NamingContext =
857       std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
858 
859   // Allocator for UsedNamesScopeTy
860   llvm::BumpPtrAllocator allocator;
861 
862   // Add a scope for the top level operation.
863   auto *topLevelNamesScope =
864       new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
865 
866   SmallVector<NamingContext, 8> nameContext;
867   for (Region &region : op->getRegions())
868     nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
869                                           nextConflictID, topLevelNamesScope));
870 
871   numberValuesInOp(*op);
872 
873   while (!nameContext.empty()) {
874     Region *region;
875     UsedNamesScopeTy *parentScope;
876     std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
877         nameContext.pop_back_val();
878 
879     // When we switch from one subtree to another, pop the scopes(needless)
880     // until the parent scope.
881     while (usedNames.getCurScope() != parentScope) {
882       usedNames.getCurScope()->~UsedNamesScopeTy();
883       assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
884              "top level parentScope must be a nullptr");
885     }
886 
887     // Add a scope for the current region.
888     auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
889         UsedNamesScopeTy(usedNames);
890 
891     numberValuesInRegion(*region);
892 
893     for (Operation &op : region->getOps())
894       for (Region &region : op.getRegions())
895         nameContext.push_back(std::make_tuple(&region, nextValueID,
896                                               nextArgumentID, nextConflictID,
897                                               curNamesScope));
898   }
899 
900   // Manually remove all the scopes.
901   while (usedNames.getCurScope() != nullptr)
902     usedNames.getCurScope()->~UsedNamesScopeTy();
903 }
904 
printValueID(Value value,bool printResultNo,raw_ostream & stream) const905 void SSANameState::printValueID(Value value, bool printResultNo,
906                                 raw_ostream &stream) const {
907   if (!value) {
908     stream << "<<NULL>>";
909     return;
910   }
911 
912   Optional<int> resultNo;
913   auto lookupValue = value;
914 
915   // If this is an operation result, collect the head lookup value of the result
916   // group and the result number of 'result' within that group.
917   if (OpResult result = value.dyn_cast<OpResult>())
918     getResultIDAndNumber(result, lookupValue, resultNo);
919 
920   auto it = valueIDs.find(lookupValue);
921   if (it == valueIDs.end()) {
922     stream << "<<UNKNOWN SSA VALUE>>";
923     return;
924   }
925 
926   stream << '%';
927   if (it->second != NameSentinel) {
928     stream << it->second;
929   } else {
930     auto nameIt = valueNames.find(lookupValue);
931     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
932     stream << nameIt->second;
933   }
934 
935   if (resultNo.hasValue() && printResultNo)
936     stream << '#' << resultNo;
937 }
938 
getOpResultGroups(Operation * op)939 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
940   auto it = opResultGroups.find(op);
941   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
942 }
943 
getBlockID(Block * block)944 unsigned SSANameState::getBlockID(Block *block) {
945   auto it = blockIDs.find(block);
946   return it != blockIDs.end() ? it->second : NameSentinel;
947 }
948 
shadowRegionArgs(Region & region,ValueRange namesToUse)949 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
950   assert(!region.empty() && "cannot shadow arguments of an empty region");
951   assert(region.getNumArguments() == namesToUse.size() &&
952          "incorrect number of names passed in");
953   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
954          "only KnownIsolatedFromAbove ops can shadow names");
955 
956   SmallVector<char, 16> nameStr;
957   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
958     auto nameToUse = namesToUse[i];
959     if (nameToUse == nullptr)
960       continue;
961     auto nameToReplace = region.getArgument(i);
962 
963     nameStr.clear();
964     llvm::raw_svector_ostream nameStream(nameStr);
965     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
966 
967     // Entry block arguments should already have a pretty "arg" name.
968     assert(valueIDs[nameToReplace] == NameSentinel);
969 
970     // Use the name without the leading %.
971     auto name = StringRef(nameStream.str()).drop_front();
972 
973     // Overwrite the name.
974     valueNames[nameToReplace] = name.copy(usedNameAllocator);
975   }
976 }
977 
numberValuesInRegion(Region & region)978 void SSANameState::numberValuesInRegion(Region &region) {
979   // Number the values within this region in a breadth-first order.
980   unsigned nextBlockID = 0;
981   for (auto &block : region) {
982     // Each block gets a unique ID, and all of the operations within it get
983     // numbered as well.
984     blockIDs[&block] = nextBlockID++;
985     numberValuesInBlock(block);
986   }
987 }
988 
numberValuesInBlock(Block & block)989 void SSANameState::numberValuesInBlock(Block &block) {
990   auto setArgNameFn = [&](Value arg, StringRef name) {
991     assert(!valueIDs.count(arg) && "arg numbered multiple times");
992     assert(arg.cast<BlockArgument>().getOwner() == &block &&
993            "arg not defined in 'block'");
994     setValueName(arg, name);
995   };
996 
997   bool isEntryBlock = block.isEntryBlock();
998   if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
999     if (auto *op = block.getParentOp()) {
1000       if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
1001         asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
1002     }
1003   }
1004 
1005   // Number the block arguments. We give entry block arguments a special name
1006   // 'arg'.
1007   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1008   llvm::raw_svector_ostream specialName(specialNameBuffer);
1009   for (auto arg : block.getArguments()) {
1010     if (valueIDs.count(arg))
1011       continue;
1012     if (isEntryBlock) {
1013       specialNameBuffer.resize(strlen("arg"));
1014       specialName << nextArgumentID++;
1015     }
1016     setValueName(arg, specialName.str());
1017   }
1018 
1019   // Number the operations in this block.
1020   for (auto &op : block)
1021     numberValuesInOp(op);
1022 }
1023 
numberValuesInOp(Operation & op)1024 void SSANameState::numberValuesInOp(Operation &op) {
1025   unsigned numResults = op.getNumResults();
1026   if (numResults == 0)
1027     return;
1028   Value resultBegin = op.getResult(0);
1029 
1030   // Function used to set the special result names for the operation.
1031   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1032   auto setResultNameFn = [&](Value result, StringRef name) {
1033     assert(!valueIDs.count(result) && "result numbered multiple times");
1034     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1035     setValueName(result, name);
1036 
1037     // Record the result number for groups not anchored at 0.
1038     if (int resultNo = result.cast<OpResult>().getResultNumber())
1039       resultGroups.push_back(resultNo);
1040   };
1041   if (!printerFlags.shouldPrintGenericOpForm()) {
1042     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
1043       asmInterface.getAsmResultNames(setResultNameFn);
1044     else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
1045       asmInterface->getAsmResultNames(&op, setResultNameFn);
1046   }
1047 
1048   // If the first result wasn't numbered, give it a default number.
1049   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1050     ++nextValueID;
1051 
1052   // If this operation has multiple result groups, mark it.
1053   if (resultGroups.size() != 1) {
1054     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1055     opResultGroups.try_emplace(&op, std::move(resultGroups));
1056   }
1057 }
1058 
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const1059 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
1060                                         Optional<int> &lookupResultNo) const {
1061   Operation *owner = result.getOwner();
1062   if (owner->getNumResults() == 1)
1063     return;
1064   int resultNo = result.getResultNumber();
1065 
1066   // If this operation has multiple result groups, we will need to find the
1067   // one corresponding to this result.
1068   auto resultGroupIt = opResultGroups.find(owner);
1069   if (resultGroupIt == opResultGroups.end()) {
1070     // If not, just use the first result.
1071     lookupResultNo = resultNo;
1072     lookupValue = owner->getResult(0);
1073     return;
1074   }
1075 
1076   // Find the correct index using a binary search, as the groups are ordered.
1077   ArrayRef<int> resultGroups = resultGroupIt->second;
1078   auto it = llvm::upper_bound(resultGroups, resultNo);
1079   int groupResultNo = 0, groupSize = 0;
1080 
1081   // If there are no smaller elements, the last result group is the lookup.
1082   if (it == resultGroups.end()) {
1083     groupResultNo = resultGroups.back();
1084     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1085   } else {
1086     // Otherwise, the previous element is the lookup.
1087     groupResultNo = *std::prev(it);
1088     groupSize = *it - groupResultNo;
1089   }
1090 
1091   // We only record the result number for a group of size greater than 1.
1092   if (groupSize != 1)
1093     lookupResultNo = resultNo - groupResultNo;
1094   lookupValue = owner->getResult(groupResultNo);
1095 }
1096 
setValueName(Value value,StringRef name)1097 void SSANameState::setValueName(Value value, StringRef name) {
1098   // If the name is empty, the value uses the default numbering.
1099   if (name.empty()) {
1100     valueIDs[value] = nextValueID++;
1101     return;
1102   }
1103 
1104   valueIDs[value] = NameSentinel;
1105   valueNames[value] = uniqueValueName(name);
1106 }
1107 
uniqueValueName(StringRef name)1108 StringRef SSANameState::uniqueValueName(StringRef name) {
1109   SmallString<16> tmpBuffer;
1110   name = sanitizeIdentifier(name, tmpBuffer);
1111 
1112   // Check to see if this name is already unique.
1113   if (!usedNames.count(name)) {
1114     name = name.copy(usedNameAllocator);
1115   } else {
1116     // Otherwise, we had a conflict - probe until we find a unique name. This
1117     // is guaranteed to terminate (and usually in a single iteration) because it
1118     // generates new names by incrementing nextConflictID.
1119     SmallString<64> probeName(name);
1120     probeName.push_back('_');
1121     while (true) {
1122       probeName += llvm::utostr(nextConflictID++);
1123       if (!usedNames.count(probeName)) {
1124         name = probeName.str().copy(usedNameAllocator);
1125         break;
1126       }
1127       probeName.resize(name.size() + 1);
1128     }
1129   }
1130 
1131   usedNames.insert(name, char());
1132   return name;
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // AsmState
1137 //===----------------------------------------------------------------------===//
1138 
1139 namespace mlir {
1140 namespace detail {
1141 class AsmStateImpl {
1142 public:
AsmStateImpl(Operation * op,const OpPrintingFlags & printerFlags,AsmState::LocationMap * locationMap)1143   explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1144                         AsmState::LocationMap *locationMap)
1145       : interfaces(op->getContext()), nameState(op, printerFlags, interfaces),
1146         printerFlags(printerFlags), locationMap(locationMap) {}
1147 
1148   /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op)1149   void initializeAliases(Operation *op) {
1150     aliasState.initialize(op, printerFlags, interfaces);
1151   }
1152 
1153   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1154   /// null if one wasn't registered.
getOpAsmInterface(Dialect * dialect)1155   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1156     return interfaces.getInterfaceFor(dialect);
1157   }
1158 
1159   /// Get the state used for aliases.
getAliasState()1160   AliasState &getAliasState() { return aliasState; }
1161 
1162   /// Get the state used for SSA names.
getSSANameState()1163   SSANameState &getSSANameState() { return nameState; }
1164 
1165   /// Register the location, line and column, within the buffer that the given
1166   /// operation was printed at.
registerOperationLocation(Operation * op,unsigned line,unsigned col)1167   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1168     if (locationMap)
1169       (*locationMap)[op] = std::make_pair(line, col);
1170   }
1171 
1172 private:
1173   /// Collection of OpAsm interfaces implemented in the context.
1174   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1175 
1176   /// The state used for attribute and type aliases.
1177   AliasState aliasState;
1178 
1179   /// The state used for SSA value names.
1180   SSANameState nameState;
1181 
1182   /// Flags that control op output.
1183   OpPrintingFlags printerFlags;
1184 
1185   /// An optional location map to be populated.
1186   AsmState::LocationMap *locationMap;
1187 };
1188 } // end namespace detail
1189 } // end namespace mlir
1190 
AsmState(Operation * op,const OpPrintingFlags & printerFlags,LocationMap * locationMap)1191 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1192                    LocationMap *locationMap)
1193     : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
~AsmState()1194 AsmState::~AsmState() {}
1195 
1196 //===----------------------------------------------------------------------===//
1197 // ModulePrinter
1198 //===----------------------------------------------------------------------===//
1199 
1200 namespace {
1201 class ModulePrinter {
1202 public:
ModulePrinter(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)1203   ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1204                 AsmStateImpl *state = nullptr)
1205       : os(os), printerFlags(flags), state(state) {}
ModulePrinter(ModulePrinter & printer)1206   explicit ModulePrinter(ModulePrinter &printer)
1207       : os(printer.os), printerFlags(printer.printerFlags),
1208         state(printer.state) {}
1209 
1210   /// Returns the output stream of the printer.
getStream()1211   raw_ostream &getStream() { return os; }
1212 
1213   template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor each_fn) const1214   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
1215     llvm::interleaveComma(c, os, each_fn);
1216   }
1217 
1218   /// This enum describes the different kinds of elision for the type of an
1219   /// attribute when printing it.
1220   enum class AttrTypeElision {
1221     /// The type must not be elided,
1222     Never,
1223     /// The type may be elided when it matches the default used in the parser
1224     /// (for example i64 is the default for integer attributes).
1225     May,
1226     /// The type must be elided.
1227     Must
1228   };
1229 
1230   /// Print the given attribute.
1231   void printAttribute(Attribute attr,
1232                       AttrTypeElision typeElision = AttrTypeElision::Never);
1233 
1234   void printType(Type type);
1235 
1236   /// Print the given location to the stream. If `allowAlias` is true, this
1237   /// allows for the internal location to use an attribute alias.
1238   void printLocation(LocationAttr loc, bool allowAlias = false);
1239 
1240   void printAffineMap(AffineMap map);
1241   void
1242   printAffineExpr(AffineExpr expr,
1243                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1244   void printAffineConstraint(AffineExpr expr, bool isEq);
1245   void printIntegerSet(IntegerSet set);
1246 
1247 protected:
1248   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1249                              ArrayRef<StringRef> elidedAttrs = {},
1250                              bool withKeyword = false);
1251   void printNamedAttribute(NamedAttribute attr);
1252   void printTrailingLocation(Location loc, bool allowAlias = true);
1253   void printLocationInternal(LocationAttr loc, bool pretty = false);
1254 
1255   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1256   /// used instead of individual elements when the elements attr is large.
1257   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1258 
1259   /// Print a dense string elements attribute.
1260   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1261 
1262   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1263   /// used instead of individual elements when the elements attr is large.
1264   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1265                                      bool allowHex);
1266 
1267   void printDialectAttribute(Attribute attr);
1268   void printDialectType(Type type);
1269 
1270   /// This enum is used to represent the binding strength of the enclosing
1271   /// context that an AffineExprStorage is being printed in, so we can
1272   /// intelligently produce parens.
1273   enum class BindingStrength {
1274     Weak,   // + and -
1275     Strong, // All other binary operators.
1276   };
1277   void printAffineExprInternal(
1278       AffineExpr expr, BindingStrength enclosingTightness,
1279       function_ref<void(unsigned, bool)> printValueName = nullptr);
1280 
1281   /// The output stream for the printer.
1282   raw_ostream &os;
1283 
1284   /// A set of flags to control the printer's behavior.
1285   OpPrintingFlags printerFlags;
1286 
1287   /// An optional printer state for the module.
1288   AsmStateImpl *state;
1289 
1290   /// A tracker for the number of new lines emitted during printing.
1291   NewLineCounter newLine;
1292 };
1293 } // end anonymous namespace
1294 
printTrailingLocation(Location loc,bool allowAlias)1295 void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) {
1296   // Check to see if we are printing debug information.
1297   if (!printerFlags.shouldPrintDebugInfo())
1298     return;
1299 
1300   os << " ";
1301   printLocation(loc, /*allowAlias=*/allowAlias);
1302 }
1303 
printLocationInternal(LocationAttr loc,bool pretty)1304 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
1305   TypeSwitch<LocationAttr>(loc)
1306       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1307         printLocationInternal(loc.getFallbackLocation(), pretty);
1308       })
1309       .Case<UnknownLoc>([&](UnknownLoc loc) {
1310         if (pretty)
1311           os << "[unknown]";
1312         else
1313           os << "unknown";
1314       })
1315       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1316         if (pretty) {
1317           os << loc.getFilename();
1318         } else {
1319           os << "\"";
1320           printEscapedString(loc.getFilename(), os);
1321           os << "\"";
1322         }
1323         os << ':' << loc.getLine() << ':' << loc.getColumn();
1324       })
1325       .Case<NameLoc>([&](NameLoc loc) {
1326         os << '\"';
1327         printEscapedString(loc.getName(), os);
1328         os << '\"';
1329 
1330         // Print the child if it isn't unknown.
1331         auto childLoc = loc.getChildLoc();
1332         if (!childLoc.isa<UnknownLoc>()) {
1333           os << '(';
1334           printLocationInternal(childLoc, pretty);
1335           os << ')';
1336         }
1337       })
1338       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1339         Location caller = loc.getCaller();
1340         Location callee = loc.getCallee();
1341         if (!pretty)
1342           os << "callsite(";
1343         printLocationInternal(callee, pretty);
1344         if (pretty) {
1345           if (callee.isa<NameLoc>()) {
1346             if (caller.isa<FileLineColLoc>()) {
1347               os << " at ";
1348             } else {
1349               os << newLine << " at ";
1350             }
1351           } else {
1352             os << newLine << " at ";
1353           }
1354         } else {
1355           os << " at ";
1356         }
1357         printLocationInternal(caller, pretty);
1358         if (!pretty)
1359           os << ")";
1360       })
1361       .Case<FusedLoc>([&](FusedLoc loc) {
1362         if (!pretty)
1363           os << "fused";
1364         if (Attribute metadata = loc.getMetadata())
1365           os << '<' << metadata << '>';
1366         os << '[';
1367         interleave(
1368             loc.getLocations(),
1369             [&](Location loc) { printLocationInternal(loc, pretty); },
1370             [&]() { os << ", "; });
1371         os << ']';
1372       });
1373 }
1374 
1375 /// Print a floating point value in a way that the parser will be able to
1376 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)1377 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1378   // We would like to output the FP constant value in exponential notation,
1379   // but we cannot do this if doing so will lose precision.  Check here to
1380   // make sure that we only output it in exponential format if we can parse
1381   // the value back and get the same value.
1382   bool isInf = apValue.isInfinity();
1383   bool isNaN = apValue.isNaN();
1384   if (!isInf && !isNaN) {
1385     SmallString<128> strValue;
1386     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1387                      /*TruncateZero=*/false);
1388 
1389     // Check to make sure that the stringized number is not some string like
1390     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1391     // that the string matches the "[-+]?[0-9]" regex.
1392     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1393             ((strValue[0] == '-' || strValue[0] == '+') &&
1394              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1395            "[-+]?[0-9] regex does not match!");
1396 
1397     // Parse back the stringized version and check that the value is equal
1398     // (i.e., there is no precision loss).
1399     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1400       os << strValue;
1401       return;
1402     }
1403 
1404     // If it is not, use the default format of APFloat instead of the
1405     // exponential notation.
1406     strValue.clear();
1407     apValue.toString(strValue);
1408 
1409     // Make sure that we can parse the default form as a float.
1410     if (strValue.str().contains('.')) {
1411       os << strValue;
1412       return;
1413     }
1414   }
1415 
1416   // Print special values in hexadecimal format. The sign bit should be included
1417   // in the literal.
1418   SmallVector<char, 16> str;
1419   APInt apInt = apValue.bitcastToAPInt();
1420   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1421                  /*formatAsCLiteral=*/true);
1422   os << str;
1423 }
1424 
printLocation(LocationAttr loc,bool allowAlias)1425 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
1426   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1427     return printLocationInternal(loc, /*pretty=*/true);
1428 
1429   os << "loc(";
1430   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1431     printLocationInternal(loc);
1432   os << ')';
1433 }
1434 
1435 /// Returns true if the given dialect symbol data is simple enough to print in
1436 /// the pretty form, i.e. without the enclosing "".
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1437 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1438   // The name must start with an identifier.
1439   if (symName.empty() || !isalpha(symName.front()))
1440     return false;
1441 
1442   // Ignore all the characters that are valid in an identifier in the symbol
1443   // name.
1444   symName = symName.drop_while(
1445       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1446   if (symName.empty())
1447     return true;
1448 
1449   // If we got to an unexpected character, then it must be a <>.  Check those
1450   // recursively.
1451   if (symName.front() != '<' || symName.back() != '>')
1452     return false;
1453 
1454   SmallVector<char, 8> nestedPunctuation;
1455   do {
1456     // If we ran out of characters, then we had a punctuation mismatch.
1457     if (symName.empty())
1458       return false;
1459 
1460     auto c = symName.front();
1461     symName = symName.drop_front();
1462 
1463     switch (c) {
1464     // We never allow null characters. This is an EOF indicator for the lexer
1465     // which we could handle, but isn't important for any known dialect.
1466     case '\0':
1467       return false;
1468     case '<':
1469     case '[':
1470     case '(':
1471     case '{':
1472       nestedPunctuation.push_back(c);
1473       continue;
1474     case '-':
1475       // Treat `->` as a special token.
1476       if (!symName.empty() && symName.front() == '>') {
1477         symName = symName.drop_front();
1478         continue;
1479       }
1480       break;
1481     // Reject types with mismatched brackets.
1482     case '>':
1483       if (nestedPunctuation.pop_back_val() != '<')
1484         return false;
1485       break;
1486     case ']':
1487       if (nestedPunctuation.pop_back_val() != '[')
1488         return false;
1489       break;
1490     case ')':
1491       if (nestedPunctuation.pop_back_val() != '(')
1492         return false;
1493       break;
1494     case '}':
1495       if (nestedPunctuation.pop_back_val() != '{')
1496         return false;
1497       break;
1498     default:
1499       continue;
1500     }
1501 
1502     // We're done when the punctuation is fully matched.
1503   } while (!nestedPunctuation.empty());
1504 
1505   // If there were extra characters, then we failed.
1506   return symName.empty();
1507 }
1508 
1509 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1510 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1511                                StringRef dialectName, StringRef symString) {
1512   os << symPrefix << dialectName;
1513 
1514   // If this symbol name is simple enough, print it directly in pretty form,
1515   // otherwise, we print it as an escaped string.
1516   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1517     os << '.' << symString;
1518     return;
1519   }
1520 
1521   os << "<\"";
1522   llvm::printEscapedString(symString, os);
1523   os << "\">";
1524 }
1525 
1526 /// Returns true if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1527 static bool isBareIdentifier(StringRef name) {
1528   assert(!name.empty() && "invalid name");
1529 
1530   // By making this unsigned, the value passed in to isalnum will always be
1531   // in the range 0-255. This is important when building with MSVC because
1532   // its implementation will assert. This situation can arise when dealing
1533   // with UTF-8 multibyte characters.
1534   unsigned char firstChar = static_cast<unsigned char>(name[0]);
1535   if (!isalpha(firstChar) && firstChar != '_')
1536     return false;
1537   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1538     return isalnum(c) || c == '_' || c == '$' || c == '.';
1539   });
1540 }
1541 
1542 /// Print the given string as a symbol reference. A symbol reference is
1543 /// represented as a string prefixed with '@'. The reference is surrounded with
1544 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1545 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1546   assert(!symbolRef.empty() && "expected valid symbol reference");
1547 
1548   // If the symbol can be represented as a bare identifier, write it directly.
1549   if (isBareIdentifier(symbolRef)) {
1550     os << '@' << symbolRef;
1551     return;
1552   }
1553 
1554   // Otherwise, output the reference wrapped in quotes with proper escaping.
1555   os << "@\"";
1556   printEscapedString(symbolRef, os);
1557   os << '"';
1558 }
1559 
1560 // Print out a valid ElementsAttr that is succinct and can represent any
1561 // potential shape/type, for use when eliding a large ElementsAttr.
1562 //
1563 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1564 // hopefully alert readers to the fact that this has been elided.
1565 //
1566 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1567 // accept the string "elided". The first string must be a registered dialect
1568 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1569 static void printElidedElementsAttr(raw_ostream &os) {
1570   os << R"(opaque<"_", "0xDEADBEEF">)";
1571 }
1572 
printAttribute(Attribute attr,AttrTypeElision typeElision)1573 void ModulePrinter::printAttribute(Attribute attr,
1574                                    AttrTypeElision typeElision) {
1575   if (!attr) {
1576     os << "<<NULL ATTRIBUTE>>";
1577     return;
1578   }
1579 
1580   // Try to print an alias for this attribute.
1581   if (state && succeeded(state->getAliasState().getAlias(attr, os)))
1582     return;
1583 
1584   auto attrType = attr.getType();
1585   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1586     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1587                        opaqueAttr.getAttrData());
1588   } else if (attr.isa<UnitAttr>()) {
1589     os << "unit";
1590     return;
1591   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1592     os << '{';
1593     interleaveComma(dictAttr.getValue(),
1594                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1595     os << '}';
1596 
1597   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1598     if (attrType.isSignlessInteger(1)) {
1599       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1600 
1601       // Boolean integer attributes always elides the type.
1602       return;
1603     }
1604 
1605     // Only print attributes as unsigned if they are explicitly unsigned or are
1606     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1607     // values print as signed.
1608     bool isUnsigned =
1609         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1610     intAttr.getValue().print(os, !isUnsigned);
1611 
1612     // IntegerAttr elides the type if I64.
1613     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1614       return;
1615 
1616   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1617     printFloatValue(floatAttr.getValue(), os);
1618 
1619     // FloatAttr elides the type if F64.
1620     if (typeElision == AttrTypeElision::May && attrType.isF64())
1621       return;
1622 
1623   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1624     os << '"';
1625     printEscapedString(strAttr.getValue(), os);
1626     os << '"';
1627 
1628   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1629     os << '[';
1630     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1631       printAttribute(attr, AttrTypeElision::May);
1632     });
1633     os << ']';
1634 
1635   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1636     os << "affine_map<";
1637     affineMapAttr.getValue().print(os);
1638     os << '>';
1639 
1640     // AffineMap always elides the type.
1641     return;
1642 
1643   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1644     os << "affine_set<";
1645     integerSetAttr.getValue().print(os);
1646     os << '>';
1647 
1648     // IntegerSet always elides the type.
1649     return;
1650 
1651   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1652     printType(typeAttr.getValue());
1653 
1654   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1655     printSymbolReference(refAttr.getRootReference(), os);
1656     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1657       os << "::";
1658       printSymbolReference(nestedRef.getValue(), os);
1659     }
1660 
1661   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1662     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1663       printElidedElementsAttr(os);
1664     } else {
1665       os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x"
1666          << llvm::toHex(opaqueAttr.getValue()) << "\">";
1667     }
1668 
1669   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1670     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1671       printElidedElementsAttr(os);
1672     } else {
1673       os << "dense<";
1674       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1675       os << '>';
1676     }
1677 
1678   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1679     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1680       printElidedElementsAttr(os);
1681     } else {
1682       os << "dense<";
1683       printDenseStringElementsAttr(strEltAttr);
1684       os << '>';
1685     }
1686 
1687   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1688     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1689         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1690       printElidedElementsAttr(os);
1691     } else {
1692       os << "sparse<";
1693       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1694       if (indices.getNumElements() != 0) {
1695         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1696         os << ", ";
1697         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1698       }
1699       os << '>';
1700     }
1701 
1702   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1703     printLocation(locAttr);
1704 
1705   } else {
1706     return printDialectAttribute(attr);
1707   }
1708 
1709   // Don't print the type if we must elide it, or if it is a None type.
1710   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1711     os << " : ";
1712     printType(attrType);
1713   }
1714 }
1715 
1716 /// Print the integer element of a DenseElementsAttr.
printDenseIntElement(const APInt & value,raw_ostream & os,bool isSigned)1717 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1718                                  bool isSigned) {
1719   if (value.getBitWidth() == 1)
1720     os << (value.getBoolValue() ? "true" : "false");
1721   else
1722     value.print(os, isSigned);
1723 }
1724 
1725 static void
printDenseElementsAttrImpl(bool isSplat,ShapedType type,raw_ostream & os,function_ref<void (unsigned)> printEltFn)1726 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1727                            function_ref<void(unsigned)> printEltFn) {
1728   // Special case for 0-d and splat tensors.
1729   if (isSplat)
1730     return printEltFn(0);
1731 
1732   // Special case for degenerate tensors.
1733   auto numElements = type.getNumElements();
1734   if (numElements == 0)
1735     return;
1736 
1737   // We use a mixed-radix counter to iterate through the shape. When we bump a
1738   // non-least-significant digit, we emit a close bracket. When we next emit an
1739   // element we re-open all closed brackets.
1740 
1741   // The mixed-radix counter, with radices in 'shape'.
1742   int64_t rank = type.getRank();
1743   SmallVector<unsigned, 4> counter(rank, 0);
1744   // The number of brackets that have been opened and not closed.
1745   unsigned openBrackets = 0;
1746 
1747   auto shape = type.getShape();
1748   auto bumpCounter = [&] {
1749     // Bump the least significant digit.
1750     ++counter[rank - 1];
1751     // Iterate backwards bubbling back the increment.
1752     for (unsigned i = rank - 1; i > 0; --i)
1753       if (counter[i] >= shape[i]) {
1754         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1755         counter[i] = 0;
1756         ++counter[i - 1];
1757         --openBrackets;
1758         os << ']';
1759       }
1760   };
1761 
1762   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1763     if (idx != 0)
1764       os << ", ";
1765     while (openBrackets++ < rank)
1766       os << '[';
1767     openBrackets = rank;
1768     printEltFn(idx);
1769     bumpCounter();
1770   }
1771   while (openBrackets-- > 0)
1772     os << ']';
1773 }
1774 
printDenseElementsAttr(DenseElementsAttr attr,bool allowHex)1775 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
1776                                            bool allowHex) {
1777   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1778     return printDenseStringElementsAttr(stringAttr);
1779 
1780   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1781                                 allowHex);
1782 }
1783 
printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,bool allowHex)1784 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1785                                                   bool allowHex) {
1786   auto type = attr.getType();
1787   auto elementType = type.getElementType();
1788 
1789   // Check to see if we should format this attribute as a hex string.
1790   auto numElements = type.getNumElements();
1791   if (!attr.isSplat() && allowHex &&
1792       shouldPrintElementsAttrWithHex(numElements)) {
1793     ArrayRef<char> rawData = attr.getRawData();
1794     if (llvm::support::endian::system_endianness() ==
1795         llvm::support::endianness::big) {
1796       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1797       // machines. It is converted here to print in LE format.
1798       SmallVector<char, 64> outDataVec(rawData.size());
1799       MutableArrayRef<char> convRawData(outDataVec);
1800       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1801           rawData, convRawData, type);
1802       os << '"' << "0x"
1803          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1804          << "\"";
1805     } else {
1806       os << '"' << "0x"
1807          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1808     }
1809 
1810     return;
1811   }
1812 
1813   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1814     Type complexElementType = complexTy.getElementType();
1815     // Note: The if and else below had a common lambda function which invoked
1816     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1817     // and hence was replaced.
1818     if (complexElementType.isa<IntegerType>()) {
1819       bool isSigned = !complexElementType.isUnsignedInteger();
1820       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1821         auto complexValue = *(attr.getComplexIntValues().begin() + index);
1822         os << "(";
1823         printDenseIntElement(complexValue.real(), os, isSigned);
1824         os << ",";
1825         printDenseIntElement(complexValue.imag(), os, isSigned);
1826         os << ")";
1827       });
1828     } else {
1829       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1830         auto complexValue = *(attr.getComplexFloatValues().begin() + index);
1831         os << "(";
1832         printFloatValue(complexValue.real(), os);
1833         os << ",";
1834         printFloatValue(complexValue.imag(), os);
1835         os << ")";
1836       });
1837     }
1838   } else if (elementType.isIntOrIndex()) {
1839     bool isSigned = !elementType.isUnsignedInteger();
1840     auto intValues = attr.getIntValues();
1841     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1842       printDenseIntElement(*(intValues.begin() + index), os, isSigned);
1843     });
1844   } else {
1845     assert(elementType.isa<FloatType>() && "unexpected element type");
1846     auto floatValues = attr.getFloatValues();
1847     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1848       printFloatValue(*(floatValues.begin() + index), os);
1849     });
1850   }
1851 }
1852 
printDenseStringElementsAttr(DenseStringElementsAttr attr)1853 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
1854   ArrayRef<StringRef> data = attr.getRawStringData();
1855   auto printFn = [&](unsigned index) {
1856     os << "\"";
1857     printEscapedString(data[index], os);
1858     os << "\"";
1859   };
1860   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1861 }
1862 
printType(Type type)1863 void ModulePrinter::printType(Type type) {
1864   if (!type) {
1865     os << "<<NULL TYPE>>";
1866     return;
1867   }
1868 
1869   // Try to print an alias for this type.
1870   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1871     return;
1872 
1873   TypeSwitch<Type>(type)
1874       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1875         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1876                            opaqueTy.getTypeData());
1877       })
1878       .Case<IndexType>([&](Type) { os << "index"; })
1879       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1880       .Case<Float16Type>([&](Type) { os << "f16"; })
1881       .Case<Float32Type>([&](Type) { os << "f32"; })
1882       .Case<Float64Type>([&](Type) { os << "f64"; })
1883       .Case<Float80Type>([&](Type) { os << "f80"; })
1884       .Case<Float128Type>([&](Type) { os << "f128"; })
1885       .Case<IntegerType>([&](IntegerType integerTy) {
1886         if (integerTy.isSigned())
1887           os << 's';
1888         else if (integerTy.isUnsigned())
1889           os << 'u';
1890         os << 'i' << integerTy.getWidth();
1891       })
1892       .Case<FunctionType>([&](FunctionType funcTy) {
1893         os << '(';
1894         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1895         os << ") -> ";
1896         ArrayRef<Type> results = funcTy.getResults();
1897         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1898           os << results[0];
1899         } else {
1900           os << '(';
1901           interleaveComma(results, [&](Type ty) { printType(ty); });
1902           os << ')';
1903         }
1904       })
1905       .Case<VectorType>([&](VectorType vectorTy) {
1906         os << "vector<";
1907         for (int64_t dim : vectorTy.getShape())
1908           os << dim << 'x';
1909         os << vectorTy.getElementType() << '>';
1910       })
1911       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1912         os << "tensor<";
1913         for (int64_t dim : tensorTy.getShape()) {
1914           if (ShapedType::isDynamic(dim))
1915             os << '?';
1916           else
1917             os << dim;
1918           os << 'x';
1919         }
1920         os << tensorTy.getElementType();
1921         // Only print the encoding attribute value if set.
1922         if (tensorTy.getEncoding()) {
1923           os << ", ";
1924           printAttribute(tensorTy.getEncoding());
1925         }
1926         os << '>';
1927       })
1928       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1929         os << "tensor<*x";
1930         printType(tensorTy.getElementType());
1931         os << '>';
1932       })
1933       .Case<MemRefType>([&](MemRefType memrefTy) {
1934         os << "memref<";
1935         for (int64_t dim : memrefTy.getShape()) {
1936           if (ShapedType::isDynamic(dim))
1937             os << '?';
1938           else
1939             os << dim;
1940           os << 'x';
1941         }
1942         printType(memrefTy.getElementType());
1943         for (auto map : memrefTy.getAffineMaps()) {
1944           os << ", ";
1945           printAttribute(AffineMapAttr::get(map));
1946         }
1947         // Only print the memory space if it is the non-default one.
1948         if (memrefTy.getMemorySpace()) {
1949           os << ", ";
1950           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
1951         }
1952         os << '>';
1953       })
1954       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
1955         os << "memref<*x";
1956         printType(memrefTy.getElementType());
1957         // Only print the memory space if it is the non-default one.
1958         if (memrefTy.getMemorySpace()) {
1959           os << ", ";
1960           printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
1961         }
1962         os << '>';
1963       })
1964       .Case<ComplexType>([&](ComplexType complexTy) {
1965         os << "complex<";
1966         printType(complexTy.getElementType());
1967         os << '>';
1968       })
1969       .Case<TupleType>([&](TupleType tupleTy) {
1970         os << "tuple<";
1971         interleaveComma(tupleTy.getTypes(),
1972                         [&](Type type) { printType(type); });
1973         os << '>';
1974       })
1975       .Case<NoneType>([&](Type) { os << "none"; })
1976       .Default([&](Type type) { return printDialectType(type); });
1977 }
1978 
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)1979 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1980                                           ArrayRef<StringRef> elidedAttrs,
1981                                           bool withKeyword) {
1982   // If there are no attributes, then there is nothing to be done.
1983   if (attrs.empty())
1984     return;
1985 
1986   // Functor used to print a filtered attribute list.
1987   auto printFilteredAttributesFn = [&](auto filteredAttrs) {
1988     // Print the 'attributes' keyword if necessary.
1989     if (withKeyword)
1990       os << " attributes";
1991 
1992     // Otherwise, print them all out in braces.
1993     os << " {";
1994     interleaveComma(filteredAttrs,
1995                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1996     os << '}';
1997   };
1998 
1999   // If no attributes are elided, we can directly print with no filtering.
2000   if (elidedAttrs.empty())
2001     return printFilteredAttributesFn(attrs);
2002 
2003   // Otherwise, filter out any attributes that shouldn't be included.
2004   llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2005                                                 elidedAttrs.end());
2006   auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2007     return !elidedAttrsSet.contains(attr.first.strref());
2008   });
2009   if (!filteredAttrs.empty())
2010     printFilteredAttributesFn(filteredAttrs);
2011 }
2012 
printNamedAttribute(NamedAttribute attr)2013 void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
2014   if (isBareIdentifier(attr.first)) {
2015     os << attr.first;
2016   } else {
2017     os << '"';
2018     printEscapedString(attr.first.strref(), os);
2019     os << '"';
2020   }
2021 
2022   // Pretty printing elides the attribute value for unit attributes.
2023   if (attr.second.isa<UnitAttr>())
2024     return;
2025 
2026   os << " = ";
2027   printAttribute(attr.second);
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // CustomDialectAsmPrinter
2032 //===----------------------------------------------------------------------===//
2033 
2034 namespace {
2035 /// This class provides the main specialization of the DialectAsmPrinter that is
2036 /// used to provide support for print attributes and types. This hooks allows
2037 /// for dialects to hook into the main ModulePrinter.
2038 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
2039 public:
CustomDialectAsmPrinter__anon6cb58f063e11::CustomDialectAsmPrinter2040   CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
~CustomDialectAsmPrinter__anon6cb58f063e11::CustomDialectAsmPrinter2041   ~CustomDialectAsmPrinter() override {}
2042 
getStream__anon6cb58f063e11::CustomDialectAsmPrinter2043   raw_ostream &getStream() const override { return printer.getStream(); }
2044 
2045   /// Print the given attribute to the stream.
printAttribute__anon6cb58f063e11::CustomDialectAsmPrinter2046   void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
2047 
2048   /// Print the given floating point value in a stablized form.
printFloat__anon6cb58f063e11::CustomDialectAsmPrinter2049   void printFloat(const APFloat &value) override {
2050     printFloatValue(value, getStream());
2051   }
2052 
2053   /// Print the given type to the stream.
printType__anon6cb58f063e11::CustomDialectAsmPrinter2054   void printType(Type type) override { printer.printType(type); }
2055 
2056   /// The main module printer.
2057   ModulePrinter &printer;
2058 };
2059 } // end anonymous namespace
2060 
printDialectAttribute(Attribute attr)2061 void ModulePrinter::printDialectAttribute(Attribute attr) {
2062   auto &dialect = attr.getDialect();
2063 
2064   // Ask the dialect to serialize the attribute to a string.
2065   std::string attrName;
2066   {
2067     llvm::raw_string_ostream attrNameStr(attrName);
2068     ModulePrinter subPrinter(attrNameStr, printerFlags, state);
2069     CustomDialectAsmPrinter printer(subPrinter);
2070     dialect.printAttribute(attr, printer);
2071   }
2072   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2073 }
2074 
printDialectType(Type type)2075 void ModulePrinter::printDialectType(Type type) {
2076   auto &dialect = type.getDialect();
2077 
2078   // Ask the dialect to serialize the type to a string.
2079   std::string typeName;
2080   {
2081     llvm::raw_string_ostream typeNameStr(typeName);
2082     ModulePrinter subPrinter(typeNameStr, printerFlags, state);
2083     CustomDialectAsmPrinter printer(subPrinter);
2084     dialect.printType(type, printer);
2085   }
2086   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2087 }
2088 
2089 //===----------------------------------------------------------------------===//
2090 // Affine expressions and maps
2091 //===----------------------------------------------------------------------===//
2092 
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)2093 void ModulePrinter::printAffineExpr(
2094     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2095   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2096 }
2097 
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)2098 void ModulePrinter::printAffineExprInternal(
2099     AffineExpr expr, BindingStrength enclosingTightness,
2100     function_ref<void(unsigned, bool)> printValueName) {
2101   const char *binopSpelling = nullptr;
2102   switch (expr.getKind()) {
2103   case AffineExprKind::SymbolId: {
2104     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2105     if (printValueName)
2106       printValueName(pos, /*isSymbol=*/true);
2107     else
2108       os << 's' << pos;
2109     return;
2110   }
2111   case AffineExprKind::DimId: {
2112     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2113     if (printValueName)
2114       printValueName(pos, /*isSymbol=*/false);
2115     else
2116       os << 'd' << pos;
2117     return;
2118   }
2119   case AffineExprKind::Constant:
2120     os << expr.cast<AffineConstantExpr>().getValue();
2121     return;
2122   case AffineExprKind::Add:
2123     binopSpelling = " + ";
2124     break;
2125   case AffineExprKind::Mul:
2126     binopSpelling = " * ";
2127     break;
2128   case AffineExprKind::FloorDiv:
2129     binopSpelling = " floordiv ";
2130     break;
2131   case AffineExprKind::CeilDiv:
2132     binopSpelling = " ceildiv ";
2133     break;
2134   case AffineExprKind::Mod:
2135     binopSpelling = " mod ";
2136     break;
2137   }
2138 
2139   auto binOp = expr.cast<AffineBinaryOpExpr>();
2140   AffineExpr lhsExpr = binOp.getLHS();
2141   AffineExpr rhsExpr = binOp.getRHS();
2142 
2143   // Handle tightly binding binary operators.
2144   if (binOp.getKind() != AffineExprKind::Add) {
2145     if (enclosingTightness == BindingStrength::Strong)
2146       os << '(';
2147 
2148     // Pretty print multiplication with -1.
2149     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2150     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2151         rhsConst.getValue() == -1) {
2152       os << "-";
2153       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2154       if (enclosingTightness == BindingStrength::Strong)
2155         os << ')';
2156       return;
2157     }
2158 
2159     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2160 
2161     os << binopSpelling;
2162     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2163 
2164     if (enclosingTightness == BindingStrength::Strong)
2165       os << ')';
2166     return;
2167   }
2168 
2169   // Print out special "pretty" forms for add.
2170   if (enclosingTightness == BindingStrength::Strong)
2171     os << '(';
2172 
2173   // Pretty print addition to a product that has a negative operand as a
2174   // subtraction.
2175   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2176     if (rhs.getKind() == AffineExprKind::Mul) {
2177       AffineExpr rrhsExpr = rhs.getRHS();
2178       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2179         if (rrhs.getValue() == -1) {
2180           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2181                                   printValueName);
2182           os << " - ";
2183           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2184             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2185                                     printValueName);
2186           } else {
2187             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2188                                     printValueName);
2189           }
2190 
2191           if (enclosingTightness == BindingStrength::Strong)
2192             os << ')';
2193           return;
2194         }
2195 
2196         if (rrhs.getValue() < -1) {
2197           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2198                                   printValueName);
2199           os << " - ";
2200           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2201                                   printValueName);
2202           os << " * " << -rrhs.getValue();
2203           if (enclosingTightness == BindingStrength::Strong)
2204             os << ')';
2205           return;
2206         }
2207       }
2208     }
2209   }
2210 
2211   // Pretty print addition to a negative number as a subtraction.
2212   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2213     if (rhsConst.getValue() < 0) {
2214       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2215       os << " - " << -rhsConst.getValue();
2216       if (enclosingTightness == BindingStrength::Strong)
2217         os << ')';
2218       return;
2219     }
2220   }
2221 
2222   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2223 
2224   os << " + ";
2225   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2226 
2227   if (enclosingTightness == BindingStrength::Strong)
2228     os << ')';
2229 }
2230 
printAffineConstraint(AffineExpr expr,bool isEq)2231 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
2232   printAffineExprInternal(expr, BindingStrength::Weak);
2233   isEq ? os << " == 0" : os << " >= 0";
2234 }
2235 
printAffineMap(AffineMap map)2236 void ModulePrinter::printAffineMap(AffineMap map) {
2237   // Dimension identifiers.
2238   os << '(';
2239   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2240     os << 'd' << i << ", ";
2241   if (map.getNumDims() >= 1)
2242     os << 'd' << map.getNumDims() - 1;
2243   os << ')';
2244 
2245   // Symbolic identifiers.
2246   if (map.getNumSymbols() != 0) {
2247     os << '[';
2248     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2249       os << 's' << i << ", ";
2250     if (map.getNumSymbols() >= 1)
2251       os << 's' << map.getNumSymbols() - 1;
2252     os << ']';
2253   }
2254 
2255   // Result affine expressions.
2256   os << " -> (";
2257   interleaveComma(map.getResults(),
2258                   [&](AffineExpr expr) { printAffineExpr(expr); });
2259   os << ')';
2260 }
2261 
printIntegerSet(IntegerSet set)2262 void ModulePrinter::printIntegerSet(IntegerSet set) {
2263   // Dimension identifiers.
2264   os << '(';
2265   for (unsigned i = 1; i < set.getNumDims(); ++i)
2266     os << 'd' << i - 1 << ", ";
2267   if (set.getNumDims() >= 1)
2268     os << 'd' << set.getNumDims() - 1;
2269   os << ')';
2270 
2271   // Symbolic identifiers.
2272   if (set.getNumSymbols() != 0) {
2273     os << '[';
2274     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2275       os << 's' << i << ", ";
2276     if (set.getNumSymbols() >= 1)
2277       os << 's' << set.getNumSymbols() - 1;
2278     os << ']';
2279   }
2280 
2281   // Print constraints.
2282   os << " : (";
2283   int numConstraints = set.getNumConstraints();
2284   for (int i = 1; i < numConstraints; ++i) {
2285     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2286     os << ", ";
2287   }
2288   if (numConstraints >= 1)
2289     printAffineConstraint(set.getConstraint(numConstraints - 1),
2290                           set.isEq(numConstraints - 1));
2291   os << ')';
2292 }
2293 
2294 //===----------------------------------------------------------------------===//
2295 // OperationPrinter
2296 //===----------------------------------------------------------------------===//
2297 
2298 namespace {
2299 /// This class contains the logic for printing operations, regions, and blocks.
2300 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
2301 public:
OperationPrinter(raw_ostream & os,OpPrintingFlags flags,AsmStateImpl & state)2302   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2303                             AsmStateImpl &state)
2304       : ModulePrinter(os, flags, &state) {}
2305 
2306   /// Print the given top-level operation.
2307   void printTopLevelOperation(Operation *op);
2308 
2309   /// Print the given operation with its indent and location.
2310   void print(Operation *op);
2311   /// Print the bare location, not including indentation/location/etc.
2312   void printOperation(Operation *op);
2313   /// Print the given operation in the generic form.
2314   void printGenericOp(Operation *op) override;
2315 
2316   /// Print the name of the given block.
2317   void printBlockName(Block *block);
2318 
2319   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2320   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2321   /// operation of the block is not printed.
2322   void print(Block *block, bool printBlockArgs = true,
2323              bool printBlockTerminator = true);
2324 
2325   /// Print the ID of the given value, optionally with its result number.
2326   void printValueID(Value value, bool printResultNo = true,
2327                     raw_ostream *streamOverride = nullptr) const;
2328 
2329   //===--------------------------------------------------------------------===//
2330   // OpAsmPrinter methods
2331   //===--------------------------------------------------------------------===//
2332 
2333   /// Return the current stream of the printer.
getStream() const2334   raw_ostream &getStream() const override { return os; }
2335 
2336   /// Print a newline and indent the printer to the start of the current
2337   /// operation.
printNewline()2338   void printNewline() override {
2339     os << newLine;
2340     os.indent(currentIndent);
2341   }
2342 
2343   /// Print the given type.
printType(Type type)2344   void printType(Type type) override { ModulePrinter::printType(type); }
2345 
2346   /// Print the given attribute.
printAttribute(Attribute attr)2347   void printAttribute(Attribute attr) override {
2348     ModulePrinter::printAttribute(attr);
2349   }
2350 
2351   /// Print the given attribute without its type. The corresponding parser must
2352   /// provide a valid type for the attribute.
printAttributeWithoutType(Attribute attr)2353   void printAttributeWithoutType(Attribute attr) override {
2354     ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
2355   }
2356 
2357   /// Print a block argument in the usual format of:
2358   ///   %ssaName : type {attr1=42} loc("here")
2359   /// where location printing is controlled by the standard internal option.
2360   /// You may pass omitType=true to not print a type, and pass an empty
2361   /// attribute list if you don't care for attributes.
2362   void printRegionArgument(BlockArgument arg,
2363                            ArrayRef<NamedAttribute> argAttrs = {},
2364                            bool omitType = false) override;
2365 
2366   /// Print the ID for the given value.
printOperand(Value value)2367   void printOperand(Value value) override { printValueID(value); }
printOperand(Value value,raw_ostream & os)2368   void printOperand(Value value, raw_ostream &os) override {
2369     printValueID(value, /*printResultNo=*/true, &os);
2370   }
2371 
2372   /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2373   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2374                              ArrayRef<StringRef> elidedAttrs = {}) override {
2375     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
2376   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2377   void printOptionalAttrDictWithKeyword(
2378       ArrayRef<NamedAttribute> attrs,
2379       ArrayRef<StringRef> elidedAttrs = {}) override {
2380     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
2381                                          /*withKeyword=*/true);
2382   }
2383 
2384   /// Print the given successor.
2385   void printSuccessor(Block *successor) override;
2386 
2387   /// Print an operation successor with the operands used for the block
2388   /// arguments.
2389   void printSuccessorAndUseList(Block *successor,
2390                                 ValueRange succOperands) override;
2391 
2392   /// Print the given region.
2393   void printRegion(Region &region, bool printEntryBlockArgs,
2394                    bool printBlockTerminators, bool printEmptyBlock) override;
2395 
2396   /// Renumber the arguments for the specified region to the same names as the
2397   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2398   /// operations. If any entry in namesToUse is null, the corresponding
2399   /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)2400   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2401     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2402   }
2403 
2404   /// Print the given affine map with the symbol and dimension operands printed
2405   /// inline with the map.
2406   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2407                               ValueRange operands) override;
2408 
2409   /// Print the given affine expression with the symbol and dimension operands
2410   /// printed inline with the expression.
2411   void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
2412                                ValueRange symOperands) override;
2413 
2414   /// Print the given string as a symbol reference.
printSymbolName(StringRef symbolRef)2415   void printSymbolName(StringRef symbolRef) override {
2416     ::printSymbolReference(symbolRef, os);
2417   }
2418 
2419 private:
2420   /// The number of spaces used for indenting nested operations.
2421   const static unsigned indentWidth = 2;
2422 
2423   // This is the current indentation level for nested structures.
2424   unsigned currentIndent = 0;
2425 };
2426 } // end anonymous namespace
2427 
printTopLevelOperation(Operation * op)2428 void OperationPrinter::printTopLevelOperation(Operation *op) {
2429   // Output the aliases at the top level that can't be deferred.
2430   state->getAliasState().printNonDeferredAliases(os, newLine);
2431 
2432   // Print the module.
2433   print(op);
2434   os << newLine;
2435 
2436   // Output the aliases at the top level that can be deferred.
2437   state->getAliasState().printDeferredAliases(os, newLine);
2438 }
2439 
2440 /// Print a block argument in the usual format of:
2441 ///   %ssaName : type {attr1=42} loc("here")
2442 /// where location printing is controlled by the standard internal option.
2443 /// You may pass omitType=true to not print a type, and pass an empty
2444 /// attribute list if you don't care for attributes.
printRegionArgument(BlockArgument arg,ArrayRef<NamedAttribute> argAttrs,bool omitType)2445 void OperationPrinter::printRegionArgument(BlockArgument arg,
2446                                            ArrayRef<NamedAttribute> argAttrs,
2447                                            bool omitType) {
2448   printOperand(arg);
2449   if (!omitType) {
2450     os << ": ";
2451     printType(arg.getType());
2452   }
2453   printOptionalAttrDict(argAttrs);
2454   // TODO: We should allow location aliases on block arguments.
2455   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2456 }
2457 
print(Operation * op)2458 void OperationPrinter::print(Operation *op) {
2459   // Track the location of this operation.
2460   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2461 
2462   os.indent(currentIndent);
2463   printOperation(op);
2464   printTrailingLocation(op->getLoc());
2465 }
2466 
printOperation(Operation * op)2467 void OperationPrinter::printOperation(Operation *op) {
2468   if (size_t numResults = op->getNumResults()) {
2469     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2470       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2471       if (resultCount > 1)
2472         os << ':' << resultCount;
2473     };
2474 
2475     // Check to see if this operation has multiple result groups.
2476     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2477     if (!resultGroups.empty()) {
2478       // Interleave the groups excluding the last one, this one will be handled
2479       // separately.
2480       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2481         printResultGroup(resultGroups[i],
2482                          resultGroups[i + 1] - resultGroups[i]);
2483       });
2484       os << ", ";
2485       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2486 
2487     } else {
2488       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2489     }
2490 
2491     os << " = ";
2492   }
2493 
2494   // If requested, always print the generic form.
2495   if (!printerFlags.shouldPrintGenericOpForm()) {
2496     // Check to see if this is a known operation.  If so, use the registered
2497     // custom printer hook.
2498     if (auto *opInfo = op->getAbstractOperation()) {
2499       opInfo->printAssembly(op, *this);
2500       return;
2501     }
2502     // Otherwise try to dispatch to the dialect, if available.
2503     if (Dialect *dialect = op->getDialect()) {
2504       if (succeeded(dialect->printOperation(op, *this)))
2505         return;
2506     }
2507   }
2508 
2509   // Otherwise print with the generic assembly form.
2510   printGenericOp(op);
2511 }
2512 
printGenericOp(Operation * op)2513 void OperationPrinter::printGenericOp(Operation *op) {
2514   os << '"';
2515   printEscapedString(op->getName().getStringRef(), os);
2516   os << "\"(";
2517   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2518   os << ')';
2519 
2520   // For terminators, print the list of successors and their operands.
2521   if (op->getNumSuccessors() != 0) {
2522     os << '[';
2523     interleaveComma(op->getSuccessors(),
2524                     [&](Block *successor) { printBlockName(successor); });
2525     os << ']';
2526   }
2527 
2528   // Print regions.
2529   if (op->getNumRegions() != 0) {
2530     os << " (";
2531     interleaveComma(op->getRegions(), [&](Region &region) {
2532       printRegion(region, /*printEntryBlockArgs=*/true,
2533                   /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
2534     });
2535     os << ')';
2536   }
2537 
2538   auto attrs = op->getAttrs();
2539   printOptionalAttrDict(attrs);
2540 
2541   // Print the type signature of the operation.
2542   os << " : ";
2543   printFunctionalType(op);
2544 }
2545 
printBlockName(Block * block)2546 void OperationPrinter::printBlockName(Block *block) {
2547   auto id = state->getSSANameState().getBlockID(block);
2548   if (id != SSANameState::NameSentinel)
2549     os << "^bb" << id;
2550   else
2551     os << "^INVALIDBLOCK";
2552 }
2553 
print(Block * block,bool printBlockArgs,bool printBlockTerminator)2554 void OperationPrinter::print(Block *block, bool printBlockArgs,
2555                              bool printBlockTerminator) {
2556   // Print the block label and argument list if requested.
2557   if (printBlockArgs) {
2558     os.indent(currentIndent);
2559     printBlockName(block);
2560 
2561     // Print the argument list if non-empty.
2562     if (!block->args_empty()) {
2563       os << '(';
2564       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2565         printValueID(arg);
2566         os << ": ";
2567         printType(arg.getType());
2568         // TODO: We should allow location aliases on block arguments.
2569         printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
2570       });
2571       os << ')';
2572     }
2573     os << ':';
2574 
2575     // Print out some context information about the predecessors of this block.
2576     if (!block->getParent()) {
2577       os << "  // block is not in a region!";
2578     } else if (block->hasNoPredecessors()) {
2579       os << "  // no predecessors";
2580     } else if (auto *pred = block->getSinglePredecessor()) {
2581       os << "  // pred: ";
2582       printBlockName(pred);
2583     } else {
2584       // We want to print the predecessors in increasing numeric order, not in
2585       // whatever order the use-list is in, so gather and sort them.
2586       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2587       for (auto *pred : block->getPredecessors())
2588         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2589       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2590 
2591       os << "  // " << predIDs.size() << " preds: ";
2592 
2593       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2594         printBlockName(pred.second);
2595       });
2596     }
2597     os << newLine;
2598   }
2599 
2600   currentIndent += indentWidth;
2601   bool hasTerminator =
2602       !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
2603   auto range = llvm::make_range(
2604       block->begin(),
2605       std::prev(block->end(),
2606                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
2607   for (auto &op : range) {
2608     print(&op);
2609     os << newLine;
2610   }
2611   currentIndent -= indentWidth;
2612 }
2613 
printValueID(Value value,bool printResultNo,raw_ostream * streamOverride) const2614 void OperationPrinter::printValueID(Value value, bool printResultNo,
2615                                     raw_ostream *streamOverride) const {
2616   state->getSSANameState().printValueID(value, printResultNo,
2617                                         streamOverride ? *streamOverride : os);
2618 }
2619 
printSuccessor(Block * successor)2620 void OperationPrinter::printSuccessor(Block *successor) {
2621   printBlockName(successor);
2622 }
2623 
printSuccessorAndUseList(Block * successor,ValueRange succOperands)2624 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2625                                                 ValueRange succOperands) {
2626   printBlockName(successor);
2627   if (succOperands.empty())
2628     return;
2629 
2630   os << '(';
2631   interleaveComma(succOperands,
2632                   [this](Value operand) { printValueID(operand); });
2633   os << " : ";
2634   interleaveComma(succOperands,
2635                   [this](Value operand) { printType(operand.getType()); });
2636   os << ')';
2637 }
2638 
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock)2639 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2640                                    bool printBlockTerminators,
2641                                    bool printEmptyBlock) {
2642   os << " {" << newLine;
2643   if (!region.empty()) {
2644     auto *entryBlock = &region.front();
2645     // Force printing the block header if printEmptyBlock is set and the block
2646     // is empty or if printEntryBlockArgs is set and there are arguments to
2647     // print.
2648     bool shouldAlwaysPrintBlockHeader =
2649         (printEmptyBlock && entryBlock->empty()) ||
2650         (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
2651     print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
2652     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2653       print(&b);
2654   }
2655   os.indent(currentIndent) << "}";
2656 }
2657 
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)2658 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2659                                               ValueRange operands) {
2660   AffineMap map = mapAttr.getValue();
2661   unsigned numDims = map.getNumDims();
2662   auto printValueName = [&](unsigned pos, bool isSymbol) {
2663     unsigned index = isSymbol ? numDims + pos : pos;
2664     assert(index < operands.size());
2665     if (isSymbol)
2666       os << "symbol(";
2667     printValueID(operands[index]);
2668     if (isSymbol)
2669       os << ')';
2670   };
2671 
2672   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2673     printAffineExpr(expr, printValueName);
2674   });
2675 }
2676 
printAffineExprOfSSAIds(AffineExpr expr,ValueRange dimOperands,ValueRange symOperands)2677 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
2678                                                ValueRange dimOperands,
2679                                                ValueRange symOperands) {
2680   auto printValueName = [&](unsigned pos, bool isSymbol) {
2681     if (!isSymbol)
2682       return printValueID(dimOperands[pos]);
2683     os << "symbol(";
2684     printValueID(symOperands[pos]);
2685     os << ')';
2686   };
2687   printAffineExpr(expr, printValueName);
2688 }
2689 
2690 //===----------------------------------------------------------------------===//
2691 // print and dump methods
2692 //===----------------------------------------------------------------------===//
2693 
print(raw_ostream & os) const2694 void Attribute::print(raw_ostream &os) const {
2695   ModulePrinter(os).printAttribute(*this);
2696 }
2697 
dump() const2698 void Attribute::dump() const {
2699   print(llvm::errs());
2700   llvm::errs() << "\n";
2701 }
2702 
print(raw_ostream & os) const2703 void Type::print(raw_ostream &os) const { ModulePrinter(os).printType(*this); }
2704 
dump() const2705 void Type::dump() const { print(llvm::errs()); }
2706 
dump() const2707 void AffineMap::dump() const {
2708   print(llvm::errs());
2709   llvm::errs() << "\n";
2710 }
2711 
dump() const2712 void IntegerSet::dump() const {
2713   print(llvm::errs());
2714   llvm::errs() << "\n";
2715 }
2716 
print(raw_ostream & os) const2717 void AffineExpr::print(raw_ostream &os) const {
2718   if (!expr) {
2719     os << "<<NULL AFFINE EXPR>>";
2720     return;
2721   }
2722   ModulePrinter(os).printAffineExpr(*this);
2723 }
2724 
dump() const2725 void AffineExpr::dump() const {
2726   print(llvm::errs());
2727   llvm::errs() << "\n";
2728 }
2729 
print(raw_ostream & os) const2730 void AffineMap::print(raw_ostream &os) const {
2731   if (!map) {
2732     os << "<<NULL AFFINE MAP>>";
2733     return;
2734   }
2735   ModulePrinter(os).printAffineMap(*this);
2736 }
2737 
print(raw_ostream & os) const2738 void IntegerSet::print(raw_ostream &os) const {
2739   ModulePrinter(os).printIntegerSet(*this);
2740 }
2741 
print(raw_ostream & os)2742 void Value::print(raw_ostream &os) {
2743   if (auto *op = getDefiningOp())
2744     return op->print(os);
2745   // TODO: Improve BlockArgument print'ing.
2746   BlockArgument arg = this->cast<BlockArgument>();
2747   os << "<block argument> of type '" << arg.getType()
2748      << "' at index: " << arg.getArgNumber();
2749 }
print(raw_ostream & os,AsmState & state)2750 void Value::print(raw_ostream &os, AsmState &state) {
2751   if (auto *op = getDefiningOp())
2752     return op->print(os, state);
2753 
2754   // TODO: Improve BlockArgument print'ing.
2755   BlockArgument arg = this->cast<BlockArgument>();
2756   os << "<block argument> of type '" << arg.getType()
2757      << "' at index: " << arg.getArgNumber();
2758 }
2759 
dump()2760 void Value::dump() {
2761   print(llvm::errs());
2762   llvm::errs() << "\n";
2763 }
2764 
printAsOperand(raw_ostream & os,AsmState & state)2765 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2766   // TODO: This doesn't necessarily capture all potential cases.
2767   // Currently, region arguments can be shadowed when printing the main
2768   // operation. If the IR hasn't been printed, this will produce the old SSA
2769   // name and not the shadowed name.
2770   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2771                                                  os);
2772 }
2773 
print(raw_ostream & os,const OpPrintingFlags & printerFlags)2774 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
2775   // If this is a top level operation, we also print aliases.
2776   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
2777     AsmState state(this, printerFlags);
2778     state.getImpl().initializeAliases(this);
2779     print(os, state, printerFlags);
2780     return;
2781   }
2782 
2783   // Find the operation to number from based upon the provided flags.
2784   Operation *op = this;
2785   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
2786   do {
2787     // If we are printing local scope, stop at the first operation that is
2788     // isolated from above.
2789     if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
2790       break;
2791 
2792     // Otherwise, traverse up to the next parent.
2793     Operation *parentOp = op->getParentOp();
2794     if (!parentOp)
2795       break;
2796     op = parentOp;
2797   } while (true);
2798 
2799   AsmState state(op, printerFlags);
2800   print(os, state, printerFlags);
2801 }
print(raw_ostream & os,AsmState & state,const OpPrintingFlags & flags)2802 void Operation::print(raw_ostream &os, AsmState &state,
2803                       const OpPrintingFlags &flags) {
2804   OperationPrinter printer(os, flags, state.getImpl());
2805   if (!getParent() && !flags.shouldUseLocalScope())
2806     printer.printTopLevelOperation(this);
2807   else
2808     printer.print(this);
2809 }
2810 
dump()2811 void Operation::dump() {
2812   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2813   llvm::errs() << "\n";
2814 }
2815 
print(raw_ostream & os)2816 void Block::print(raw_ostream &os) {
2817   Operation *parentOp = getParentOp();
2818   if (!parentOp) {
2819     os << "<<UNLINKED BLOCK>>\n";
2820     return;
2821   }
2822   // Get the top-level op.
2823   while (auto *nextOp = parentOp->getParentOp())
2824     parentOp = nextOp;
2825 
2826   AsmState state(parentOp);
2827   print(os, state);
2828 }
print(raw_ostream & os,AsmState & state)2829 void Block::print(raw_ostream &os, AsmState &state) {
2830   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2831 }
2832 
dump()2833 void Block::dump() { print(llvm::errs()); }
2834 
2835 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)2836 void Block::printAsOperand(raw_ostream &os, bool printType) {
2837   Operation *parentOp = getParentOp();
2838   if (!parentOp) {
2839     os << "<<UNLINKED BLOCK>>\n";
2840     return;
2841   }
2842   AsmState state(parentOp);
2843   printAsOperand(os, state);
2844 }
printAsOperand(raw_ostream & os,AsmState & state)2845 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2846   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2847   printer.printBlockName(this);
2848 }
2849