1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDL/IR/PDL.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/StringSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::pdl; 18 19 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // PDLDialect 23 //===----------------------------------------------------------------------===// 24 25 void PDLDialect::initialize() { 26 addOperations< 27 #define GET_OP_LIST 28 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 29 >(); 30 registerTypes(); 31 } 32 33 //===----------------------------------------------------------------------===// 34 // PDL Operations 35 //===----------------------------------------------------------------------===// 36 37 /// Returns true if the given operation is used by a "binding" pdl operation 38 /// within the main matcher body of a `pdl.pattern`. 39 static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) { 40 for (OpOperand &use : op->getUses()) { 41 Operation *user = use.getOwner(); 42 if (user->getBlock() != matcherBlock) 43 continue; 44 if (isa<AttributeOp, OperandOp, OperandsOp, OperationOp>(user)) 45 return true; 46 // Only the first operand of RewriteOp may be bound to, i.e. the root 47 // operation of the pattern. 48 if (isa<RewriteOp>(user) && use.getOperandNumber() == 0) 49 return true; 50 // A result by itself is not binding, it must also be bound. 51 if (isa<ResultOp, ResultsOp>(user) && 52 hasBindingUseInMatcher(user, matcherBlock)) 53 return true; 54 } 55 return false; 56 } 57 58 /// Returns success if the given operation is used by a "binding" pdl operation 59 /// within the main matcher body of a `pdl.pattern`. On failure, emits an error 60 /// with the given context message. 61 static LogicalResult 62 verifyHasBindingUseInMatcher(Operation *op, 63 StringRef bindableContextStr = "`pdl.operation`") { 64 // If the pattern is not a pattern, there is nothing to do. 65 if (!isa<PatternOp>(op->getParentOp())) 66 return success(); 67 if (hasBindingUseInMatcher(op, op->getBlock())) 68 return success(); 69 return op->emitOpError() 70 << "expected a bindable (i.e. " << bindableContextStr 71 << ") user when defined in the matcher body of a `pdl.pattern`"; 72 } 73 74 //===----------------------------------------------------------------------===// 75 // pdl::ApplyNativeConstraintOp 76 //===----------------------------------------------------------------------===// 77 78 static LogicalResult verify(ApplyNativeConstraintOp op) { 79 if (op.getNumOperands() == 0) 80 return op.emitOpError("expected at least one argument"); 81 return success(); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // pdl::ApplyNativeRewriteOp 86 //===----------------------------------------------------------------------===// 87 88 static LogicalResult verify(ApplyNativeRewriteOp op) { 89 if (op.getNumOperands() == 0 && op.getNumResults() == 0) 90 return op.emitOpError("expected at least one argument or result"); 91 return success(); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // pdl::AttributeOp 96 //===----------------------------------------------------------------------===// 97 98 static LogicalResult verify(AttributeOp op) { 99 Value attrType = op.type(); 100 Optional<Attribute> attrValue = op.value(); 101 102 if (!attrValue && isa<RewriteOp>(op->getParentOp())) 103 return op.emitOpError("expected constant value when specified within a " 104 "`pdl.rewrite`"); 105 if (attrValue && attrType) 106 return op.emitOpError("expected only one of [`type`, `value`] to be set"); 107 return verifyHasBindingUseInMatcher(op); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // pdl::OperandOp 112 //===----------------------------------------------------------------------===// 113 114 static LogicalResult verify(OperandOp op) { 115 return verifyHasBindingUseInMatcher(op); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // pdl::OperandsOp 120 //===----------------------------------------------------------------------===// 121 122 static LogicalResult verify(OperandsOp op) { 123 return verifyHasBindingUseInMatcher(op); 124 } 125 126 //===----------------------------------------------------------------------===// 127 // pdl::OperationOp 128 //===----------------------------------------------------------------------===// 129 130 static ParseResult parseOperationOpAttributes( 131 OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands, 132 ArrayAttr &attrNamesAttr) { 133 Builder &builder = p.getBuilder(); 134 SmallVector<Attribute, 4> attrNames; 135 if (succeeded(p.parseOptionalLBrace())) { 136 do { 137 StringAttr nameAttr; 138 OpAsmParser::OperandType operand; 139 if (p.parseAttribute(nameAttr) || p.parseEqual() || 140 p.parseOperand(operand)) 141 return failure(); 142 attrNames.push_back(nameAttr); 143 attrOperands.push_back(operand); 144 } while (succeeded(p.parseOptionalComma())); 145 if (p.parseRBrace()) 146 return failure(); 147 } 148 attrNamesAttr = builder.getArrayAttr(attrNames); 149 return success(); 150 } 151 152 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, 153 OperandRange attrArgs, 154 ArrayAttr attrNames) { 155 if (attrNames.empty()) 156 return; 157 p << " {"; 158 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 159 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 160 p << '}'; 161 } 162 163 /// Verifies that the result types of this operation, defined within a 164 /// `pdl.rewrite`, can be inferred. 165 static LogicalResult verifyResultTypesAreInferrable(OperationOp op, 166 OperandRange resultTypes) { 167 // Functor that returns if the given use can be used to infer a type. 168 Block *rewriterBlock = op->getBlock(); 169 auto canInferTypeFromUse = [&](OpOperand &use) { 170 // If the use is within a ReplaceOp and isn't the operation being replaced 171 // (i.e. is not the first operand of the replacement), we can infer a type. 172 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner()); 173 if (!replOpUser || use.getOperandNumber() == 0) 174 return false; 175 // Make sure the replaced operation was defined before this one. 176 Operation *replacedOp = replOpUser.operation().getDefiningOp(); 177 return replacedOp->getBlock() != rewriterBlock || 178 replacedOp->isBeforeInBlock(op); 179 }; 180 181 // Check to see if the uses of the operation itself can be used to infer 182 // types. 183 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) 184 return success(); 185 186 // Otherwise, make sure each of the types can be inferred. 187 for (auto it : llvm::enumerate(resultTypes)) { 188 Operation *resultTypeOp = it.value().getDefiningOp(); 189 assert(resultTypeOp && "expected valid result type operation"); 190 191 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be 192 // usable. 193 if (isa<ApplyNativeRewriteOp>(resultTypeOp)) 194 continue; 195 196 // If the type operation was defined in the matcher and constrains the 197 // result of an input operation, it can be used. 198 auto constrainsInputOp = [rewriterBlock](Operation *user) { 199 return user->getBlock() != rewriterBlock && isa<OperationOp>(user); 200 }; 201 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) { 202 if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp)) 203 continue; 204 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) { 205 if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp)) 206 continue; 207 } 208 209 return op 210 .emitOpError("must have inferable or constrained result types when " 211 "nested within `pdl.rewrite`") 212 .attachNote() 213 .append("result type #", it.index(), " was not constrained"); 214 } 215 return success(); 216 } 217 218 static LogicalResult verify(OperationOp op) { 219 bool isWithinRewrite = isa<RewriteOp>(op->getParentOp()); 220 if (isWithinRewrite && !op.name()) 221 return op.emitOpError("must have an operation name when nested within " 222 "a `pdl.rewrite`"); 223 ArrayAttr attributeNames = op.attributeNames(); 224 auto attributeValues = op.attributes(); 225 if (attributeNames.size() != attributeValues.size()) { 226 return op.emitOpError() 227 << "expected the same number of attribute values and attribute " 228 "names, got " 229 << attributeNames.size() << " names and " << attributeValues.size() 230 << " values"; 231 } 232 233 // If the operation is within a rewrite body and doesn't have type inference, 234 // ensure that the result types can be resolved. 235 if (isWithinRewrite && !op.hasTypeInference()) { 236 if (failed(verifyResultTypesAreInferrable(op, op.types()))) 237 return failure(); 238 } 239 240 return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`"); 241 } 242 243 bool OperationOp::hasTypeInference() { 244 Optional<StringRef> opName = name(); 245 if (!opName) 246 return false; 247 248 OperationName name(*opName, getContext()); 249 if (const AbstractOperation *op = name.getAbstractOperation()) 250 return op->getInterface<InferTypeOpInterface>(); 251 return false; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // pdl::PatternOp 256 //===----------------------------------------------------------------------===// 257 258 static LogicalResult verify(PatternOp pattern) { 259 Region &body = pattern.body(); 260 auto *term = body.front().getTerminator(); 261 if (!isa<RewriteOp>(term)) { 262 return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") 263 .attachNote(term->getLoc()) 264 .append("see terminator defined here"); 265 } 266 267 // Check that all values defined in the top-level pattern are referenced at 268 // least once in the source tree. 269 WalkResult result = body.walk([&](Operation *op) -> WalkResult { 270 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) { 271 pattern 272 .emitOpError("expected only `pdl` operations within the pattern body") 273 .attachNote(op->getLoc()) 274 .append("see non-`pdl` operation defined here"); 275 return WalkResult::interrupt(); 276 } 277 return WalkResult::advance(); 278 }); 279 return failure(result.wasInterrupted()); 280 } 281 282 void PatternOp::build(OpBuilder &builder, OperationState &state, 283 Optional<StringRef> rootKind, Optional<uint16_t> benefit, 284 Optional<StringRef> name) { 285 build(builder, state, 286 rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), 287 builder.getI16IntegerAttr(benefit ? *benefit : 0), 288 name ? builder.getStringAttr(*name) : StringAttr()); 289 state.regions[0]->emplaceBlock(); 290 } 291 292 /// Returns the rewrite operation of this pattern. 293 RewriteOp PatternOp::getRewriter() { 294 return cast<RewriteOp>(body().front().getTerminator()); 295 } 296 297 /// Return the root operation kind that this pattern matches, or None if 298 /// there isn't a specific root. 299 Optional<StringRef> PatternOp::getRootKind() { 300 OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp()); 301 return rootOp.name(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // pdl::ReplaceOp 306 //===----------------------------------------------------------------------===// 307 308 static LogicalResult verify(ReplaceOp op) { 309 if (op.replOperation() && !op.replValues().empty()) 310 return op.emitOpError() << "expected no replacement values to be provided" 311 " when the replacement operation is present"; 312 return success(); 313 } 314 315 //===----------------------------------------------------------------------===// 316 // pdl::ResultsOp 317 //===----------------------------------------------------------------------===// 318 319 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, 320 Type &resultType) { 321 if (!index) { 322 resultType = RangeType::get(p.getBuilder().getType<ValueType>()); 323 return success(); 324 } 325 if (p.parseArrow() || p.parseType(resultType)) 326 return failure(); 327 return success(); 328 } 329 330 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, 331 IntegerAttr index, Type resultType) { 332 if (index) 333 p << " -> " << resultType; 334 } 335 336 static LogicalResult verify(ResultsOp op) { 337 if (!op.index() && op.getType().isa<pdl::ValueType>()) { 338 return op.emitOpError() << "expected `pdl.range<value>` result type when " 339 "no index is specified, but got: " 340 << op.getType(); 341 } 342 return success(); 343 } 344 345 //===----------------------------------------------------------------------===// 346 // pdl::RewriteOp 347 //===----------------------------------------------------------------------===// 348 349 static LogicalResult verify(RewriteOp op) { 350 Region &rewriteRegion = op.body(); 351 352 // Handle the case where the rewrite is external. 353 if (op.name()) { 354 if (!rewriteRegion.empty()) { 355 return op.emitOpError() 356 << "expected rewrite region to be empty when rewrite is external"; 357 } 358 return success(); 359 } 360 361 // Otherwise, check that the rewrite region only contains a single block. 362 if (rewriteRegion.empty()) { 363 return op.emitOpError() << "expected rewrite region to be non-empty if " 364 "external name is not specified"; 365 } 366 367 // Check that no additional arguments were provided. 368 if (!op.externalArgs().empty()) { 369 return op.emitOpError() << "expected no external arguments when the " 370 "rewrite is specified inline"; 371 } 372 if (op.externalConstParams()) { 373 return op.emitOpError() << "expected no external constant parameters when " 374 "the rewrite is specified inline"; 375 } 376 377 return success(); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // pdl::TypeOp 382 //===----------------------------------------------------------------------===// 383 384 static LogicalResult verify(TypeOp op) { 385 return verifyHasBindingUseInMatcher( 386 op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`"); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // pdl::TypesOp 391 //===----------------------------------------------------------------------===// 392 393 static LogicalResult verify(TypesOp op) { 394 return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`"); 395 } 396 397 //===----------------------------------------------------------------------===// 398 // TableGen'd op method definitions 399 //===----------------------------------------------------------------------===// 400 401 #define GET_OP_CLASSES 402 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 403