1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Function.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/Operation.h"
26 #include "mlir/IR/StandardTypes.h"
27 #include "mlir/Support/STLExtras.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/ScopedHashTable.h"
33 #include "llvm/ADT/SetVector.h"
34 #include "llvm/ADT/SmallString.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringSet.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/Regex.h"
39 #include "llvm/Support/SaveAndRestore.h"
40 using namespace mlir;
41 using namespace mlir::detail;
42 
print(raw_ostream & os) const43 void Identifier::print(raw_ostream &os) const { os << str(); }
44 
dump() const45 void Identifier::dump() const { print(llvm::errs()); }
46 
print(raw_ostream & os) const47 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
48 
dump() const49 void OperationName::dump() const { print(llvm::errs()); }
50 
~DialectAsmPrinter()51 DialectAsmPrinter::~DialectAsmPrinter() {}
52 
~OpAsmPrinter()53 OpAsmPrinter::~OpAsmPrinter() {}
54 
55 //===--------------------------------------------------------------------===//
56 // Operation OpAsm interface.
57 //===--------------------------------------------------------------------===//
58 
59 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
60 #include "mlir/IR/OpAsmInterface.cpp.inc"
61 
62 //===----------------------------------------------------------------------===//
63 // OpPrintingFlags
64 //===----------------------------------------------------------------------===//
65 
66 static llvm::cl::opt<unsigned> elideElementsAttrIfLarger(
67     "mlir-elide-elementsattrs-if-larger",
68     llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
69                    "more elements than the given upper limit"));
70 
71 static llvm::cl::opt<bool>
72     printDebugInfoOpt("mlir-print-debuginfo",
73                       llvm::cl::desc("Print debug info in MLIR output"),
74                       llvm::cl::init(false));
75 
76 static llvm::cl::opt<bool> printPrettyDebugInfoOpt(
77     "mlir-pretty-debuginfo",
78     llvm::cl::desc("Print pretty debug info in MLIR output"),
79     llvm::cl::init(false));
80 
81 // Use the generic op output form in the operation printer even if the custom
82 // form is defined.
83 static llvm::cl::opt<bool>
84     printGenericOpFormOpt("mlir-print-op-generic",
85                           llvm::cl::desc("Print the generic op form"),
86                           llvm::cl::init(false), llvm::cl::Hidden);
87 
88 static llvm::cl::opt<bool> printLocalScopeOpt(
89     "mlir-print-local-scope",
90     llvm::cl::desc("Print assuming in local scope by default"),
91     llvm::cl::init(false), llvm::cl::Hidden);
92 
93 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()94 OpPrintingFlags::OpPrintingFlags()
95     : elementsAttrElementLimit(
96           elideElementsAttrIfLarger.getNumOccurrences()
97               ? Optional<int64_t>(elideElementsAttrIfLarger)
98               : Optional<int64_t>()),
99       printDebugInfoFlag(printDebugInfoOpt),
100       printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt),
101       printGenericOpFormFlag(printGenericOpFormOpt),
102       printLocalScope(printLocalScopeOpt) {}
103 
104 /// Enable the elision of large elements attributes, by printing a '...'
105 /// instead of the element data, when the number of elements is greater than
106 /// `largeElementLimit`. Note: The IR generated with this option is not
107 /// parsable.
108 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)109 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
110   elementsAttrElementLimit = largeElementLimit;
111   return *this;
112 }
113 
114 /// Enable printing of debug information. If 'prettyForm' is set to true,
115 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)116 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
117   printDebugInfoFlag = true;
118   printDebugInfoPrettyFormFlag = prettyForm;
119   return *this;
120 }
121 
122 /// Always print operations in the generic form.
printGenericOpForm()123 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
124   printGenericOpFormFlag = true;
125   return *this;
126 }
127 
128 /// Use local scope when printing the operation. This allows for using the
129 /// printer in a more localized and thread-safe setting, but may not necessarily
130 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()131 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
132   printLocalScope = true;
133   return *this;
134 }
135 
136 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const137 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
138   return elementsAttrElementLimit.hasValue() &&
139          *elementsAttrElementLimit < int64_t(attr.getNumElements());
140 }
141 
142 /// Return if debug information should be printed.
shouldPrintDebugInfo() const143 bool OpPrintingFlags::shouldPrintDebugInfo() const {
144   return printDebugInfoFlag;
145 }
146 
147 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const148 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
149   return printDebugInfoPrettyFormFlag;
150 }
151 
152 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const153 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
154   return printGenericOpFormFlag;
155 }
156 
157 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const158 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
159 
160 //===----------------------------------------------------------------------===//
161 // AliasState
162 //===----------------------------------------------------------------------===//
163 
164 namespace {
165 /// This class manages the state for type and attribute aliases.
166 class AliasState {
167 public:
168   // Initialize the internal aliases.
169   void
170   initialize(Operation *op,
171              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
172 
173   /// Return a name used for an attribute alias, or empty if there is no alias.
174   Twine getAttributeAlias(Attribute attr) const;
175 
176   /// Print all of the referenced attribute aliases.
177   void printAttributeAliases(raw_ostream &os) const;
178 
179   /// Return a string to use as an alias for the given type, or empty if there
180   /// is no alias recorded.
181   StringRef getTypeAlias(Type ty) const;
182 
183   /// Print all of the referenced type aliases.
184   void printTypeAliases(raw_ostream &os) const;
185 
186 private:
187   /// A special index constant used for non-kind attribute aliases.
188   enum { NonAttrKindAlias = -1 };
189 
190   /// Record a reference to the given attribute.
191   void recordAttributeReference(Attribute attr);
192 
193   /// Record a reference to the given type.
194   void recordTypeReference(Type ty);
195 
196   // Visit functions.
197   void visitOperation(Operation *op);
198   void visitType(Type type);
199   void visitAttribute(Attribute attr);
200 
201   /// Set of attributes known to be used within the module.
202   llvm::SetVector<Attribute> usedAttributes;
203 
204   /// Mapping between attribute and a pair comprised of a base alias name and a
205   /// count suffix. If the suffix is set to -1, it is not displayed.
206   llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
207 
208   /// Mapping between attribute kind and a pair comprised of a base alias name
209   /// and a unique list of attributes belonging to this kind sorted by location
210   /// seen in the module.
211   llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
212       attrKindToAlias;
213 
214   /// Set of types known to be used within the module.
215   llvm::SetVector<Type> usedTypes;
216 
217   /// A mapping between a type and a given alias.
218   DenseMap<Type, StringRef> typeToAlias;
219 };
220 } // end anonymous namespace
221 
222 // Utility to generate a function to register a symbol alias.
canRegisterAlias(StringRef name,llvm::StringSet<> & usedAliases)223 static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
224   assert(!name.empty() && "expected alias name to be non-empty");
225   // TODO(riverriddle) Assert that the provided alias name can be lexed as
226   // an identifier.
227 
228   // Check that the alias doesn't contain a '.' character and the name is not
229   // already in use.
230   return !name.contains('.') && usedAliases.insert(name).second;
231 }
232 
initialize(Operation * op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)233 void AliasState::initialize(
234     Operation *op,
235     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
236   // Track the identifiers in use for each symbol so that the same identifier
237   // isn't used twice.
238   llvm::StringSet<> usedAliases;
239 
240   // Collect the set of aliases from each dialect.
241   SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
242   SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
243   SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
244 
245   // AffineMap/Integer set have specific kind aliases.
246   attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
247   attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
248 
249   for (auto &interface : interfaces) {
250     interface.getAttributeKindAliases(attributeKindAliases);
251     interface.getAttributeAliases(attributeAliases);
252     interface.getTypeAliases(typeAliases);
253   }
254 
255   // Setup the attribute kind aliases.
256   StringRef alias;
257   unsigned attrKind;
258   for (auto &attrAliasPair : attributeKindAliases) {
259     std::tie(attrKind, alias) = attrAliasPair;
260     assert(!alias.empty() && "expected non-empty alias string");
261     if (!usedAliases.count(alias) && !alias.contains('.'))
262       attrKindToAlias.insert({attrKind, {alias, {}}});
263   }
264 
265   // Clear the set of used identifiers so that the attribute kind aliases are
266   // just a prefix and not the full alias, i.e. there may be some overlap.
267   usedAliases.clear();
268 
269   // Register the attribute aliases.
270   // Create a regex for the attribute kind alias names, these have a prefix with
271   // a counter appended to the end. We prevent normal aliases from having these
272   // names to avoid collisions.
273   llvm::Regex reservedAttrNames("[0-9]+$");
274 
275   // Attribute value aliases.
276   Attribute attr;
277   for (auto &attrAliasPair : attributeAliases) {
278     std::tie(attr, alias) = attrAliasPair;
279     if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
280       attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
281   }
282 
283   // Clear the set of used identifiers as types can have the same identifiers as
284   // affine structures.
285   usedAliases.clear();
286 
287   // Type aliases.
288   for (auto &typeAliasPair : typeAliases)
289     if (canRegisterAlias(typeAliasPair.second, usedAliases))
290       typeToAlias.insert(typeAliasPair);
291 
292   // Traverse the given IR to generate the set of used attributes/types.
293   op->walk([&](Operation *op) { visitOperation(op); });
294 }
295 
296 /// Return a name used for an attribute alias, or empty if there is no alias.
getAttributeAlias(Attribute attr) const297 Twine AliasState::getAttributeAlias(Attribute attr) const {
298   auto alias = attrToAlias.find(attr);
299   if (alias == attrToAlias.end())
300     return Twine();
301 
302   // Return the alias for this attribute, along with the index if this was
303   // generated by a kind alias.
304   int kindIndex = alias->second.second;
305   return alias->second.first +
306          (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
307 }
308 
309 /// Print all of the referenced attribute aliases.
printAttributeAliases(raw_ostream & os) const310 void AliasState::printAttributeAliases(raw_ostream &os) const {
311   auto printAlias = [&](StringRef alias, Attribute attr, int index) {
312     os << '#' << alias;
313     if (index != NonAttrKindAlias)
314       os << index;
315     os << " = " << attr << '\n';
316   };
317 
318   // Print all of the attribute kind aliases.
319   for (auto &kindAlias : attrKindToAlias) {
320     auto &aliasAttrsPair = kindAlias.second;
321     for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
322       printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
323     os << "\n";
324   }
325 
326   // In a second pass print all of the remaining attribute aliases that aren't
327   // kind aliases.
328   for (Attribute attr : usedAttributes) {
329     auto alias = attrToAlias.find(attr);
330     if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
331       printAlias(alias->second.first, attr, alias->second.second);
332   }
333 }
334 
335 /// Return a string to use as an alias for the given type, or empty if there
336 /// is no alias recorded.
getTypeAlias(Type ty) const337 StringRef AliasState::getTypeAlias(Type ty) const {
338   return typeToAlias.lookup(ty);
339 }
340 
341 /// Print all of the referenced type aliases.
printTypeAliases(raw_ostream & os) const342 void AliasState::printTypeAliases(raw_ostream &os) const {
343   for (Type type : usedTypes) {
344     auto alias = typeToAlias.find(type);
345     if (alias != typeToAlias.end())
346       os << '!' << alias->second << " = type " << type << '\n';
347   }
348 }
349 
350 /// Record a reference to the given attribute.
recordAttributeReference(Attribute attr)351 void AliasState::recordAttributeReference(Attribute attr) {
352   // Don't recheck attributes that have already been seen or those that
353   // already have an alias.
354   if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
355     return;
356 
357   // If this attribute kind has an alias, then record one for this attribute.
358   auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
359   if (alias == attrKindToAlias.end())
360     return;
361   std::pair<StringRef, int> attrAlias(alias->second.first,
362                                       alias->second.second.size());
363   attrToAlias.insert({attr, attrAlias});
364   alias->second.second.push_back(attr);
365 }
366 
367 /// Record a reference to the given type.
recordTypeReference(Type ty)368 void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
369 
370 // TODO Support visiting other types/operations when implemented.
visitType(Type type)371 void AliasState::visitType(Type type) {
372   recordTypeReference(type);
373 
374   if (auto funcType = type.dyn_cast<FunctionType>()) {
375     // Visit input and result types for functions.
376     for (auto input : funcType.getInputs())
377       visitType(input);
378     for (auto result : funcType.getResults())
379       visitType(result);
380   } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
381     visitType(shapedType.getElementType());
382 
383     // Visit affine maps in memref type.
384     if (auto memref = type.dyn_cast<MemRefType>())
385       for (auto map : memref.getAffineMaps())
386         recordAttributeReference(AffineMapAttr::get(map));
387   }
388 }
389 
visitAttribute(Attribute attr)390 void AliasState::visitAttribute(Attribute attr) {
391   recordAttributeReference(attr);
392 
393   if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
394     for (auto elt : arrayAttr.getValue())
395       visitAttribute(elt);
396   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
397     visitType(typeAttr.getValue());
398   }
399 }
400 
visitOperation(Operation * op)401 void AliasState::visitOperation(Operation *op) {
402   // Visit all the types used in the operation.
403   for (auto type : op->getOperandTypes())
404     visitType(type);
405   for (auto type : op->getResultTypes())
406     visitType(type);
407   for (auto &region : op->getRegions())
408     for (auto &block : region)
409       for (auto arg : block.getArguments())
410         visitType(arg.getType());
411 
412   // Visit each of the attributes.
413   for (auto elt : op->getAttrs())
414     visitAttribute(elt.second);
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // SSANameState
419 //===----------------------------------------------------------------------===//
420 
421 namespace {
422 /// This class manages the state of SSA value names.
423 class SSANameState {
424 public:
425   /// A sentinal value used for values with names set.
426   enum : unsigned { NameSentinel = ~0U };
427 
428   SSANameState(Operation *op,
429                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
430 
431   /// Print the SSA identifier for the given value to 'stream'. If
432   /// 'printResultNo' is true, it also presents the result number ('#' number)
433   /// of this value.
434   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
435 
436   /// Return the result indices for each of the result groups registered by this
437   /// operation, or empty if none exist.
438   ArrayRef<int> getOpResultGroups(Operation *op);
439 
440   /// Get the ID for the given block.
441   unsigned getBlockID(Block *block);
442 
443   /// Renumber the arguments for the specified region to the same names as the
444   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
445   /// details.
446   void shadowRegionArgs(Region &region, ValueRange namesToUse);
447 
448 private:
449   /// Number the SSA values within the given IR unit.
450   void numberValuesInRegion(
451       Region &region,
452       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
453   void numberValuesInBlock(
454       Block &block,
455       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
456   void numberValuesInOp(
457       Operation &op,
458       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
459 
460   /// Given a result of an operation 'result', find the result group head
461   /// 'lookupValue' and the result of 'result' within that group in
462   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
463   /// has more than 1 result.
464   void getResultIDAndNumber(OpResult result, Value &lookupValue,
465                             Optional<int> &lookupResultNo) const;
466 
467   /// Set a special value name for the given value.
468   void setValueName(Value value, StringRef name);
469 
470   /// Uniques the given value name within the printer. If the given name
471   /// conflicts, it is automatically renamed.
472   StringRef uniqueValueName(StringRef name);
473 
474   /// This is the value ID for each SSA value. If this returns NameSentinel,
475   /// then the valueID has an entry in valueNames.
476   DenseMap<Value, unsigned> valueIDs;
477   DenseMap<Value, StringRef> valueNames;
478 
479   /// This is a map of operations that contain multiple named result groups,
480   /// i.e. there may be multiple names for the results of the operation. The
481   /// value of this map are the result numbers that start a result group.
482   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
483 
484   /// This is the block ID for each block in the current.
485   DenseMap<Block *, unsigned> blockIDs;
486 
487   /// This keeps track of all of the non-numeric names that are in flight,
488   /// allowing us to check for duplicates.
489   /// Note: the value of the map is unused.
490   llvm::ScopedHashTable<StringRef, char> usedNames;
491   llvm::BumpPtrAllocator usedNameAllocator;
492 
493   /// This is the next value ID to assign in numbering.
494   unsigned nextValueID = 0;
495   /// This is the next ID to assign to a region entry block argument.
496   unsigned nextArgumentID = 0;
497   /// This is the next ID to assign when a name conflict is detected.
498   unsigned nextConflictID = 0;
499 };
500 } // end anonymous namespace
501 
SSANameState(Operation * op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)502 SSANameState::SSANameState(
503     Operation *op,
504     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
505   llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
506   numberValuesInOp(*op, interfaces);
507 
508   for (auto &region : op->getRegions())
509     numberValuesInRegion(region, interfaces);
510 }
511 
printValueID(Value value,bool printResultNo,raw_ostream & stream) const512 void SSANameState::printValueID(Value value, bool printResultNo,
513                                 raw_ostream &stream) const {
514   if (!value) {
515     stream << "<<NULL>>";
516     return;
517   }
518 
519   Optional<int> resultNo;
520   auto lookupValue = value;
521 
522   // If this is an operation result, collect the head lookup value of the result
523   // group and the result number of 'result' within that group.
524   if (OpResult result = value.dyn_cast<OpResult>())
525     getResultIDAndNumber(result, lookupValue, resultNo);
526 
527   auto it = valueIDs.find(lookupValue);
528   if (it == valueIDs.end()) {
529     stream << "<<UNKNOWN SSA VALUE>>";
530     return;
531   }
532 
533   stream << '%';
534   if (it->second != NameSentinel) {
535     stream << it->second;
536   } else {
537     auto nameIt = valueNames.find(lookupValue);
538     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
539     stream << nameIt->second;
540   }
541 
542   if (resultNo.hasValue() && printResultNo)
543     stream << '#' << resultNo;
544 }
545 
getOpResultGroups(Operation * op)546 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
547   auto it = opResultGroups.find(op);
548   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
549 }
550 
getBlockID(Block * block)551 unsigned SSANameState::getBlockID(Block *block) {
552   auto it = blockIDs.find(block);
553   return it != blockIDs.end() ? it->second : NameSentinel;
554 }
555 
shadowRegionArgs(Region & region,ValueRange namesToUse)556 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
557   assert(!region.empty() && "cannot shadow arguments of an empty region");
558   assert(region.front().getNumArguments() == namesToUse.size() &&
559          "incorrect number of names passed in");
560   assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
561          "only KnownIsolatedFromAbove ops can shadow names");
562 
563   SmallVector<char, 16> nameStr;
564   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
565     auto nameToUse = namesToUse[i];
566     if (nameToUse == nullptr)
567       continue;
568     auto nameToReplace = region.front().getArgument(i);
569 
570     nameStr.clear();
571     llvm::raw_svector_ostream nameStream(nameStr);
572     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
573 
574     // Entry block arguments should already have a pretty "arg" name.
575     assert(valueIDs[nameToReplace] == NameSentinel);
576 
577     // Use the name without the leading %.
578     auto name = StringRef(nameStream.str()).drop_front();
579 
580     // Overwrite the name.
581     valueNames[nameToReplace] = name.copy(usedNameAllocator);
582   }
583 }
584 
numberValuesInRegion(Region & region,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)585 void SSANameState::numberValuesInRegion(
586     Region &region,
587     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
588   // Save the current value ids to allow for numbering values in sibling regions
589   // the same.
590   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
591   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
592   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
593 
594   // Push a new used names scope.
595   llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
596 
597   // Number the values within this region in a breadth-first order.
598   unsigned nextBlockID = 0;
599   for (auto &block : region) {
600     // Each block gets a unique ID, and all of the operations within it get
601     // numbered as well.
602     blockIDs[&block] = nextBlockID++;
603     numberValuesInBlock(block, interfaces);
604   }
605 
606   // After that we traverse the nested regions.
607   // TODO: Rework this loop to not use recursion.
608   for (auto &block : region) {
609     for (auto &op : block)
610       for (auto &nestedRegion : op.getRegions())
611         numberValuesInRegion(nestedRegion, interfaces);
612   }
613 }
614 
numberValuesInBlock(Block & block,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)615 void SSANameState::numberValuesInBlock(
616     Block &block,
617     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
618   auto setArgNameFn = [&](Value arg, StringRef name) {
619     assert(!valueIDs.count(arg) && "arg numbered multiple times");
620     assert(arg.cast<BlockArgument>().getOwner() == &block &&
621            "arg not defined in 'block'");
622     setValueName(arg, name);
623   };
624 
625   bool isEntryBlock = block.isEntryBlock();
626   if (isEntryBlock) {
627     if (auto *op = block.getParentOp()) {
628       if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
629         asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
630     }
631   }
632 
633   // Number the block arguments. We give entry block arguments a special name
634   // 'arg'.
635   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
636   llvm::raw_svector_ostream specialName(specialNameBuffer);
637   for (auto arg : block.getArguments()) {
638     if (valueIDs.count(arg))
639       continue;
640     if (isEntryBlock) {
641       specialNameBuffer.resize(strlen("arg"));
642       specialName << nextArgumentID++;
643     }
644     setValueName(arg, specialName.str());
645   }
646 
647   // Number the operations in this block.
648   for (auto &op : block)
649     numberValuesInOp(op, interfaces);
650 }
651 
numberValuesInOp(Operation & op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)652 void SSANameState::numberValuesInOp(
653     Operation &op,
654     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
655   unsigned numResults = op.getNumResults();
656   if (numResults == 0)
657     return;
658   Value resultBegin = op.getResult(0);
659 
660   // Function used to set the special result names for the operation.
661   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
662   auto setResultNameFn = [&](Value result, StringRef name) {
663     assert(!valueIDs.count(result) && "result numbered multiple times");
664     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
665     setValueName(result, name);
666 
667     // Record the result number for groups not anchored at 0.
668     if (int resultNo = result.cast<OpResult>().getResultNumber())
669       resultGroups.push_back(resultNo);
670   };
671   if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
672     asmInterface.getAsmResultNames(setResultNameFn);
673   else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
674     asmInterface->getAsmResultNames(&op, setResultNameFn);
675 
676   // If the first result wasn't numbered, give it a default number.
677   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
678     ++nextValueID;
679 
680   // If this operation has multiple result groups, mark it.
681   if (resultGroups.size() != 1) {
682     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
683     opResultGroups.try_emplace(&op, std::move(resultGroups));
684   }
685 }
686 
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const687 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
688                                         Optional<int> &lookupResultNo) const {
689   Operation *owner = result.getOwner();
690   if (owner->getNumResults() == 1)
691     return;
692   int resultNo = result.getResultNumber();
693 
694   // If this operation has multiple result groups, we will need to find the
695   // one corresponding to this result.
696   auto resultGroupIt = opResultGroups.find(owner);
697   if (resultGroupIt == opResultGroups.end()) {
698     // If not, just use the first result.
699     lookupResultNo = resultNo;
700     lookupValue = owner->getResult(0);
701     return;
702   }
703 
704   // Find the correct index using a binary search, as the groups are ordered.
705   ArrayRef<int> resultGroups = resultGroupIt->second;
706   auto it = llvm::upper_bound(resultGroups, resultNo);
707   int groupResultNo = 0, groupSize = 0;
708 
709   // If there are no smaller elements, the last result group is the lookup.
710   if (it == resultGroups.end()) {
711     groupResultNo = resultGroups.back();
712     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
713   } else {
714     // Otherwise, the previous element is the lookup.
715     groupResultNo = *std::prev(it);
716     groupSize = *it - groupResultNo;
717   }
718 
719   // We only record the result number for a group of size greater than 1.
720   if (groupSize != 1)
721     lookupResultNo = resultNo - groupResultNo;
722   lookupValue = owner->getResult(groupResultNo);
723 }
724 
setValueName(Value value,StringRef name)725 void SSANameState::setValueName(Value value, StringRef name) {
726   // If the name is empty, the value uses the default numbering.
727   if (name.empty()) {
728     valueIDs[value] = nextValueID++;
729     return;
730   }
731 
732   valueIDs[value] = NameSentinel;
733   valueNames[value] = uniqueValueName(name);
734 }
735 
uniqueValueName(StringRef name)736 StringRef SSANameState::uniqueValueName(StringRef name) {
737   // Check to see if this name is already unique.
738   if (!usedNames.count(name)) {
739     name = name.copy(usedNameAllocator);
740   } else {
741     // Otherwise, we had a conflict - probe until we find a unique name. This
742     // is guaranteed to terminate (and usually in a single iteration) because it
743     // generates new names by incrementing nextConflictID.
744     SmallString<64> probeName(name);
745     probeName.push_back('_');
746     while (true) {
747       probeName.resize(name.size() + 1);
748       probeName += llvm::utostr(nextConflictID++);
749       if (!usedNames.count(probeName)) {
750         name = StringRef(probeName).copy(usedNameAllocator);
751         break;
752       }
753     }
754   }
755 
756   usedNames.insert(name, char());
757   return name;
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // AsmState
762 //===----------------------------------------------------------------------===//
763 
764 namespace mlir {
765 namespace detail {
766 class AsmStateImpl {
767 public:
AsmStateImpl(Operation * op)768   explicit AsmStateImpl(Operation *op)
769       : interfaces(op->getContext()), nameState(op, interfaces) {}
770 
771   /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op)772   void initializeAliases(Operation *op) {
773     aliasState.initialize(op, interfaces);
774   }
775 
776   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
777   /// null if one wasn't registered.
getOpAsmInterface(Dialect * dialect)778   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
779     return interfaces.getInterfaceFor(dialect);
780   }
781 
782   /// Get the state used for aliases.
getAliasState()783   AliasState &getAliasState() { return aliasState; }
784 
785   /// Get the state used for SSA names.
getSSANameState()786   SSANameState &getSSANameState() { return nameState; }
787 
788 private:
789   /// Collection of OpAsm interfaces implemented in the context.
790   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
791 
792   /// The state used for attribute and type aliases.
793   AliasState aliasState;
794 
795   /// The state used for SSA value names.
796   SSANameState nameState;
797 };
798 } // end namespace detail
799 } // end namespace mlir
800 
AsmState(Operation * op)801 AsmState::AsmState(Operation *op) : impl(std::make_unique<AsmStateImpl>(op)) {}
~AsmState()802 AsmState::~AsmState() {}
803 
804 //===----------------------------------------------------------------------===//
805 // ModulePrinter
806 //===----------------------------------------------------------------------===//
807 
808 namespace {
809 class ModulePrinter {
810 public:
ModulePrinter(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)811   ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
812                 AsmStateImpl *state = nullptr)
813       : os(os), printerFlags(flags), state(state) {}
ModulePrinter(ModulePrinter & printer)814   explicit ModulePrinter(ModulePrinter &printer)
815       : os(printer.os), printerFlags(printer.printerFlags),
816         state(printer.state) {}
817 
818   /// Returns the output stream of the printer.
getStream()819   raw_ostream &getStream() { return os; }
820 
821   template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor each_fn) const822   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
823     mlir::interleaveComma(c, os, each_fn);
824   }
825 
826   /// Print the given attribute. If 'mayElideType' is true, some attributes are
827   /// printed without the type when the type matches the default used in the
828   /// parser (for example i64 is the default for integer attributes).
829   void printAttribute(Attribute attr, bool mayElideType = false);
830 
831   void printType(Type type);
832   void printLocation(LocationAttr loc);
833 
834   void printAffineMap(AffineMap map);
835   void
836   printAffineExpr(AffineExpr expr,
837                   function_ref<void(unsigned, bool)> printValueName = nullptr);
838   void printAffineConstraint(AffineExpr expr, bool isEq);
839   void printIntegerSet(IntegerSet set);
840 
841 protected:
842   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
843                              ArrayRef<StringRef> elidedAttrs = {},
844                              bool withKeyword = false);
845   void printTrailingLocation(Location loc);
846   void printLocationInternal(LocationAttr loc, bool pretty = false);
847   void printDenseElementsAttr(DenseElementsAttr attr);
848 
849   void printDialectAttribute(Attribute attr);
850   void printDialectType(Type type);
851 
852   /// This enum is used to represent the binding strength of the enclosing
853   /// context that an AffineExprStorage is being printed in, so we can
854   /// intelligently produce parens.
855   enum class BindingStrength {
856     Weak,   // + and -
857     Strong, // All other binary operators.
858   };
859   void printAffineExprInternal(
860       AffineExpr expr, BindingStrength enclosingTightness,
861       function_ref<void(unsigned, bool)> printValueName = nullptr);
862 
863   /// The output stream for the printer.
864   raw_ostream &os;
865 
866   /// A set of flags to control the printer's behavior.
867   OpPrintingFlags printerFlags;
868 
869   /// An optional printer state for the module.
870   AsmStateImpl *state;
871 };
872 } // end anonymous namespace
873 
printTrailingLocation(Location loc)874 void ModulePrinter::printTrailingLocation(Location loc) {
875   // Check to see if we are printing debug information.
876   if (!printerFlags.shouldPrintDebugInfo())
877     return;
878 
879   os << " ";
880   printLocation(loc);
881 }
882 
printLocationInternal(LocationAttr loc,bool pretty)883 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
884   switch (loc.getKind()) {
885   case StandardAttributes::OpaqueLocation:
886     printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
887     break;
888   case StandardAttributes::UnknownLocation:
889     if (pretty)
890       os << "[unknown]";
891     else
892       os << "unknown";
893     break;
894   case StandardAttributes::FileLineColLocation: {
895     auto fileLoc = loc.cast<FileLineColLoc>();
896     auto mayQuote = pretty ? "" : "\"";
897     os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
898        << fileLoc.getLine() << ':' << fileLoc.getColumn();
899     break;
900   }
901   case StandardAttributes::NameLocation: {
902     auto nameLoc = loc.cast<NameLoc>();
903     os << '\"' << nameLoc.getName() << '\"';
904 
905     // Print the child if it isn't unknown.
906     auto childLoc = nameLoc.getChildLoc();
907     if (!childLoc.isa<UnknownLoc>()) {
908       os << '(';
909       printLocationInternal(childLoc, pretty);
910       os << ')';
911     }
912     break;
913   }
914   case StandardAttributes::CallSiteLocation: {
915     auto callLocation = loc.cast<CallSiteLoc>();
916     auto caller = callLocation.getCaller();
917     auto callee = callLocation.getCallee();
918     if (!pretty)
919       os << "callsite(";
920     printLocationInternal(callee, pretty);
921     if (pretty) {
922       if (callee.isa<NameLoc>()) {
923         if (caller.isa<FileLineColLoc>()) {
924           os << " at ";
925         } else {
926           os << "\n at ";
927         }
928       } else {
929         os << "\n at ";
930       }
931     } else {
932       os << " at ";
933     }
934     printLocationInternal(caller, pretty);
935     if (!pretty)
936       os << ")";
937     break;
938   }
939   case StandardAttributes::FusedLocation: {
940     auto fusedLoc = loc.cast<FusedLoc>();
941     if (!pretty)
942       os << "fused";
943     if (auto metadata = fusedLoc.getMetadata())
944       os << '<' << metadata << '>';
945     os << '[';
946     interleave(
947         fusedLoc.getLocations(),
948         [&](Location loc) { printLocationInternal(loc, pretty); },
949         [&]() { os << ", "; });
950     os << ']';
951     break;
952   }
953   }
954 }
955 
956 /// Print a floating point value in a way that the parser will be able to
957 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)958 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
959   // We would like to output the FP constant value in exponential notation,
960   // but we cannot do this if doing so will lose precision.  Check here to
961   // make sure that we only output it in exponential format if we can parse
962   // the value back and get the same value.
963   bool isInf = apValue.isInfinity();
964   bool isNaN = apValue.isNaN();
965   if (!isInf && !isNaN) {
966     SmallString<128> strValue;
967     apValue.toString(strValue, 6, 0, false);
968 
969     // Check to make sure that the stringized number is not some string like
970     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
971     // that the string matches the "[-+]?[0-9]" regex.
972     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
973             ((strValue[0] == '-' || strValue[0] == '+') &&
974              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
975            "[-+]?[0-9] regex does not match!");
976 
977     // Parse back the stringized version and check that the value is equal
978     // (i.e., there is no precision loss). If it is not, use the default format
979     // of APFloat instead of the exponential notation.
980     if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
981       strValue.clear();
982       apValue.toString(strValue);
983     }
984     os << strValue;
985     return;
986   }
987 
988   // Print special values in hexadecimal format.  The sign bit should be
989   // included in the literal.
990   SmallVector<char, 16> str;
991   APInt apInt = apValue.bitcastToAPInt();
992   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
993                  /*formatAsCLiteral=*/true);
994   os << str;
995 }
996 
printLocation(LocationAttr loc)997 void ModulePrinter::printLocation(LocationAttr loc) {
998   if (printerFlags.shouldPrintDebugInfoPrettyForm()) {
999     printLocationInternal(loc, /*pretty=*/true);
1000   } else {
1001     os << "loc(";
1002     printLocationInternal(loc);
1003     os << ')';
1004   }
1005 }
1006 
1007 /// Returns if the given dialect symbol data is simple enough to print in the
1008 /// pretty form, i.e. without the enclosing "".
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1009 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1010   // The name must start with an identifier.
1011   if (symName.empty() || !isalpha(symName.front()))
1012     return false;
1013 
1014   // Ignore all the characters that are valid in an identifier in the symbol
1015   // name.
1016   symName = symName.drop_while(
1017       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1018   if (symName.empty())
1019     return true;
1020 
1021   // If we got to an unexpected character, then it must be a <>.  Check those
1022   // recursively.
1023   if (symName.front() != '<' || symName.back() != '>')
1024     return false;
1025 
1026   SmallVector<char, 8> nestedPunctuation;
1027   do {
1028     // If we ran out of characters, then we had a punctuation mismatch.
1029     if (symName.empty())
1030       return false;
1031 
1032     auto c = symName.front();
1033     symName = symName.drop_front();
1034 
1035     switch (c) {
1036     // We never allow null characters. This is an EOF indicator for the lexer
1037     // which we could handle, but isn't important for any known dialect.
1038     case '\0':
1039       return false;
1040     case '<':
1041     case '[':
1042     case '(':
1043     case '{':
1044       nestedPunctuation.push_back(c);
1045       continue;
1046     case '-':
1047       // Treat `->` as a special token.
1048       if (!symName.empty() && symName.front() == '>') {
1049         symName = symName.drop_front();
1050         continue;
1051       }
1052       break;
1053     // Reject types with mismatched brackets.
1054     case '>':
1055       if (nestedPunctuation.pop_back_val() != '<')
1056         return false;
1057       break;
1058     case ']':
1059       if (nestedPunctuation.pop_back_val() != '[')
1060         return false;
1061       break;
1062     case ')':
1063       if (nestedPunctuation.pop_back_val() != '(')
1064         return false;
1065       break;
1066     case '}':
1067       if (nestedPunctuation.pop_back_val() != '{')
1068         return false;
1069       break;
1070     default:
1071       continue;
1072     }
1073 
1074     // We're done when the punctuation is fully matched.
1075   } while (!nestedPunctuation.empty());
1076 
1077   // If there were extra characters, then we failed.
1078   return symName.empty();
1079 }
1080 
1081 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1082 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1083                                StringRef dialectName, StringRef symString) {
1084   os << symPrefix << dialectName;
1085 
1086   // If this symbol name is simple enough, print it directly in pretty form,
1087   // otherwise, we print it as an escaped string.
1088   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1089     os << '.' << symString;
1090     return;
1091   }
1092 
1093   // TODO: escape the symbol name, it could contain " characters.
1094   os << "<\"" << symString << "\">";
1095 }
1096 
1097 /// Returns if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1098 static bool isBareIdentifier(StringRef name) {
1099   assert(!name.empty() && "invalid name");
1100 
1101   // By making this unsigned, the value passed in to isalnum will always be
1102   // in the range 0-255. This is important when building with MSVC because
1103   // its implementation will assert. This situation can arise when dealing
1104   // with UTF-8 multibyte characters.
1105   unsigned char firstChar = static_cast<unsigned char>(name[0]);
1106   if (!isalpha(firstChar) && firstChar != '_')
1107     return false;
1108   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1109     return isalnum(c) || c == '_' || c == '$' || c == '.';
1110   });
1111 }
1112 
1113 /// Print the given string as a symbol reference. A symbol reference is
1114 /// represented as a string prefixed with '@'. The reference is surrounded with
1115 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1116 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1117   assert(!symbolRef.empty() && "expected valid symbol reference");
1118 
1119   // If the symbol can be represented as a bare identifier, write it directly.
1120   if (isBareIdentifier(symbolRef)) {
1121     os << '@' << symbolRef;
1122     return;
1123   }
1124 
1125   // Otherwise, output the reference wrapped in quotes with proper escaping.
1126   os << "@\"";
1127   printEscapedString(symbolRef, os);
1128   os << '"';
1129 }
1130 
1131 // Print out a valid ElementsAttr that is succinct and can represent any
1132 // potential shape/type, for use when eliding a large ElementsAttr.
1133 //
1134 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1135 // hopefully alert readers to the fact that this has been elided.
1136 //
1137 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1138 // accept the string "elided". The first string must be a registered dialect
1139 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1140 static void printElidedElementsAttr(raw_ostream &os) {
1141   os << R"(opaque<"", "0xDEADBEEF">)";
1142 }
1143 
printAttribute(Attribute attr,bool mayElideType)1144 void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
1145   if (!attr) {
1146     os << "<<NULL ATTRIBUTE>>";
1147     return;
1148   }
1149 
1150   // Check for an alias for this attribute.
1151   if (state) {
1152     Twine alias = state->getAliasState().getAttributeAlias(attr);
1153     if (!alias.isTriviallyEmpty()) {
1154       os << '#' << alias;
1155       return;
1156     }
1157   }
1158 
1159   switch (attr.getKind()) {
1160   default:
1161     return printDialectAttribute(attr);
1162 
1163   case StandardAttributes::Opaque: {
1164     auto opaqueAttr = attr.cast<OpaqueAttr>();
1165     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1166                        opaqueAttr.getAttrData());
1167     break;
1168   }
1169   case StandardAttributes::Unit:
1170     os << "unit";
1171     break;
1172   case StandardAttributes::Bool:
1173     os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
1174 
1175     // BoolAttr always elides the type.
1176     return;
1177   case StandardAttributes::Dictionary:
1178     os << '{';
1179     interleaveComma(attr.cast<DictionaryAttr>().getValue(),
1180                     [&](NamedAttribute attr) {
1181                       os << attr.first;
1182 
1183                       // The value of a UnitAttr is elided within a dictionary.
1184                       if (attr.second.isa<UnitAttr>())
1185                         return;
1186 
1187                       os << " = ";
1188                       printAttribute(attr.second);
1189                     });
1190     os << '}';
1191     break;
1192   case StandardAttributes::Integer: {
1193     auto intAttr = attr.cast<IntegerAttr>();
1194     // Print all integer attributes as signed unless i1.
1195     bool isSigned = intAttr.getType().isIndex() ||
1196                     intAttr.getType().getIntOrFloatBitWidth() != 1;
1197     intAttr.getValue().print(os, isSigned);
1198 
1199     // IntegerAttr elides the type if I64.
1200     if (mayElideType && intAttr.getType().isInteger(64))
1201       return;
1202     break;
1203   }
1204   case StandardAttributes::Float: {
1205     auto floatAttr = attr.cast<FloatAttr>();
1206     printFloatValue(floatAttr.getValue(), os);
1207 
1208     // FloatAttr elides the type if F64.
1209     if (mayElideType && floatAttr.getType().isF64())
1210       return;
1211     break;
1212   }
1213   case StandardAttributes::String:
1214     os << '"';
1215     printEscapedString(attr.cast<StringAttr>().getValue(), os);
1216     os << '"';
1217     break;
1218   case StandardAttributes::Array:
1219     os << '[';
1220     interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
1221       printAttribute(attr, /*mayElideType=*/true);
1222     });
1223     os << ']';
1224     break;
1225   case StandardAttributes::AffineMap:
1226     os << "affine_map<";
1227     attr.cast<AffineMapAttr>().getValue().print(os);
1228     os << '>';
1229 
1230     // AffineMap always elides the type.
1231     return;
1232   case StandardAttributes::IntegerSet:
1233     os << "affine_set<";
1234     attr.cast<IntegerSetAttr>().getValue().print(os);
1235     os << '>';
1236 
1237     // IntegerSet always elides the type.
1238     return;
1239   case StandardAttributes::Type:
1240     printType(attr.cast<TypeAttr>().getValue());
1241     break;
1242   case StandardAttributes::SymbolRef: {
1243     auto refAttr = attr.dyn_cast<SymbolRefAttr>();
1244     printSymbolReference(refAttr.getRootReference(), os);
1245     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1246       os << "::";
1247       printSymbolReference(nestedRef.getValue(), os);
1248     }
1249     break;
1250   }
1251   case StandardAttributes::OpaqueElements: {
1252     auto eltsAttr = attr.cast<OpaqueElementsAttr>();
1253     if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1254       printElidedElementsAttr(os);
1255       break;
1256     }
1257     os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
1258     os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
1259     break;
1260   }
1261   case StandardAttributes::DenseElements: {
1262     auto eltsAttr = attr.cast<DenseElementsAttr>();
1263     if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1264       printElidedElementsAttr(os);
1265       break;
1266     }
1267     os << "dense<";
1268     printDenseElementsAttr(eltsAttr);
1269     os << '>';
1270     break;
1271   }
1272   case StandardAttributes::SparseElements: {
1273     auto elementsAttr = attr.cast<SparseElementsAttr>();
1274     if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
1275         printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
1276       printElidedElementsAttr(os);
1277       break;
1278     }
1279     os << "sparse<";
1280     printDenseElementsAttr(elementsAttr.getIndices());
1281     os << ", ";
1282     printDenseElementsAttr(elementsAttr.getValues());
1283     os << '>';
1284     break;
1285   }
1286 
1287   // Location attributes.
1288   case StandardAttributes::CallSiteLocation:
1289   case StandardAttributes::FileLineColLocation:
1290   case StandardAttributes::FusedLocation:
1291   case StandardAttributes::NameLocation:
1292   case StandardAttributes::OpaqueLocation:
1293   case StandardAttributes::UnknownLocation:
1294     printLocation(attr.cast<LocationAttr>());
1295     break;
1296   }
1297 
1298   // Print the type if it isn't a 'none' type.
1299   auto attrType = attr.getType();
1300   if (!attrType.isa<NoneType>()) {
1301     os << " : ";
1302     printType(attrType);
1303   }
1304 }
1305 
1306 /// Print the integer element of the given DenseElementsAttr at 'index'.
printDenseIntElement(DenseElementsAttr attr,raw_ostream & os,unsigned index)1307 static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
1308                                  unsigned index) {
1309   APInt value = *std::next(attr.int_value_begin(), index);
1310   if (value.getBitWidth() == 1)
1311     os << (value.getBoolValue() ? "true" : "false");
1312   else
1313     value.print(os, /*isSigned=*/true);
1314 }
1315 
1316 /// Print the float element of the given DenseElementsAttr at 'index'.
printDenseFloatElement(DenseElementsAttr attr,raw_ostream & os,unsigned index)1317 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
1318                                    unsigned index) {
1319   APFloat value = *std::next(attr.float_value_begin(), index);
1320   printFloatValue(value, os);
1321 }
1322 
printDenseElementsAttr(DenseElementsAttr attr)1323 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
1324   auto type = attr.getType();
1325   auto shape = type.getShape();
1326   auto rank = type.getRank();
1327 
1328   // The function used to print elements of this attribute.
1329   auto printEltFn = type.getElementType().isa<IntegerType>()
1330                         ? printDenseIntElement
1331                         : printDenseFloatElement;
1332 
1333   // Special case for 0-d and splat tensors.
1334   if (attr.isSplat()) {
1335     printEltFn(attr, os, 0);
1336     return;
1337   }
1338 
1339   // Special case for degenerate tensors.
1340   auto numElements = type.getNumElements();
1341   if (numElements == 0) {
1342     for (int i = 0; i < rank; ++i)
1343       os << '[';
1344     for (int i = 0; i < rank; ++i)
1345       os << ']';
1346     return;
1347   }
1348 
1349   // We use a mixed-radix counter to iterate through the shape. When we bump a
1350   // non-least-significant digit, we emit a close bracket. When we next emit an
1351   // element we re-open all closed brackets.
1352 
1353   // The mixed-radix counter, with radices in 'shape'.
1354   SmallVector<unsigned, 4> counter(rank, 0);
1355   // The number of brackets that have been opened and not closed.
1356   unsigned openBrackets = 0;
1357 
1358   auto bumpCounter = [&]() {
1359     // Bump the least significant digit.
1360     ++counter[rank - 1];
1361     // Iterate backwards bubbling back the increment.
1362     for (unsigned i = rank - 1; i > 0; --i)
1363       if (counter[i] >= shape[i]) {
1364         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1365         counter[i] = 0;
1366         ++counter[i - 1];
1367         --openBrackets;
1368         os << ']';
1369       }
1370   };
1371 
1372   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1373     if (idx != 0)
1374       os << ", ";
1375     while (openBrackets++ < rank)
1376       os << '[';
1377     openBrackets = rank;
1378     printEltFn(attr, os, idx);
1379     bumpCounter();
1380   }
1381   while (openBrackets-- > 0)
1382     os << ']';
1383 }
1384 
printType(Type type)1385 void ModulePrinter::printType(Type type) {
1386   // Check for an alias for this type.
1387   if (state) {
1388     StringRef alias = state->getAliasState().getTypeAlias(type);
1389     if (!alias.empty()) {
1390       os << '!' << alias;
1391       return;
1392     }
1393   }
1394 
1395   switch (type.getKind()) {
1396   default:
1397     return printDialectType(type);
1398 
1399   case Type::Kind::Opaque: {
1400     auto opaqueTy = type.cast<OpaqueType>();
1401     printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1402                        opaqueTy.getTypeData());
1403     return;
1404   }
1405   case StandardTypes::Index:
1406     os << "index";
1407     return;
1408   case StandardTypes::BF16:
1409     os << "bf16";
1410     return;
1411   case StandardTypes::F16:
1412     os << "f16";
1413     return;
1414   case StandardTypes::F32:
1415     os << "f32";
1416     return;
1417   case StandardTypes::F64:
1418     os << "f64";
1419     return;
1420 
1421   case StandardTypes::Integer: {
1422     auto integer = type.cast<IntegerType>();
1423     os << 'i' << integer.getWidth();
1424     return;
1425   }
1426   case Type::Kind::Function: {
1427     auto func = type.cast<FunctionType>();
1428     os << '(';
1429     interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
1430     os << ") -> ";
1431     auto results = func.getResults();
1432     if (results.size() == 1 && !results[0].isa<FunctionType>())
1433       os << results[0];
1434     else {
1435       os << '(';
1436       interleaveComma(results, [&](Type type) { printType(type); });
1437       os << ')';
1438     }
1439     return;
1440   }
1441   case StandardTypes::Vector: {
1442     auto v = type.cast<VectorType>();
1443     os << "vector<";
1444     for (auto dim : v.getShape())
1445       os << dim << 'x';
1446     os << v.getElementType() << '>';
1447     return;
1448   }
1449   case StandardTypes::RankedTensor: {
1450     auto v = type.cast<RankedTensorType>();
1451     os << "tensor<";
1452     for (auto dim : v.getShape()) {
1453       if (dim < 0)
1454         os << '?';
1455       else
1456         os << dim;
1457       os << 'x';
1458     }
1459     os << v.getElementType() << '>';
1460     return;
1461   }
1462   case StandardTypes::UnrankedTensor: {
1463     auto v = type.cast<UnrankedTensorType>();
1464     os << "tensor<*x";
1465     printType(v.getElementType());
1466     os << '>';
1467     return;
1468   }
1469   case StandardTypes::MemRef: {
1470     auto v = type.cast<MemRefType>();
1471     os << "memref<";
1472     for (auto dim : v.getShape()) {
1473       if (dim < 0)
1474         os << '?';
1475       else
1476         os << dim;
1477       os << 'x';
1478     }
1479     printType(v.getElementType());
1480     for (auto map : v.getAffineMaps()) {
1481       os << ", ";
1482       printAttribute(AffineMapAttr::get(map));
1483     }
1484     // Only print the memory space if it is the non-default one.
1485     if (v.getMemorySpace())
1486       os << ", " << v.getMemorySpace();
1487     os << '>';
1488     return;
1489   }
1490   case StandardTypes::UnrankedMemRef: {
1491     auto v = type.cast<UnrankedMemRefType>();
1492     os << "memref<*x";
1493     printType(v.getElementType());
1494     os << '>';
1495     return;
1496   }
1497   case StandardTypes::Complex:
1498     os << "complex<";
1499     printType(type.cast<ComplexType>().getElementType());
1500     os << '>';
1501     return;
1502   case StandardTypes::Tuple: {
1503     auto tuple = type.cast<TupleType>();
1504     os << "tuple<";
1505     interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
1506     os << '>';
1507     return;
1508   }
1509   case StandardTypes::None:
1510     os << "none";
1511     return;
1512   }
1513 }
1514 
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)1515 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1516                                           ArrayRef<StringRef> elidedAttrs,
1517                                           bool withKeyword) {
1518   // If there are no attributes, then there is nothing to be done.
1519   if (attrs.empty())
1520     return;
1521 
1522   // Filter out any attributes that shouldn't be included.
1523   SmallVector<NamedAttribute, 8> filteredAttrs(
1524       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1525         return !llvm::is_contained(elidedAttrs, attr.first.strref());
1526       }));
1527 
1528   // If there are no attributes left to print after filtering, then we're done.
1529   if (filteredAttrs.empty())
1530     return;
1531 
1532   // Print the 'attributes' keyword if necessary.
1533   if (withKeyword)
1534     os << " attributes";
1535 
1536   // Otherwise, print them all out in braces.
1537   os << " {";
1538   interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
1539     os << attr.first;
1540 
1541     // Pretty printing elides the attribute value for unit attributes.
1542     if (attr.second.isa<UnitAttr>())
1543       return;
1544 
1545     os << " = ";
1546     printAttribute(attr.second);
1547   });
1548   os << '}';
1549 }
1550 
1551 //===----------------------------------------------------------------------===//
1552 // CustomDialectAsmPrinter
1553 //===----------------------------------------------------------------------===//
1554 
1555 namespace {
1556 /// This class provides the main specialization of the DialectAsmPrinter that is
1557 /// used to provide support for print attributes and types. This hooks allows
1558 /// for dialects to hook into the main ModulePrinter.
1559 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
1560 public:
CustomDialectAsmPrinter__anonbb6a0fc11611::CustomDialectAsmPrinter1561   CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
~CustomDialectAsmPrinter__anonbb6a0fc11611::CustomDialectAsmPrinter1562   ~CustomDialectAsmPrinter() override {}
1563 
getStream__anonbb6a0fc11611::CustomDialectAsmPrinter1564   raw_ostream &getStream() const override { return printer.getStream(); }
1565 
1566   /// Print the given attribute to the stream.
printAttribute__anonbb6a0fc11611::CustomDialectAsmPrinter1567   void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
1568 
1569   /// Print the given floating point value in a stablized form.
printFloat__anonbb6a0fc11611::CustomDialectAsmPrinter1570   void printFloat(const APFloat &value) override {
1571     printFloatValue(value, getStream());
1572   }
1573 
1574   /// Print the given type to the stream.
printType__anonbb6a0fc11611::CustomDialectAsmPrinter1575   void printType(Type type) override { printer.printType(type); }
1576 
1577   /// The main module printer.
1578   ModulePrinter &printer;
1579 };
1580 } // end anonymous namespace
1581 
printDialectAttribute(Attribute attr)1582 void ModulePrinter::printDialectAttribute(Attribute attr) {
1583   auto &dialect = attr.getDialect();
1584 
1585   // Ask the dialect to serialize the attribute to a string.
1586   std::string attrName;
1587   {
1588     llvm::raw_string_ostream attrNameStr(attrName);
1589     ModulePrinter subPrinter(attrNameStr, printerFlags, state);
1590     CustomDialectAsmPrinter printer(subPrinter);
1591     dialect.printAttribute(attr, printer);
1592   }
1593   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
1594 }
1595 
printDialectType(Type type)1596 void ModulePrinter::printDialectType(Type type) {
1597   auto &dialect = type.getDialect();
1598 
1599   // Ask the dialect to serialize the type to a string.
1600   std::string typeName;
1601   {
1602     llvm::raw_string_ostream typeNameStr(typeName);
1603     ModulePrinter subPrinter(typeNameStr, printerFlags, state);
1604     CustomDialectAsmPrinter printer(subPrinter);
1605     dialect.printType(type, printer);
1606   }
1607   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
1608 }
1609 
1610 //===----------------------------------------------------------------------===//
1611 // Affine expressions and maps
1612 //===----------------------------------------------------------------------===//
1613 
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)1614 void ModulePrinter::printAffineExpr(
1615     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
1616   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
1617 }
1618 
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)1619 void ModulePrinter::printAffineExprInternal(
1620     AffineExpr expr, BindingStrength enclosingTightness,
1621     function_ref<void(unsigned, bool)> printValueName) {
1622   const char *binopSpelling = nullptr;
1623   switch (expr.getKind()) {
1624   case AffineExprKind::SymbolId: {
1625     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
1626     if (printValueName)
1627       printValueName(pos, /*isSymbol=*/true);
1628     else
1629       os << 's' << pos;
1630     return;
1631   }
1632   case AffineExprKind::DimId: {
1633     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
1634     if (printValueName)
1635       printValueName(pos, /*isSymbol=*/false);
1636     else
1637       os << 'd' << pos;
1638     return;
1639   }
1640   case AffineExprKind::Constant:
1641     os << expr.cast<AffineConstantExpr>().getValue();
1642     return;
1643   case AffineExprKind::Add:
1644     binopSpelling = " + ";
1645     break;
1646   case AffineExprKind::Mul:
1647     binopSpelling = " * ";
1648     break;
1649   case AffineExprKind::FloorDiv:
1650     binopSpelling = " floordiv ";
1651     break;
1652   case AffineExprKind::CeilDiv:
1653     binopSpelling = " ceildiv ";
1654     break;
1655   case AffineExprKind::Mod:
1656     binopSpelling = " mod ";
1657     break;
1658   }
1659 
1660   auto binOp = expr.cast<AffineBinaryOpExpr>();
1661   AffineExpr lhsExpr = binOp.getLHS();
1662   AffineExpr rhsExpr = binOp.getRHS();
1663 
1664   // Handle tightly binding binary operators.
1665   if (binOp.getKind() != AffineExprKind::Add) {
1666     if (enclosingTightness == BindingStrength::Strong)
1667       os << '(';
1668 
1669     // Pretty print multiplication with -1.
1670     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
1671     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
1672         rhsConst.getValue() == -1) {
1673       os << "-";
1674       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
1675       if (enclosingTightness == BindingStrength::Strong)
1676         os << ')';
1677       return;
1678     }
1679 
1680     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
1681 
1682     os << binopSpelling;
1683     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
1684 
1685     if (enclosingTightness == BindingStrength::Strong)
1686       os << ')';
1687     return;
1688   }
1689 
1690   // Print out special "pretty" forms for add.
1691   if (enclosingTightness == BindingStrength::Strong)
1692     os << '(';
1693 
1694   // Pretty print addition to a product that has a negative operand as a
1695   // subtraction.
1696   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
1697     if (rhs.getKind() == AffineExprKind::Mul) {
1698       AffineExpr rrhsExpr = rhs.getRHS();
1699       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
1700         if (rrhs.getValue() == -1) {
1701           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
1702                                   printValueName);
1703           os << " - ";
1704           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
1705             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
1706                                     printValueName);
1707           } else {
1708             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
1709                                     printValueName);
1710           }
1711 
1712           if (enclosingTightness == BindingStrength::Strong)
1713             os << ')';
1714           return;
1715         }
1716 
1717         if (rrhs.getValue() < -1) {
1718           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
1719                                   printValueName);
1720           os << " - ";
1721           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
1722                                   printValueName);
1723           os << " * " << -rrhs.getValue();
1724           if (enclosingTightness == BindingStrength::Strong)
1725             os << ')';
1726           return;
1727         }
1728       }
1729     }
1730   }
1731 
1732   // Pretty print addition to a negative number as a subtraction.
1733   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
1734     if (rhsConst.getValue() < 0) {
1735       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
1736       os << " - " << -rhsConst.getValue();
1737       if (enclosingTightness == BindingStrength::Strong)
1738         os << ')';
1739       return;
1740     }
1741   }
1742 
1743   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
1744 
1745   os << " + ";
1746   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
1747 
1748   if (enclosingTightness == BindingStrength::Strong)
1749     os << ')';
1750 }
1751 
printAffineConstraint(AffineExpr expr,bool isEq)1752 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
1753   printAffineExprInternal(expr, BindingStrength::Weak);
1754   isEq ? os << " == 0" : os << " >= 0";
1755 }
1756 
printAffineMap(AffineMap map)1757 void ModulePrinter::printAffineMap(AffineMap map) {
1758   // Dimension identifiers.
1759   os << '(';
1760   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
1761     os << 'd' << i << ", ";
1762   if (map.getNumDims() >= 1)
1763     os << 'd' << map.getNumDims() - 1;
1764   os << ')';
1765 
1766   // Symbolic identifiers.
1767   if (map.getNumSymbols() != 0) {
1768     os << '[';
1769     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
1770       os << 's' << i << ", ";
1771     if (map.getNumSymbols() >= 1)
1772       os << 's' << map.getNumSymbols() - 1;
1773     os << ']';
1774   }
1775 
1776   // Result affine expressions.
1777   os << " -> (";
1778   interleaveComma(map.getResults(),
1779                   [&](AffineExpr expr) { printAffineExpr(expr); });
1780   os << ')';
1781 }
1782 
printIntegerSet(IntegerSet set)1783 void ModulePrinter::printIntegerSet(IntegerSet set) {
1784   // Dimension identifiers.
1785   os << '(';
1786   for (unsigned i = 1; i < set.getNumDims(); ++i)
1787     os << 'd' << i - 1 << ", ";
1788   if (set.getNumDims() >= 1)
1789     os << 'd' << set.getNumDims() - 1;
1790   os << ')';
1791 
1792   // Symbolic identifiers.
1793   if (set.getNumSymbols() != 0) {
1794     os << '[';
1795     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
1796       os << 's' << i << ", ";
1797     if (set.getNumSymbols() >= 1)
1798       os << 's' << set.getNumSymbols() - 1;
1799     os << ']';
1800   }
1801 
1802   // Print constraints.
1803   os << " : (";
1804   int numConstraints = set.getNumConstraints();
1805   for (int i = 1; i < numConstraints; ++i) {
1806     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
1807     os << ", ";
1808   }
1809   if (numConstraints >= 1)
1810     printAffineConstraint(set.getConstraint(numConstraints - 1),
1811                           set.isEq(numConstraints - 1));
1812   os << ')';
1813 }
1814 
1815 //===----------------------------------------------------------------------===//
1816 // OperationPrinter
1817 //===----------------------------------------------------------------------===//
1818 
1819 namespace {
1820 /// This class contains the logic for printing operations, regions, and blocks.
1821 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
1822 public:
OperationPrinter(raw_ostream & os,OpPrintingFlags flags,AsmStateImpl & state)1823   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
1824                             AsmStateImpl &state)
1825       : ModulePrinter(os, flags, &state) {}
1826 
1827   /// Print the given top-level module.
1828   void print(ModuleOp op);
1829   /// Print the given operation with its indent and location.
1830   void print(Operation *op);
1831   /// Print the bare location, not including indentation/location/etc.
1832   void printOperation(Operation *op);
1833   /// Print the given operation in the generic form.
1834   void printGenericOp(Operation *op) override;
1835 
1836   /// Print the name of the given block.
1837   void printBlockName(Block *block);
1838 
1839   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
1840   /// block are not printed. If 'printBlockTerminator' is false, the terminator
1841   /// operation of the block is not printed.
1842   void print(Block *block, bool printBlockArgs = true,
1843              bool printBlockTerminator = true);
1844 
1845   /// Print the ID of the given value, optionally with its result number.
1846   void printValueID(Value value, bool printResultNo = true) const;
1847 
1848   //===--------------------------------------------------------------------===//
1849   // OpAsmPrinter methods
1850   //===--------------------------------------------------------------------===//
1851 
1852   /// Return the current stream of the printer.
getStream() const1853   raw_ostream &getStream() const override { return os; }
1854 
1855   /// Print the given type.
printType(Type type)1856   void printType(Type type) override { ModulePrinter::printType(type); }
1857 
1858   /// Print the given attribute.
printAttribute(Attribute attr)1859   void printAttribute(Attribute attr) override {
1860     ModulePrinter::printAttribute(attr);
1861   }
1862 
1863   /// Print the ID for the given value.
printOperand(Value value)1864   void printOperand(Value value) override { printValueID(value); }
1865 
1866   /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})1867   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1868                              ArrayRef<StringRef> elidedAttrs = {}) override {
1869     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
1870   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})1871   void printOptionalAttrDictWithKeyword(
1872       ArrayRef<NamedAttribute> attrs,
1873       ArrayRef<StringRef> elidedAttrs = {}) override {
1874     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
1875                                          /*withKeyword=*/true);
1876   }
1877 
1878   /// Print an operation successor with the operands used for the block
1879   /// arguments.
1880   void printSuccessorAndUseList(Operation *term, unsigned index) override;
1881 
1882   /// Print the given region.
1883   void printRegion(Region &region, bool printEntryBlockArgs,
1884                    bool printBlockTerminators) override;
1885 
1886   /// Renumber the arguments for the specified region to the same names as the
1887   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
1888   /// operations. If any entry in namesToUse is null, the corresponding
1889   /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)1890   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
1891     state->getSSANameState().shadowRegionArgs(region, namesToUse);
1892   }
1893 
1894   /// Print the given affine map with the smybol and dimension operands printed
1895   /// inline with the map.
1896   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
1897                               ValueRange operands) override;
1898 
1899   /// Print the given string as a symbol reference.
printSymbolName(StringRef symbolRef)1900   void printSymbolName(StringRef symbolRef) override {
1901     ::printSymbolReference(symbolRef, os);
1902   }
1903 
1904 private:
1905   /// The number of spaces used for indenting nested operations.
1906   const static unsigned indentWidth = 2;
1907 
1908   // This is the current indentation level for nested structures.
1909   unsigned currentIndent = 0;
1910 };
1911 } // end anonymous namespace
1912 
print(ModuleOp op)1913 void OperationPrinter::print(ModuleOp op) {
1914   // Output the aliases at the top level.
1915   state->getAliasState().printAttributeAliases(os);
1916   state->getAliasState().printTypeAliases(os);
1917 
1918   // Print the module.
1919   print(op.getOperation());
1920 }
1921 
print(Operation * op)1922 void OperationPrinter::print(Operation *op) {
1923   os.indent(currentIndent);
1924   printOperation(op);
1925   printTrailingLocation(op->getLoc());
1926 }
1927 
printOperation(Operation * op)1928 void OperationPrinter::printOperation(Operation *op) {
1929   if (size_t numResults = op->getNumResults()) {
1930     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
1931       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
1932       if (resultCount > 1)
1933         os << ':' << resultCount;
1934     };
1935 
1936     // Check to see if this operation has multiple result groups.
1937     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
1938     if (!resultGroups.empty()) {
1939       // Interleave the groups excluding the last one, this one will be handled
1940       // separately.
1941       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
1942         printResultGroup(resultGroups[i],
1943                          resultGroups[i + 1] - resultGroups[i]);
1944       });
1945       os << ", ";
1946       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
1947 
1948     } else {
1949       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
1950     }
1951 
1952     os << " = ";
1953   }
1954 
1955   // If requested, always print the generic form.
1956   if (!printerFlags.shouldPrintGenericOpForm()) {
1957     // Check to see if this is a known operation.  If so, use the registered
1958     // custom printer hook.
1959     if (auto *opInfo = op->getAbstractOperation()) {
1960       opInfo->printAssembly(op, *this);
1961       return;
1962     }
1963   }
1964 
1965   // Otherwise print with the generic assembly form.
1966   printGenericOp(op);
1967 }
1968 
printGenericOp(Operation * op)1969 void OperationPrinter::printGenericOp(Operation *op) {
1970   os << '"';
1971   printEscapedString(op->getName().getStringRef(), os);
1972   os << "\"(";
1973 
1974   // Get the list of operands that are not successor operands.
1975   unsigned totalNumSuccessorOperands = 0;
1976   unsigned numSuccessors = op->getNumSuccessors();
1977   for (unsigned i = 0; i < numSuccessors; ++i)
1978     totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
1979   unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
1980   interleaveComma(op->getOperands().take_front(numProperOperands),
1981                   [&](Value value) { printValueID(value); });
1982 
1983   os << ')';
1984 
1985   // For terminators, print the list of successors and their operands.
1986   if (numSuccessors != 0) {
1987     os << '[';
1988     interleaveComma(llvm::seq<unsigned>(0, numSuccessors),
1989                     [&](unsigned i) { printSuccessorAndUseList(op, i); });
1990     os << ']';
1991   }
1992 
1993   // Print regions.
1994   if (op->getNumRegions() != 0) {
1995     os << " (";
1996     interleaveComma(op->getRegions(), [&](Region &region) {
1997       printRegion(region, /*printEntryBlockArgs=*/true,
1998                   /*printBlockTerminators=*/true);
1999     });
2000     os << ')';
2001   }
2002 
2003   auto attrs = op->getAttrs();
2004   printOptionalAttrDict(attrs);
2005 
2006   // Print the type signature of the operation.
2007   os << " : ";
2008   printFunctionalType(op);
2009 }
2010 
printBlockName(Block * block)2011 void OperationPrinter::printBlockName(Block *block) {
2012   auto id = state->getSSANameState().getBlockID(block);
2013   if (id != SSANameState::NameSentinel)
2014     os << "^bb" << id;
2015   else
2016     os << "^INVALIDBLOCK";
2017 }
2018 
print(Block * block,bool printBlockArgs,bool printBlockTerminator)2019 void OperationPrinter::print(Block *block, bool printBlockArgs,
2020                              bool printBlockTerminator) {
2021   // Print the block label and argument list if requested.
2022   if (printBlockArgs) {
2023     os.indent(currentIndent);
2024     printBlockName(block);
2025 
2026     // Print the argument list if non-empty.
2027     if (!block->args_empty()) {
2028       os << '(';
2029       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2030         printValueID(arg);
2031         os << ": ";
2032         printType(arg.getType());
2033       });
2034       os << ')';
2035     }
2036     os << ':';
2037 
2038     // Print out some context information about the predecessors of this block.
2039     if (!block->getParent()) {
2040       os << "\t// block is not in a region!";
2041     } else if (block->hasNoPredecessors()) {
2042       os << "\t// no predecessors";
2043     } else if (auto *pred = block->getSinglePredecessor()) {
2044       os << "\t// pred: ";
2045       printBlockName(pred);
2046     } else {
2047       // We want to print the predecessors in increasing numeric order, not in
2048       // whatever order the use-list is in, so gather and sort them.
2049       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2050       for (auto *pred : block->getPredecessors())
2051         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2052       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2053 
2054       os << "\t// " << predIDs.size() << " preds: ";
2055 
2056       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2057         printBlockName(pred.second);
2058       });
2059     }
2060     os << '\n';
2061   }
2062 
2063   currentIndent += indentWidth;
2064   auto range = llvm::make_range(
2065       block->getOperations().begin(),
2066       std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
2067   for (auto &op : range) {
2068     print(&op);
2069     os << '\n';
2070   }
2071   currentIndent -= indentWidth;
2072 }
2073 
printValueID(Value value,bool printResultNo) const2074 void OperationPrinter::printValueID(Value value, bool printResultNo) const {
2075   state->getSSANameState().printValueID(value, printResultNo, os);
2076 }
2077 
printSuccessorAndUseList(Operation * term,unsigned index)2078 void OperationPrinter::printSuccessorAndUseList(Operation *term,
2079                                                 unsigned index) {
2080   printBlockName(term->getSuccessor(index));
2081 
2082   auto succOperands = term->getSuccessorOperands(index);
2083   if (succOperands.begin() == succOperands.end())
2084     return;
2085 
2086   os << '(';
2087   interleaveComma(succOperands,
2088                   [this](Value operand) { printValueID(operand); });
2089   os << " : ";
2090   interleaveComma(succOperands,
2091                   [this](Value operand) { printType(operand.getType()); });
2092   os << ')';
2093 }
2094 
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators)2095 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2096                                    bool printBlockTerminators) {
2097   os << " {\n";
2098   if (!region.empty()) {
2099     auto *entryBlock = &region.front();
2100     print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
2101           printBlockTerminators);
2102     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2103       print(&b);
2104   }
2105   os.indent(currentIndent) << "}";
2106 }
2107 
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)2108 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2109                                               ValueRange operands) {
2110   AffineMap map = mapAttr.getValue();
2111   unsigned numDims = map.getNumDims();
2112   auto printValueName = [&](unsigned pos, bool isSymbol) {
2113     unsigned index = isSymbol ? numDims + pos : pos;
2114     assert(index < operands.size());
2115     if (isSymbol)
2116       os << "symbol(";
2117     printValueID(operands[index]);
2118     if (isSymbol)
2119       os << ')';
2120   };
2121 
2122   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2123     printAffineExpr(expr, printValueName);
2124   });
2125 }
2126 
2127 //===----------------------------------------------------------------------===//
2128 // print and dump methods
2129 //===----------------------------------------------------------------------===//
2130 
print(raw_ostream & os) const2131 void Attribute::print(raw_ostream &os) const {
2132   ModulePrinter(os).printAttribute(*this);
2133 }
2134 
dump() const2135 void Attribute::dump() const {
2136   print(llvm::errs());
2137   llvm::errs() << "\n";
2138 }
2139 
print(raw_ostream & os)2140 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2141 
dump()2142 void Type::dump() { print(llvm::errs()); }
2143 
dump() const2144 void AffineMap::dump() const {
2145   print(llvm::errs());
2146   llvm::errs() << "\n";
2147 }
2148 
dump() const2149 void IntegerSet::dump() const {
2150   print(llvm::errs());
2151   llvm::errs() << "\n";
2152 }
2153 
print(raw_ostream & os) const2154 void AffineExpr::print(raw_ostream &os) const {
2155   if (expr == nullptr) {
2156     os << "null affine expr";
2157     return;
2158   }
2159   ModulePrinter(os).printAffineExpr(*this);
2160 }
2161 
dump() const2162 void AffineExpr::dump() const {
2163   print(llvm::errs());
2164   llvm::errs() << "\n";
2165 }
2166 
print(raw_ostream & os) const2167 void AffineMap::print(raw_ostream &os) const {
2168   if (map == nullptr) {
2169     os << "null affine map";
2170     return;
2171   }
2172   ModulePrinter(os).printAffineMap(*this);
2173 }
2174 
print(raw_ostream & os) const2175 void IntegerSet::print(raw_ostream &os) const {
2176   ModulePrinter(os).printIntegerSet(*this);
2177 }
2178 
print(raw_ostream & os)2179 void Value::print(raw_ostream &os) {
2180   if (auto *op = getDefiningOp())
2181     return op->print(os);
2182   // TODO: Improve this.
2183   assert(isa<BlockArgument>());
2184   os << "<block argument>\n";
2185 }
print(raw_ostream & os,AsmState & state)2186 void Value::print(raw_ostream &os, AsmState &state) {
2187   if (auto *op = getDefiningOp())
2188     return op->print(os, state);
2189 
2190   // TODO: Improve this.
2191   assert(isa<BlockArgument>());
2192   os << "<block argument>\n";
2193 }
2194 
dump()2195 void Value::dump() {
2196   print(llvm::errs());
2197   llvm::errs() << "\n";
2198 }
2199 
printAsOperand(raw_ostream & os,AsmState & state)2200 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2201   // TODO(riverriddle) This doesn't necessarily capture all potential cases.
2202   // Currently, region arguments can be shadowed when printing the main
2203   // operation. If the IR hasn't been printed, this will produce the old SSA
2204   // name and not the shadowed name.
2205   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2206                                                  os);
2207 }
2208 
print(raw_ostream & os,OpPrintingFlags flags)2209 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2210   // Handle top-level operations or local printing.
2211   if (!getParent() || flags.shouldUseLocalScope()) {
2212     AsmState state(this);
2213     OperationPrinter(os, flags, state.getImpl()).print(this);
2214     return;
2215   }
2216 
2217   Operation *parentOp = getParentOp();
2218   if (!parentOp) {
2219     os << "<<UNLINKED OPERATION>>\n";
2220     return;
2221   }
2222   // Get the top-level op.
2223   while (auto *nextOp = parentOp->getParentOp())
2224     parentOp = nextOp;
2225 
2226   AsmState state(parentOp);
2227   print(os, state, flags);
2228 }
print(raw_ostream & os,AsmState & state,OpPrintingFlags flags)2229 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2230   OperationPrinter(os, flags, state.getImpl()).print(this);
2231 }
2232 
dump()2233 void Operation::dump() {
2234   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2235   llvm::errs() << "\n";
2236 }
2237 
print(raw_ostream & os)2238 void Block::print(raw_ostream &os) {
2239   Operation *parentOp = getParentOp();
2240   if (!parentOp) {
2241     os << "<<UNLINKED BLOCK>>\n";
2242     return;
2243   }
2244   // Get the top-level op.
2245   while (auto *nextOp = parentOp->getParentOp())
2246     parentOp = nextOp;
2247 
2248   AsmState state(parentOp);
2249   print(os, state);
2250 }
print(raw_ostream & os,AsmState & state)2251 void Block::print(raw_ostream &os, AsmState &state) {
2252   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2253 }
2254 
dump()2255 void Block::dump() { print(llvm::errs()); }
2256 
2257 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)2258 void Block::printAsOperand(raw_ostream &os, bool printType) {
2259   Operation *parentOp = getParentOp();
2260   if (!parentOp) {
2261     os << "<<UNLINKED BLOCK>>\n";
2262     return;
2263   }
2264   // Get the top-level op.
2265   while (auto *nextOp = parentOp->getParentOp())
2266     parentOp = nextOp;
2267 
2268   AsmState state(parentOp);
2269   printAsOperand(os, state);
2270 }
printAsOperand(raw_ostream & os,AsmState & state)2271 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2272   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2273   printer.printBlockName(this);
2274 }
2275 
print(raw_ostream & os,OpPrintingFlags flags)2276 void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
2277   AsmState state(*this);
2278 
2279   // Don't populate aliases when printing at local scope.
2280   if (!flags.shouldUseLocalScope())
2281     state.getImpl().initializeAliases(*this);
2282   print(os, state, flags);
2283 }
print(raw_ostream & os,AsmState & state,OpPrintingFlags flags)2284 void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2285   OperationPrinter(os, flags, state.getImpl()).print(*this);
2286 }
2287 
dump()2288 void ModuleOp::dump() { print(llvm::errs()); }
2289