1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23
24 using namespace mlir;
25 using namespace tblgen;
26
27 using llvm::formatv;
28
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32
isUnspecified() const33 bool DagLeaf::isUnspecified() const {
34 return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36
isOperandMatcher() const37 bool DagLeaf::isOperandMatcher() const {
38 // Operand matchers specify a type constraint.
39 return isSubClassOf("TypeConstraint");
40 }
41
isAttrMatcher() const42 bool DagLeaf::isAttrMatcher() const {
43 // Attribute matchers specify an attribute constraint.
44 return isSubClassOf("AttrConstraint");
45 }
46
isNativeCodeCall() const47 bool DagLeaf::isNativeCodeCall() const {
48 return isSubClassOf("NativeCodeCall");
49 }
50
isConstantAttr() const51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52
isEnumAttrCase() const53 bool DagLeaf::isEnumAttrCase() const {
54 return isSubClassOf("EnumAttrCaseInfo");
55 }
56
isStringAttr() const57 bool DagLeaf::isStringAttr() const {
58 return isa<llvm::StringInit>(def);
59 }
60
getAsConstraint() const61 Constraint DagLeaf::getAsConstraint() const {
62 assert((isOperandMatcher() || isAttrMatcher()) &&
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66
getAsConstantAttr() const67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69 return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71
getAsEnumAttrCase() const72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74 return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76
getConditionTemplate() const77 std::string DagLeaf::getConditionTemplate() const {
78 return getAsConstraint().getConditionTemplate();
79 }
80
getNativeCodeTemplate() const81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85
getNumReturnsOfNativeCode() const86 int DagLeaf::getNumReturnsOfNativeCode() const {
87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
89 }
90
getStringAttr() const91 std::string DagLeaf::getStringAttr() const {
92 assert(isStringAttr() && "the DAG leaf must be string attribute");
93 return def->getAsUnquotedString();
94 }
isSubClassOf(StringRef superclass) const95 bool DagLeaf::isSubClassOf(StringRef superclass) const {
96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97 return defInit->getDef()->isSubClassOf(superclass);
98 return false;
99 }
100
print(raw_ostream & os) const101 void DagLeaf::print(raw_ostream &os) const {
102 if (def)
103 def->print(os);
104 }
105
106 //===----------------------------------------------------------------------===//
107 // DagNode
108 //===----------------------------------------------------------------------===//
109
isNativeCodeCall() const110 bool DagNode::isNativeCodeCall() const {
111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112 return defInit->getDef()->isSubClassOf("NativeCodeCall");
113 return false;
114 }
115
isOperation() const116 bool DagNode::isOperation() const {
117 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
118 }
119
getNativeCodeTemplate() const120 llvm::StringRef DagNode::getNativeCodeTemplate() const {
121 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
122 return cast<llvm::DefInit>(node->getOperator())
123 ->getDef()
124 ->getValueAsString("expression");
125 }
126
getNumReturnsOfNativeCode() const127 int DagNode::getNumReturnsOfNativeCode() const {
128 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
129 return cast<llvm::DefInit>(node->getOperator())
130 ->getDef()
131 ->getValueAsInt("numReturns");
132 }
133
getSymbol() const134 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
135
getDialectOp(RecordOperatorMap * mapper) const136 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
137 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
138 auto it = mapper->find(opDef);
139 if (it != mapper->end())
140 return *it->second;
141 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
142 .first->second;
143 }
144
getNumOps() const145 int DagNode::getNumOps() const {
146 int count = isReplaceWithValue() ? 0 : 1;
147 for (int i = 0, e = getNumArgs(); i != e; ++i) {
148 if (auto child = getArgAsNestedDag(i))
149 count += child.getNumOps();
150 }
151 return count;
152 }
153
getNumArgs() const154 int DagNode::getNumArgs() const { return node->getNumArgs(); }
155
isNestedDagArg(unsigned index) const156 bool DagNode::isNestedDagArg(unsigned index) const {
157 return isa<llvm::DagInit>(node->getArg(index));
158 }
159
getArgAsNestedDag(unsigned index) const160 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
161 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
162 }
163
getArgAsLeaf(unsigned index) const164 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
165 assert(!isNestedDagArg(index));
166 return DagLeaf(node->getArg(index));
167 }
168
getArgName(unsigned index) const169 StringRef DagNode::getArgName(unsigned index) const {
170 return node->getArgNameStr(index);
171 }
172
isReplaceWithValue() const173 bool DagNode::isReplaceWithValue() const {
174 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
175 return dagOpDef->getName() == "replaceWithValue";
176 }
177
isLocationDirective() const178 bool DagNode::isLocationDirective() const {
179 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
180 return dagOpDef->getName() == "location";
181 }
182
print(raw_ostream & os) const183 void DagNode::print(raw_ostream &os) const {
184 if (node)
185 node->print(os);
186 }
187
188 //===----------------------------------------------------------------------===//
189 // SymbolInfoMap
190 //===----------------------------------------------------------------------===//
191
getValuePackName(StringRef symbol,int * index)192 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
193 StringRef name, indexStr;
194 int idx = -1;
195 std::tie(name, indexStr) = symbol.rsplit("__");
196
197 if (indexStr.consumeInteger(10, idx)) {
198 // The second part is not an index; we return the whole symbol as-is.
199 return symbol;
200 }
201 if (index) {
202 *index = idx;
203 }
204 return name;
205 }
206
SymbolInfo(const Operator * op,SymbolInfo::Kind kind,Optional<DagAndConstant> dagAndConstant)207 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
208 Optional<DagAndConstant> dagAndConstant)
209 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
210
getStaticValueCount() const211 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
212 switch (kind) {
213 case Kind::Attr:
214 case Kind::Operand:
215 case Kind::Value:
216 return 1;
217 case Kind::Result:
218 return op->getNumResults();
219 case Kind::MultipleValues:
220 return getSize();
221 }
222 llvm_unreachable("unknown kind");
223 }
224
getVarName(StringRef name) const225 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
226 return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
227 }
228
getVarDecl(StringRef name) const229 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
230 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
231 switch (kind) {
232 case Kind::Attr: {
233 if (op) {
234 auto type = op->getArg(getArgIndex())
235 .get<NamedAttribute *>()
236 ->attr.getStorageType();
237 return std::string(formatv("{0} {1};\n", type, name));
238 }
239 // TODO(suderman): Use a more exact type when available.
240 return std::string(formatv("Attribute {0};\n", name));
241 }
242 case Kind::Operand: {
243 // Use operand range for captured operands (to support potential variadic
244 // operands).
245 return std::string(
246 formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
247 getVarName(name)));
248 }
249 case Kind::Value: {
250 return std::string(formatv("::mlir::Value {0};\n", name));
251 }
252 case Kind::MultipleValues: {
253 // This is for the variable used in the source pattern. Each named value in
254 // source pattern will only be bound to a Value. The others in the result
255 // pattern may be associated with multiple Values as we will use `auto` to
256 // do the type inference.
257 return std::string(formatv(
258 "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
259 }
260 case Kind::Result: {
261 // Use the op itself for captured results.
262 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
263 }
264 }
265 llvm_unreachable("unknown kind");
266 }
267
getValueAndRangeUse(StringRef name,int index,const char * fmt,const char * separator) const268 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
269 StringRef name, int index, const char *fmt, const char *separator) const {
270 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
271 switch (kind) {
272 case Kind::Attr: {
273 assert(index < 0);
274 auto repl = formatv(fmt, name);
275 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
276 return std::string(repl);
277 }
278 case Kind::Operand: {
279 assert(index < 0);
280 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
281 // If this operand is variadic, then return a range. Otherwise, return the
282 // value itself.
283 if (operand->isVariableLength()) {
284 auto repl = formatv(fmt, name);
285 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
286 return std::string(repl);
287 }
288 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
289 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
290 return std::string(repl);
291 }
292 case Kind::Result: {
293 // If `index` is greater than zero, then we are referencing a specific
294 // result of a multi-result op. The result can still be variadic.
295 if (index >= 0) {
296 std::string v =
297 std::string(formatv("{0}.getODSResults({1})", name, index));
298 if (!op->getResult(index).isVariadic())
299 v = std::string(formatv("(*{0}.begin())", v));
300 auto repl = formatv(fmt, v);
301 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
302 return std::string(repl);
303 }
304
305 // If this op has no result at all but still we bind a symbol to it, it
306 // means we want to capture the op itself.
307 if (op->getNumResults() == 0) {
308 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
309 return std::string(name);
310 }
311
312 // We are referencing all results of the multi-result op. A specific result
313 // can either be a value or a range. Then join them with `separator`.
314 SmallVector<std::string, 4> values;
315 values.reserve(op->getNumResults());
316
317 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
318 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
319 if (!op->getResult(i).isVariadic()) {
320 v = std::string(formatv("(*{0}.begin())", v));
321 }
322 values.push_back(std::string(formatv(fmt, v)));
323 }
324 auto repl = llvm::join(values, separator);
325 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
326 return repl;
327 }
328 case Kind::Value: {
329 assert(index < 0);
330 assert(op == nullptr);
331 auto repl = formatv(fmt, name);
332 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
333 return std::string(repl);
334 }
335 case Kind::MultipleValues: {
336 assert(op == nullptr);
337 assert(index < getSize());
338 if (index >= 0) {
339 std::string repl =
340 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
341 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
342 return repl;
343 }
344 // If it doesn't specify certain element, unpack them all.
345 auto repl =
346 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
347 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
348 return std::string(repl);
349 }
350 }
351 llvm_unreachable("unknown kind");
352 }
353
getAllRangeUse(StringRef name,int index,const char * fmt,const char * separator) const354 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
355 StringRef name, int index, const char *fmt, const char *separator) const {
356 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
357 switch (kind) {
358 case Kind::Attr:
359 case Kind::Operand: {
360 assert(index < 0 && "only allowed for symbol bound to result");
361 auto repl = formatv(fmt, name);
362 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
363 return std::string(repl);
364 }
365 case Kind::Result: {
366 if (index >= 0) {
367 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
368 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
369 return std::string(repl);
370 }
371
372 // We are referencing all results of the multi-result op. Each result should
373 // have a value range, and then join them with `separator`.
374 SmallVector<std::string, 4> values;
375 values.reserve(op->getNumResults());
376
377 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
378 values.push_back(std::string(
379 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
380 }
381 auto repl = llvm::join(values, separator);
382 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
383 return repl;
384 }
385 case Kind::Value: {
386 assert(index < 0 && "only allowed for symbol bound to result");
387 assert(op == nullptr);
388 auto repl = formatv(fmt, formatv("{{{0}}", name));
389 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
390 return std::string(repl);
391 }
392 case Kind::MultipleValues: {
393 assert(op == nullptr);
394 assert(index < getSize());
395 if (index >= 0) {
396 std::string repl =
397 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
398 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
399 return repl;
400 }
401 auto repl =
402 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
403 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
404 return std::string(repl);
405 }
406 }
407 llvm_unreachable("unknown kind");
408 }
409
bindOpArgument(DagNode node,StringRef symbol,const Operator & op,int argIndex)410 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
411 const Operator &op, int argIndex) {
412 StringRef name = getValuePackName(symbol);
413 if (name != symbol) {
414 auto error = formatv(
415 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
416 PrintFatalError(loc, error);
417 }
418
419 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
420 ? SymbolInfo::getAttr(&op, argIndex)
421 : SymbolInfo::getOperand(node, &op, argIndex);
422
423 std::string key = symbol.str();
424 if (symbolInfoMap.count(key)) {
425 // Only non unique name for the operand is supported.
426 if (symInfo.kind != SymbolInfo::Kind::Operand) {
427 return false;
428 }
429
430 // Cannot add new operand if there is already non operand with the same
431 // name.
432 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
433 return false;
434 }
435 }
436
437 symbolInfoMap.emplace(key, symInfo);
438 return true;
439 }
440
bindOpResult(StringRef symbol,const Operator & op)441 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
442 std::string name = getValuePackName(symbol).str();
443 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
444
445 return symbolInfoMap.count(inserted->first) == 1;
446 }
447
bindValues(StringRef symbol,int numValues)448 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
449 std::string name = getValuePackName(symbol).str();
450 if (numValues > 1)
451 return bindMultipleValues(name, numValues);
452 return bindValue(name);
453 }
454
bindValue(StringRef symbol)455 bool SymbolInfoMap::bindValue(StringRef symbol) {
456 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
457 return symbolInfoMap.count(inserted->first) == 1;
458 }
459
bindMultipleValues(StringRef symbol,int numValues)460 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
461 std::string name = getValuePackName(symbol).str();
462 auto inserted =
463 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
464 return symbolInfoMap.count(inserted->first) == 1;
465 }
466
bindAttr(StringRef symbol)467 bool SymbolInfoMap::bindAttr(StringRef symbol) {
468 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
469 return symbolInfoMap.count(inserted->first) == 1;
470 }
471
contains(StringRef symbol) const472 bool SymbolInfoMap::contains(StringRef symbol) const {
473 return find(symbol) != symbolInfoMap.end();
474 }
475
find(StringRef key) const476 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
477 std::string name = getValuePackName(key).str();
478
479 return symbolInfoMap.find(name);
480 }
481
482 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,DagNode node,const Operator & op,int argIndex) const483 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
484 int argIndex) const {
485 std::string name = getValuePackName(key).str();
486 auto range = symbolInfoMap.equal_range(name);
487
488 const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
489
490 for (auto it = range.first; it != range.second; ++it)
491 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
492 return it;
493
494 return symbolInfoMap.end();
495 }
496
497 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
getRangeOfEqualElements(StringRef key)498 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
499 std::string name = getValuePackName(key).str();
500
501 return symbolInfoMap.equal_range(name);
502 }
503
count(StringRef key) const504 int SymbolInfoMap::count(StringRef key) const {
505 std::string name = getValuePackName(key).str();
506 return symbolInfoMap.count(name);
507 }
508
getStaticValueCount(StringRef symbol) const509 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
510 StringRef name = getValuePackName(symbol);
511 if (name != symbol) {
512 // If there is a trailing index inside symbol, it references just one
513 // static value.
514 return 1;
515 }
516 // Otherwise, find how many it represents by querying the symbol's info.
517 return find(name)->second.getStaticValueCount();
518 }
519
getValueAndRangeUse(StringRef symbol,const char * fmt,const char * separator) const520 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
521 const char *fmt,
522 const char *separator) const {
523 int index = -1;
524 StringRef name = getValuePackName(symbol, &index);
525
526 auto it = symbolInfoMap.find(name.str());
527 if (it == symbolInfoMap.end()) {
528 auto error = formatv("referencing unbound symbol '{0}'", symbol);
529 PrintFatalError(loc, error);
530 }
531
532 return it->second.getValueAndRangeUse(name, index, fmt, separator);
533 }
534
getAllRangeUse(StringRef symbol,const char * fmt,const char * separator) const535 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
536 const char *separator) const {
537 int index = -1;
538 StringRef name = getValuePackName(symbol, &index);
539
540 auto it = symbolInfoMap.find(name.str());
541 if (it == symbolInfoMap.end()) {
542 auto error = formatv("referencing unbound symbol '{0}'", symbol);
543 PrintFatalError(loc, error);
544 }
545
546 return it->second.getAllRangeUse(name, index, fmt, separator);
547 }
548
assignUniqueAlternativeNames()549 void SymbolInfoMap::assignUniqueAlternativeNames() {
550 llvm::StringSet<> usedNames;
551
552 for (auto symbolInfoIt = symbolInfoMap.begin();
553 symbolInfoIt != symbolInfoMap.end();) {
554 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
555 auto startRange = range.first;
556 auto endRange = range.second;
557
558 auto operandName = symbolInfoIt->first;
559 int startSearchIndex = 0;
560 for (++startRange; startRange != endRange; ++startRange) {
561 // Current operand name is not unique, find a unique one
562 // and set the alternative name.
563 for (int i = startSearchIndex;; ++i) {
564 std::string alternativeName = operandName + std::to_string(i);
565 if (!usedNames.contains(alternativeName) &&
566 symbolInfoMap.count(alternativeName) == 0) {
567 usedNames.insert(alternativeName);
568 startRange->second.alternativeName = alternativeName;
569 startSearchIndex = i + 1;
570
571 break;
572 }
573 }
574 }
575
576 symbolInfoIt = endRange;
577 }
578 }
579
580 //===----------------------------------------------------------------------===//
581 // Pattern
582 //==----------------------------------------------------------------------===//
583
Pattern(const llvm::Record * def,RecordOperatorMap * mapper)584 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
585 : def(*def), recordOpMap(mapper) {}
586
getSourcePattern() const587 DagNode Pattern::getSourcePattern() const {
588 return DagNode(def.getValueAsDag("sourcePattern"));
589 }
590
getNumResultPatterns() const591 int Pattern::getNumResultPatterns() const {
592 auto *results = def.getValueAsListInit("resultPatterns");
593 return results->size();
594 }
595
getResultPattern(unsigned index) const596 DagNode Pattern::getResultPattern(unsigned index) const {
597 auto *results = def.getValueAsListInit("resultPatterns");
598 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
599 }
600
collectSourcePatternBoundSymbols(SymbolInfoMap & infoMap)601 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
602 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
603 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
604 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
605
606 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
607 infoMap.assignUniqueAlternativeNames();
608 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
609 }
610
collectResultPatternBoundSymbols(SymbolInfoMap & infoMap)611 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
612 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
613 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
614 auto pattern = getResultPattern(i);
615 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
616 }
617 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
618 }
619
getSourceRootOp()620 const Operator &Pattern::getSourceRootOp() {
621 return getSourcePattern().getDialectOp(recordOpMap);
622 }
623
getDialectOp(DagNode node)624 Operator &Pattern::getDialectOp(DagNode node) {
625 return node.getDialectOp(recordOpMap);
626 }
627
getConstraints() const628 std::vector<AppliedConstraint> Pattern::getConstraints() const {
629 auto *listInit = def.getValueAsListInit("constraints");
630 std::vector<AppliedConstraint> ret;
631 ret.reserve(listInit->size());
632
633 for (auto it : *listInit) {
634 auto *dagInit = dyn_cast<llvm::DagInit>(it);
635 if (!dagInit)
636 PrintFatalError(&def, "all elements in Pattern multi-entity "
637 "constraints should be DAG nodes");
638
639 std::vector<std::string> entities;
640 entities.reserve(dagInit->arg_size());
641 for (auto *argName : dagInit->getArgNames()) {
642 if (!argName) {
643 PrintFatalError(
644 &def,
645 "operands to additional constraints can only be symbol references");
646 }
647 entities.push_back(std::string(argName->getValue()));
648 }
649
650 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
651 dagInit->getNameStr(), std::move(entities));
652 }
653 return ret;
654 }
655
getBenefit() const656 int Pattern::getBenefit() const {
657 // The initial benefit value is a heuristic with number of ops in the source
658 // pattern.
659 int initBenefit = getSourcePattern().getNumOps();
660 llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
661 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
662 PrintFatalError(&def,
663 "The 'addBenefit' takes and only takes one integer value");
664 }
665 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
666 }
667
getLocation() const668 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
669 std::vector<std::pair<StringRef, unsigned>> result;
670 result.reserve(def.getLoc().size());
671 for (auto loc : def.getLoc()) {
672 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
673 assert(buf && "invalid source location");
674 result.emplace_back(
675 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
676 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
677 }
678 return result;
679 }
680
verifyBind(bool result,StringRef symbolName)681 void Pattern::verifyBind(bool result, StringRef symbolName) {
682 if (!result) {
683 auto err = formatv("symbol '{0}' bound more than once", symbolName);
684 PrintFatalError(&def, err);
685 }
686 }
687
collectBoundSymbols(DagNode tree,SymbolInfoMap & infoMap,bool isSrcPattern)688 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
689 bool isSrcPattern) {
690 auto treeName = tree.getSymbol();
691 auto numTreeArgs = tree.getNumArgs();
692
693 if (tree.isNativeCodeCall()) {
694 if (!treeName.empty()) {
695 if (!isSrcPattern) {
696 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
697 << treeName << '\n');
698 verifyBind(
699 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
700 treeName);
701 } else {
702 PrintFatalError(&def,
703 formatv("binding symbol '{0}' to NativecodeCall in "
704 "MatchPattern is not supported",
705 treeName));
706 }
707 }
708
709 for (int i = 0; i != numTreeArgs; ++i) {
710 if (auto treeArg = tree.getArgAsNestedDag(i)) {
711 // This DAG node argument is a DAG node itself. Go inside recursively.
712 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
713 continue;
714 }
715
716 if (!isSrcPattern)
717 continue;
718
719 // We can only bind symbols to arguments in source pattern. Those
720 // symbols are referenced in result patterns.
721 auto treeArgName = tree.getArgName(i);
722
723 // `$_` is a special symbol meaning ignore the current argument.
724 if (!treeArgName.empty() && treeArgName != "_") {
725 DagLeaf leaf = tree.getArgAsLeaf(i);
726
727 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
728 if (leaf.isUnspecified()) {
729 // This is case of $c, a Value without any constraints.
730 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
731 } else {
732 auto constraint = leaf.getAsConstraint();
733 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
734 leaf.isConstantAttr() ||
735 constraint.getKind() == Constraint::Kind::CK_Attr;
736
737 if (isAttr) {
738 // This is case of $a, a binding to a certain attribute.
739 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
740 continue;
741 }
742
743 // This is case of $b, a binding to a certain type.
744 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
745 }
746 }
747 }
748
749 return;
750 }
751
752 if (tree.isOperation()) {
753 auto &op = getDialectOp(tree);
754 auto numOpArgs = op.getNumArgs();
755
756 // The pattern might have the last argument specifying the location.
757 bool hasLocDirective = false;
758 if (numTreeArgs != 0) {
759 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
760 hasLocDirective = lastArg.isLocationDirective();
761 }
762
763 if (numOpArgs != numTreeArgs - hasLocDirective) {
764 auto err = formatv("op '{0}' argument number mismatch: "
765 "{1} in pattern vs. {2} in definition",
766 op.getOperationName(), numTreeArgs, numOpArgs);
767 PrintFatalError(&def, err);
768 }
769
770 // The name attached to the DAG node's operator is for representing the
771 // results generated from this op. It should be remembered as bound results.
772 if (!treeName.empty()) {
773 LLVM_DEBUG(llvm::dbgs()
774 << "found symbol bound to op result: " << treeName << '\n');
775 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
776 }
777
778 for (int i = 0; i != numTreeArgs; ++i) {
779 if (auto treeArg = tree.getArgAsNestedDag(i)) {
780 // This DAG node argument is a DAG node itself. Go inside recursively.
781 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
782 continue;
783 }
784
785 if (isSrcPattern) {
786 // We can only bind symbols to op arguments in source pattern. Those
787 // symbols are referenced in result patterns.
788 auto treeArgName = tree.getArgName(i);
789 // `$_` is a special symbol meaning ignore the current argument.
790 if (!treeArgName.empty() && treeArgName != "_") {
791 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
792 << treeArgName << '\n');
793 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
794 treeArgName);
795 }
796 }
797 }
798 return;
799 }
800
801 if (!treeName.empty()) {
802 PrintFatalError(
803 &def, formatv("binding symbol '{0}' to non-operation/native code call "
804 "unsupported right now",
805 treeName));
806 }
807 return;
808 }
809