1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23
24 using namespace mlir;
25 using namespace tblgen;
26
27 using llvm::formatv;
28
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32
isUnspecified() const33 bool DagLeaf::isUnspecified() const {
34 return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36
isOperandMatcher() const37 bool DagLeaf::isOperandMatcher() const {
38 // Operand matchers specify a type constraint.
39 return isSubClassOf("TypeConstraint");
40 }
41
isAttrMatcher() const42 bool DagLeaf::isAttrMatcher() const {
43 // Attribute matchers specify an attribute constraint.
44 return isSubClassOf("AttrConstraint");
45 }
46
isNativeCodeCall() const47 bool DagLeaf::isNativeCodeCall() const {
48 return isSubClassOf("NativeCodeCall");
49 }
50
isConstantAttr() const51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52
isEnumAttrCase() const53 bool DagLeaf::isEnumAttrCase() const {
54 return isSubClassOf("EnumAttrCaseInfo");
55 }
56
isStringAttr() const57 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
58
getAsConstraint() const59 Constraint DagLeaf::getAsConstraint() const {
60 assert((isOperandMatcher() || isAttrMatcher()) &&
61 "the DAG leaf must be operand or attribute");
62 return Constraint(cast<llvm::DefInit>(def)->getDef());
63 }
64
getAsConstantAttr() const65 ConstantAttr DagLeaf::getAsConstantAttr() const {
66 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
67 return ConstantAttr(cast<llvm::DefInit>(def));
68 }
69
getAsEnumAttrCase() const70 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
71 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
72 return EnumAttrCase(cast<llvm::DefInit>(def));
73 }
74
getConditionTemplate() const75 std::string DagLeaf::getConditionTemplate() const {
76 return getAsConstraint().getConditionTemplate();
77 }
78
getNativeCodeTemplate() const79 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
80 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
81 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
82 }
83
getNumReturnsOfNativeCode() const84 int DagLeaf::getNumReturnsOfNativeCode() const {
85 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
87 }
88
getStringAttr() const89 std::string DagLeaf::getStringAttr() const {
90 assert(isStringAttr() && "the DAG leaf must be string attribute");
91 return def->getAsUnquotedString();
92 }
isSubClassOf(StringRef superclass) const93 bool DagLeaf::isSubClassOf(StringRef superclass) const {
94 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
95 return defInit->getDef()->isSubClassOf(superclass);
96 return false;
97 }
98
print(raw_ostream & os) const99 void DagLeaf::print(raw_ostream &os) const {
100 if (def)
101 def->print(os);
102 }
103
104 //===----------------------------------------------------------------------===//
105 // DagNode
106 //===----------------------------------------------------------------------===//
107
isNativeCodeCall() const108 bool DagNode::isNativeCodeCall() const {
109 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
110 return defInit->getDef()->isSubClassOf("NativeCodeCall");
111 return false;
112 }
113
isOperation() const114 bool DagNode::isOperation() const {
115 return !isNativeCodeCall() && !isReplaceWithValue() &&
116 !isLocationDirective() && !isReturnTypeDirective();
117 }
118
getNativeCodeTemplate() const119 llvm::StringRef DagNode::getNativeCodeTemplate() const {
120 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
121 return cast<llvm::DefInit>(node->getOperator())
122 ->getDef()
123 ->getValueAsString("expression");
124 }
125
getNumReturnsOfNativeCode() const126 int DagNode::getNumReturnsOfNativeCode() const {
127 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
128 return cast<llvm::DefInit>(node->getOperator())
129 ->getDef()
130 ->getValueAsInt("numReturns");
131 }
132
getSymbol() const133 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
134
getDialectOp(RecordOperatorMap * mapper) const135 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
136 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
137 auto it = mapper->find(opDef);
138 if (it != mapper->end())
139 return *it->second;
140 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
141 .first->second;
142 }
143
getNumOps() const144 int DagNode::getNumOps() const {
145 int count = isReplaceWithValue() ? 0 : 1;
146 for (int i = 0, e = getNumArgs(); i != e; ++i) {
147 if (auto child = getArgAsNestedDag(i))
148 count += child.getNumOps();
149 }
150 return count;
151 }
152
getNumArgs() const153 int DagNode::getNumArgs() const { return node->getNumArgs(); }
154
isNestedDagArg(unsigned index) const155 bool DagNode::isNestedDagArg(unsigned index) const {
156 return isa<llvm::DagInit>(node->getArg(index));
157 }
158
getArgAsNestedDag(unsigned index) const159 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
160 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
161 }
162
getArgAsLeaf(unsigned index) const163 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
164 assert(!isNestedDagArg(index));
165 return DagLeaf(node->getArg(index));
166 }
167
getArgName(unsigned index) const168 StringRef DagNode::getArgName(unsigned index) const {
169 return node->getArgNameStr(index);
170 }
171
isReplaceWithValue() const172 bool DagNode::isReplaceWithValue() const {
173 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
174 return dagOpDef->getName() == "replaceWithValue";
175 }
176
isLocationDirective() const177 bool DagNode::isLocationDirective() const {
178 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179 return dagOpDef->getName() == "location";
180 }
181
isReturnTypeDirective() const182 bool DagNode::isReturnTypeDirective() const {
183 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184 return dagOpDef->getName() == "returnType";
185 }
186
print(raw_ostream & os) const187 void DagNode::print(raw_ostream &os) const {
188 if (node)
189 node->print(os);
190 }
191
192 //===----------------------------------------------------------------------===//
193 // SymbolInfoMap
194 //===----------------------------------------------------------------------===//
195
getValuePackName(StringRef symbol,int * index)196 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
197 StringRef name, indexStr;
198 int idx = -1;
199 std::tie(name, indexStr) = symbol.rsplit("__");
200
201 if (indexStr.consumeInteger(10, idx)) {
202 // The second part is not an index; we return the whole symbol as-is.
203 return symbol;
204 }
205 if (index) {
206 *index = idx;
207 }
208 return name;
209 }
210
SymbolInfo(const Operator * op,SymbolInfo::Kind kind,Optional<DagAndConstant> dagAndConstant)211 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
212 Optional<DagAndConstant> dagAndConstant)
213 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
214
getStaticValueCount() const215 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
216 switch (kind) {
217 case Kind::Attr:
218 case Kind::Operand:
219 case Kind::Value:
220 return 1;
221 case Kind::Result:
222 return op->getNumResults();
223 case Kind::MultipleValues:
224 return getSize();
225 }
226 llvm_unreachable("unknown kind");
227 }
228
getVarName(StringRef name) const229 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
230 return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
231 }
232
getVarTypeStr(StringRef name) const233 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
234 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
235 switch (kind) {
236 case Kind::Attr: {
237 if (op)
238 return op->getArg(getArgIndex())
239 .get<NamedAttribute *>()
240 ->attr.getStorageType()
241 .str();
242 // TODO(suderman): Use a more exact type when available.
243 return "Attribute";
244 }
245 case Kind::Operand: {
246 // Use operand range for captured operands (to support potential variadic
247 // operands).
248 return "::mlir::Operation::operand_range";
249 }
250 case Kind::Value: {
251 return "::mlir::Value";
252 }
253 case Kind::MultipleValues: {
254 return "::mlir::ValueRange";
255 }
256 case Kind::Result: {
257 // Use the op itself for captured results.
258 return op->getQualCppClassName();
259 }
260 }
261 llvm_unreachable("unknown kind");
262 }
263
getVarDecl(StringRef name) const264 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
265 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
266 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
267 return std::string(
268 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
269 }
270
getArgDecl(StringRef name) const271 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
272 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
273 return std::string(
274 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
275 }
276
getValueAndRangeUse(StringRef name,int index,const char * fmt,const char * separator) const277 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
278 StringRef name, int index, const char *fmt, const char *separator) const {
279 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
280 switch (kind) {
281 case Kind::Attr: {
282 assert(index < 0);
283 auto repl = formatv(fmt, name);
284 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
285 return std::string(repl);
286 }
287 case Kind::Operand: {
288 assert(index < 0);
289 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
290 // If this operand is variadic, then return a range. Otherwise, return the
291 // value itself.
292 if (operand->isVariableLength()) {
293 auto repl = formatv(fmt, name);
294 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
295 return std::string(repl);
296 }
297 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
298 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
299 return std::string(repl);
300 }
301 case Kind::Result: {
302 // If `index` is greater than zero, then we are referencing a specific
303 // result of a multi-result op. The result can still be variadic.
304 if (index >= 0) {
305 std::string v =
306 std::string(formatv("{0}.getODSResults({1})", name, index));
307 if (!op->getResult(index).isVariadic())
308 v = std::string(formatv("(*{0}.begin())", v));
309 auto repl = formatv(fmt, v);
310 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
311 return std::string(repl);
312 }
313
314 // If this op has no result at all but still we bind a symbol to it, it
315 // means we want to capture the op itself.
316 if (op->getNumResults() == 0) {
317 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
318 return std::string(name);
319 }
320
321 // We are referencing all results of the multi-result op. A specific result
322 // can either be a value or a range. Then join them with `separator`.
323 SmallVector<std::string, 4> values;
324 values.reserve(op->getNumResults());
325
326 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
327 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
328 if (!op->getResult(i).isVariadic()) {
329 v = std::string(formatv("(*{0}.begin())", v));
330 }
331 values.push_back(std::string(formatv(fmt, v)));
332 }
333 auto repl = llvm::join(values, separator);
334 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
335 return repl;
336 }
337 case Kind::Value: {
338 assert(index < 0);
339 assert(op == nullptr);
340 auto repl = formatv(fmt, name);
341 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
342 return std::string(repl);
343 }
344 case Kind::MultipleValues: {
345 assert(op == nullptr);
346 assert(index < getSize());
347 if (index >= 0) {
348 std::string repl =
349 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
350 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
351 return repl;
352 }
353 // If it doesn't specify certain element, unpack them all.
354 auto repl =
355 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
356 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
357 return std::string(repl);
358 }
359 }
360 llvm_unreachable("unknown kind");
361 }
362
getAllRangeUse(StringRef name,int index,const char * fmt,const char * separator) const363 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
364 StringRef name, int index, const char *fmt, const char *separator) const {
365 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
366 switch (kind) {
367 case Kind::Attr:
368 case Kind::Operand: {
369 assert(index < 0 && "only allowed for symbol bound to result");
370 auto repl = formatv(fmt, name);
371 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
372 return std::string(repl);
373 }
374 case Kind::Result: {
375 if (index >= 0) {
376 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
377 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
378 return std::string(repl);
379 }
380
381 // We are referencing all results of the multi-result op. Each result should
382 // have a value range, and then join them with `separator`.
383 SmallVector<std::string, 4> values;
384 values.reserve(op->getNumResults());
385
386 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
387 values.push_back(std::string(
388 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
389 }
390 auto repl = llvm::join(values, separator);
391 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
392 return repl;
393 }
394 case Kind::Value: {
395 assert(index < 0 && "only allowed for symbol bound to result");
396 assert(op == nullptr);
397 auto repl = formatv(fmt, formatv("{{{0}}", name));
398 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
399 return std::string(repl);
400 }
401 case Kind::MultipleValues: {
402 assert(op == nullptr);
403 assert(index < getSize());
404 if (index >= 0) {
405 std::string repl =
406 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
407 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
408 return repl;
409 }
410 auto repl =
411 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
412 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
413 return std::string(repl);
414 }
415 }
416 llvm_unreachable("unknown kind");
417 }
418
bindOpArgument(DagNode node,StringRef symbol,const Operator & op,int argIndex)419 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
420 const Operator &op, int argIndex) {
421 StringRef name = getValuePackName(symbol);
422 if (name != symbol) {
423 auto error = formatv(
424 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
425 PrintFatalError(loc, error);
426 }
427
428 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
429 ? SymbolInfo::getAttr(&op, argIndex)
430 : SymbolInfo::getOperand(node, &op, argIndex);
431
432 std::string key = symbol.str();
433 if (symbolInfoMap.count(key)) {
434 // Only non unique name for the operand is supported.
435 if (symInfo.kind != SymbolInfo::Kind::Operand) {
436 return false;
437 }
438
439 // Cannot add new operand if there is already non operand with the same
440 // name.
441 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
442 return false;
443 }
444 }
445
446 symbolInfoMap.emplace(key, symInfo);
447 return true;
448 }
449
bindOpResult(StringRef symbol,const Operator & op)450 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
451 std::string name = getValuePackName(symbol).str();
452 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
453
454 return symbolInfoMap.count(inserted->first) == 1;
455 }
456
bindValues(StringRef symbol,int numValues)457 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
458 std::string name = getValuePackName(symbol).str();
459 if (numValues > 1)
460 return bindMultipleValues(name, numValues);
461 return bindValue(name);
462 }
463
bindValue(StringRef symbol)464 bool SymbolInfoMap::bindValue(StringRef symbol) {
465 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
466 return symbolInfoMap.count(inserted->first) == 1;
467 }
468
bindMultipleValues(StringRef symbol,int numValues)469 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
470 std::string name = getValuePackName(symbol).str();
471 auto inserted =
472 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
473 return symbolInfoMap.count(inserted->first) == 1;
474 }
475
bindAttr(StringRef symbol)476 bool SymbolInfoMap::bindAttr(StringRef symbol) {
477 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
478 return symbolInfoMap.count(inserted->first) == 1;
479 }
480
contains(StringRef symbol) const481 bool SymbolInfoMap::contains(StringRef symbol) const {
482 return find(symbol) != symbolInfoMap.end();
483 }
484
find(StringRef key) const485 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
486 std::string name = getValuePackName(key).str();
487
488 return symbolInfoMap.find(name);
489 }
490
491 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,DagNode node,const Operator & op,int argIndex) const492 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
493 int argIndex) const {
494 return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
495 }
496
497 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,SymbolInfo symbolInfo) const498 SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const {
499 std::string name = getValuePackName(key).str();
500 auto range = symbolInfoMap.equal_range(name);
501
502 for (auto it = range.first; it != range.second; ++it)
503 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
504 return it;
505
506 return symbolInfoMap.end();
507 }
508
509 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
getRangeOfEqualElements(StringRef key)510 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
511 std::string name = getValuePackName(key).str();
512
513 return symbolInfoMap.equal_range(name);
514 }
515
count(StringRef key) const516 int SymbolInfoMap::count(StringRef key) const {
517 std::string name = getValuePackName(key).str();
518 return symbolInfoMap.count(name);
519 }
520
getStaticValueCount(StringRef symbol) const521 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
522 StringRef name = getValuePackName(symbol);
523 if (name != symbol) {
524 // If there is a trailing index inside symbol, it references just one
525 // static value.
526 return 1;
527 }
528 // Otherwise, find how many it represents by querying the symbol's info.
529 return find(name)->second.getStaticValueCount();
530 }
531
getValueAndRangeUse(StringRef symbol,const char * fmt,const char * separator) const532 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
533 const char *fmt,
534 const char *separator) const {
535 int index = -1;
536 StringRef name = getValuePackName(symbol, &index);
537
538 auto it = symbolInfoMap.find(name.str());
539 if (it == symbolInfoMap.end()) {
540 auto error = formatv("referencing unbound symbol '{0}'", symbol);
541 PrintFatalError(loc, error);
542 }
543
544 return it->second.getValueAndRangeUse(name, index, fmt, separator);
545 }
546
getAllRangeUse(StringRef symbol,const char * fmt,const char * separator) const547 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
548 const char *separator) const {
549 int index = -1;
550 StringRef name = getValuePackName(symbol, &index);
551
552 auto it = symbolInfoMap.find(name.str());
553 if (it == symbolInfoMap.end()) {
554 auto error = formatv("referencing unbound symbol '{0}'", symbol);
555 PrintFatalError(loc, error);
556 }
557
558 return it->second.getAllRangeUse(name, index, fmt, separator);
559 }
560
assignUniqueAlternativeNames()561 void SymbolInfoMap::assignUniqueAlternativeNames() {
562 llvm::StringSet<> usedNames;
563
564 for (auto symbolInfoIt = symbolInfoMap.begin();
565 symbolInfoIt != symbolInfoMap.end();) {
566 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
567 auto startRange = range.first;
568 auto endRange = range.second;
569
570 auto operandName = symbolInfoIt->first;
571 int startSearchIndex = 0;
572 for (++startRange; startRange != endRange; ++startRange) {
573 // Current operand name is not unique, find a unique one
574 // and set the alternative name.
575 for (int i = startSearchIndex;; ++i) {
576 std::string alternativeName = operandName + std::to_string(i);
577 if (!usedNames.contains(alternativeName) &&
578 symbolInfoMap.count(alternativeName) == 0) {
579 usedNames.insert(alternativeName);
580 startRange->second.alternativeName = alternativeName;
581 startSearchIndex = i + 1;
582
583 break;
584 }
585 }
586 }
587
588 symbolInfoIt = endRange;
589 }
590 }
591
592 //===----------------------------------------------------------------------===//
593 // Pattern
594 //==----------------------------------------------------------------------===//
595
Pattern(const llvm::Record * def,RecordOperatorMap * mapper)596 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
597 : def(*def), recordOpMap(mapper) {}
598
getSourcePattern() const599 DagNode Pattern::getSourcePattern() const {
600 return DagNode(def.getValueAsDag("sourcePattern"));
601 }
602
getNumResultPatterns() const603 int Pattern::getNumResultPatterns() const {
604 auto *results = def.getValueAsListInit("resultPatterns");
605 return results->size();
606 }
607
getResultPattern(unsigned index) const608 DagNode Pattern::getResultPattern(unsigned index) const {
609 auto *results = def.getValueAsListInit("resultPatterns");
610 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
611 }
612
collectSourcePatternBoundSymbols(SymbolInfoMap & infoMap)613 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
614 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
615 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
616 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
617
618 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
619 infoMap.assignUniqueAlternativeNames();
620 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
621 }
622
collectResultPatternBoundSymbols(SymbolInfoMap & infoMap)623 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
624 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
625 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
626 auto pattern = getResultPattern(i);
627 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
628 }
629 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
630 }
631
getSourceRootOp()632 const Operator &Pattern::getSourceRootOp() {
633 return getSourcePattern().getDialectOp(recordOpMap);
634 }
635
getDialectOp(DagNode node)636 Operator &Pattern::getDialectOp(DagNode node) {
637 return node.getDialectOp(recordOpMap);
638 }
639
getConstraints() const640 std::vector<AppliedConstraint> Pattern::getConstraints() const {
641 auto *listInit = def.getValueAsListInit("constraints");
642 std::vector<AppliedConstraint> ret;
643 ret.reserve(listInit->size());
644
645 for (auto it : *listInit) {
646 auto *dagInit = dyn_cast<llvm::DagInit>(it);
647 if (!dagInit)
648 PrintFatalError(&def, "all elements in Pattern multi-entity "
649 "constraints should be DAG nodes");
650
651 std::vector<std::string> entities;
652 entities.reserve(dagInit->arg_size());
653 for (auto *argName : dagInit->getArgNames()) {
654 if (!argName) {
655 PrintFatalError(
656 &def,
657 "operands to additional constraints can only be symbol references");
658 }
659 entities.push_back(std::string(argName->getValue()));
660 }
661
662 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
663 dagInit->getNameStr(), std::move(entities));
664 }
665 return ret;
666 }
667
getBenefit() const668 int Pattern::getBenefit() const {
669 // The initial benefit value is a heuristic with number of ops in the source
670 // pattern.
671 int initBenefit = getSourcePattern().getNumOps();
672 llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
673 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
674 PrintFatalError(&def,
675 "The 'addBenefit' takes and only takes one integer value");
676 }
677 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
678 }
679
getLocation() const680 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
681 std::vector<std::pair<StringRef, unsigned>> result;
682 result.reserve(def.getLoc().size());
683 for (auto loc : def.getLoc()) {
684 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
685 assert(buf && "invalid source location");
686 result.emplace_back(
687 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
688 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
689 }
690 return result;
691 }
692
verifyBind(bool result,StringRef symbolName)693 void Pattern::verifyBind(bool result, StringRef symbolName) {
694 if (!result) {
695 auto err = formatv("symbol '{0}' bound more than once", symbolName);
696 PrintFatalError(&def, err);
697 }
698 }
699
collectBoundSymbols(DagNode tree,SymbolInfoMap & infoMap,bool isSrcPattern)700 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
701 bool isSrcPattern) {
702 auto treeName = tree.getSymbol();
703 auto numTreeArgs = tree.getNumArgs();
704
705 if (tree.isNativeCodeCall()) {
706 if (!treeName.empty()) {
707 if (!isSrcPattern) {
708 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
709 << treeName << '\n');
710 verifyBind(
711 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
712 treeName);
713 } else {
714 PrintFatalError(&def,
715 formatv("binding symbol '{0}' to NativecodeCall in "
716 "MatchPattern is not supported",
717 treeName));
718 }
719 }
720
721 for (int i = 0; i != numTreeArgs; ++i) {
722 if (auto treeArg = tree.getArgAsNestedDag(i)) {
723 // This DAG node argument is a DAG node itself. Go inside recursively.
724 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
725 continue;
726 }
727
728 if (!isSrcPattern)
729 continue;
730
731 // We can only bind symbols to arguments in source pattern. Those
732 // symbols are referenced in result patterns.
733 auto treeArgName = tree.getArgName(i);
734
735 // `$_` is a special symbol meaning ignore the current argument.
736 if (!treeArgName.empty() && treeArgName != "_") {
737 DagLeaf leaf = tree.getArgAsLeaf(i);
738
739 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
740 if (leaf.isUnspecified()) {
741 // This is case of $c, a Value without any constraints.
742 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
743 } else {
744 auto constraint = leaf.getAsConstraint();
745 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
746 leaf.isConstantAttr() ||
747 constraint.getKind() == Constraint::Kind::CK_Attr;
748
749 if (isAttr) {
750 // This is case of $a, a binding to a certain attribute.
751 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
752 continue;
753 }
754
755 // This is case of $b, a binding to a certain type.
756 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
757 }
758 }
759 }
760
761 return;
762 }
763
764 if (tree.isOperation()) {
765 auto &op = getDialectOp(tree);
766 auto numOpArgs = op.getNumArgs();
767
768 // The pattern might have trailing directives.
769 int numDirectives = 0;
770 for (int i = numTreeArgs - 1; i >= 0; --i) {
771 if (auto dagArg = tree.getArgAsNestedDag(i)) {
772 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
773 ++numDirectives;
774 else
775 break;
776 }
777 }
778
779 if (numOpArgs != numTreeArgs - numDirectives) {
780 auto err = formatv("op '{0}' argument number mismatch: "
781 "{1} in pattern vs. {2} in definition",
782 op.getOperationName(), numTreeArgs, numOpArgs);
783 PrintFatalError(&def, err);
784 }
785
786 // The name attached to the DAG node's operator is for representing the
787 // results generated from this op. It should be remembered as bound results.
788 if (!treeName.empty()) {
789 LLVM_DEBUG(llvm::dbgs()
790 << "found symbol bound to op result: " << treeName << '\n');
791 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
792 }
793
794 for (int i = 0; i != numTreeArgs; ++i) {
795 if (auto treeArg = tree.getArgAsNestedDag(i)) {
796 // This DAG node argument is a DAG node itself. Go inside recursively.
797 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
798 continue;
799 }
800
801 if (isSrcPattern) {
802 // We can only bind symbols to op arguments in source pattern. Those
803 // symbols are referenced in result patterns.
804 auto treeArgName = tree.getArgName(i);
805 // `$_` is a special symbol meaning ignore the current argument.
806 if (!treeArgName.empty() && treeArgName != "_") {
807 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
808 << treeArgName << '\n');
809 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
810 treeArgName);
811 }
812 }
813 }
814 return;
815 }
816
817 if (!treeName.empty()) {
818 PrintFatalError(
819 &def, formatv("binding symbol '{0}' to non-operation/native code call "
820 "unsupported right now",
821 treeName));
822 }
823 }
824