1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
isUnspecified() const33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
isOperandMatcher() const37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
isAttrMatcher() const42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
isNativeCodeCall() const47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
isConstantAttr() const51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
isEnumAttrCase() const53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
isStringAttr() const57 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
58 
getAsConstraint() const59 Constraint DagLeaf::getAsConstraint() const {
60   assert((isOperandMatcher() || isAttrMatcher()) &&
61          "the DAG leaf must be operand or attribute");
62   return Constraint(cast<llvm::DefInit>(def)->getDef());
63 }
64 
getAsConstantAttr() const65 ConstantAttr DagLeaf::getAsConstantAttr() const {
66   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
67   return ConstantAttr(cast<llvm::DefInit>(def));
68 }
69 
getAsEnumAttrCase() const70 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
71   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
72   return EnumAttrCase(cast<llvm::DefInit>(def));
73 }
74 
getConditionTemplate() const75 std::string DagLeaf::getConditionTemplate() const {
76   return getAsConstraint().getConditionTemplate();
77 }
78 
getNativeCodeTemplate() const79 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
80   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
81   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
82 }
83 
getNumReturnsOfNativeCode() const84 int DagLeaf::getNumReturnsOfNativeCode() const {
85   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86   return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
87 }
88 
getStringAttr() const89 std::string DagLeaf::getStringAttr() const {
90   assert(isStringAttr() && "the DAG leaf must be string attribute");
91   return def->getAsUnquotedString();
92 }
isSubClassOf(StringRef superclass) const93 bool DagLeaf::isSubClassOf(StringRef superclass) const {
94   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
95     return defInit->getDef()->isSubClassOf(superclass);
96   return false;
97 }
98 
print(raw_ostream & os) const99 void DagLeaf::print(raw_ostream &os) const {
100   if (def)
101     def->print(os);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // DagNode
106 //===----------------------------------------------------------------------===//
107 
isNativeCodeCall() const108 bool DagNode::isNativeCodeCall() const {
109   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
110     return defInit->getDef()->isSubClassOf("NativeCodeCall");
111   return false;
112 }
113 
isOperation() const114 bool DagNode::isOperation() const {
115   return !isNativeCodeCall() && !isReplaceWithValue() &&
116          !isLocationDirective() && !isReturnTypeDirective();
117 }
118 
getNativeCodeTemplate() const119 llvm::StringRef DagNode::getNativeCodeTemplate() const {
120   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
121   return cast<llvm::DefInit>(node->getOperator())
122       ->getDef()
123       ->getValueAsString("expression");
124 }
125 
getNumReturnsOfNativeCode() const126 int DagNode::getNumReturnsOfNativeCode() const {
127   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
128   return cast<llvm::DefInit>(node->getOperator())
129       ->getDef()
130       ->getValueAsInt("numReturns");
131 }
132 
getSymbol() const133 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
134 
getDialectOp(RecordOperatorMap * mapper) const135 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
136   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
137   auto it = mapper->find(opDef);
138   if (it != mapper->end())
139     return *it->second;
140   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
141               .first->second;
142 }
143 
getNumOps() const144 int DagNode::getNumOps() const {
145   int count = isReplaceWithValue() ? 0 : 1;
146   for (int i = 0, e = getNumArgs(); i != e; ++i) {
147     if (auto child = getArgAsNestedDag(i))
148       count += child.getNumOps();
149   }
150   return count;
151 }
152 
getNumArgs() const153 int DagNode::getNumArgs() const { return node->getNumArgs(); }
154 
isNestedDagArg(unsigned index) const155 bool DagNode::isNestedDagArg(unsigned index) const {
156   return isa<llvm::DagInit>(node->getArg(index));
157 }
158 
getArgAsNestedDag(unsigned index) const159 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
160   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
161 }
162 
getArgAsLeaf(unsigned index) const163 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
164   assert(!isNestedDagArg(index));
165   return DagLeaf(node->getArg(index));
166 }
167 
getArgName(unsigned index) const168 StringRef DagNode::getArgName(unsigned index) const {
169   return node->getArgNameStr(index);
170 }
171 
isReplaceWithValue() const172 bool DagNode::isReplaceWithValue() const {
173   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
174   return dagOpDef->getName() == "replaceWithValue";
175 }
176 
isLocationDirective() const177 bool DagNode::isLocationDirective() const {
178   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179   return dagOpDef->getName() == "location";
180 }
181 
isReturnTypeDirective() const182 bool DagNode::isReturnTypeDirective() const {
183   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184   return dagOpDef->getName() == "returnType";
185 }
186 
print(raw_ostream & os) const187 void DagNode::print(raw_ostream &os) const {
188   if (node)
189     node->print(os);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // SymbolInfoMap
194 //===----------------------------------------------------------------------===//
195 
getValuePackName(StringRef symbol,int * index)196 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
197   StringRef name, indexStr;
198   int idx = -1;
199   std::tie(name, indexStr) = symbol.rsplit("__");
200 
201   if (indexStr.consumeInteger(10, idx)) {
202     // The second part is not an index; we return the whole symbol as-is.
203     return symbol;
204   }
205   if (index) {
206     *index = idx;
207   }
208   return name;
209 }
210 
SymbolInfo(const Operator * op,SymbolInfo::Kind kind,Optional<DagAndConstant> dagAndConstant)211 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
212                                       Optional<DagAndConstant> dagAndConstant)
213     : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
214 
getStaticValueCount() const215 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
216   switch (kind) {
217   case Kind::Attr:
218   case Kind::Operand:
219   case Kind::Value:
220     return 1;
221   case Kind::Result:
222     return op->getNumResults();
223   case Kind::MultipleValues:
224     return getSize();
225   }
226   llvm_unreachable("unknown kind");
227 }
228 
getVarName(StringRef name) const229 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
230   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
231 }
232 
getVarTypeStr(StringRef name) const233 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
234   LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
235   switch (kind) {
236   case Kind::Attr: {
237     if (op)
238       return op->getArg(getArgIndex())
239           .get<NamedAttribute *>()
240           ->attr.getStorageType()
241           .str();
242     // TODO(suderman): Use a more exact type when available.
243     return "Attribute";
244   }
245   case Kind::Operand: {
246     // Use operand range for captured operands (to support potential variadic
247     // operands).
248     return "::mlir::Operation::operand_range";
249   }
250   case Kind::Value: {
251     return "::mlir::Value";
252   }
253   case Kind::MultipleValues: {
254     return "::mlir::ValueRange";
255   }
256   case Kind::Result: {
257     // Use the op itself for captured results.
258     return op->getQualCppClassName();
259   }
260   }
261   llvm_unreachable("unknown kind");
262 }
263 
getVarDecl(StringRef name) const264 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
265   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
266   std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
267   return std::string(
268       formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
269 }
270 
getArgDecl(StringRef name) const271 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
272   LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
273   return std::string(
274       formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
275 }
276 
getValueAndRangeUse(StringRef name,int index,const char * fmt,const char * separator) const277 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
278     StringRef name, int index, const char *fmt, const char *separator) const {
279   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
280   switch (kind) {
281   case Kind::Attr: {
282     assert(index < 0);
283     auto repl = formatv(fmt, name);
284     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
285     return std::string(repl);
286   }
287   case Kind::Operand: {
288     assert(index < 0);
289     auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
290     // If this operand is variadic, then return a range. Otherwise, return the
291     // value itself.
292     if (operand->isVariableLength()) {
293       auto repl = formatv(fmt, name);
294       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
295       return std::string(repl);
296     }
297     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
298     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
299     return std::string(repl);
300   }
301   case Kind::Result: {
302     // If `index` is greater than zero, then we are referencing a specific
303     // result of a multi-result op. The result can still be variadic.
304     if (index >= 0) {
305       std::string v =
306           std::string(formatv("{0}.getODSResults({1})", name, index));
307       if (!op->getResult(index).isVariadic())
308         v = std::string(formatv("(*{0}.begin())", v));
309       auto repl = formatv(fmt, v);
310       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
311       return std::string(repl);
312     }
313 
314     // If this op has no result at all but still we bind a symbol to it, it
315     // means we want to capture the op itself.
316     if (op->getNumResults() == 0) {
317       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
318       return std::string(name);
319     }
320 
321     // We are referencing all results of the multi-result op. A specific result
322     // can either be a value or a range. Then join them with `separator`.
323     SmallVector<std::string, 4> values;
324     values.reserve(op->getNumResults());
325 
326     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
327       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
328       if (!op->getResult(i).isVariadic()) {
329         v = std::string(formatv("(*{0}.begin())", v));
330       }
331       values.push_back(std::string(formatv(fmt, v)));
332     }
333     auto repl = llvm::join(values, separator);
334     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
335     return repl;
336   }
337   case Kind::Value: {
338     assert(index < 0);
339     assert(op == nullptr);
340     auto repl = formatv(fmt, name);
341     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
342     return std::string(repl);
343   }
344   case Kind::MultipleValues: {
345     assert(op == nullptr);
346     assert(index < getSize());
347     if (index >= 0) {
348       std::string repl =
349           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
350       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
351       return repl;
352     }
353     // If it doesn't specify certain element, unpack them all.
354     auto repl =
355         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
356     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
357     return std::string(repl);
358   }
359   }
360   llvm_unreachable("unknown kind");
361 }
362 
getAllRangeUse(StringRef name,int index,const char * fmt,const char * separator) const363 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
364     StringRef name, int index, const char *fmt, const char *separator) const {
365   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
366   switch (kind) {
367   case Kind::Attr:
368   case Kind::Operand: {
369     assert(index < 0 && "only allowed for symbol bound to result");
370     auto repl = formatv(fmt, name);
371     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
372     return std::string(repl);
373   }
374   case Kind::Result: {
375     if (index >= 0) {
376       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
377       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
378       return std::string(repl);
379     }
380 
381     // We are referencing all results of the multi-result op. Each result should
382     // have a value range, and then join them with `separator`.
383     SmallVector<std::string, 4> values;
384     values.reserve(op->getNumResults());
385 
386     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
387       values.push_back(std::string(
388           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
389     }
390     auto repl = llvm::join(values, separator);
391     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
392     return repl;
393   }
394   case Kind::Value: {
395     assert(index < 0 && "only allowed for symbol bound to result");
396     assert(op == nullptr);
397     auto repl = formatv(fmt, formatv("{{{0}}", name));
398     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
399     return std::string(repl);
400   }
401   case Kind::MultipleValues: {
402     assert(op == nullptr);
403     assert(index < getSize());
404     if (index >= 0) {
405       std::string repl =
406           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
407       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
408       return repl;
409     }
410     auto repl =
411         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
412     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
413     return std::string(repl);
414   }
415   }
416   llvm_unreachable("unknown kind");
417 }
418 
bindOpArgument(DagNode node,StringRef symbol,const Operator & op,int argIndex)419 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
420                                    const Operator &op, int argIndex) {
421   StringRef name = getValuePackName(symbol);
422   if (name != symbol) {
423     auto error = formatv(
424         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
425     PrintFatalError(loc, error);
426   }
427 
428   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
429                      ? SymbolInfo::getAttr(&op, argIndex)
430                      : SymbolInfo::getOperand(node, &op, argIndex);
431 
432   std::string key = symbol.str();
433   if (symbolInfoMap.count(key)) {
434     // Only non unique name for the operand is supported.
435     if (symInfo.kind != SymbolInfo::Kind::Operand) {
436       return false;
437     }
438 
439     // Cannot add new operand if there is already non operand with the same
440     // name.
441     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
442       return false;
443     }
444   }
445 
446   symbolInfoMap.emplace(key, symInfo);
447   return true;
448 }
449 
bindOpResult(StringRef symbol,const Operator & op)450 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
451   std::string name = getValuePackName(symbol).str();
452   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
453 
454   return symbolInfoMap.count(inserted->first) == 1;
455 }
456 
bindValues(StringRef symbol,int numValues)457 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
458   std::string name = getValuePackName(symbol).str();
459   if (numValues > 1)
460     return bindMultipleValues(name, numValues);
461   return bindValue(name);
462 }
463 
bindValue(StringRef symbol)464 bool SymbolInfoMap::bindValue(StringRef symbol) {
465   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
466   return symbolInfoMap.count(inserted->first) == 1;
467 }
468 
bindMultipleValues(StringRef symbol,int numValues)469 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
470   std::string name = getValuePackName(symbol).str();
471   auto inserted =
472       symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
473   return symbolInfoMap.count(inserted->first) == 1;
474 }
475 
bindAttr(StringRef symbol)476 bool SymbolInfoMap::bindAttr(StringRef symbol) {
477   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
478   return symbolInfoMap.count(inserted->first) == 1;
479 }
480 
contains(StringRef symbol) const481 bool SymbolInfoMap::contains(StringRef symbol) const {
482   return find(symbol) != symbolInfoMap.end();
483 }
484 
find(StringRef key) const485 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
486   std::string name = getValuePackName(key).str();
487 
488   return symbolInfoMap.find(name);
489 }
490 
491 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,DagNode node,const Operator & op,int argIndex) const492 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
493                                int argIndex) const {
494   return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
495 }
496 
497 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,SymbolInfo symbolInfo) const498 SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const {
499   std::string name = getValuePackName(key).str();
500   auto range = symbolInfoMap.equal_range(name);
501 
502   for (auto it = range.first; it != range.second; ++it)
503     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
504       return it;
505 
506   return symbolInfoMap.end();
507 }
508 
509 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
getRangeOfEqualElements(StringRef key)510 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
511   std::string name = getValuePackName(key).str();
512 
513   return symbolInfoMap.equal_range(name);
514 }
515 
count(StringRef key) const516 int SymbolInfoMap::count(StringRef key) const {
517   std::string name = getValuePackName(key).str();
518   return symbolInfoMap.count(name);
519 }
520 
getStaticValueCount(StringRef symbol) const521 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
522   StringRef name = getValuePackName(symbol);
523   if (name != symbol) {
524     // If there is a trailing index inside symbol, it references just one
525     // static value.
526     return 1;
527   }
528   // Otherwise, find how many it represents by querying the symbol's info.
529   return find(name)->second.getStaticValueCount();
530 }
531 
getValueAndRangeUse(StringRef symbol,const char * fmt,const char * separator) const532 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
533                                                const char *fmt,
534                                                const char *separator) const {
535   int index = -1;
536   StringRef name = getValuePackName(symbol, &index);
537 
538   auto it = symbolInfoMap.find(name.str());
539   if (it == symbolInfoMap.end()) {
540     auto error = formatv("referencing unbound symbol '{0}'", symbol);
541     PrintFatalError(loc, error);
542   }
543 
544   return it->second.getValueAndRangeUse(name, index, fmt, separator);
545 }
546 
getAllRangeUse(StringRef symbol,const char * fmt,const char * separator) const547 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
548                                           const char *separator) const {
549   int index = -1;
550   StringRef name = getValuePackName(symbol, &index);
551 
552   auto it = symbolInfoMap.find(name.str());
553   if (it == symbolInfoMap.end()) {
554     auto error = formatv("referencing unbound symbol '{0}'", symbol);
555     PrintFatalError(loc, error);
556   }
557 
558   return it->second.getAllRangeUse(name, index, fmt, separator);
559 }
560 
assignUniqueAlternativeNames()561 void SymbolInfoMap::assignUniqueAlternativeNames() {
562   llvm::StringSet<> usedNames;
563 
564   for (auto symbolInfoIt = symbolInfoMap.begin();
565        symbolInfoIt != symbolInfoMap.end();) {
566     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
567     auto startRange = range.first;
568     auto endRange = range.second;
569 
570     auto operandName = symbolInfoIt->first;
571     int startSearchIndex = 0;
572     for (++startRange; startRange != endRange; ++startRange) {
573       // Current operand name is not unique, find a unique one
574       // and set the alternative name.
575       for (int i = startSearchIndex;; ++i) {
576         std::string alternativeName = operandName + std::to_string(i);
577         if (!usedNames.contains(alternativeName) &&
578             symbolInfoMap.count(alternativeName) == 0) {
579           usedNames.insert(alternativeName);
580           startRange->second.alternativeName = alternativeName;
581           startSearchIndex = i + 1;
582 
583           break;
584         }
585       }
586     }
587 
588     symbolInfoIt = endRange;
589   }
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // Pattern
594 //==----------------------------------------------------------------------===//
595 
Pattern(const llvm::Record * def,RecordOperatorMap * mapper)596 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
597     : def(*def), recordOpMap(mapper) {}
598 
getSourcePattern() const599 DagNode Pattern::getSourcePattern() const {
600   return DagNode(def.getValueAsDag("sourcePattern"));
601 }
602 
getNumResultPatterns() const603 int Pattern::getNumResultPatterns() const {
604   auto *results = def.getValueAsListInit("resultPatterns");
605   return results->size();
606 }
607 
getResultPattern(unsigned index) const608 DagNode Pattern::getResultPattern(unsigned index) const {
609   auto *results = def.getValueAsListInit("resultPatterns");
610   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
611 }
612 
collectSourcePatternBoundSymbols(SymbolInfoMap & infoMap)613 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
614   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
615   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
616   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
617 
618   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
619   infoMap.assignUniqueAlternativeNames();
620   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
621 }
622 
collectResultPatternBoundSymbols(SymbolInfoMap & infoMap)623 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
624   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
625   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
626     auto pattern = getResultPattern(i);
627     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
628   }
629   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
630 }
631 
getSourceRootOp()632 const Operator &Pattern::getSourceRootOp() {
633   return getSourcePattern().getDialectOp(recordOpMap);
634 }
635 
getDialectOp(DagNode node)636 Operator &Pattern::getDialectOp(DagNode node) {
637   return node.getDialectOp(recordOpMap);
638 }
639 
getConstraints() const640 std::vector<AppliedConstraint> Pattern::getConstraints() const {
641   auto *listInit = def.getValueAsListInit("constraints");
642   std::vector<AppliedConstraint> ret;
643   ret.reserve(listInit->size());
644 
645   for (auto it : *listInit) {
646     auto *dagInit = dyn_cast<llvm::DagInit>(it);
647     if (!dagInit)
648       PrintFatalError(&def, "all elements in Pattern multi-entity "
649                             "constraints should be DAG nodes");
650 
651     std::vector<std::string> entities;
652     entities.reserve(dagInit->arg_size());
653     for (auto *argName : dagInit->getArgNames()) {
654       if (!argName) {
655         PrintFatalError(
656             &def,
657             "operands to additional constraints can only be symbol references");
658       }
659       entities.push_back(std::string(argName->getValue()));
660     }
661 
662     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
663                      dagInit->getNameStr(), std::move(entities));
664   }
665   return ret;
666 }
667 
getBenefit() const668 int Pattern::getBenefit() const {
669   // The initial benefit value is a heuristic with number of ops in the source
670   // pattern.
671   int initBenefit = getSourcePattern().getNumOps();
672   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
673   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
674     PrintFatalError(&def,
675                     "The 'addBenefit' takes and only takes one integer value");
676   }
677   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
678 }
679 
getLocation() const680 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
681   std::vector<std::pair<StringRef, unsigned>> result;
682   result.reserve(def.getLoc().size());
683   for (auto loc : def.getLoc()) {
684     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
685     assert(buf && "invalid source location");
686     result.emplace_back(
687         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
688         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
689   }
690   return result;
691 }
692 
verifyBind(bool result,StringRef symbolName)693 void Pattern::verifyBind(bool result, StringRef symbolName) {
694   if (!result) {
695     auto err = formatv("symbol '{0}' bound more than once", symbolName);
696     PrintFatalError(&def, err);
697   }
698 }
699 
collectBoundSymbols(DagNode tree,SymbolInfoMap & infoMap,bool isSrcPattern)700 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
701                                   bool isSrcPattern) {
702   auto treeName = tree.getSymbol();
703   auto numTreeArgs = tree.getNumArgs();
704 
705   if (tree.isNativeCodeCall()) {
706     if (!treeName.empty()) {
707       if (!isSrcPattern) {
708         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
709                                 << treeName << '\n');
710         verifyBind(
711             infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
712             treeName);
713       } else {
714         PrintFatalError(&def,
715                         formatv("binding symbol '{0}' to NativecodeCall in "
716                                 "MatchPattern is not supported",
717                                 treeName));
718       }
719     }
720 
721     for (int i = 0; i != numTreeArgs; ++i) {
722       if (auto treeArg = tree.getArgAsNestedDag(i)) {
723         // This DAG node argument is a DAG node itself. Go inside recursively.
724         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
725         continue;
726       }
727 
728       if (!isSrcPattern)
729         continue;
730 
731       // We can only bind symbols to arguments in source pattern. Those
732       // symbols are referenced in result patterns.
733       auto treeArgName = tree.getArgName(i);
734 
735       // `$_` is a special symbol meaning ignore the current argument.
736       if (!treeArgName.empty() && treeArgName != "_") {
737         DagLeaf leaf = tree.getArgAsLeaf(i);
738 
739         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
740         if (leaf.isUnspecified()) {
741           // This is case of $c, a Value without any constraints.
742           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
743         } else {
744           auto constraint = leaf.getAsConstraint();
745           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
746                         leaf.isConstantAttr() ||
747                         constraint.getKind() == Constraint::Kind::CK_Attr;
748 
749           if (isAttr) {
750             // This is case of $a, a binding to a certain attribute.
751             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
752             continue;
753           }
754 
755           // This is case of $b, a binding to a certain type.
756           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
757         }
758       }
759     }
760 
761     return;
762   }
763 
764   if (tree.isOperation()) {
765     auto &op = getDialectOp(tree);
766     auto numOpArgs = op.getNumArgs();
767 
768     // The pattern might have trailing directives.
769     int numDirectives = 0;
770     for (int i = numTreeArgs - 1; i >= 0; --i) {
771       if (auto dagArg = tree.getArgAsNestedDag(i)) {
772         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
773           ++numDirectives;
774         else
775           break;
776       }
777     }
778 
779     if (numOpArgs != numTreeArgs - numDirectives) {
780       auto err = formatv("op '{0}' argument number mismatch: "
781                          "{1} in pattern vs. {2} in definition",
782                          op.getOperationName(), numTreeArgs, numOpArgs);
783       PrintFatalError(&def, err);
784     }
785 
786     // The name attached to the DAG node's operator is for representing the
787     // results generated from this op. It should be remembered as bound results.
788     if (!treeName.empty()) {
789       LLVM_DEBUG(llvm::dbgs()
790                  << "found symbol bound to op result: " << treeName << '\n');
791       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
792     }
793 
794     for (int i = 0; i != numTreeArgs; ++i) {
795       if (auto treeArg = tree.getArgAsNestedDag(i)) {
796         // This DAG node argument is a DAG node itself. Go inside recursively.
797         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
798         continue;
799       }
800 
801       if (isSrcPattern) {
802         // We can only bind symbols to op arguments in source pattern. Those
803         // symbols are referenced in result patterns.
804         auto treeArgName = tree.getArgName(i);
805         // `$_` is a special symbol meaning ignore the current argument.
806         if (!treeArgName.empty() && treeArgName != "_") {
807           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
808                                   << treeArgName << '\n');
809           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
810                      treeArgName);
811         }
812       }
813     }
814     return;
815   }
816 
817   if (!treeName.empty()) {
818     PrintFatalError(
819         &def, formatv("binding symbol '{0}' to non-operation/native code call "
820                       "unsupported right now",
821                       treeName));
822   }
823 }
824