1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/StringSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdl;
18 
19 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // PDLDialect
23 //===----------------------------------------------------------------------===//
24 
25 void PDLDialect::initialize() {
26   addOperations<
27 #define GET_OP_LIST
28 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
29       >();
30   registerTypes();
31 }
32 
33 //===----------------------------------------------------------------------===//
34 // PDL Operations
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns true if the given operation is used by a "binding" pdl operation
38 /// within the main matcher body of a `pdl.pattern`.
39 static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) {
40   for (OpOperand &use : op->getUses()) {
41     Operation *user = use.getOwner();
42     if (user->getBlock() != matcherBlock)
43       continue;
44     if (isa<AttributeOp, OperandOp, OperandsOp, OperationOp>(user))
45       return true;
46     // Only the first operand of RewriteOp may be bound to, i.e. the root
47     // operation of the pattern.
48     if (isa<RewriteOp>(user) && use.getOperandNumber() == 0)
49       return true;
50     // A result by itself is not binding, it must also be bound.
51     if (isa<ResultOp, ResultsOp>(user) &&
52         hasBindingUseInMatcher(user, matcherBlock))
53       return true;
54   }
55   return false;
56 }
57 
58 /// Returns success if the given operation is used by a "binding" pdl operation
59 /// within the main matcher body of a `pdl.pattern`. On failure, emits an error
60 /// with the given context message.
61 static LogicalResult
62 verifyHasBindingUseInMatcher(Operation *op,
63                              StringRef bindableContextStr = "`pdl.operation`") {
64   // If the pattern is not a pattern, there is nothing to do.
65   if (!isa<PatternOp>(op->getParentOp()))
66     return success();
67   if (hasBindingUseInMatcher(op, op->getBlock()))
68     return success();
69   return op->emitOpError()
70          << "expected a bindable (i.e. " << bindableContextStr
71          << ") user when defined in the matcher body of a `pdl.pattern`";
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // pdl::ApplyNativeConstraintOp
76 //===----------------------------------------------------------------------===//
77 
78 static LogicalResult verify(ApplyNativeConstraintOp op) {
79   if (op.getNumOperands() == 0)
80     return op.emitOpError("expected at least one argument");
81   return success();
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // pdl::ApplyNativeRewriteOp
86 //===----------------------------------------------------------------------===//
87 
88 static LogicalResult verify(ApplyNativeRewriteOp op) {
89   if (op.getNumOperands() == 0 && op.getNumResults() == 0)
90     return op.emitOpError("expected at least one argument or result");
91   return success();
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // pdl::AttributeOp
96 //===----------------------------------------------------------------------===//
97 
98 static LogicalResult verify(AttributeOp op) {
99   Value attrType = op.type();
100   Optional<Attribute> attrValue = op.value();
101 
102   if (!attrValue && isa<RewriteOp>(op->getParentOp()))
103     return op.emitOpError("expected constant value when specified within a "
104                           "`pdl.rewrite`");
105   if (attrValue && attrType)
106     return op.emitOpError("expected only one of [`type`, `value`] to be set");
107   return verifyHasBindingUseInMatcher(op);
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // pdl::OperandOp
112 //===----------------------------------------------------------------------===//
113 
114 static LogicalResult verify(OperandOp op) {
115   return verifyHasBindingUseInMatcher(op);
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // pdl::OperandsOp
120 //===----------------------------------------------------------------------===//
121 
122 static LogicalResult verify(OperandsOp op) {
123   return verifyHasBindingUseInMatcher(op);
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // pdl::OperationOp
128 //===----------------------------------------------------------------------===//
129 
130 static ParseResult parseOperationOpAttributes(
131     OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
132     ArrayAttr &attrNamesAttr) {
133   Builder &builder = p.getBuilder();
134   SmallVector<Attribute, 4> attrNames;
135   if (succeeded(p.parseOptionalLBrace())) {
136     do {
137       StringAttr nameAttr;
138       OpAsmParser::OperandType operand;
139       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
140           p.parseOperand(operand))
141         return failure();
142       attrNames.push_back(nameAttr);
143       attrOperands.push_back(operand);
144     } while (succeeded(p.parseOptionalComma()));
145     if (p.parseRBrace())
146       return failure();
147   }
148   attrNamesAttr = builder.getArrayAttr(attrNames);
149   return success();
150 }
151 
152 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
153                                        OperandRange attrArgs,
154                                        ArrayAttr attrNames) {
155   if (attrNames.empty())
156     return;
157   p << " {";
158   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
159                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
160   p << '}';
161 }
162 
163 /// Verifies that the result types of this operation, defined within a
164 /// `pdl.rewrite`, can be inferred.
165 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
166                                                     OperandRange resultTypes) {
167   // Functor that returns if the given use can be used to infer a type.
168   Block *rewriterBlock = op->getBlock();
169   auto canInferTypeFromUse = [&](OpOperand &use) {
170     // If the use is within a ReplaceOp and isn't the operation being replaced
171     // (i.e. is not the first operand of the replacement), we can infer a type.
172     ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
173     if (!replOpUser || use.getOperandNumber() == 0)
174       return false;
175     // Make sure the replaced operation was defined before this one.
176     Operation *replacedOp = replOpUser.operation().getDefiningOp();
177     return replacedOp->getBlock() != rewriterBlock ||
178            replacedOp->isBeforeInBlock(op);
179   };
180 
181   // Check to see if the uses of the operation itself can be used to infer
182   // types.
183   if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
184     return success();
185 
186   // Otherwise, make sure each of the types can be inferred.
187   for (auto it : llvm::enumerate(resultTypes)) {
188     Operation *resultTypeOp = it.value().getDefiningOp();
189     assert(resultTypeOp && "expected valid result type operation");
190 
191     // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
192     // usable.
193     if (isa<ApplyNativeRewriteOp>(resultTypeOp))
194       continue;
195 
196     // If the type operation was defined in the matcher and constrains the
197     // result of an input operation, it can be used.
198     auto constrainsInputOp = [rewriterBlock](Operation *user) {
199       return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
200     };
201     if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
202       if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
203         continue;
204     } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
205       if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
206         continue;
207     }
208 
209     return op
210         .emitOpError("must have inferable or constrained result types when "
211                      "nested within `pdl.rewrite`")
212         .attachNote()
213         .append("result type #", it.index(), " was not constrained");
214   }
215   return success();
216 }
217 
218 static LogicalResult verify(OperationOp op) {
219   bool isWithinRewrite = isa<RewriteOp>(op->getParentOp());
220   if (isWithinRewrite && !op.name())
221     return op.emitOpError("must have an operation name when nested within "
222                           "a `pdl.rewrite`");
223   ArrayAttr attributeNames = op.attributeNames();
224   auto attributeValues = op.attributes();
225   if (attributeNames.size() != attributeValues.size()) {
226     return op.emitOpError()
227            << "expected the same number of attribute values and attribute "
228               "names, got "
229            << attributeNames.size() << " names and " << attributeValues.size()
230            << " values";
231   }
232 
233   // If the operation is within a rewrite body and doesn't have type inference,
234   // ensure that the result types can be resolved.
235   if (isWithinRewrite && !op.hasTypeInference()) {
236     if (failed(verifyResultTypesAreInferrable(op, op.types())))
237       return failure();
238   }
239 
240   return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`");
241 }
242 
243 bool OperationOp::hasTypeInference() {
244   Optional<StringRef> opName = name();
245   if (!opName)
246     return false;
247 
248   OperationName name(*opName, getContext());
249   if (const AbstractOperation *op = name.getAbstractOperation())
250     return op->getInterface<InferTypeOpInterface>();
251   return false;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // pdl::PatternOp
256 //===----------------------------------------------------------------------===//
257 
258 static LogicalResult verify(PatternOp pattern) {
259   Region &body = pattern.body();
260   auto *term = body.front().getTerminator();
261   if (!isa<RewriteOp>(term)) {
262     return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
263         .attachNote(term->getLoc())
264         .append("see terminator defined here");
265   }
266 
267   // Check that all values defined in the top-level pattern are referenced at
268   // least once in the source tree.
269   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
270     if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
271       pattern
272           .emitOpError("expected only `pdl` operations within the pattern body")
273           .attachNote(op->getLoc())
274           .append("see non-`pdl` operation defined here");
275       return WalkResult::interrupt();
276     }
277     return WalkResult::advance();
278   });
279   return failure(result.wasInterrupted());
280 }
281 
282 void PatternOp::build(OpBuilder &builder, OperationState &state,
283                       Optional<StringRef> rootKind, Optional<uint16_t> benefit,
284                       Optional<StringRef> name) {
285   build(builder, state,
286         rootKind ? builder.getStringAttr(*rootKind) : StringAttr(),
287         builder.getI16IntegerAttr(benefit ? *benefit : 0),
288         name ? builder.getStringAttr(*name) : StringAttr());
289   state.regions[0]->emplaceBlock();
290 }
291 
292 /// Returns the rewrite operation of this pattern.
293 RewriteOp PatternOp::getRewriter() {
294   return cast<RewriteOp>(body().front().getTerminator());
295 }
296 
297 /// Return the root operation kind that this pattern matches, or None if
298 /// there isn't a specific root.
299 Optional<StringRef> PatternOp::getRootKind() {
300   OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp());
301   return rootOp.name();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // pdl::ReplaceOp
306 //===----------------------------------------------------------------------===//
307 
308 static LogicalResult verify(ReplaceOp op) {
309   if (op.replOperation() && !op.replValues().empty())
310     return op.emitOpError() << "expected no replacement values to be provided"
311                                " when the replacement operation is present";
312   return success();
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // pdl::ResultsOp
317 //===----------------------------------------------------------------------===//
318 
319 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
320                                          Type &resultType) {
321   if (!index) {
322     resultType = RangeType::get(p.getBuilder().getType<ValueType>());
323     return success();
324   }
325   if (p.parseArrow() || p.parseType(resultType))
326     return failure();
327   return success();
328 }
329 
330 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
331                                   IntegerAttr index, Type resultType) {
332   if (index)
333     p << " -> " << resultType;
334 }
335 
336 static LogicalResult verify(ResultsOp op) {
337   if (!op.index() && op.getType().isa<pdl::ValueType>()) {
338     return op.emitOpError() << "expected `pdl.range<value>` result type when "
339                                "no index is specified, but got: "
340                             << op.getType();
341   }
342   return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // pdl::RewriteOp
347 //===----------------------------------------------------------------------===//
348 
349 static LogicalResult verify(RewriteOp op) {
350   Region &rewriteRegion = op.body();
351 
352   // Handle the case where the rewrite is external.
353   if (op.name()) {
354     if (!rewriteRegion.empty()) {
355       return op.emitOpError()
356              << "expected rewrite region to be empty when rewrite is external";
357     }
358     return success();
359   }
360 
361   // Otherwise, check that the rewrite region only contains a single block.
362   if (rewriteRegion.empty()) {
363     return op.emitOpError() << "expected rewrite region to be non-empty if "
364                                "external name is not specified";
365   }
366 
367   // Check that no additional arguments were provided.
368   if (!op.externalArgs().empty()) {
369     return op.emitOpError() << "expected no external arguments when the "
370                                "rewrite is specified inline";
371   }
372   if (op.externalConstParams()) {
373     return op.emitOpError() << "expected no external constant parameters when "
374                                "the rewrite is specified inline";
375   }
376 
377   return success();
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // pdl::TypeOp
382 //===----------------------------------------------------------------------===//
383 
384 static LogicalResult verify(TypeOp op) {
385   return verifyHasBindingUseInMatcher(
386       op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`");
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // pdl::TypesOp
391 //===----------------------------------------------------------------------===//
392 
393 static LogicalResult verify(TypesOp op) {
394   return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`");
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // TableGen'd op method definitions
399 //===----------------------------------------------------------------------===//
400 
401 #define GET_OP_CLASSES
402 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
403