1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Function.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/Operation.h"
26 #include "mlir/IR/StandardTypes.h"
27 #include "mlir/Support/STLExtras.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/ScopedHashTable.h"
33 #include "llvm/ADT/SetVector.h"
34 #include "llvm/ADT/SmallString.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringSet.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/Regex.h"
39 #include "llvm/Support/SaveAndRestore.h"
40 using namespace mlir;
41 using namespace mlir::detail;
42
print(raw_ostream & os) const43 void Identifier::print(raw_ostream &os) const { os << str(); }
44
dump() const45 void Identifier::dump() const { print(llvm::errs()); }
46
print(raw_ostream & os) const47 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
48
dump() const49 void OperationName::dump() const { print(llvm::errs()); }
50
~DialectAsmPrinter()51 DialectAsmPrinter::~DialectAsmPrinter() {}
52
~OpAsmPrinter()53 OpAsmPrinter::~OpAsmPrinter() {}
54
55 //===--------------------------------------------------------------------===//
56 // Operation OpAsm interface.
57 //===--------------------------------------------------------------------===//
58
59 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
60 #include "mlir/IR/OpAsmInterface.cpp.inc"
61
62 //===----------------------------------------------------------------------===//
63 // OpPrintingFlags
64 //===----------------------------------------------------------------------===//
65
66 static llvm::cl::opt<unsigned> elideElementsAttrIfLarger(
67 "mlir-elide-elementsattrs-if-larger",
68 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
69 "more elements than the given upper limit"));
70
71 static llvm::cl::opt<bool>
72 printDebugInfoOpt("mlir-print-debuginfo",
73 llvm::cl::desc("Print debug info in MLIR output"),
74 llvm::cl::init(false));
75
76 static llvm::cl::opt<bool> printPrettyDebugInfoOpt(
77 "mlir-pretty-debuginfo",
78 llvm::cl::desc("Print pretty debug info in MLIR output"),
79 llvm::cl::init(false));
80
81 // Use the generic op output form in the operation printer even if the custom
82 // form is defined.
83 static llvm::cl::opt<bool>
84 printGenericOpFormOpt("mlir-print-op-generic",
85 llvm::cl::desc("Print the generic op form"),
86 llvm::cl::init(false), llvm::cl::Hidden);
87
88 static llvm::cl::opt<bool> printLocalScopeOpt(
89 "mlir-print-local-scope",
90 llvm::cl::desc("Print assuming in local scope by default"),
91 llvm::cl::init(false), llvm::cl::Hidden);
92
93 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()94 OpPrintingFlags::OpPrintingFlags()
95 : elementsAttrElementLimit(
96 elideElementsAttrIfLarger.getNumOccurrences()
97 ? Optional<int64_t>(elideElementsAttrIfLarger)
98 : Optional<int64_t>()),
99 printDebugInfoFlag(printDebugInfoOpt),
100 printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt),
101 printGenericOpFormFlag(printGenericOpFormOpt),
102 printLocalScope(printLocalScopeOpt) {}
103
104 /// Enable the elision of large elements attributes, by printing a '...'
105 /// instead of the element data, when the number of elements is greater than
106 /// `largeElementLimit`. Note: The IR generated with this option is not
107 /// parsable.
108 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)109 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
110 elementsAttrElementLimit = largeElementLimit;
111 return *this;
112 }
113
114 /// Enable printing of debug information. If 'prettyForm' is set to true,
115 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)116 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
117 printDebugInfoFlag = true;
118 printDebugInfoPrettyFormFlag = prettyForm;
119 return *this;
120 }
121
122 /// Always print operations in the generic form.
printGenericOpForm()123 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
124 printGenericOpFormFlag = true;
125 return *this;
126 }
127
128 /// Use local scope when printing the operation. This allows for using the
129 /// printer in a more localized and thread-safe setting, but may not necessarily
130 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()131 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
132 printLocalScope = true;
133 return *this;
134 }
135
136 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const137 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
138 return elementsAttrElementLimit.hasValue() &&
139 *elementsAttrElementLimit < int64_t(attr.getNumElements());
140 }
141
142 /// Return if debug information should be printed.
shouldPrintDebugInfo() const143 bool OpPrintingFlags::shouldPrintDebugInfo() const {
144 return printDebugInfoFlag;
145 }
146
147 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const148 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
149 return printDebugInfoPrettyFormFlag;
150 }
151
152 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const153 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
154 return printGenericOpFormFlag;
155 }
156
157 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const158 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
159
160 //===----------------------------------------------------------------------===//
161 // AliasState
162 //===----------------------------------------------------------------------===//
163
164 namespace {
165 /// This class manages the state for type and attribute aliases.
166 class AliasState {
167 public:
168 // Initialize the internal aliases.
169 void
170 initialize(Operation *op,
171 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
172
173 /// Return a name used for an attribute alias, or empty if there is no alias.
174 Twine getAttributeAlias(Attribute attr) const;
175
176 /// Print all of the referenced attribute aliases.
177 void printAttributeAliases(raw_ostream &os) const;
178
179 /// Return a string to use as an alias for the given type, or empty if there
180 /// is no alias recorded.
181 StringRef getTypeAlias(Type ty) const;
182
183 /// Print all of the referenced type aliases.
184 void printTypeAliases(raw_ostream &os) const;
185
186 private:
187 /// A special index constant used for non-kind attribute aliases.
188 enum { NonAttrKindAlias = -1 };
189
190 /// Record a reference to the given attribute.
191 void recordAttributeReference(Attribute attr);
192
193 /// Record a reference to the given type.
194 void recordTypeReference(Type ty);
195
196 // Visit functions.
197 void visitOperation(Operation *op);
198 void visitType(Type type);
199 void visitAttribute(Attribute attr);
200
201 /// Set of attributes known to be used within the module.
202 llvm::SetVector<Attribute> usedAttributes;
203
204 /// Mapping between attribute and a pair comprised of a base alias name and a
205 /// count suffix. If the suffix is set to -1, it is not displayed.
206 llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
207
208 /// Mapping between attribute kind and a pair comprised of a base alias name
209 /// and a unique list of attributes belonging to this kind sorted by location
210 /// seen in the module.
211 llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
212 attrKindToAlias;
213
214 /// Set of types known to be used within the module.
215 llvm::SetVector<Type> usedTypes;
216
217 /// A mapping between a type and a given alias.
218 DenseMap<Type, StringRef> typeToAlias;
219 };
220 } // end anonymous namespace
221
222 // Utility to generate a function to register a symbol alias.
canRegisterAlias(StringRef name,llvm::StringSet<> & usedAliases)223 static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
224 assert(!name.empty() && "expected alias name to be non-empty");
225 // TODO(riverriddle) Assert that the provided alias name can be lexed as
226 // an identifier.
227
228 // Check that the alias doesn't contain a '.' character and the name is not
229 // already in use.
230 return !name.contains('.') && usedAliases.insert(name).second;
231 }
232
initialize(Operation * op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)233 void AliasState::initialize(
234 Operation *op,
235 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
236 // Track the identifiers in use for each symbol so that the same identifier
237 // isn't used twice.
238 llvm::StringSet<> usedAliases;
239
240 // Collect the set of aliases from each dialect.
241 SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
242 SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
243 SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
244
245 // AffineMap/Integer set have specific kind aliases.
246 attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
247 attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
248
249 for (auto &interface : interfaces) {
250 interface.getAttributeKindAliases(attributeKindAliases);
251 interface.getAttributeAliases(attributeAliases);
252 interface.getTypeAliases(typeAliases);
253 }
254
255 // Setup the attribute kind aliases.
256 StringRef alias;
257 unsigned attrKind;
258 for (auto &attrAliasPair : attributeKindAliases) {
259 std::tie(attrKind, alias) = attrAliasPair;
260 assert(!alias.empty() && "expected non-empty alias string");
261 if (!usedAliases.count(alias) && !alias.contains('.'))
262 attrKindToAlias.insert({attrKind, {alias, {}}});
263 }
264
265 // Clear the set of used identifiers so that the attribute kind aliases are
266 // just a prefix and not the full alias, i.e. there may be some overlap.
267 usedAliases.clear();
268
269 // Register the attribute aliases.
270 // Create a regex for the attribute kind alias names, these have a prefix with
271 // a counter appended to the end. We prevent normal aliases from having these
272 // names to avoid collisions.
273 llvm::Regex reservedAttrNames("[0-9]+$");
274
275 // Attribute value aliases.
276 Attribute attr;
277 for (auto &attrAliasPair : attributeAliases) {
278 std::tie(attr, alias) = attrAliasPair;
279 if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
280 attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
281 }
282
283 // Clear the set of used identifiers as types can have the same identifiers as
284 // affine structures.
285 usedAliases.clear();
286
287 // Type aliases.
288 for (auto &typeAliasPair : typeAliases)
289 if (canRegisterAlias(typeAliasPair.second, usedAliases))
290 typeToAlias.insert(typeAliasPair);
291
292 // Traverse the given IR to generate the set of used attributes/types.
293 op->walk([&](Operation *op) { visitOperation(op); });
294 }
295
296 /// Return a name used for an attribute alias, or empty if there is no alias.
getAttributeAlias(Attribute attr) const297 Twine AliasState::getAttributeAlias(Attribute attr) const {
298 auto alias = attrToAlias.find(attr);
299 if (alias == attrToAlias.end())
300 return Twine();
301
302 // Return the alias for this attribute, along with the index if this was
303 // generated by a kind alias.
304 int kindIndex = alias->second.second;
305 return alias->second.first +
306 (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
307 }
308
309 /// Print all of the referenced attribute aliases.
printAttributeAliases(raw_ostream & os) const310 void AliasState::printAttributeAliases(raw_ostream &os) const {
311 auto printAlias = [&](StringRef alias, Attribute attr, int index) {
312 os << '#' << alias;
313 if (index != NonAttrKindAlias)
314 os << index;
315 os << " = " << attr << '\n';
316 };
317
318 // Print all of the attribute kind aliases.
319 for (auto &kindAlias : attrKindToAlias) {
320 auto &aliasAttrsPair = kindAlias.second;
321 for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
322 printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
323 os << "\n";
324 }
325
326 // In a second pass print all of the remaining attribute aliases that aren't
327 // kind aliases.
328 for (Attribute attr : usedAttributes) {
329 auto alias = attrToAlias.find(attr);
330 if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
331 printAlias(alias->second.first, attr, alias->second.second);
332 }
333 }
334
335 /// Return a string to use as an alias for the given type, or empty if there
336 /// is no alias recorded.
getTypeAlias(Type ty) const337 StringRef AliasState::getTypeAlias(Type ty) const {
338 return typeToAlias.lookup(ty);
339 }
340
341 /// Print all of the referenced type aliases.
printTypeAliases(raw_ostream & os) const342 void AliasState::printTypeAliases(raw_ostream &os) const {
343 for (Type type : usedTypes) {
344 auto alias = typeToAlias.find(type);
345 if (alias != typeToAlias.end())
346 os << '!' << alias->second << " = type " << type << '\n';
347 }
348 }
349
350 /// Record a reference to the given attribute.
recordAttributeReference(Attribute attr)351 void AliasState::recordAttributeReference(Attribute attr) {
352 // Don't recheck attributes that have already been seen or those that
353 // already have an alias.
354 if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
355 return;
356
357 // If this attribute kind has an alias, then record one for this attribute.
358 auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
359 if (alias == attrKindToAlias.end())
360 return;
361 std::pair<StringRef, int> attrAlias(alias->second.first,
362 alias->second.second.size());
363 attrToAlias.insert({attr, attrAlias});
364 alias->second.second.push_back(attr);
365 }
366
367 /// Record a reference to the given type.
recordTypeReference(Type ty)368 void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
369
370 // TODO Support visiting other types/operations when implemented.
visitType(Type type)371 void AliasState::visitType(Type type) {
372 recordTypeReference(type);
373
374 if (auto funcType = type.dyn_cast<FunctionType>()) {
375 // Visit input and result types for functions.
376 for (auto input : funcType.getInputs())
377 visitType(input);
378 for (auto result : funcType.getResults())
379 visitType(result);
380 } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
381 visitType(shapedType.getElementType());
382
383 // Visit affine maps in memref type.
384 if (auto memref = type.dyn_cast<MemRefType>())
385 for (auto map : memref.getAffineMaps())
386 recordAttributeReference(AffineMapAttr::get(map));
387 }
388 }
389
visitAttribute(Attribute attr)390 void AliasState::visitAttribute(Attribute attr) {
391 recordAttributeReference(attr);
392
393 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
394 for (auto elt : arrayAttr.getValue())
395 visitAttribute(elt);
396 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
397 visitType(typeAttr.getValue());
398 }
399 }
400
visitOperation(Operation * op)401 void AliasState::visitOperation(Operation *op) {
402 // Visit all the types used in the operation.
403 for (auto type : op->getOperandTypes())
404 visitType(type);
405 for (auto type : op->getResultTypes())
406 visitType(type);
407 for (auto ®ion : op->getRegions())
408 for (auto &block : region)
409 for (auto arg : block.getArguments())
410 visitType(arg.getType());
411
412 // Visit each of the attributes.
413 for (auto elt : op->getAttrs())
414 visitAttribute(elt.second);
415 }
416
417 //===----------------------------------------------------------------------===//
418 // SSANameState
419 //===----------------------------------------------------------------------===//
420
421 namespace {
422 /// This class manages the state of SSA value names.
423 class SSANameState {
424 public:
425 /// A sentinal value used for values with names set.
426 enum : unsigned { NameSentinel = ~0U };
427
428 SSANameState(Operation *op,
429 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
430
431 /// Print the SSA identifier for the given value to 'stream'. If
432 /// 'printResultNo' is true, it also presents the result number ('#' number)
433 /// of this value.
434 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
435
436 /// Return the result indices for each of the result groups registered by this
437 /// operation, or empty if none exist.
438 ArrayRef<int> getOpResultGroups(Operation *op);
439
440 /// Get the ID for the given block.
441 unsigned getBlockID(Block *block);
442
443 /// Renumber the arguments for the specified region to the same names as the
444 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
445 /// details.
446 void shadowRegionArgs(Region ®ion, ValueRange namesToUse);
447
448 private:
449 /// Number the SSA values within the given IR unit.
450 void numberValuesInRegion(
451 Region ®ion,
452 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
453 void numberValuesInBlock(
454 Block &block,
455 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
456 void numberValuesInOp(
457 Operation &op,
458 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
459
460 /// Given a result of an operation 'result', find the result group head
461 /// 'lookupValue' and the result of 'result' within that group in
462 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
463 /// has more than 1 result.
464 void getResultIDAndNumber(OpResult result, Value &lookupValue,
465 Optional<int> &lookupResultNo) const;
466
467 /// Set a special value name for the given value.
468 void setValueName(Value value, StringRef name);
469
470 /// Uniques the given value name within the printer. If the given name
471 /// conflicts, it is automatically renamed.
472 StringRef uniqueValueName(StringRef name);
473
474 /// This is the value ID for each SSA value. If this returns NameSentinel,
475 /// then the valueID has an entry in valueNames.
476 DenseMap<Value, unsigned> valueIDs;
477 DenseMap<Value, StringRef> valueNames;
478
479 /// This is a map of operations that contain multiple named result groups,
480 /// i.e. there may be multiple names for the results of the operation. The
481 /// value of this map are the result numbers that start a result group.
482 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
483
484 /// This is the block ID for each block in the current.
485 DenseMap<Block *, unsigned> blockIDs;
486
487 /// This keeps track of all of the non-numeric names that are in flight,
488 /// allowing us to check for duplicates.
489 /// Note: the value of the map is unused.
490 llvm::ScopedHashTable<StringRef, char> usedNames;
491 llvm::BumpPtrAllocator usedNameAllocator;
492
493 /// This is the next value ID to assign in numbering.
494 unsigned nextValueID = 0;
495 /// This is the next ID to assign to a region entry block argument.
496 unsigned nextArgumentID = 0;
497 /// This is the next ID to assign when a name conflict is detected.
498 unsigned nextConflictID = 0;
499 };
500 } // end anonymous namespace
501
SSANameState(Operation * op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)502 SSANameState::SSANameState(
503 Operation *op,
504 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
505 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
506 numberValuesInOp(*op, interfaces);
507
508 for (auto ®ion : op->getRegions())
509 numberValuesInRegion(region, interfaces);
510 }
511
printValueID(Value value,bool printResultNo,raw_ostream & stream) const512 void SSANameState::printValueID(Value value, bool printResultNo,
513 raw_ostream &stream) const {
514 if (!value) {
515 stream << "<<NULL>>";
516 return;
517 }
518
519 Optional<int> resultNo;
520 auto lookupValue = value;
521
522 // If this is an operation result, collect the head lookup value of the result
523 // group and the result number of 'result' within that group.
524 if (OpResult result = value.dyn_cast<OpResult>())
525 getResultIDAndNumber(result, lookupValue, resultNo);
526
527 auto it = valueIDs.find(lookupValue);
528 if (it == valueIDs.end()) {
529 stream << "<<UNKNOWN SSA VALUE>>";
530 return;
531 }
532
533 stream << '%';
534 if (it->second != NameSentinel) {
535 stream << it->second;
536 } else {
537 auto nameIt = valueNames.find(lookupValue);
538 assert(nameIt != valueNames.end() && "Didn't have a name entry?");
539 stream << nameIt->second;
540 }
541
542 if (resultNo.hasValue() && printResultNo)
543 stream << '#' << resultNo;
544 }
545
getOpResultGroups(Operation * op)546 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
547 auto it = opResultGroups.find(op);
548 return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
549 }
550
getBlockID(Block * block)551 unsigned SSANameState::getBlockID(Block *block) {
552 auto it = blockIDs.find(block);
553 return it != blockIDs.end() ? it->second : NameSentinel;
554 }
555
shadowRegionArgs(Region & region,ValueRange namesToUse)556 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
557 assert(!region.empty() && "cannot shadow arguments of an empty region");
558 assert(region.front().getNumArguments() == namesToUse.size() &&
559 "incorrect number of names passed in");
560 assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
561 "only KnownIsolatedFromAbove ops can shadow names");
562
563 SmallVector<char, 16> nameStr;
564 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
565 auto nameToUse = namesToUse[i];
566 if (nameToUse == nullptr)
567 continue;
568 auto nameToReplace = region.front().getArgument(i);
569
570 nameStr.clear();
571 llvm::raw_svector_ostream nameStream(nameStr);
572 printValueID(nameToUse, /*printResultNo=*/true, nameStream);
573
574 // Entry block arguments should already have a pretty "arg" name.
575 assert(valueIDs[nameToReplace] == NameSentinel);
576
577 // Use the name without the leading %.
578 auto name = StringRef(nameStream.str()).drop_front();
579
580 // Overwrite the name.
581 valueNames[nameToReplace] = name.copy(usedNameAllocator);
582 }
583 }
584
numberValuesInRegion(Region & region,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)585 void SSANameState::numberValuesInRegion(
586 Region ®ion,
587 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
588 // Save the current value ids to allow for numbering values in sibling regions
589 // the same.
590 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
591 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
592 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
593
594 // Push a new used names scope.
595 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
596
597 // Number the values within this region in a breadth-first order.
598 unsigned nextBlockID = 0;
599 for (auto &block : region) {
600 // Each block gets a unique ID, and all of the operations within it get
601 // numbered as well.
602 blockIDs[&block] = nextBlockID++;
603 numberValuesInBlock(block, interfaces);
604 }
605
606 // After that we traverse the nested regions.
607 // TODO: Rework this loop to not use recursion.
608 for (auto &block : region) {
609 for (auto &op : block)
610 for (auto &nestedRegion : op.getRegions())
611 numberValuesInRegion(nestedRegion, interfaces);
612 }
613 }
614
numberValuesInBlock(Block & block,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)615 void SSANameState::numberValuesInBlock(
616 Block &block,
617 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
618 auto setArgNameFn = [&](Value arg, StringRef name) {
619 assert(!valueIDs.count(arg) && "arg numbered multiple times");
620 assert(arg.cast<BlockArgument>().getOwner() == &block &&
621 "arg not defined in 'block'");
622 setValueName(arg, name);
623 };
624
625 bool isEntryBlock = block.isEntryBlock();
626 if (isEntryBlock) {
627 if (auto *op = block.getParentOp()) {
628 if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
629 asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
630 }
631 }
632
633 // Number the block arguments. We give entry block arguments a special name
634 // 'arg'.
635 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
636 llvm::raw_svector_ostream specialName(specialNameBuffer);
637 for (auto arg : block.getArguments()) {
638 if (valueIDs.count(arg))
639 continue;
640 if (isEntryBlock) {
641 specialNameBuffer.resize(strlen("arg"));
642 specialName << nextArgumentID++;
643 }
644 setValueName(arg, specialName.str());
645 }
646
647 // Number the operations in this block.
648 for (auto &op : block)
649 numberValuesInOp(op, interfaces);
650 }
651
numberValuesInOp(Operation & op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)652 void SSANameState::numberValuesInOp(
653 Operation &op,
654 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
655 unsigned numResults = op.getNumResults();
656 if (numResults == 0)
657 return;
658 Value resultBegin = op.getResult(0);
659
660 // Function used to set the special result names for the operation.
661 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
662 auto setResultNameFn = [&](Value result, StringRef name) {
663 assert(!valueIDs.count(result) && "result numbered multiple times");
664 assert(result.getDefiningOp() == &op && "result not defined by 'op'");
665 setValueName(result, name);
666
667 // Record the result number for groups not anchored at 0.
668 if (int resultNo = result.cast<OpResult>().getResultNumber())
669 resultGroups.push_back(resultNo);
670 };
671 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
672 asmInterface.getAsmResultNames(setResultNameFn);
673 else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
674 asmInterface->getAsmResultNames(&op, setResultNameFn);
675
676 // If the first result wasn't numbered, give it a default number.
677 if (valueIDs.try_emplace(resultBegin, nextValueID).second)
678 ++nextValueID;
679
680 // If this operation has multiple result groups, mark it.
681 if (resultGroups.size() != 1) {
682 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
683 opResultGroups.try_emplace(&op, std::move(resultGroups));
684 }
685 }
686
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const687 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
688 Optional<int> &lookupResultNo) const {
689 Operation *owner = result.getOwner();
690 if (owner->getNumResults() == 1)
691 return;
692 int resultNo = result.getResultNumber();
693
694 // If this operation has multiple result groups, we will need to find the
695 // one corresponding to this result.
696 auto resultGroupIt = opResultGroups.find(owner);
697 if (resultGroupIt == opResultGroups.end()) {
698 // If not, just use the first result.
699 lookupResultNo = resultNo;
700 lookupValue = owner->getResult(0);
701 return;
702 }
703
704 // Find the correct index using a binary search, as the groups are ordered.
705 ArrayRef<int> resultGroups = resultGroupIt->second;
706 auto it = llvm::upper_bound(resultGroups, resultNo);
707 int groupResultNo = 0, groupSize = 0;
708
709 // If there are no smaller elements, the last result group is the lookup.
710 if (it == resultGroups.end()) {
711 groupResultNo = resultGroups.back();
712 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
713 } else {
714 // Otherwise, the previous element is the lookup.
715 groupResultNo = *std::prev(it);
716 groupSize = *it - groupResultNo;
717 }
718
719 // We only record the result number for a group of size greater than 1.
720 if (groupSize != 1)
721 lookupResultNo = resultNo - groupResultNo;
722 lookupValue = owner->getResult(groupResultNo);
723 }
724
setValueName(Value value,StringRef name)725 void SSANameState::setValueName(Value value, StringRef name) {
726 // If the name is empty, the value uses the default numbering.
727 if (name.empty()) {
728 valueIDs[value] = nextValueID++;
729 return;
730 }
731
732 valueIDs[value] = NameSentinel;
733 valueNames[value] = uniqueValueName(name);
734 }
735
uniqueValueName(StringRef name)736 StringRef SSANameState::uniqueValueName(StringRef name) {
737 // Check to see if this name is already unique.
738 if (!usedNames.count(name)) {
739 name = name.copy(usedNameAllocator);
740 } else {
741 // Otherwise, we had a conflict - probe until we find a unique name. This
742 // is guaranteed to terminate (and usually in a single iteration) because it
743 // generates new names by incrementing nextConflictID.
744 SmallString<64> probeName(name);
745 probeName.push_back('_');
746 while (true) {
747 probeName.resize(name.size() + 1);
748 probeName += llvm::utostr(nextConflictID++);
749 if (!usedNames.count(probeName)) {
750 name = StringRef(probeName).copy(usedNameAllocator);
751 break;
752 }
753 }
754 }
755
756 usedNames.insert(name, char());
757 return name;
758 }
759
760 //===----------------------------------------------------------------------===//
761 // AsmState
762 //===----------------------------------------------------------------------===//
763
764 namespace mlir {
765 namespace detail {
766 class AsmStateImpl {
767 public:
AsmStateImpl(Operation * op)768 explicit AsmStateImpl(Operation *op)
769 : interfaces(op->getContext()), nameState(op, interfaces) {}
770
771 /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op)772 void initializeAliases(Operation *op) {
773 aliasState.initialize(op, interfaces);
774 }
775
776 /// Get an instance of the OpAsmDialectInterface for the given dialect, or
777 /// null if one wasn't registered.
getOpAsmInterface(Dialect * dialect)778 const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
779 return interfaces.getInterfaceFor(dialect);
780 }
781
782 /// Get the state used for aliases.
getAliasState()783 AliasState &getAliasState() { return aliasState; }
784
785 /// Get the state used for SSA names.
getSSANameState()786 SSANameState &getSSANameState() { return nameState; }
787
788 private:
789 /// Collection of OpAsm interfaces implemented in the context.
790 DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
791
792 /// The state used for attribute and type aliases.
793 AliasState aliasState;
794
795 /// The state used for SSA value names.
796 SSANameState nameState;
797 };
798 } // end namespace detail
799 } // end namespace mlir
800
AsmState(Operation * op)801 AsmState::AsmState(Operation *op) : impl(std::make_unique<AsmStateImpl>(op)) {}
~AsmState()802 AsmState::~AsmState() {}
803
804 //===----------------------------------------------------------------------===//
805 // ModulePrinter
806 //===----------------------------------------------------------------------===//
807
808 namespace {
809 class ModulePrinter {
810 public:
ModulePrinter(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)811 ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
812 AsmStateImpl *state = nullptr)
813 : os(os), printerFlags(flags), state(state) {}
ModulePrinter(ModulePrinter & printer)814 explicit ModulePrinter(ModulePrinter &printer)
815 : os(printer.os), printerFlags(printer.printerFlags),
816 state(printer.state) {}
817
818 /// Returns the output stream of the printer.
getStream()819 raw_ostream &getStream() { return os; }
820
821 template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor each_fn) const822 inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
823 mlir::interleaveComma(c, os, each_fn);
824 }
825
826 /// Print the given attribute. If 'mayElideType' is true, some attributes are
827 /// printed without the type when the type matches the default used in the
828 /// parser (for example i64 is the default for integer attributes).
829 void printAttribute(Attribute attr, bool mayElideType = false);
830
831 void printType(Type type);
832 void printLocation(LocationAttr loc);
833
834 void printAffineMap(AffineMap map);
835 void
836 printAffineExpr(AffineExpr expr,
837 function_ref<void(unsigned, bool)> printValueName = nullptr);
838 void printAffineConstraint(AffineExpr expr, bool isEq);
839 void printIntegerSet(IntegerSet set);
840
841 protected:
842 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
843 ArrayRef<StringRef> elidedAttrs = {},
844 bool withKeyword = false);
845 void printTrailingLocation(Location loc);
846 void printLocationInternal(LocationAttr loc, bool pretty = false);
847 void printDenseElementsAttr(DenseElementsAttr attr);
848
849 void printDialectAttribute(Attribute attr);
850 void printDialectType(Type type);
851
852 /// This enum is used to represent the binding strength of the enclosing
853 /// context that an AffineExprStorage is being printed in, so we can
854 /// intelligently produce parens.
855 enum class BindingStrength {
856 Weak, // + and -
857 Strong, // All other binary operators.
858 };
859 void printAffineExprInternal(
860 AffineExpr expr, BindingStrength enclosingTightness,
861 function_ref<void(unsigned, bool)> printValueName = nullptr);
862
863 /// The output stream for the printer.
864 raw_ostream &os;
865
866 /// A set of flags to control the printer's behavior.
867 OpPrintingFlags printerFlags;
868
869 /// An optional printer state for the module.
870 AsmStateImpl *state;
871 };
872 } // end anonymous namespace
873
printTrailingLocation(Location loc)874 void ModulePrinter::printTrailingLocation(Location loc) {
875 // Check to see if we are printing debug information.
876 if (!printerFlags.shouldPrintDebugInfo())
877 return;
878
879 os << " ";
880 printLocation(loc);
881 }
882
printLocationInternal(LocationAttr loc,bool pretty)883 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
884 switch (loc.getKind()) {
885 case StandardAttributes::OpaqueLocation:
886 printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
887 break;
888 case StandardAttributes::UnknownLocation:
889 if (pretty)
890 os << "[unknown]";
891 else
892 os << "unknown";
893 break;
894 case StandardAttributes::FileLineColLocation: {
895 auto fileLoc = loc.cast<FileLineColLoc>();
896 auto mayQuote = pretty ? "" : "\"";
897 os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
898 << fileLoc.getLine() << ':' << fileLoc.getColumn();
899 break;
900 }
901 case StandardAttributes::NameLocation: {
902 auto nameLoc = loc.cast<NameLoc>();
903 os << '\"' << nameLoc.getName() << '\"';
904
905 // Print the child if it isn't unknown.
906 auto childLoc = nameLoc.getChildLoc();
907 if (!childLoc.isa<UnknownLoc>()) {
908 os << '(';
909 printLocationInternal(childLoc, pretty);
910 os << ')';
911 }
912 break;
913 }
914 case StandardAttributes::CallSiteLocation: {
915 auto callLocation = loc.cast<CallSiteLoc>();
916 auto caller = callLocation.getCaller();
917 auto callee = callLocation.getCallee();
918 if (!pretty)
919 os << "callsite(";
920 printLocationInternal(callee, pretty);
921 if (pretty) {
922 if (callee.isa<NameLoc>()) {
923 if (caller.isa<FileLineColLoc>()) {
924 os << " at ";
925 } else {
926 os << "\n at ";
927 }
928 } else {
929 os << "\n at ";
930 }
931 } else {
932 os << " at ";
933 }
934 printLocationInternal(caller, pretty);
935 if (!pretty)
936 os << ")";
937 break;
938 }
939 case StandardAttributes::FusedLocation: {
940 auto fusedLoc = loc.cast<FusedLoc>();
941 if (!pretty)
942 os << "fused";
943 if (auto metadata = fusedLoc.getMetadata())
944 os << '<' << metadata << '>';
945 os << '[';
946 interleave(
947 fusedLoc.getLocations(),
948 [&](Location loc) { printLocationInternal(loc, pretty); },
949 [&]() { os << ", "; });
950 os << ']';
951 break;
952 }
953 }
954 }
955
956 /// Print a floating point value in a way that the parser will be able to
957 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)958 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
959 // We would like to output the FP constant value in exponential notation,
960 // but we cannot do this if doing so will lose precision. Check here to
961 // make sure that we only output it in exponential format if we can parse
962 // the value back and get the same value.
963 bool isInf = apValue.isInfinity();
964 bool isNaN = apValue.isNaN();
965 if (!isInf && !isNaN) {
966 SmallString<128> strValue;
967 apValue.toString(strValue, 6, 0, false);
968
969 // Check to make sure that the stringized number is not some string like
970 // "Inf" or NaN, that atof will accept, but the lexer will not. Check
971 // that the string matches the "[-+]?[0-9]" regex.
972 assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
973 ((strValue[0] == '-' || strValue[0] == '+') &&
974 (strValue[1] >= '0' && strValue[1] <= '9'))) &&
975 "[-+]?[0-9] regex does not match!");
976
977 // Parse back the stringized version and check that the value is equal
978 // (i.e., there is no precision loss). If it is not, use the default format
979 // of APFloat instead of the exponential notation.
980 if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
981 strValue.clear();
982 apValue.toString(strValue);
983 }
984 os << strValue;
985 return;
986 }
987
988 // Print special values in hexadecimal format. The sign bit should be
989 // included in the literal.
990 SmallVector<char, 16> str;
991 APInt apInt = apValue.bitcastToAPInt();
992 apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
993 /*formatAsCLiteral=*/true);
994 os << str;
995 }
996
printLocation(LocationAttr loc)997 void ModulePrinter::printLocation(LocationAttr loc) {
998 if (printerFlags.shouldPrintDebugInfoPrettyForm()) {
999 printLocationInternal(loc, /*pretty=*/true);
1000 } else {
1001 os << "loc(";
1002 printLocationInternal(loc);
1003 os << ')';
1004 }
1005 }
1006
1007 /// Returns if the given dialect symbol data is simple enough to print in the
1008 /// pretty form, i.e. without the enclosing "".
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1009 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1010 // The name must start with an identifier.
1011 if (symName.empty() || !isalpha(symName.front()))
1012 return false;
1013
1014 // Ignore all the characters that are valid in an identifier in the symbol
1015 // name.
1016 symName = symName.drop_while(
1017 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1018 if (symName.empty())
1019 return true;
1020
1021 // If we got to an unexpected character, then it must be a <>. Check those
1022 // recursively.
1023 if (symName.front() != '<' || symName.back() != '>')
1024 return false;
1025
1026 SmallVector<char, 8> nestedPunctuation;
1027 do {
1028 // If we ran out of characters, then we had a punctuation mismatch.
1029 if (symName.empty())
1030 return false;
1031
1032 auto c = symName.front();
1033 symName = symName.drop_front();
1034
1035 switch (c) {
1036 // We never allow null characters. This is an EOF indicator for the lexer
1037 // which we could handle, but isn't important for any known dialect.
1038 case '\0':
1039 return false;
1040 case '<':
1041 case '[':
1042 case '(':
1043 case '{':
1044 nestedPunctuation.push_back(c);
1045 continue;
1046 case '-':
1047 // Treat `->` as a special token.
1048 if (!symName.empty() && symName.front() == '>') {
1049 symName = symName.drop_front();
1050 continue;
1051 }
1052 break;
1053 // Reject types with mismatched brackets.
1054 case '>':
1055 if (nestedPunctuation.pop_back_val() != '<')
1056 return false;
1057 break;
1058 case ']':
1059 if (nestedPunctuation.pop_back_val() != '[')
1060 return false;
1061 break;
1062 case ')':
1063 if (nestedPunctuation.pop_back_val() != '(')
1064 return false;
1065 break;
1066 case '}':
1067 if (nestedPunctuation.pop_back_val() != '{')
1068 return false;
1069 break;
1070 default:
1071 continue;
1072 }
1073
1074 // We're done when the punctuation is fully matched.
1075 } while (!nestedPunctuation.empty());
1076
1077 // If there were extra characters, then we failed.
1078 return symName.empty();
1079 }
1080
1081 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1082 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1083 StringRef dialectName, StringRef symString) {
1084 os << symPrefix << dialectName;
1085
1086 // If this symbol name is simple enough, print it directly in pretty form,
1087 // otherwise, we print it as an escaped string.
1088 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1089 os << '.' << symString;
1090 return;
1091 }
1092
1093 // TODO: escape the symbol name, it could contain " characters.
1094 os << "<\"" << symString << "\">";
1095 }
1096
1097 /// Returns if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1098 static bool isBareIdentifier(StringRef name) {
1099 assert(!name.empty() && "invalid name");
1100
1101 // By making this unsigned, the value passed in to isalnum will always be
1102 // in the range 0-255. This is important when building with MSVC because
1103 // its implementation will assert. This situation can arise when dealing
1104 // with UTF-8 multibyte characters.
1105 unsigned char firstChar = static_cast<unsigned char>(name[0]);
1106 if (!isalpha(firstChar) && firstChar != '_')
1107 return false;
1108 return llvm::all_of(name.drop_front(), [](unsigned char c) {
1109 return isalnum(c) || c == '_' || c == '$' || c == '.';
1110 });
1111 }
1112
1113 /// Print the given string as a symbol reference. A symbol reference is
1114 /// represented as a string prefixed with '@'. The reference is surrounded with
1115 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1116 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1117 assert(!symbolRef.empty() && "expected valid symbol reference");
1118
1119 // If the symbol can be represented as a bare identifier, write it directly.
1120 if (isBareIdentifier(symbolRef)) {
1121 os << '@' << symbolRef;
1122 return;
1123 }
1124
1125 // Otherwise, output the reference wrapped in quotes with proper escaping.
1126 os << "@\"";
1127 printEscapedString(symbolRef, os);
1128 os << '"';
1129 }
1130
1131 // Print out a valid ElementsAttr that is succinct and can represent any
1132 // potential shape/type, for use when eliding a large ElementsAttr.
1133 //
1134 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1135 // hopefully alert readers to the fact that this has been elided.
1136 //
1137 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1138 // accept the string "elided". The first string must be a registered dialect
1139 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1140 static void printElidedElementsAttr(raw_ostream &os) {
1141 os << R"(opaque<"", "0xDEADBEEF">)";
1142 }
1143
printAttribute(Attribute attr,bool mayElideType)1144 void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
1145 if (!attr) {
1146 os << "<<NULL ATTRIBUTE>>";
1147 return;
1148 }
1149
1150 // Check for an alias for this attribute.
1151 if (state) {
1152 Twine alias = state->getAliasState().getAttributeAlias(attr);
1153 if (!alias.isTriviallyEmpty()) {
1154 os << '#' << alias;
1155 return;
1156 }
1157 }
1158
1159 switch (attr.getKind()) {
1160 default:
1161 return printDialectAttribute(attr);
1162
1163 case StandardAttributes::Opaque: {
1164 auto opaqueAttr = attr.cast<OpaqueAttr>();
1165 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1166 opaqueAttr.getAttrData());
1167 break;
1168 }
1169 case StandardAttributes::Unit:
1170 os << "unit";
1171 break;
1172 case StandardAttributes::Bool:
1173 os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
1174
1175 // BoolAttr always elides the type.
1176 return;
1177 case StandardAttributes::Dictionary:
1178 os << '{';
1179 interleaveComma(attr.cast<DictionaryAttr>().getValue(),
1180 [&](NamedAttribute attr) {
1181 os << attr.first;
1182
1183 // The value of a UnitAttr is elided within a dictionary.
1184 if (attr.second.isa<UnitAttr>())
1185 return;
1186
1187 os << " = ";
1188 printAttribute(attr.second);
1189 });
1190 os << '}';
1191 break;
1192 case StandardAttributes::Integer: {
1193 auto intAttr = attr.cast<IntegerAttr>();
1194 // Print all integer attributes as signed unless i1.
1195 bool isSigned = intAttr.getType().isIndex() ||
1196 intAttr.getType().getIntOrFloatBitWidth() != 1;
1197 intAttr.getValue().print(os, isSigned);
1198
1199 // IntegerAttr elides the type if I64.
1200 if (mayElideType && intAttr.getType().isInteger(64))
1201 return;
1202 break;
1203 }
1204 case StandardAttributes::Float: {
1205 auto floatAttr = attr.cast<FloatAttr>();
1206 printFloatValue(floatAttr.getValue(), os);
1207
1208 // FloatAttr elides the type if F64.
1209 if (mayElideType && floatAttr.getType().isF64())
1210 return;
1211 break;
1212 }
1213 case StandardAttributes::String:
1214 os << '"';
1215 printEscapedString(attr.cast<StringAttr>().getValue(), os);
1216 os << '"';
1217 break;
1218 case StandardAttributes::Array:
1219 os << '[';
1220 interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
1221 printAttribute(attr, /*mayElideType=*/true);
1222 });
1223 os << ']';
1224 break;
1225 case StandardAttributes::AffineMap:
1226 os << "affine_map<";
1227 attr.cast<AffineMapAttr>().getValue().print(os);
1228 os << '>';
1229
1230 // AffineMap always elides the type.
1231 return;
1232 case StandardAttributes::IntegerSet:
1233 os << "affine_set<";
1234 attr.cast<IntegerSetAttr>().getValue().print(os);
1235 os << '>';
1236
1237 // IntegerSet always elides the type.
1238 return;
1239 case StandardAttributes::Type:
1240 printType(attr.cast<TypeAttr>().getValue());
1241 break;
1242 case StandardAttributes::SymbolRef: {
1243 auto refAttr = attr.dyn_cast<SymbolRefAttr>();
1244 printSymbolReference(refAttr.getRootReference(), os);
1245 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1246 os << "::";
1247 printSymbolReference(nestedRef.getValue(), os);
1248 }
1249 break;
1250 }
1251 case StandardAttributes::OpaqueElements: {
1252 auto eltsAttr = attr.cast<OpaqueElementsAttr>();
1253 if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1254 printElidedElementsAttr(os);
1255 break;
1256 }
1257 os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
1258 os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
1259 break;
1260 }
1261 case StandardAttributes::DenseElements: {
1262 auto eltsAttr = attr.cast<DenseElementsAttr>();
1263 if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
1264 printElidedElementsAttr(os);
1265 break;
1266 }
1267 os << "dense<";
1268 printDenseElementsAttr(eltsAttr);
1269 os << '>';
1270 break;
1271 }
1272 case StandardAttributes::SparseElements: {
1273 auto elementsAttr = attr.cast<SparseElementsAttr>();
1274 if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
1275 printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
1276 printElidedElementsAttr(os);
1277 break;
1278 }
1279 os << "sparse<";
1280 printDenseElementsAttr(elementsAttr.getIndices());
1281 os << ", ";
1282 printDenseElementsAttr(elementsAttr.getValues());
1283 os << '>';
1284 break;
1285 }
1286
1287 // Location attributes.
1288 case StandardAttributes::CallSiteLocation:
1289 case StandardAttributes::FileLineColLocation:
1290 case StandardAttributes::FusedLocation:
1291 case StandardAttributes::NameLocation:
1292 case StandardAttributes::OpaqueLocation:
1293 case StandardAttributes::UnknownLocation:
1294 printLocation(attr.cast<LocationAttr>());
1295 break;
1296 }
1297
1298 // Print the type if it isn't a 'none' type.
1299 auto attrType = attr.getType();
1300 if (!attrType.isa<NoneType>()) {
1301 os << " : ";
1302 printType(attrType);
1303 }
1304 }
1305
1306 /// Print the integer element of the given DenseElementsAttr at 'index'.
printDenseIntElement(DenseElementsAttr attr,raw_ostream & os,unsigned index)1307 static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
1308 unsigned index) {
1309 APInt value = *std::next(attr.int_value_begin(), index);
1310 if (value.getBitWidth() == 1)
1311 os << (value.getBoolValue() ? "true" : "false");
1312 else
1313 value.print(os, /*isSigned=*/true);
1314 }
1315
1316 /// Print the float element of the given DenseElementsAttr at 'index'.
printDenseFloatElement(DenseElementsAttr attr,raw_ostream & os,unsigned index)1317 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
1318 unsigned index) {
1319 APFloat value = *std::next(attr.float_value_begin(), index);
1320 printFloatValue(value, os);
1321 }
1322
printDenseElementsAttr(DenseElementsAttr attr)1323 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
1324 auto type = attr.getType();
1325 auto shape = type.getShape();
1326 auto rank = type.getRank();
1327
1328 // The function used to print elements of this attribute.
1329 auto printEltFn = type.getElementType().isa<IntegerType>()
1330 ? printDenseIntElement
1331 : printDenseFloatElement;
1332
1333 // Special case for 0-d and splat tensors.
1334 if (attr.isSplat()) {
1335 printEltFn(attr, os, 0);
1336 return;
1337 }
1338
1339 // Special case for degenerate tensors.
1340 auto numElements = type.getNumElements();
1341 if (numElements == 0) {
1342 for (int i = 0; i < rank; ++i)
1343 os << '[';
1344 for (int i = 0; i < rank; ++i)
1345 os << ']';
1346 return;
1347 }
1348
1349 // We use a mixed-radix counter to iterate through the shape. When we bump a
1350 // non-least-significant digit, we emit a close bracket. When we next emit an
1351 // element we re-open all closed brackets.
1352
1353 // The mixed-radix counter, with radices in 'shape'.
1354 SmallVector<unsigned, 4> counter(rank, 0);
1355 // The number of brackets that have been opened and not closed.
1356 unsigned openBrackets = 0;
1357
1358 auto bumpCounter = [&]() {
1359 // Bump the least significant digit.
1360 ++counter[rank - 1];
1361 // Iterate backwards bubbling back the increment.
1362 for (unsigned i = rank - 1; i > 0; --i)
1363 if (counter[i] >= shape[i]) {
1364 // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1365 counter[i] = 0;
1366 ++counter[i - 1];
1367 --openBrackets;
1368 os << ']';
1369 }
1370 };
1371
1372 for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1373 if (idx != 0)
1374 os << ", ";
1375 while (openBrackets++ < rank)
1376 os << '[';
1377 openBrackets = rank;
1378 printEltFn(attr, os, idx);
1379 bumpCounter();
1380 }
1381 while (openBrackets-- > 0)
1382 os << ']';
1383 }
1384
printType(Type type)1385 void ModulePrinter::printType(Type type) {
1386 // Check for an alias for this type.
1387 if (state) {
1388 StringRef alias = state->getAliasState().getTypeAlias(type);
1389 if (!alias.empty()) {
1390 os << '!' << alias;
1391 return;
1392 }
1393 }
1394
1395 switch (type.getKind()) {
1396 default:
1397 return printDialectType(type);
1398
1399 case Type::Kind::Opaque: {
1400 auto opaqueTy = type.cast<OpaqueType>();
1401 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1402 opaqueTy.getTypeData());
1403 return;
1404 }
1405 case StandardTypes::Index:
1406 os << "index";
1407 return;
1408 case StandardTypes::BF16:
1409 os << "bf16";
1410 return;
1411 case StandardTypes::F16:
1412 os << "f16";
1413 return;
1414 case StandardTypes::F32:
1415 os << "f32";
1416 return;
1417 case StandardTypes::F64:
1418 os << "f64";
1419 return;
1420
1421 case StandardTypes::Integer: {
1422 auto integer = type.cast<IntegerType>();
1423 os << 'i' << integer.getWidth();
1424 return;
1425 }
1426 case Type::Kind::Function: {
1427 auto func = type.cast<FunctionType>();
1428 os << '(';
1429 interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
1430 os << ") -> ";
1431 auto results = func.getResults();
1432 if (results.size() == 1 && !results[0].isa<FunctionType>())
1433 os << results[0];
1434 else {
1435 os << '(';
1436 interleaveComma(results, [&](Type type) { printType(type); });
1437 os << ')';
1438 }
1439 return;
1440 }
1441 case StandardTypes::Vector: {
1442 auto v = type.cast<VectorType>();
1443 os << "vector<";
1444 for (auto dim : v.getShape())
1445 os << dim << 'x';
1446 os << v.getElementType() << '>';
1447 return;
1448 }
1449 case StandardTypes::RankedTensor: {
1450 auto v = type.cast<RankedTensorType>();
1451 os << "tensor<";
1452 for (auto dim : v.getShape()) {
1453 if (dim < 0)
1454 os << '?';
1455 else
1456 os << dim;
1457 os << 'x';
1458 }
1459 os << v.getElementType() << '>';
1460 return;
1461 }
1462 case StandardTypes::UnrankedTensor: {
1463 auto v = type.cast<UnrankedTensorType>();
1464 os << "tensor<*x";
1465 printType(v.getElementType());
1466 os << '>';
1467 return;
1468 }
1469 case StandardTypes::MemRef: {
1470 auto v = type.cast<MemRefType>();
1471 os << "memref<";
1472 for (auto dim : v.getShape()) {
1473 if (dim < 0)
1474 os << '?';
1475 else
1476 os << dim;
1477 os << 'x';
1478 }
1479 printType(v.getElementType());
1480 for (auto map : v.getAffineMaps()) {
1481 os << ", ";
1482 printAttribute(AffineMapAttr::get(map));
1483 }
1484 // Only print the memory space if it is the non-default one.
1485 if (v.getMemorySpace())
1486 os << ", " << v.getMemorySpace();
1487 os << '>';
1488 return;
1489 }
1490 case StandardTypes::UnrankedMemRef: {
1491 auto v = type.cast<UnrankedMemRefType>();
1492 os << "memref<*x";
1493 printType(v.getElementType());
1494 os << '>';
1495 return;
1496 }
1497 case StandardTypes::Complex:
1498 os << "complex<";
1499 printType(type.cast<ComplexType>().getElementType());
1500 os << '>';
1501 return;
1502 case StandardTypes::Tuple: {
1503 auto tuple = type.cast<TupleType>();
1504 os << "tuple<";
1505 interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
1506 os << '>';
1507 return;
1508 }
1509 case StandardTypes::None:
1510 os << "none";
1511 return;
1512 }
1513 }
1514
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)1515 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1516 ArrayRef<StringRef> elidedAttrs,
1517 bool withKeyword) {
1518 // If there are no attributes, then there is nothing to be done.
1519 if (attrs.empty())
1520 return;
1521
1522 // Filter out any attributes that shouldn't be included.
1523 SmallVector<NamedAttribute, 8> filteredAttrs(
1524 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1525 return !llvm::is_contained(elidedAttrs, attr.first.strref());
1526 }));
1527
1528 // If there are no attributes left to print after filtering, then we're done.
1529 if (filteredAttrs.empty())
1530 return;
1531
1532 // Print the 'attributes' keyword if necessary.
1533 if (withKeyword)
1534 os << " attributes";
1535
1536 // Otherwise, print them all out in braces.
1537 os << " {";
1538 interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
1539 os << attr.first;
1540
1541 // Pretty printing elides the attribute value for unit attributes.
1542 if (attr.second.isa<UnitAttr>())
1543 return;
1544
1545 os << " = ";
1546 printAttribute(attr.second);
1547 });
1548 os << '}';
1549 }
1550
1551 //===----------------------------------------------------------------------===//
1552 // CustomDialectAsmPrinter
1553 //===----------------------------------------------------------------------===//
1554
1555 namespace {
1556 /// This class provides the main specialization of the DialectAsmPrinter that is
1557 /// used to provide support for print attributes and types. This hooks allows
1558 /// for dialects to hook into the main ModulePrinter.
1559 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
1560 public:
CustomDialectAsmPrinter__anonbb6a0fc11611::CustomDialectAsmPrinter1561 CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
~CustomDialectAsmPrinter__anonbb6a0fc11611::CustomDialectAsmPrinter1562 ~CustomDialectAsmPrinter() override {}
1563
getStream__anonbb6a0fc11611::CustomDialectAsmPrinter1564 raw_ostream &getStream() const override { return printer.getStream(); }
1565
1566 /// Print the given attribute to the stream.
printAttribute__anonbb6a0fc11611::CustomDialectAsmPrinter1567 void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
1568
1569 /// Print the given floating point value in a stablized form.
printFloat__anonbb6a0fc11611::CustomDialectAsmPrinter1570 void printFloat(const APFloat &value) override {
1571 printFloatValue(value, getStream());
1572 }
1573
1574 /// Print the given type to the stream.
printType__anonbb6a0fc11611::CustomDialectAsmPrinter1575 void printType(Type type) override { printer.printType(type); }
1576
1577 /// The main module printer.
1578 ModulePrinter &printer;
1579 };
1580 } // end anonymous namespace
1581
printDialectAttribute(Attribute attr)1582 void ModulePrinter::printDialectAttribute(Attribute attr) {
1583 auto &dialect = attr.getDialect();
1584
1585 // Ask the dialect to serialize the attribute to a string.
1586 std::string attrName;
1587 {
1588 llvm::raw_string_ostream attrNameStr(attrName);
1589 ModulePrinter subPrinter(attrNameStr, printerFlags, state);
1590 CustomDialectAsmPrinter printer(subPrinter);
1591 dialect.printAttribute(attr, printer);
1592 }
1593 printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
1594 }
1595
printDialectType(Type type)1596 void ModulePrinter::printDialectType(Type type) {
1597 auto &dialect = type.getDialect();
1598
1599 // Ask the dialect to serialize the type to a string.
1600 std::string typeName;
1601 {
1602 llvm::raw_string_ostream typeNameStr(typeName);
1603 ModulePrinter subPrinter(typeNameStr, printerFlags, state);
1604 CustomDialectAsmPrinter printer(subPrinter);
1605 dialect.printType(type, printer);
1606 }
1607 printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
1608 }
1609
1610 //===----------------------------------------------------------------------===//
1611 // Affine expressions and maps
1612 //===----------------------------------------------------------------------===//
1613
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)1614 void ModulePrinter::printAffineExpr(
1615 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
1616 printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
1617 }
1618
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)1619 void ModulePrinter::printAffineExprInternal(
1620 AffineExpr expr, BindingStrength enclosingTightness,
1621 function_ref<void(unsigned, bool)> printValueName) {
1622 const char *binopSpelling = nullptr;
1623 switch (expr.getKind()) {
1624 case AffineExprKind::SymbolId: {
1625 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
1626 if (printValueName)
1627 printValueName(pos, /*isSymbol=*/true);
1628 else
1629 os << 's' << pos;
1630 return;
1631 }
1632 case AffineExprKind::DimId: {
1633 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
1634 if (printValueName)
1635 printValueName(pos, /*isSymbol=*/false);
1636 else
1637 os << 'd' << pos;
1638 return;
1639 }
1640 case AffineExprKind::Constant:
1641 os << expr.cast<AffineConstantExpr>().getValue();
1642 return;
1643 case AffineExprKind::Add:
1644 binopSpelling = " + ";
1645 break;
1646 case AffineExprKind::Mul:
1647 binopSpelling = " * ";
1648 break;
1649 case AffineExprKind::FloorDiv:
1650 binopSpelling = " floordiv ";
1651 break;
1652 case AffineExprKind::CeilDiv:
1653 binopSpelling = " ceildiv ";
1654 break;
1655 case AffineExprKind::Mod:
1656 binopSpelling = " mod ";
1657 break;
1658 }
1659
1660 auto binOp = expr.cast<AffineBinaryOpExpr>();
1661 AffineExpr lhsExpr = binOp.getLHS();
1662 AffineExpr rhsExpr = binOp.getRHS();
1663
1664 // Handle tightly binding binary operators.
1665 if (binOp.getKind() != AffineExprKind::Add) {
1666 if (enclosingTightness == BindingStrength::Strong)
1667 os << '(';
1668
1669 // Pretty print multiplication with -1.
1670 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
1671 if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
1672 rhsConst.getValue() == -1) {
1673 os << "-";
1674 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
1675 if (enclosingTightness == BindingStrength::Strong)
1676 os << ')';
1677 return;
1678 }
1679
1680 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
1681
1682 os << binopSpelling;
1683 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
1684
1685 if (enclosingTightness == BindingStrength::Strong)
1686 os << ')';
1687 return;
1688 }
1689
1690 // Print out special "pretty" forms for add.
1691 if (enclosingTightness == BindingStrength::Strong)
1692 os << '(';
1693
1694 // Pretty print addition to a product that has a negative operand as a
1695 // subtraction.
1696 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
1697 if (rhs.getKind() == AffineExprKind::Mul) {
1698 AffineExpr rrhsExpr = rhs.getRHS();
1699 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
1700 if (rrhs.getValue() == -1) {
1701 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
1702 printValueName);
1703 os << " - ";
1704 if (rhs.getLHS().getKind() == AffineExprKind::Add) {
1705 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
1706 printValueName);
1707 } else {
1708 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
1709 printValueName);
1710 }
1711
1712 if (enclosingTightness == BindingStrength::Strong)
1713 os << ')';
1714 return;
1715 }
1716
1717 if (rrhs.getValue() < -1) {
1718 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
1719 printValueName);
1720 os << " - ";
1721 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
1722 printValueName);
1723 os << " * " << -rrhs.getValue();
1724 if (enclosingTightness == BindingStrength::Strong)
1725 os << ')';
1726 return;
1727 }
1728 }
1729 }
1730 }
1731
1732 // Pretty print addition to a negative number as a subtraction.
1733 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
1734 if (rhsConst.getValue() < 0) {
1735 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
1736 os << " - " << -rhsConst.getValue();
1737 if (enclosingTightness == BindingStrength::Strong)
1738 os << ')';
1739 return;
1740 }
1741 }
1742
1743 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
1744
1745 os << " + ";
1746 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
1747
1748 if (enclosingTightness == BindingStrength::Strong)
1749 os << ')';
1750 }
1751
printAffineConstraint(AffineExpr expr,bool isEq)1752 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
1753 printAffineExprInternal(expr, BindingStrength::Weak);
1754 isEq ? os << " == 0" : os << " >= 0";
1755 }
1756
printAffineMap(AffineMap map)1757 void ModulePrinter::printAffineMap(AffineMap map) {
1758 // Dimension identifiers.
1759 os << '(';
1760 for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
1761 os << 'd' << i << ", ";
1762 if (map.getNumDims() >= 1)
1763 os << 'd' << map.getNumDims() - 1;
1764 os << ')';
1765
1766 // Symbolic identifiers.
1767 if (map.getNumSymbols() != 0) {
1768 os << '[';
1769 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
1770 os << 's' << i << ", ";
1771 if (map.getNumSymbols() >= 1)
1772 os << 's' << map.getNumSymbols() - 1;
1773 os << ']';
1774 }
1775
1776 // Result affine expressions.
1777 os << " -> (";
1778 interleaveComma(map.getResults(),
1779 [&](AffineExpr expr) { printAffineExpr(expr); });
1780 os << ')';
1781 }
1782
printIntegerSet(IntegerSet set)1783 void ModulePrinter::printIntegerSet(IntegerSet set) {
1784 // Dimension identifiers.
1785 os << '(';
1786 for (unsigned i = 1; i < set.getNumDims(); ++i)
1787 os << 'd' << i - 1 << ", ";
1788 if (set.getNumDims() >= 1)
1789 os << 'd' << set.getNumDims() - 1;
1790 os << ')';
1791
1792 // Symbolic identifiers.
1793 if (set.getNumSymbols() != 0) {
1794 os << '[';
1795 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
1796 os << 's' << i << ", ";
1797 if (set.getNumSymbols() >= 1)
1798 os << 's' << set.getNumSymbols() - 1;
1799 os << ']';
1800 }
1801
1802 // Print constraints.
1803 os << " : (";
1804 int numConstraints = set.getNumConstraints();
1805 for (int i = 1; i < numConstraints; ++i) {
1806 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
1807 os << ", ";
1808 }
1809 if (numConstraints >= 1)
1810 printAffineConstraint(set.getConstraint(numConstraints - 1),
1811 set.isEq(numConstraints - 1));
1812 os << ')';
1813 }
1814
1815 //===----------------------------------------------------------------------===//
1816 // OperationPrinter
1817 //===----------------------------------------------------------------------===//
1818
1819 namespace {
1820 /// This class contains the logic for printing operations, regions, and blocks.
1821 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
1822 public:
OperationPrinter(raw_ostream & os,OpPrintingFlags flags,AsmStateImpl & state)1823 explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
1824 AsmStateImpl &state)
1825 : ModulePrinter(os, flags, &state) {}
1826
1827 /// Print the given top-level module.
1828 void print(ModuleOp op);
1829 /// Print the given operation with its indent and location.
1830 void print(Operation *op);
1831 /// Print the bare location, not including indentation/location/etc.
1832 void printOperation(Operation *op);
1833 /// Print the given operation in the generic form.
1834 void printGenericOp(Operation *op) override;
1835
1836 /// Print the name of the given block.
1837 void printBlockName(Block *block);
1838
1839 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
1840 /// block are not printed. If 'printBlockTerminator' is false, the terminator
1841 /// operation of the block is not printed.
1842 void print(Block *block, bool printBlockArgs = true,
1843 bool printBlockTerminator = true);
1844
1845 /// Print the ID of the given value, optionally with its result number.
1846 void printValueID(Value value, bool printResultNo = true) const;
1847
1848 //===--------------------------------------------------------------------===//
1849 // OpAsmPrinter methods
1850 //===--------------------------------------------------------------------===//
1851
1852 /// Return the current stream of the printer.
getStream() const1853 raw_ostream &getStream() const override { return os; }
1854
1855 /// Print the given type.
printType(Type type)1856 void printType(Type type) override { ModulePrinter::printType(type); }
1857
1858 /// Print the given attribute.
printAttribute(Attribute attr)1859 void printAttribute(Attribute attr) override {
1860 ModulePrinter::printAttribute(attr);
1861 }
1862
1863 /// Print the ID for the given value.
printOperand(Value value)1864 void printOperand(Value value) override { printValueID(value); }
1865
1866 /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})1867 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1868 ArrayRef<StringRef> elidedAttrs = {}) override {
1869 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
1870 }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})1871 void printOptionalAttrDictWithKeyword(
1872 ArrayRef<NamedAttribute> attrs,
1873 ArrayRef<StringRef> elidedAttrs = {}) override {
1874 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
1875 /*withKeyword=*/true);
1876 }
1877
1878 /// Print an operation successor with the operands used for the block
1879 /// arguments.
1880 void printSuccessorAndUseList(Operation *term, unsigned index) override;
1881
1882 /// Print the given region.
1883 void printRegion(Region ®ion, bool printEntryBlockArgs,
1884 bool printBlockTerminators) override;
1885
1886 /// Renumber the arguments for the specified region to the same names as the
1887 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
1888 /// operations. If any entry in namesToUse is null, the corresponding
1889 /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)1890 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override {
1891 state->getSSANameState().shadowRegionArgs(region, namesToUse);
1892 }
1893
1894 /// Print the given affine map with the smybol and dimension operands printed
1895 /// inline with the map.
1896 void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
1897 ValueRange operands) override;
1898
1899 /// Print the given string as a symbol reference.
printSymbolName(StringRef symbolRef)1900 void printSymbolName(StringRef symbolRef) override {
1901 ::printSymbolReference(symbolRef, os);
1902 }
1903
1904 private:
1905 /// The number of spaces used for indenting nested operations.
1906 const static unsigned indentWidth = 2;
1907
1908 // This is the current indentation level for nested structures.
1909 unsigned currentIndent = 0;
1910 };
1911 } // end anonymous namespace
1912
print(ModuleOp op)1913 void OperationPrinter::print(ModuleOp op) {
1914 // Output the aliases at the top level.
1915 state->getAliasState().printAttributeAliases(os);
1916 state->getAliasState().printTypeAliases(os);
1917
1918 // Print the module.
1919 print(op.getOperation());
1920 }
1921
print(Operation * op)1922 void OperationPrinter::print(Operation *op) {
1923 os.indent(currentIndent);
1924 printOperation(op);
1925 printTrailingLocation(op->getLoc());
1926 }
1927
printOperation(Operation * op)1928 void OperationPrinter::printOperation(Operation *op) {
1929 if (size_t numResults = op->getNumResults()) {
1930 auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
1931 printValueID(op->getResult(resultNo), /*printResultNo=*/false);
1932 if (resultCount > 1)
1933 os << ':' << resultCount;
1934 };
1935
1936 // Check to see if this operation has multiple result groups.
1937 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
1938 if (!resultGroups.empty()) {
1939 // Interleave the groups excluding the last one, this one will be handled
1940 // separately.
1941 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
1942 printResultGroup(resultGroups[i],
1943 resultGroups[i + 1] - resultGroups[i]);
1944 });
1945 os << ", ";
1946 printResultGroup(resultGroups.back(), numResults - resultGroups.back());
1947
1948 } else {
1949 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
1950 }
1951
1952 os << " = ";
1953 }
1954
1955 // If requested, always print the generic form.
1956 if (!printerFlags.shouldPrintGenericOpForm()) {
1957 // Check to see if this is a known operation. If so, use the registered
1958 // custom printer hook.
1959 if (auto *opInfo = op->getAbstractOperation()) {
1960 opInfo->printAssembly(op, *this);
1961 return;
1962 }
1963 }
1964
1965 // Otherwise print with the generic assembly form.
1966 printGenericOp(op);
1967 }
1968
printGenericOp(Operation * op)1969 void OperationPrinter::printGenericOp(Operation *op) {
1970 os << '"';
1971 printEscapedString(op->getName().getStringRef(), os);
1972 os << "\"(";
1973
1974 // Get the list of operands that are not successor operands.
1975 unsigned totalNumSuccessorOperands = 0;
1976 unsigned numSuccessors = op->getNumSuccessors();
1977 for (unsigned i = 0; i < numSuccessors; ++i)
1978 totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
1979 unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
1980 interleaveComma(op->getOperands().take_front(numProperOperands),
1981 [&](Value value) { printValueID(value); });
1982
1983 os << ')';
1984
1985 // For terminators, print the list of successors and their operands.
1986 if (numSuccessors != 0) {
1987 os << '[';
1988 interleaveComma(llvm::seq<unsigned>(0, numSuccessors),
1989 [&](unsigned i) { printSuccessorAndUseList(op, i); });
1990 os << ']';
1991 }
1992
1993 // Print regions.
1994 if (op->getNumRegions() != 0) {
1995 os << " (";
1996 interleaveComma(op->getRegions(), [&](Region ®ion) {
1997 printRegion(region, /*printEntryBlockArgs=*/true,
1998 /*printBlockTerminators=*/true);
1999 });
2000 os << ')';
2001 }
2002
2003 auto attrs = op->getAttrs();
2004 printOptionalAttrDict(attrs);
2005
2006 // Print the type signature of the operation.
2007 os << " : ";
2008 printFunctionalType(op);
2009 }
2010
printBlockName(Block * block)2011 void OperationPrinter::printBlockName(Block *block) {
2012 auto id = state->getSSANameState().getBlockID(block);
2013 if (id != SSANameState::NameSentinel)
2014 os << "^bb" << id;
2015 else
2016 os << "^INVALIDBLOCK";
2017 }
2018
print(Block * block,bool printBlockArgs,bool printBlockTerminator)2019 void OperationPrinter::print(Block *block, bool printBlockArgs,
2020 bool printBlockTerminator) {
2021 // Print the block label and argument list if requested.
2022 if (printBlockArgs) {
2023 os.indent(currentIndent);
2024 printBlockName(block);
2025
2026 // Print the argument list if non-empty.
2027 if (!block->args_empty()) {
2028 os << '(';
2029 interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2030 printValueID(arg);
2031 os << ": ";
2032 printType(arg.getType());
2033 });
2034 os << ')';
2035 }
2036 os << ':';
2037
2038 // Print out some context information about the predecessors of this block.
2039 if (!block->getParent()) {
2040 os << "\t// block is not in a region!";
2041 } else if (block->hasNoPredecessors()) {
2042 os << "\t// no predecessors";
2043 } else if (auto *pred = block->getSinglePredecessor()) {
2044 os << "\t// pred: ";
2045 printBlockName(pred);
2046 } else {
2047 // We want to print the predecessors in increasing numeric order, not in
2048 // whatever order the use-list is in, so gather and sort them.
2049 SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2050 for (auto *pred : block->getPredecessors())
2051 predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2052 llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2053
2054 os << "\t// " << predIDs.size() << " preds: ";
2055
2056 interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2057 printBlockName(pred.second);
2058 });
2059 }
2060 os << '\n';
2061 }
2062
2063 currentIndent += indentWidth;
2064 auto range = llvm::make_range(
2065 block->getOperations().begin(),
2066 std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
2067 for (auto &op : range) {
2068 print(&op);
2069 os << '\n';
2070 }
2071 currentIndent -= indentWidth;
2072 }
2073
printValueID(Value value,bool printResultNo) const2074 void OperationPrinter::printValueID(Value value, bool printResultNo) const {
2075 state->getSSANameState().printValueID(value, printResultNo, os);
2076 }
2077
printSuccessorAndUseList(Operation * term,unsigned index)2078 void OperationPrinter::printSuccessorAndUseList(Operation *term,
2079 unsigned index) {
2080 printBlockName(term->getSuccessor(index));
2081
2082 auto succOperands = term->getSuccessorOperands(index);
2083 if (succOperands.begin() == succOperands.end())
2084 return;
2085
2086 os << '(';
2087 interleaveComma(succOperands,
2088 [this](Value operand) { printValueID(operand); });
2089 os << " : ";
2090 interleaveComma(succOperands,
2091 [this](Value operand) { printType(operand.getType()); });
2092 os << ')';
2093 }
2094
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators)2095 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
2096 bool printBlockTerminators) {
2097 os << " {\n";
2098 if (!region.empty()) {
2099 auto *entryBlock = ®ion.front();
2100 print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
2101 printBlockTerminators);
2102 for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2103 print(&b);
2104 }
2105 os.indent(currentIndent) << "}";
2106 }
2107
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)2108 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2109 ValueRange operands) {
2110 AffineMap map = mapAttr.getValue();
2111 unsigned numDims = map.getNumDims();
2112 auto printValueName = [&](unsigned pos, bool isSymbol) {
2113 unsigned index = isSymbol ? numDims + pos : pos;
2114 assert(index < operands.size());
2115 if (isSymbol)
2116 os << "symbol(";
2117 printValueID(operands[index]);
2118 if (isSymbol)
2119 os << ')';
2120 };
2121
2122 interleaveComma(map.getResults(), [&](AffineExpr expr) {
2123 printAffineExpr(expr, printValueName);
2124 });
2125 }
2126
2127 //===----------------------------------------------------------------------===//
2128 // print and dump methods
2129 //===----------------------------------------------------------------------===//
2130
print(raw_ostream & os) const2131 void Attribute::print(raw_ostream &os) const {
2132 ModulePrinter(os).printAttribute(*this);
2133 }
2134
dump() const2135 void Attribute::dump() const {
2136 print(llvm::errs());
2137 llvm::errs() << "\n";
2138 }
2139
print(raw_ostream & os)2140 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2141
dump()2142 void Type::dump() { print(llvm::errs()); }
2143
dump() const2144 void AffineMap::dump() const {
2145 print(llvm::errs());
2146 llvm::errs() << "\n";
2147 }
2148
dump() const2149 void IntegerSet::dump() const {
2150 print(llvm::errs());
2151 llvm::errs() << "\n";
2152 }
2153
print(raw_ostream & os) const2154 void AffineExpr::print(raw_ostream &os) const {
2155 if (expr == nullptr) {
2156 os << "null affine expr";
2157 return;
2158 }
2159 ModulePrinter(os).printAffineExpr(*this);
2160 }
2161
dump() const2162 void AffineExpr::dump() const {
2163 print(llvm::errs());
2164 llvm::errs() << "\n";
2165 }
2166
print(raw_ostream & os) const2167 void AffineMap::print(raw_ostream &os) const {
2168 if (map == nullptr) {
2169 os << "null affine map";
2170 return;
2171 }
2172 ModulePrinter(os).printAffineMap(*this);
2173 }
2174
print(raw_ostream & os) const2175 void IntegerSet::print(raw_ostream &os) const {
2176 ModulePrinter(os).printIntegerSet(*this);
2177 }
2178
print(raw_ostream & os)2179 void Value::print(raw_ostream &os) {
2180 if (auto *op = getDefiningOp())
2181 return op->print(os);
2182 // TODO: Improve this.
2183 assert(isa<BlockArgument>());
2184 os << "<block argument>\n";
2185 }
print(raw_ostream & os,AsmState & state)2186 void Value::print(raw_ostream &os, AsmState &state) {
2187 if (auto *op = getDefiningOp())
2188 return op->print(os, state);
2189
2190 // TODO: Improve this.
2191 assert(isa<BlockArgument>());
2192 os << "<block argument>\n";
2193 }
2194
dump()2195 void Value::dump() {
2196 print(llvm::errs());
2197 llvm::errs() << "\n";
2198 }
2199
printAsOperand(raw_ostream & os,AsmState & state)2200 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2201 // TODO(riverriddle) This doesn't necessarily capture all potential cases.
2202 // Currently, region arguments can be shadowed when printing the main
2203 // operation. If the IR hasn't been printed, this will produce the old SSA
2204 // name and not the shadowed name.
2205 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2206 os);
2207 }
2208
print(raw_ostream & os,OpPrintingFlags flags)2209 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2210 // Handle top-level operations or local printing.
2211 if (!getParent() || flags.shouldUseLocalScope()) {
2212 AsmState state(this);
2213 OperationPrinter(os, flags, state.getImpl()).print(this);
2214 return;
2215 }
2216
2217 Operation *parentOp = getParentOp();
2218 if (!parentOp) {
2219 os << "<<UNLINKED OPERATION>>\n";
2220 return;
2221 }
2222 // Get the top-level op.
2223 while (auto *nextOp = parentOp->getParentOp())
2224 parentOp = nextOp;
2225
2226 AsmState state(parentOp);
2227 print(os, state, flags);
2228 }
print(raw_ostream & os,AsmState & state,OpPrintingFlags flags)2229 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2230 OperationPrinter(os, flags, state.getImpl()).print(this);
2231 }
2232
dump()2233 void Operation::dump() {
2234 print(llvm::errs(), OpPrintingFlags().useLocalScope());
2235 llvm::errs() << "\n";
2236 }
2237
print(raw_ostream & os)2238 void Block::print(raw_ostream &os) {
2239 Operation *parentOp = getParentOp();
2240 if (!parentOp) {
2241 os << "<<UNLINKED BLOCK>>\n";
2242 return;
2243 }
2244 // Get the top-level op.
2245 while (auto *nextOp = parentOp->getParentOp())
2246 parentOp = nextOp;
2247
2248 AsmState state(parentOp);
2249 print(os, state);
2250 }
print(raw_ostream & os,AsmState & state)2251 void Block::print(raw_ostream &os, AsmState &state) {
2252 OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2253 }
2254
dump()2255 void Block::dump() { print(llvm::errs()); }
2256
2257 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)2258 void Block::printAsOperand(raw_ostream &os, bool printType) {
2259 Operation *parentOp = getParentOp();
2260 if (!parentOp) {
2261 os << "<<UNLINKED BLOCK>>\n";
2262 return;
2263 }
2264 // Get the top-level op.
2265 while (auto *nextOp = parentOp->getParentOp())
2266 parentOp = nextOp;
2267
2268 AsmState state(parentOp);
2269 printAsOperand(os, state);
2270 }
printAsOperand(raw_ostream & os,AsmState & state)2271 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2272 OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2273 printer.printBlockName(this);
2274 }
2275
print(raw_ostream & os,OpPrintingFlags flags)2276 void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
2277 AsmState state(*this);
2278
2279 // Don't populate aliases when printing at local scope.
2280 if (!flags.shouldUseLocalScope())
2281 state.getImpl().initializeAliases(*this);
2282 print(os, state, flags);
2283 }
print(raw_ostream & os,AsmState & state,OpPrintingFlags flags)2284 void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2285 OperationPrinter(os, flags, state.getImpl()).print(*this);
2286 }
2287
dump()2288 void ModuleOp::dump() { print(llvm::errs()); }
2289