1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 
18 using namespace mlir;
19 using namespace mlir::pdl;
20 
21 //===----------------------------------------------------------------------===//
22 // PDLDialect
23 //===----------------------------------------------------------------------===//
24 
initialize()25 void PDLDialect::initialize() {
26   addOperations<
27 #define GET_OP_LIST
28 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
29       >();
30   addTypes<
31 #define GET_TYPEDEF_LIST
32 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
33       >();
34 }
35 
36 /// Returns true if the given operation is used by a "binding" pdl operation
37 /// within the main matcher body of a `pdl.pattern`.
38 static LogicalResult
verifyHasBindingUseInMatcher(Operation * op,StringRef bindableContextStr="`pdl.operation`")39 verifyHasBindingUseInMatcher(Operation *op,
40                              StringRef bindableContextStr = "`pdl.operation`") {
41   // If the pattern is not a pattern, there is nothing to do.
42   if (!isa<PatternOp>(op->getParentOp()))
43     return success();
44   Block *matcherBlock = op->getBlock();
45   for (Operation *user : op->getUsers()) {
46     if (user->getBlock() != matcherBlock)
47       continue;
48     if (isa<AttributeOp, InputOp, OperationOp, RewriteOp>(user))
49       return success();
50   }
51   return op->emitOpError()
52          << "expected a bindable (i.e. " << bindableContextStr
53          << ") user when defined in the matcher body of a `pdl.pattern`";
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // pdl::ApplyConstraintOp
58 //===----------------------------------------------------------------------===//
59 
verify(ApplyConstraintOp op)60 static LogicalResult verify(ApplyConstraintOp op) {
61   if (op.getNumOperands() == 0)
62     return op.emitOpError("expected at least one argument");
63   return success();
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // pdl::AttributeOp
68 //===----------------------------------------------------------------------===//
69 
verify(AttributeOp op)70 static LogicalResult verify(AttributeOp op) {
71   Value attrType = op.type();
72   Optional<Attribute> attrValue = op.value();
73 
74   if (!attrValue && isa<RewriteOp>(op->getParentOp()))
75     return op.emitOpError("expected constant value when specified within a "
76                           "`pdl.rewrite`");
77   if (attrValue && attrType)
78     return op.emitOpError("expected only one of [`type`, `value`] to be set");
79   return verifyHasBindingUseInMatcher(op);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // pdl::InputOp
84 //===----------------------------------------------------------------------===//
85 
verify(InputOp op)86 static LogicalResult verify(InputOp op) {
87   return verifyHasBindingUseInMatcher(op);
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // pdl::OperationOp
92 //===----------------------------------------------------------------------===//
93 
parseOperationOp(OpAsmParser & p,OperationState & state)94 static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
95   Builder &builder = p.getBuilder();
96 
97   // Parse the optional operation name.
98   bool startsWithOperands = succeeded(p.parseOptionalLParen());
99   bool startsWithAttributes =
100       !startsWithOperands && succeeded(p.parseOptionalLBrace());
101   bool startsWithOpName = false;
102   if (!startsWithAttributes && !startsWithOperands) {
103     StringAttr opName;
104     OptionalParseResult opNameResult =
105         p.parseOptionalAttribute(opName, "name", state.attributes);
106     startsWithOpName = opNameResult.hasValue();
107     if (startsWithOpName && failed(*opNameResult))
108       return failure();
109   }
110 
111   // Parse the operands.
112   SmallVector<OpAsmParser::OperandType, 4> operands;
113   if (startsWithOperands ||
114       (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
115     if (p.parseOperandList(operands) || p.parseRParen() ||
116         p.resolveOperands(operands, builder.getType<ValueType>(),
117                           state.operands))
118       return failure();
119   }
120 
121   // Parse the attributes.
122   SmallVector<Attribute, 4> attrNames;
123   if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
124     SmallVector<OpAsmParser::OperandType, 4> attrOps;
125     do {
126       StringAttr nameAttr;
127       OpAsmParser::OperandType operand;
128       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
129           p.parseOperand(operand))
130         return failure();
131       attrNames.push_back(nameAttr);
132       attrOps.push_back(operand);
133     } while (succeeded(p.parseOptionalComma()));
134 
135     if (p.parseRBrace() ||
136         p.resolveOperands(attrOps, builder.getType<AttributeType>(),
137                           state.operands))
138       return failure();
139   }
140   state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
141   state.addTypes(builder.getType<OperationType>());
142 
143   // Parse the result types.
144   SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
145   if (succeeded(p.parseOptionalArrow())) {
146     if (p.parseOperandList(opResultTypes) ||
147         p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
148                           state.operands))
149       return failure();
150     state.types.append(opResultTypes.size(), builder.getType<ValueType>());
151   }
152 
153   if (p.parseOptionalAttrDict(state.attributes))
154     return failure();
155 
156   int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
157                                    static_cast<int32_t>(attrNames.size()),
158                                    static_cast<int32_t>(opResultTypes.size())};
159   state.addAttribute("operand_segment_sizes",
160                      builder.getI32VectorAttr(operandSegmentSizes));
161   return success();
162 }
163 
print(OpAsmPrinter & p,OperationOp op)164 static void print(OpAsmPrinter &p, OperationOp op) {
165   p << "pdl.operation ";
166   if (Optional<StringRef> name = op.name())
167     p << '"' << *name << '"';
168 
169   auto operandValues = op.operands();
170   if (!operandValues.empty())
171     p << '(' << operandValues << ')';
172 
173   // Emit the optional attributes.
174   ArrayAttr attrNames = op.attributeNames();
175   if (!attrNames.empty()) {
176     Operation::operand_range attrArgs = op.attributes();
177     p << " {";
178     interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
179                     [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
180     p << '}';
181   }
182 
183   // Print the result type constraints of the operation.
184   if (!op.results().empty())
185     p << " -> " << op.types();
186   p.printOptionalAttrDict(op.getAttrs(),
187                           {"attributeNames", "name", "operand_segment_sizes"});
188 }
189 
190 /// Verifies that the result types of this operation, defined within a
191 /// `pdl.rewrite`, can be inferred.
verifyResultTypesAreInferrable(OperationOp op,ResultRange opResults,OperandRange resultTypes)192 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
193                                                     ResultRange opResults,
194                                                     OperandRange resultTypes) {
195   // Functor that returns if the given use can be used to infer a type.
196   Block *rewriterBlock = op->getBlock();
197   auto canInferTypeFromUse = [&](OpOperand &use) {
198     // If the use is within a ReplaceOp and isn't the operation being replaced
199     // (i.e. is not the first operand of the replacement), we can infer a type.
200     ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
201     if (!replOpUser || use.getOperandNumber() == 0)
202       return false;
203     // Make sure the replaced operation was defined before this one.
204     Operation *replacedOp = replOpUser.operation().getDefiningOp();
205     return replacedOp->getBlock() != rewriterBlock ||
206            replacedOp->isBeforeInBlock(op);
207   };
208 
209   // Check to see if the uses of the operation itself can be used to infer
210   // types.
211   if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
212     return success();
213 
214   // Otherwise, make sure each of the types can be inferred.
215   for (int i : llvm::seq<int>(0, opResults.size())) {
216     Operation *resultTypeOp = resultTypes[i].getDefiningOp();
217     assert(resultTypeOp && "expected valid result type operation");
218 
219     // If the op was defined by a `create_native`, it is guaranteed to be
220     // usable.
221     if (isa<CreateNativeOp>(resultTypeOp))
222       continue;
223 
224     // If the type is already constrained, there is nothing to do.
225     TypeOp typeOp = cast<TypeOp>(resultTypeOp);
226     if (typeOp.type())
227       continue;
228 
229     // If the type operation was defined in the matcher and constrains the
230     // result of an input operation, it can be used.
231     auto constrainsInputOp = [rewriterBlock](Operation *user) {
232       return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
233     };
234     if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
235       continue;
236 
237     // Otherwise, check to see if any uses of the result can infer the type.
238     if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
239       continue;
240     return op
241         .emitOpError("must have inferable or constrained result types when "
242                      "nested within `pdl.rewrite`")
243         .attachNote()
244         .append("result type #", i, " was not constrained");
245   }
246   return success();
247 }
248 
verify(OperationOp op)249 static LogicalResult verify(OperationOp op) {
250   bool isWithinRewrite = isa<RewriteOp>(op->getParentOp());
251   if (isWithinRewrite && !op.name())
252     return op.emitOpError("must have an operation name when nested within "
253                           "a `pdl.rewrite`");
254   ArrayAttr attributeNames = op.attributeNames();
255   auto attributeValues = op.attributes();
256   if (attributeNames.size() != attributeValues.size()) {
257     return op.emitOpError()
258            << "expected the same number of attribute values and attribute "
259               "names, got "
260            << attributeNames.size() << " names and " << attributeValues.size()
261            << " values";
262   }
263 
264   OperandRange resultTypes = op.types();
265   auto opResults = op.results();
266   if (resultTypes.size() != opResults.size()) {
267     return op.emitOpError() << "expected the same number of result values and "
268                                "result type constraints, got "
269                             << opResults.size() << " results and "
270                             << resultTypes.size() << " constraints";
271   }
272 
273   // If the operation is within a rewrite body and doesn't have type inference,
274   // ensure that the result types can be resolved.
275   if (isWithinRewrite && !op.hasTypeInference()) {
276     if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
277       return failure();
278   }
279 
280   return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`");
281 }
282 
hasTypeInference()283 bool OperationOp::hasTypeInference() {
284   Optional<StringRef> opName = name();
285   if (!opName)
286     return false;
287 
288   OperationName name(*opName, getContext());
289   if (const AbstractOperation *op = name.getAbstractOperation())
290     return op->getInterface<InferTypeOpInterface>();
291   return false;
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // pdl::PatternOp
296 //===----------------------------------------------------------------------===//
297 
verify(PatternOp pattern)298 static LogicalResult verify(PatternOp pattern) {
299   Region &body = pattern.body();
300   auto *term = body.front().getTerminator();
301   if (!isa<RewriteOp>(term)) {
302     return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
303         .attachNote(term->getLoc())
304         .append("see terminator defined here");
305   }
306 
307   // Check that all values defined in the top-level pattern are referenced at
308   // least once in the source tree.
309   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
310     if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
311       pattern
312           .emitOpError("expected only `pdl` operations within the pattern body")
313           .attachNote(op->getLoc())
314           .append("see non-`pdl` operation defined here");
315       return WalkResult::interrupt();
316     }
317     return WalkResult::advance();
318   });
319   return failure(result.wasInterrupted());
320 }
321 
build(OpBuilder & builder,OperationState & state,Optional<StringRef> rootKind,Optional<uint16_t> benefit,Optional<StringRef> name)322 void PatternOp::build(OpBuilder &builder, OperationState &state,
323                       Optional<StringRef> rootKind, Optional<uint16_t> benefit,
324                       Optional<StringRef> name) {
325   build(builder, state,
326         rootKind ? builder.getStringAttr(*rootKind) : StringAttr(),
327         builder.getI16IntegerAttr(benefit ? *benefit : 0),
328         name ? builder.getStringAttr(*name) : StringAttr());
329   builder.createBlock(state.addRegion());
330 }
331 
332 /// Returns the rewrite operation of this pattern.
getRewriter()333 RewriteOp PatternOp::getRewriter() {
334   return cast<RewriteOp>(body().front().getTerminator());
335 }
336 
337 /// Return the root operation kind that this pattern matches, or None if
338 /// there isn't a specific root.
getRootKind()339 Optional<StringRef> PatternOp::getRootKind() {
340   OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp());
341   return rootOp.name();
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // pdl::ReplaceOp
346 //===----------------------------------------------------------------------===//
347 
verify(ReplaceOp op)348 static LogicalResult verify(ReplaceOp op) {
349   auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
350   auto sourceOpResults = sourceOp.results();
351   auto replValues = op.replValues();
352 
353   if (Value replOpVal = op.replOperation()) {
354     auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
355     auto replOpResults = replOp.results();
356     if (sourceOpResults.size() != replOpResults.size()) {
357       return op.emitOpError()
358              << "expected source operation to have the same number of results "
359                 "as the replacement operation, replacement operation provided "
360              << replOpResults.size() << " but expected "
361              << sourceOpResults.size();
362     }
363 
364     if (!replValues.empty()) {
365       return op.emitOpError() << "expected no replacement values to be provided"
366                                  " when the replacement operation is present";
367     }
368 
369     return success();
370   }
371 
372   if (sourceOpResults.size() != replValues.size()) {
373     return op.emitOpError()
374            << "expected source operation to have the same number of results "
375               "as the provided replacement values, found "
376            << replValues.size() << " replacement values but expected "
377            << sourceOpResults.size();
378   }
379 
380   return success();
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // pdl::RewriteOp
385 //===----------------------------------------------------------------------===//
386 
verify(RewriteOp op)387 static LogicalResult verify(RewriteOp op) {
388   Region &rewriteRegion = op.body();
389 
390   // Handle the case where the rewrite is external.
391   if (op.name()) {
392     if (!rewriteRegion.empty()) {
393       return op.emitOpError()
394              << "expected rewrite region to be empty when rewrite is external";
395     }
396     return success();
397   }
398 
399   // Otherwise, check that the rewrite region only contains a single block.
400   if (rewriteRegion.empty()) {
401     return op.emitOpError() << "expected rewrite region to be non-empty if "
402                                "external name is not specified";
403   }
404 
405   // Check that no additional arguments were provided.
406   if (!op.externalArgs().empty()) {
407     return op.emitOpError() << "expected no external arguments when the "
408                                "rewrite is specified inline";
409   }
410   if (op.externalConstParams()) {
411     return op.emitOpError() << "expected no external constant parameters when "
412                                "the rewrite is specified inline";
413   }
414 
415   return success();
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // pdl::TypeOp
420 //===----------------------------------------------------------------------===//
421 
verify(TypeOp op)422 static LogicalResult verify(TypeOp op) {
423   return verifyHasBindingUseInMatcher(
424       op, "`pdl.attribute`, `pdl.input`, or `pdl.operation`");
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // TableGen'd op method definitions
429 //===----------------------------------------------------------------------===//
430 
431 #define GET_OP_CLASSES
432 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
433 
434 //===----------------------------------------------------------------------===//
435 // TableGen'd type method definitions
436 //===----------------------------------------------------------------------===//
437 
438 #define GET_TYPEDEF_CLASSES
439 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
440 
parseType(DialectAsmParser & parser) const441 Type PDLDialect::parseType(DialectAsmParser &parser) const {
442   StringRef keyword;
443   if (parser.parseKeyword(&keyword))
444     return Type();
445   if (Type type = generatedTypeParser(getContext(), parser, keyword))
446     return type;
447 
448   parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `")
449       << keyword << "'";
450   return Type();
451 }
452 
printType(Type type,DialectAsmPrinter & printer) const453 void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const {
454   if (failed(generatedTypePrinter(type, printer)))
455     llvm_unreachable("unknown 'pdl' type");
456 }
457