1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
isUnspecified() const33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
isOperandMatcher() const37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
isAttrMatcher() const42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
isNativeCodeCall() const47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
isConstantAttr() const51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
isEnumAttrCase() const53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
isStringAttr() const57 bool DagLeaf::isStringAttr() const {
58   return isa<llvm::StringInit>(def);
59 }
60 
getAsConstraint() const61 Constraint DagLeaf::getAsConstraint() const {
62   assert((isOperandMatcher() || isAttrMatcher()) &&
63          "the DAG leaf must be operand or attribute");
64   return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
getAsConstantAttr() const67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69   return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
getAsEnumAttrCase() const72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74   return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
getConditionTemplate() const77 std::string DagLeaf::getConditionTemplate() const {
78   return getAsConstraint().getConditionTemplate();
79 }
80 
getNativeCodeTemplate() const81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
getStringAttr() const86 std::string DagLeaf::getStringAttr() const {
87   assert(isStringAttr() && "the DAG leaf must be string attribute");
88   return def->getAsUnquotedString();
89 }
isSubClassOf(StringRef superclass) const90 bool DagLeaf::isSubClassOf(StringRef superclass) const {
91   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
92     return defInit->getDef()->isSubClassOf(superclass);
93   return false;
94 }
95 
print(raw_ostream & os) const96 void DagLeaf::print(raw_ostream &os) const {
97   if (def)
98     def->print(os);
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // DagNode
103 //===----------------------------------------------------------------------===//
104 
isNativeCodeCall() const105 bool DagNode::isNativeCodeCall() const {
106   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
107     return defInit->getDef()->isSubClassOf("NativeCodeCall");
108   return false;
109 }
110 
isOperation() const111 bool DagNode::isOperation() const {
112   return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
113 }
114 
getNativeCodeTemplate() const115 llvm::StringRef DagNode::getNativeCodeTemplate() const {
116   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
117   return cast<llvm::DefInit>(node->getOperator())
118       ->getDef()
119       ->getValueAsString("expression");
120 }
121 
getSymbol() const122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
123 
getDialectOp(RecordOperatorMap * mapper) const124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
125   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
126   auto it = mapper->find(opDef);
127   if (it != mapper->end())
128     return *it->second;
129   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
130               .first->second;
131 }
132 
getNumOps() const133 int DagNode::getNumOps() const {
134   int count = isReplaceWithValue() ? 0 : 1;
135   for (int i = 0, e = getNumArgs(); i != e; ++i) {
136     if (auto child = getArgAsNestedDag(i))
137       count += child.getNumOps();
138   }
139   return count;
140 }
141 
getNumArgs() const142 int DagNode::getNumArgs() const { return node->getNumArgs(); }
143 
isNestedDagArg(unsigned index) const144 bool DagNode::isNestedDagArg(unsigned index) const {
145   return isa<llvm::DagInit>(node->getArg(index));
146 }
147 
getArgAsNestedDag(unsigned index) const148 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
149   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
150 }
151 
getArgAsLeaf(unsigned index) const152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
153   assert(!isNestedDagArg(index));
154   return DagLeaf(node->getArg(index));
155 }
156 
getArgName(unsigned index) const157 StringRef DagNode::getArgName(unsigned index) const {
158   return node->getArgNameStr(index);
159 }
160 
isReplaceWithValue() const161 bool DagNode::isReplaceWithValue() const {
162   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
163   return dagOpDef->getName() == "replaceWithValue";
164 }
165 
isLocationDirective() const166 bool DagNode::isLocationDirective() const {
167   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
168   return dagOpDef->getName() == "location";
169 }
170 
print(raw_ostream & os) const171 void DagNode::print(raw_ostream &os) const {
172   if (node)
173     node->print(os);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // SymbolInfoMap
178 //===----------------------------------------------------------------------===//
179 
getValuePackName(StringRef symbol,int * index)180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
181   StringRef name, indexStr;
182   int idx = -1;
183   std::tie(name, indexStr) = symbol.rsplit("__");
184 
185   if (indexStr.consumeInteger(10, idx)) {
186     // The second part is not an index; we return the whole symbol as-is.
187     return symbol;
188   }
189   if (index) {
190     *index = idx;
191   }
192   return name;
193 }
194 
SymbolInfo(const Operator * op,SymbolInfo::Kind kind,Optional<int> index)195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
196                                       Optional<int> index)
197     : op(op), kind(kind), argIndex(index) {}
198 
getStaticValueCount() const199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
200   switch (kind) {
201   case Kind::Attr:
202   case Kind::Operand:
203   case Kind::Value:
204     return 1;
205   case Kind::Result:
206     return op->getNumResults();
207   }
208   llvm_unreachable("unknown kind");
209 }
210 
getVarName(StringRef name) const211 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
212   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
213 }
214 
getVarDecl(StringRef name) const215 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
216   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
217   switch (kind) {
218   case Kind::Attr: {
219     if (op) {
220       auto type =
221           op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
222       return std::string(formatv("{0} {1};\n", type, name));
223     }
224     // TODO(suderman): Use a more exact type when available.
225     return std::string(formatv("Attribute {0};\n", name));
226   }
227   case Kind::Operand: {
228     // Use operand range for captured operands (to support potential variadic
229     // operands).
230     return std::string(
231         formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
232                 getVarName(name)));
233   }
234   case Kind::Value: {
235     return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
236   }
237   case Kind::Result: {
238     // Use the op itself for captured results.
239     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
240   }
241   }
242   llvm_unreachable("unknown kind");
243 }
244 
getValueAndRangeUse(StringRef name,int index,const char * fmt,const char * separator) const245 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
246     StringRef name, int index, const char *fmt, const char *separator) const {
247   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
248   switch (kind) {
249   case Kind::Attr: {
250     assert(index < 0);
251     auto repl = formatv(fmt, name);
252     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
253     return std::string(repl);
254   }
255   case Kind::Operand: {
256     assert(index < 0);
257     auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
258     // If this operand is variadic, then return a range. Otherwise, return the
259     // value itself.
260     if (operand->isVariableLength()) {
261       auto repl = formatv(fmt, name);
262       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
263       return std::string(repl);
264     }
265     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
266     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
267     return std::string(repl);
268   }
269   case Kind::Result: {
270     // If `index` is greater than zero, then we are referencing a specific
271     // result of a multi-result op. The result can still be variadic.
272     if (index >= 0) {
273       std::string v =
274           std::string(formatv("{0}.getODSResults({1})", name, index));
275       if (!op->getResult(index).isVariadic())
276         v = std::string(formatv("(*{0}.begin())", v));
277       auto repl = formatv(fmt, v);
278       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
279       return std::string(repl);
280     }
281 
282     // If this op has no result at all but still we bind a symbol to it, it
283     // means we want to capture the op itself.
284     if (op->getNumResults() == 0) {
285       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
286       return std::string(name);
287     }
288 
289     // We are referencing all results of the multi-result op. A specific result
290     // can either be a value or a range. Then join them with `separator`.
291     SmallVector<std::string, 4> values;
292     values.reserve(op->getNumResults());
293 
294     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
295       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
296       if (!op->getResult(i).isVariadic()) {
297         v = std::string(formatv("(*{0}.begin())", v));
298       }
299       values.push_back(std::string(formatv(fmt, v)));
300     }
301     auto repl = llvm::join(values, separator);
302     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
303     return repl;
304   }
305   case Kind::Value: {
306     assert(index < 0);
307     assert(op == nullptr);
308     auto repl = formatv(fmt, name);
309     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
310     return std::string(repl);
311   }
312   }
313   llvm_unreachable("unknown kind");
314 }
315 
getAllRangeUse(StringRef name,int index,const char * fmt,const char * separator) const316 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
317     StringRef name, int index, const char *fmt, const char *separator) const {
318   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
319   switch (kind) {
320   case Kind::Attr:
321   case Kind::Operand: {
322     assert(index < 0 && "only allowed for symbol bound to result");
323     auto repl = formatv(fmt, name);
324     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
325     return std::string(repl);
326   }
327   case Kind::Result: {
328     if (index >= 0) {
329       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
330       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
331       return std::string(repl);
332     }
333 
334     // We are referencing all results of the multi-result op. Each result should
335     // have a value range, and then join them with `separator`.
336     SmallVector<std::string, 4> values;
337     values.reserve(op->getNumResults());
338 
339     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
340       values.push_back(std::string(
341           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
342     }
343     auto repl = llvm::join(values, separator);
344     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
345     return repl;
346   }
347   case Kind::Value: {
348     assert(index < 0 && "only allowed for symbol bound to result");
349     assert(op == nullptr);
350     auto repl = formatv(fmt, formatv("{{{0}}", name));
351     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
352     return std::string(repl);
353   }
354   }
355   llvm_unreachable("unknown kind");
356 }
357 
bindOpArgument(StringRef symbol,const Operator & op,int argIndex)358 bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
359                                    int argIndex) {
360   StringRef name = getValuePackName(symbol);
361   if (name != symbol) {
362     auto error = formatv(
363         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
364     PrintFatalError(loc, error);
365   }
366 
367   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
368                      ? SymbolInfo::getAttr(&op, argIndex)
369                      : SymbolInfo::getOperand(&op, argIndex);
370 
371   std::string key = symbol.str();
372   if (symbolInfoMap.count(key)) {
373     // Only non unique name for the operand is supported.
374     if (symInfo.kind != SymbolInfo::Kind::Operand) {
375       return false;
376     }
377 
378     // Cannot add new operand if there is already non operand with the same
379     // name.
380     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
381       return false;
382     }
383   }
384 
385   symbolInfoMap.emplace(key, symInfo);
386   return true;
387 }
388 
bindOpResult(StringRef symbol,const Operator & op)389 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
390   std::string name = getValuePackName(symbol).str();
391   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
392 
393   return symbolInfoMap.count(inserted->first) == 1;
394 }
395 
bindValue(StringRef symbol)396 bool SymbolInfoMap::bindValue(StringRef symbol) {
397   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
398   return symbolInfoMap.count(inserted->first) == 1;
399 }
400 
bindAttr(StringRef symbol)401 bool SymbolInfoMap::bindAttr(StringRef symbol) {
402   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
403   return symbolInfoMap.count(inserted->first) == 1;
404 }
405 
contains(StringRef symbol) const406 bool SymbolInfoMap::contains(StringRef symbol) const {
407   return find(symbol) != symbolInfoMap.end();
408 }
409 
find(StringRef key) const410 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
411   std::string name = getValuePackName(key).str();
412 
413   return symbolInfoMap.find(name);
414 }
415 
416 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,const Operator & op,int argIndex) const417 SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
418                                int argIndex) const {
419   std::string name = getValuePackName(key).str();
420   auto range = symbolInfoMap.equal_range(name);
421 
422   for (auto it = range.first; it != range.second; ++it) {
423     if (it->second.op == &op && it->second.argIndex == argIndex) {
424       return it;
425     }
426   }
427 
428   return symbolInfoMap.end();
429 }
430 
431 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
getRangeOfEqualElements(StringRef key)432 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
433   std::string name = getValuePackName(key).str();
434 
435   return symbolInfoMap.equal_range(name);
436 }
437 
count(StringRef key) const438 int SymbolInfoMap::count(StringRef key) const {
439   std::string name = getValuePackName(key).str();
440   return symbolInfoMap.count(name);
441 }
442 
getStaticValueCount(StringRef symbol) const443 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
444   StringRef name = getValuePackName(symbol);
445   if (name != symbol) {
446     // If there is a trailing index inside symbol, it references just one
447     // static value.
448     return 1;
449   }
450   // Otherwise, find how many it represents by querying the symbol's info.
451   return find(name)->second.getStaticValueCount();
452 }
453 
getValueAndRangeUse(StringRef symbol,const char * fmt,const char * separator) const454 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
455                                                const char *fmt,
456                                                const char *separator) const {
457   int index = -1;
458   StringRef name = getValuePackName(symbol, &index);
459 
460   auto it = symbolInfoMap.find(name.str());
461   if (it == symbolInfoMap.end()) {
462     auto error = formatv("referencing unbound symbol '{0}'", symbol);
463     PrintFatalError(loc, error);
464   }
465 
466   return it->second.getValueAndRangeUse(name, index, fmt, separator);
467 }
468 
getAllRangeUse(StringRef symbol,const char * fmt,const char * separator) const469 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
470                                           const char *separator) const {
471   int index = -1;
472   StringRef name = getValuePackName(symbol, &index);
473 
474   auto it = symbolInfoMap.find(name.str());
475   if (it == symbolInfoMap.end()) {
476     auto error = formatv("referencing unbound symbol '{0}'", symbol);
477     PrintFatalError(loc, error);
478   }
479 
480   return it->second.getAllRangeUse(name, index, fmt, separator);
481 }
482 
assignUniqueAlternativeNames()483 void SymbolInfoMap::assignUniqueAlternativeNames() {
484   llvm::StringSet<> usedNames;
485 
486   for (auto symbolInfoIt = symbolInfoMap.begin();
487        symbolInfoIt != symbolInfoMap.end();) {
488     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
489     auto startRange = range.first;
490     auto endRange = range.second;
491 
492     auto operandName = symbolInfoIt->first;
493     int startSearchIndex = 0;
494     for (++startRange; startRange != endRange; ++startRange) {
495       // Current operand name is not unique, find a unique one
496       // and set the alternative name.
497       for (int i = startSearchIndex;; ++i) {
498         std::string alternativeName = operandName + std::to_string(i);
499         if (!usedNames.contains(alternativeName) &&
500             symbolInfoMap.count(alternativeName) == 0) {
501           usedNames.insert(alternativeName);
502           startRange->second.alternativeName = alternativeName;
503           startSearchIndex = i + 1;
504 
505           break;
506         }
507       }
508     }
509 
510     symbolInfoIt = endRange;
511   }
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // Pattern
516 //==----------------------------------------------------------------------===//
517 
Pattern(const llvm::Record * def,RecordOperatorMap * mapper)518 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
519     : def(*def), recordOpMap(mapper) {}
520 
getSourcePattern() const521 DagNode Pattern::getSourcePattern() const {
522   return DagNode(def.getValueAsDag("sourcePattern"));
523 }
524 
getNumResultPatterns() const525 int Pattern::getNumResultPatterns() const {
526   auto *results = def.getValueAsListInit("resultPatterns");
527   return results->size();
528 }
529 
getResultPattern(unsigned index) const530 DagNode Pattern::getResultPattern(unsigned index) const {
531   auto *results = def.getValueAsListInit("resultPatterns");
532   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
533 }
534 
collectSourcePatternBoundSymbols(SymbolInfoMap & infoMap)535 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
536   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
537   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
538   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
539 
540   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
541   infoMap.assignUniqueAlternativeNames();
542   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
543 }
544 
collectResultPatternBoundSymbols(SymbolInfoMap & infoMap)545 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
546   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
547   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
548     auto pattern = getResultPattern(i);
549     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
550   }
551   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
552 }
553 
getSourceRootOp()554 const Operator &Pattern::getSourceRootOp() {
555   return getSourcePattern().getDialectOp(recordOpMap);
556 }
557 
getDialectOp(DagNode node)558 Operator &Pattern::getDialectOp(DagNode node) {
559   return node.getDialectOp(recordOpMap);
560 }
561 
getConstraints() const562 std::vector<AppliedConstraint> Pattern::getConstraints() const {
563   auto *listInit = def.getValueAsListInit("constraints");
564   std::vector<AppliedConstraint> ret;
565   ret.reserve(listInit->size());
566 
567   for (auto it : *listInit) {
568     auto *dagInit = dyn_cast<llvm::DagInit>(it);
569     if (!dagInit)
570       PrintFatalError(&def, "all elements in Pattern multi-entity "
571                             "constraints should be DAG nodes");
572 
573     std::vector<std::string> entities;
574     entities.reserve(dagInit->arg_size());
575     for (auto *argName : dagInit->getArgNames()) {
576       if (!argName) {
577         PrintFatalError(
578             &def,
579             "operands to additional constraints can only be symbol references");
580       }
581       entities.push_back(std::string(argName->getValue()));
582     }
583 
584     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
585                      dagInit->getNameStr(), std::move(entities));
586   }
587   return ret;
588 }
589 
getBenefit() const590 int Pattern::getBenefit() const {
591   // The initial benefit value is a heuristic with number of ops in the source
592   // pattern.
593   int initBenefit = getSourcePattern().getNumOps();
594   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
595   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
596     PrintFatalError(&def,
597                     "The 'addBenefit' takes and only takes one integer value");
598   }
599   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
600 }
601 
getLocation() const602 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
603   std::vector<std::pair<StringRef, unsigned>> result;
604   result.reserve(def.getLoc().size());
605   for (auto loc : def.getLoc()) {
606     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
607     assert(buf && "invalid source location");
608     result.emplace_back(
609         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
610         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
611   }
612   return result;
613 }
614 
verifyBind(bool result,StringRef symbolName)615 void Pattern::verifyBind(bool result, StringRef symbolName) {
616   if (!result) {
617     auto err = formatv("symbol '{0}' bound more than once", symbolName);
618     PrintFatalError(&def, err);
619   }
620 }
621 
collectBoundSymbols(DagNode tree,SymbolInfoMap & infoMap,bool isSrcPattern)622 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
623                                   bool isSrcPattern) {
624   auto treeName = tree.getSymbol();
625   auto numTreeArgs = tree.getNumArgs();
626 
627   if (tree.isNativeCodeCall()) {
628     if (!treeName.empty()) {
629       PrintFatalError(
630           &def,
631           formatv(
632               "binding symbol '{0}' to native code call unsupported right now",
633               treeName));
634     }
635 
636     for (int i = 0; i != numTreeArgs; ++i) {
637       if (auto treeArg = tree.getArgAsNestedDag(i)) {
638         // This DAG node argument is a DAG node itself. Go inside recursively.
639         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
640         continue;
641       }
642 
643       if (!isSrcPattern)
644         continue;
645 
646       // We can only bind symbols to arguments in source pattern. Those
647       // symbols are referenced in result patterns.
648       auto treeArgName = tree.getArgName(i);
649 
650       // `$_` is a special symbol meaning ignore the current argument.
651       if (!treeArgName.empty() && treeArgName != "_") {
652         if (tree.isNestedDagArg(i)) {
653           auto err = formatv("cannot bind '{0}' for nested native call arg",
654                              treeArgName);
655           PrintFatalError(&def, err);
656         }
657 
658         DagLeaf leaf = tree.getArgAsLeaf(i);
659         auto constraint = leaf.getAsConstraint();
660         bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
661                       leaf.isConstantAttr() ||
662                       constraint.getKind() == Constraint::Kind::CK_Attr;
663 
664         if (isAttr) {
665           verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
666           continue;
667         }
668 
669         verifyBind(infoMap.bindValue(treeArgName), treeArgName);
670       }
671     }
672 
673     return;
674   }
675 
676   if (tree.isOperation()) {
677     auto &op = getDialectOp(tree);
678     auto numOpArgs = op.getNumArgs();
679 
680     // The pattern might have the last argument specifying the location.
681     bool hasLocDirective = false;
682     if (numTreeArgs != 0) {
683       if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
684         hasLocDirective = lastArg.isLocationDirective();
685     }
686 
687     if (numOpArgs != numTreeArgs - hasLocDirective) {
688       auto err = formatv("op '{0}' argument number mismatch: "
689                          "{1} in pattern vs. {2} in definition",
690                          op.getOperationName(), numTreeArgs, numOpArgs);
691       PrintFatalError(&def, err);
692     }
693 
694     // The name attached to the DAG node's operator is for representing the
695     // results generated from this op. It should be remembered as bound results.
696     if (!treeName.empty()) {
697       LLVM_DEBUG(llvm::dbgs()
698                  << "found symbol bound to op result: " << treeName << '\n');
699       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
700     }
701 
702     for (int i = 0; i != numTreeArgs; ++i) {
703       if (auto treeArg = tree.getArgAsNestedDag(i)) {
704         // This DAG node argument is a DAG node itself. Go inside recursively.
705         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
706         continue;
707       }
708 
709       if (isSrcPattern) {
710         // We can only bind symbols to op arguments in source pattern. Those
711         // symbols are referenced in result patterns.
712         auto treeArgName = tree.getArgName(i);
713         // `$_` is a special symbol meaning ignore the current argument.
714         if (!treeArgName.empty() && treeArgName != "_") {
715           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
716                                   << treeArgName << '\n');
717           verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
718         }
719       }
720     }
721     return;
722   }
723 
724   if (!treeName.empty()) {
725     PrintFatalError(
726         &def, formatv("binding symbol '{0}' to non-operation/native code call "
727                       "unsupported right now",
728                       treeName));
729   }
730   return;
731 }
732