1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17
18 using namespace mlir;
19 using namespace mlir::test;
20
21 // Native function for testing NativeCodeCall
chooseOperand(Value input1,Value input2,BoolAttr choice)22 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
23 return choice.getValue() ? input1 : input2;
24 }
25
createOpI(PatternRewriter & rewriter,Location loc,Value input)26 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
27 rewriter.create<OpI>(loc, input);
28 }
29
handleNoResultOp(PatternRewriter & rewriter,OpSymbolBindingNoResult op)30 static void handleNoResultOp(PatternRewriter &rewriter,
31 OpSymbolBindingNoResult op) {
32 // Turn the no result op to a one-result op.
33 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
34 op.operand());
35 }
36
37 // Test that natives calls are only called once during rewrites.
38 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
39 // This let us check the number of times OpM_Test was called by inspecting
40 // the returned value in the MLIR output.
41 static int64_t opMIncreasingValue = 314159265;
OpMTest(PatternRewriter & rewriter,Value val)42 static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
43 int64_t i = opMIncreasingValue++;
44 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
45 }
46
47 namespace {
48 #include "TestPatterns.inc"
49 } // end anonymous namespace
50
51 //===----------------------------------------------------------------------===//
52 // Canonicalizer Driver.
53 //===----------------------------------------------------------------------===//
54
55 namespace {
56 struct FoldingPattern : public RewritePattern {
57 public:
FoldingPattern__anonac0207690211::FoldingPattern58 FoldingPattern(MLIRContext *context)
59 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
60 /*benefit=*/1, context) {}
61
matchAndRewrite__anonac0207690211::FoldingPattern62 LogicalResult matchAndRewrite(Operation *op,
63 PatternRewriter &rewriter) const override {
64 // Exercise OperationFolder API for a single-result operation that is folded
65 // upon construction. The operation being created through the folder has an
66 // in-place folder, and it should be still present in the output.
67 // Furthermore, the folder should not crash when attempting to recover the
68 // (unchanged) operation result.
69 OperationFolder folder(op->getContext());
70 Value result = folder.create<TestOpInPlaceFold>(
71 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
72 rewriter.getI32IntegerAttr(0));
73 assert(result);
74 rewriter.replaceOp(op, result);
75 return success();
76 }
77 };
78
79 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
runOnFunction__anonac0207690211::TestPatternDriver80 void runOnFunction() override {
81 mlir::OwningRewritePatternList patterns;
82 populateWithGenerated(&getContext(), patterns);
83
84 // Verify named pattern is generated with expected name.
85 patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
86
87 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
88 }
89 };
90 } // end anonymous namespace
91
92 //===----------------------------------------------------------------------===//
93 // ReturnType Driver.
94 //===----------------------------------------------------------------------===//
95
96 namespace {
97 // Generate ops for each instance where the type can be successfully inferred.
98 template <typename OpTy>
invokeCreateWithInferredReturnType(Operation * op)99 static void invokeCreateWithInferredReturnType(Operation *op) {
100 auto *context = op->getContext();
101 auto fop = op->getParentOfType<FuncOp>();
102 auto location = UnknownLoc::get(context);
103 OpBuilder b(op);
104 b.setInsertionPointAfter(op);
105
106 // Use permutations of 2 args as operands.
107 assert(fop.getNumArguments() >= 2);
108 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
109 for (int j = 0; j < e; ++j) {
110 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
111 SmallVector<Type, 2> inferredReturnTypes;
112 if (succeeded(OpTy::inferReturnTypes(
113 context, llvm::None, values, op->getAttrDictionary(),
114 op->getRegions(), inferredReturnTypes))) {
115 OperationState state(location, OpTy::getOperationName());
116 // TODO: Expand to regions.
117 OpTy::build(b, state, values, op->getAttrs());
118 (void)b.createOperation(state);
119 }
120 }
121 }
122 }
123
reifyReturnShape(Operation * op)124 static void reifyReturnShape(Operation *op) {
125 OpBuilder b(op);
126
127 // Use permutations of 2 args as operands.
128 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
129 SmallVector<Value, 2> shapes;
130 if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
131 return;
132 for (auto it : llvm::enumerate(shapes))
133 op->emitRemark() << "value " << it.index() << ": "
134 << it.value().getDefiningOp();
135 }
136
137 struct TestReturnTypeDriver
138 : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
runOnFunction__anonac0207690311::TestReturnTypeDriver139 void runOnFunction() override {
140 if (getFunction().getName() == "testCreateFunctions") {
141 std::vector<Operation *> ops;
142 // Collect ops to avoid triggering on inserted ops.
143 for (auto &op : getFunction().getBody().front())
144 ops.push_back(&op);
145 // Generate test patterns for each, but skip terminator.
146 for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
147 // Test create method of each of the Op classes below. The resultant
148 // output would be in reverse order underneath `op` from which
149 // the attributes and regions are used.
150 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
151 invokeCreateWithInferredReturnType<
152 OpWithShapedTypeInferTypeInterfaceOp>(op);
153 };
154 return;
155 }
156 if (getFunction().getName() == "testReifyFunctions") {
157 std::vector<Operation *> ops;
158 // Collect ops to avoid triggering on inserted ops.
159 for (auto &op : getFunction().getBody().front())
160 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
161 ops.push_back(&op);
162 // Generate test patterns for each, but skip terminator.
163 for (auto *op : ops)
164 reifyReturnShape(op);
165 }
166 }
167 };
168 } // end anonymous namespace
169
170 namespace {
171 struct TestDerivedAttributeDriver
172 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
173 void runOnFunction() override;
174 };
175 } // end anonymous namespace
176
runOnFunction()177 void TestDerivedAttributeDriver::runOnFunction() {
178 getFunction().walk([](DerivedAttributeOpInterface dOp) {
179 auto dAttr = dOp.materializeDerivedAttributes();
180 if (!dAttr)
181 return;
182 for (auto d : dAttr)
183 dOp.emitRemark() << d.first << " = " << d.second;
184 });
185 }
186
187 //===----------------------------------------------------------------------===//
188 // Legalization Driver.
189 //===----------------------------------------------------------------------===//
190
191 namespace {
192 //===----------------------------------------------------------------------===//
193 // Region-Block Rewrite Testing
194
195 /// This pattern is a simple pattern that inlines the first region of a given
196 /// operation into the parent region.
197 struct TestRegionRewriteBlockMovement : public ConversionPattern {
TestRegionRewriteBlockMovement__anonac0207690611::TestRegionRewriteBlockMovement198 TestRegionRewriteBlockMovement(MLIRContext *ctx)
199 : ConversionPattern("test.region", 1, ctx) {}
200
201 LogicalResult
matchAndRewrite__anonac0207690611::TestRegionRewriteBlockMovement202 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
203 ConversionPatternRewriter &rewriter) const final {
204 // Inline this region into the parent region.
205 auto &parentRegion = *op->getParentRegion();
206 auto &opRegion = op->getRegion(0);
207 if (op->getAttr("legalizer.should_clone"))
208 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
209 else
210 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
211
212 if (op->getAttr("legalizer.erase_old_blocks")) {
213 while (!opRegion.empty())
214 rewriter.eraseBlock(&opRegion.front());
215 }
216
217 // Drop this operation.
218 rewriter.eraseOp(op);
219 return success();
220 }
221 };
222 /// This pattern is a simple pattern that generates a region containing an
223 /// illegal operation.
224 struct TestRegionRewriteUndo : public RewritePattern {
TestRegionRewriteUndo__anonac0207690611::TestRegionRewriteUndo225 TestRegionRewriteUndo(MLIRContext *ctx)
226 : RewritePattern("test.region_builder", 1, ctx) {}
227
matchAndRewrite__anonac0207690611::TestRegionRewriteUndo228 LogicalResult matchAndRewrite(Operation *op,
229 PatternRewriter &rewriter) const final {
230 // Create the region operation with an entry block containing arguments.
231 OperationState newRegion(op->getLoc(), "test.region");
232 newRegion.addRegion();
233 auto *regionOp = rewriter.createOperation(newRegion);
234 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0));
235 entryBlock->addArgument(rewriter.getIntegerType(64));
236
237 // Add an explicitly illegal operation to ensure the conversion fails.
238 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
239 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
240
241 // Drop this operation.
242 rewriter.eraseOp(op);
243 return success();
244 }
245 };
246 /// A simple pattern that creates a block at the end of the parent region of the
247 /// matched operation.
248 struct TestCreateBlock : public RewritePattern {
TestCreateBlock__anonac0207690611::TestCreateBlock249 TestCreateBlock(MLIRContext *ctx)
250 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
251
matchAndRewrite__anonac0207690611::TestCreateBlock252 LogicalResult matchAndRewrite(Operation *op,
253 PatternRewriter &rewriter) const final {
254 Region ®ion = *op->getParentRegion();
255 Type i32Type = rewriter.getIntegerType(32);
256 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type});
257 rewriter.create<TerminatorOp>(op->getLoc());
258 rewriter.replaceOp(op, {});
259 return success();
260 }
261 };
262
263 /// A simple pattern that creates a block containing an invalid operation in
264 /// order to trigger the block creation undo mechanism.
265 struct TestCreateIllegalBlock : public RewritePattern {
TestCreateIllegalBlock__anonac0207690611::TestCreateIllegalBlock266 TestCreateIllegalBlock(MLIRContext *ctx)
267 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
268
matchAndRewrite__anonac0207690611::TestCreateIllegalBlock269 LogicalResult matchAndRewrite(Operation *op,
270 PatternRewriter &rewriter) const final {
271 Region ®ion = *op->getParentRegion();
272 Type i32Type = rewriter.getIntegerType(32);
273 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type});
274 // Create an illegal op to ensure the conversion fails.
275 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
276 rewriter.create<TerminatorOp>(op->getLoc());
277 rewriter.replaceOp(op, {});
278 return success();
279 }
280 };
281
282 /// A simple pattern that tests the undo mechanism when replacing the uses of a
283 /// block argument.
284 struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace__anonac0207690611::TestUndoBlockArgReplace285 TestUndoBlockArgReplace(MLIRContext *ctx)
286 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
287
288 LogicalResult
matchAndRewrite__anonac0207690611::TestUndoBlockArgReplace289 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
290 ConversionPatternRewriter &rewriter) const final {
291 auto illegalOp =
292 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
293 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
294 illegalOp);
295 rewriter.updateRootInPlace(op, [] {});
296 return success();
297 }
298 };
299
300 /// A rewrite pattern that tests the undo mechanism when erasing a block.
301 struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase__anonac0207690611::TestUndoBlockErase302 TestUndoBlockErase(MLIRContext *ctx)
303 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
304
305 LogicalResult
matchAndRewrite__anonac0207690611::TestUndoBlockErase306 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
307 ConversionPatternRewriter &rewriter) const final {
308 Block *secondBlock = &*std::next(op->getRegion(0).begin());
309 rewriter.setInsertionPointToStart(secondBlock);
310 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
311 rewriter.eraseBlock(secondBlock);
312 rewriter.updateRootInPlace(op, [] {});
313 return success();
314 }
315 };
316
317 //===----------------------------------------------------------------------===//
318 // Type-Conversion Rewrite Testing
319
320 /// This patterns erases a region operation that has had a type conversion.
321 struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion__anonac0207690611::TestDropOpSignatureConversion322 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
323 : ConversionPattern("test.drop_region_op", 1, converter, ctx) {}
324 LogicalResult
matchAndRewrite__anonac0207690611::TestDropOpSignatureConversion325 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
326 ConversionPatternRewriter &rewriter) const override {
327 Region ®ion = op->getRegion(0);
328 Block *entry = ®ion.front();
329
330 // Convert the original entry arguments.
331 TypeConverter &converter = *getTypeConverter();
332 TypeConverter::SignatureConversion result(entry->getNumArguments());
333 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
334 result)) ||
335 failed(rewriter.convertRegionTypes(®ion, converter, &result)))
336 return failure();
337
338 // Convert the region signature and just drop the operation.
339 rewriter.eraseOp(op);
340 return success();
341 }
342 };
343 /// This pattern simply updates the operands of the given operation.
344 struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp__anonac0207690611::TestPassthroughInvalidOp345 TestPassthroughInvalidOp(MLIRContext *ctx)
346 : ConversionPattern("test.invalid", 1, ctx) {}
347 LogicalResult
matchAndRewrite__anonac0207690611::TestPassthroughInvalidOp348 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
349 ConversionPatternRewriter &rewriter) const final {
350 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
351 llvm::None);
352 return success();
353 }
354 };
355 /// This pattern handles the case of a split return value.
356 struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType__anonac0207690611::TestSplitReturnType357 TestSplitReturnType(MLIRContext *ctx)
358 : ConversionPattern("test.return", 1, ctx) {}
359 LogicalResult
matchAndRewrite__anonac0207690611::TestSplitReturnType360 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
361 ConversionPatternRewriter &rewriter) const final {
362 // Check for a return of F32.
363 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
364 return failure();
365
366 // Check if the first operation is a cast operation, if it is we use the
367 // results directly.
368 auto *defOp = operands[0].getDefiningOp();
369 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
370 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
371 return success();
372 }
373
374 // Otherwise, fail to match.
375 return failure();
376 }
377 };
378
379 //===----------------------------------------------------------------------===//
380 // Multi-Level Type-Conversion Rewrite Testing
381 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32__anonac0207690611::TestChangeProducerTypeI32ToF32382 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
383 : ConversionPattern("test.type_producer", 1, ctx) {}
384 LogicalResult
matchAndRewrite__anonac0207690611::TestChangeProducerTypeI32ToF32385 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
386 ConversionPatternRewriter &rewriter) const final {
387 // If the type is I32, change the type to F32.
388 if (!Type(*op->result_type_begin()).isSignlessInteger(32))
389 return failure();
390 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
391 return success();
392 }
393 };
394 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64__anonac0207690611::TestChangeProducerTypeF32ToF64395 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
396 : ConversionPattern("test.type_producer", 1, ctx) {}
397 LogicalResult
matchAndRewrite__anonac0207690611::TestChangeProducerTypeF32ToF64398 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
399 ConversionPatternRewriter &rewriter) const final {
400 // If the type is F32, change the type to F64.
401 if (!Type(*op->result_type_begin()).isF32())
402 return rewriter.notifyMatchFailure(op, "expected single f32 operand");
403 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
404 return success();
405 }
406 };
407 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid__anonac0207690611::TestChangeProducerTypeF32ToInvalid408 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
409 : ConversionPattern("test.type_producer", 10, ctx) {}
410 LogicalResult
matchAndRewrite__anonac0207690611::TestChangeProducerTypeF32ToInvalid411 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
412 ConversionPatternRewriter &rewriter) const final {
413 // Always convert to B16, even though it is not a legal type. This tests
414 // that values are unmapped correctly.
415 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
416 return success();
417 }
418 };
419 struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType__anonac0207690611::TestUpdateConsumerType420 TestUpdateConsumerType(MLIRContext *ctx)
421 : ConversionPattern("test.type_consumer", 1, ctx) {}
422 LogicalResult
matchAndRewrite__anonac0207690611::TestUpdateConsumerType423 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
424 ConversionPatternRewriter &rewriter) const final {
425 // Verify that the incoming operand has been successfully remapped to F64.
426 if (!operands[0].getType().isF64())
427 return failure();
428 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
429 return success();
430 }
431 };
432
433 //===----------------------------------------------------------------------===//
434 // Non-Root Replacement Rewrite Testing
435 /// This pattern generates an invalid operation, but replaces it before the
436 /// pattern is finished. This checks that we don't need to legalize the
437 /// temporary op.
438 struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement__anonac0207690611::TestNonRootReplacement439 TestNonRootReplacement(MLIRContext *ctx)
440 : RewritePattern("test.replace_non_root", 1, ctx) {}
441
matchAndRewrite__anonac0207690611::TestNonRootReplacement442 LogicalResult matchAndRewrite(Operation *op,
443 PatternRewriter &rewriter) const final {
444 auto resultType = *op->result_type_begin();
445 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
446 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
447
448 rewriter.replaceOp(illegalOp, {legalOp});
449 rewriter.replaceOp(op, {illegalOp});
450 return success();
451 }
452 };
453
454 //===----------------------------------------------------------------------===//
455 // Recursive Rewrite Testing
456 /// This pattern is applied to the same operation multiple times, but has a
457 /// bounded recursion.
458 struct TestBoundedRecursiveRewrite
459 : public OpRewritePattern<TestRecursiveRewriteOp> {
TestBoundedRecursiveRewrite__anonac0207690611::TestBoundedRecursiveRewrite460 TestBoundedRecursiveRewrite(MLIRContext *ctx)
461 : OpRewritePattern<TestRecursiveRewriteOp>(ctx) {
462 // The conversion target handles bounding the recursion of this pattern.
463 setHasBoundedRewriteRecursion();
464 }
465
matchAndRewrite__anonac0207690611::TestBoundedRecursiveRewrite466 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
467 PatternRewriter &rewriter) const final {
468 // Decrement the depth of the op in-place.
469 rewriter.updateRootInPlace(op, [&] {
470 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
471 });
472 return success();
473 }
474 };
475
476 struct TestNestedOpCreationUndoRewrite
477 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
478 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
479
matchAndRewrite__anonac0207690611::TestNestedOpCreationUndoRewrite480 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
481 PatternRewriter &rewriter) const final {
482 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
483 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
484 return success();
485 };
486 };
487 } // namespace
488
489 namespace {
490 struct TestTypeConverter : public TypeConverter {
491 using TypeConverter::TypeConverter;
TestTypeConverter__anonac0207690a11::TestTypeConverter492 TestTypeConverter() {
493 addConversion(convertType);
494 addArgumentMaterialization(materializeCast);
495 addArgumentMaterialization(materializeOneToOneCast);
496 addSourceMaterialization(materializeCast);
497 }
498
convertType__anonac0207690a11::TestTypeConverter499 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
500 // Drop I16 types.
501 if (t.isSignlessInteger(16))
502 return success();
503
504 // Convert I64 to F64.
505 if (t.isSignlessInteger(64)) {
506 results.push_back(FloatType::getF64(t.getContext()));
507 return success();
508 }
509
510 // Convert I42 to I43.
511 if (t.isInteger(42)) {
512 results.push_back(IntegerType::get(t.getContext(), 43));
513 return success();
514 }
515
516 // Split F32 into F16,F16.
517 if (t.isF32()) {
518 results.assign(2, FloatType::getF16(t.getContext()));
519 return success();
520 }
521
522 // Otherwise, convert the type directly.
523 results.push_back(t);
524 return success();
525 }
526
527 /// Hook for materializing a conversion. This is necessary because we generate
528 /// 1->N type mappings.
materializeCast__anonac0207690a11::TestTypeConverter529 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
530 ValueRange inputs, Location loc) {
531 if (inputs.size() == 1)
532 return inputs[0];
533 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
534 }
535
536 /// Materialize the cast for one-to-one conversion from i64 to f64.
materializeOneToOneCast__anonac0207690a11::TestTypeConverter537 static Optional<Value> materializeOneToOneCast(OpBuilder &builder,
538 IntegerType resultType,
539 ValueRange inputs,
540 Location loc) {
541 if (resultType.getWidth() == 42 && inputs.size() == 1)
542 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
543 return llvm::None;
544 }
545 };
546
547 struct TestLegalizePatternDriver
548 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
549 /// The mode of conversion to use with the driver.
550 enum class ConversionMode { Analysis, Full, Partial };
551
TestLegalizePatternDriver__anonac0207690a11::TestLegalizePatternDriver552 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
553
runOnOperation__anonac0207690a11::TestLegalizePatternDriver554 void runOnOperation() override {
555 TestTypeConverter converter;
556 mlir::OwningRewritePatternList patterns;
557 populateWithGenerated(&getContext(), patterns);
558 patterns.insert<
559 TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
560 TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
561 TestPassthroughInvalidOp, TestSplitReturnType,
562 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
563 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
564 TestNonRootReplacement, TestBoundedRecursiveRewrite,
565 TestNestedOpCreationUndoRewrite>(&getContext());
566 patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
567 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
568 converter);
569 mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
570 converter);
571
572 // Define the conversion target used for the test.
573 ConversionTarget target(getContext());
574 target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
575 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
576 TerminatorOp>();
577 target
578 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
579 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
580 // Don't allow F32 operands.
581 return llvm::none_of(op.getOperandTypes(),
582 [](Type type) { return type.isF32(); });
583 });
584 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
585 return converter.isSignatureLegal(op.getType()) &&
586 converter.isLegal(&op.getBody());
587 });
588
589 // Expect the type_producer/type_consumer operations to only operate on f64.
590 target.addDynamicallyLegalOp<TestTypeProducerOp>(
591 [](TestTypeProducerOp op) { return op.getType().isF64(); });
592 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
593 return op.getOperand().getType().isF64();
594 });
595
596 // Check support for marking certain operations as recursively legal.
597 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
598 return static_cast<bool>(
599 op->getAttrOfType<UnitAttr>("test.recursively_legal"));
600 });
601
602 // Mark the bound recursion operation as dynamically legal.
603 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
604 [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
605
606 // Handle a partial conversion.
607 if (mode == ConversionMode::Partial) {
608 DenseSet<Operation *> unlegalizedOps;
609 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
610 &unlegalizedOps);
611 // Emit remarks for each legalizable operation.
612 for (auto *op : unlegalizedOps)
613 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
614 return;
615 }
616
617 // Handle a full conversion.
618 if (mode == ConversionMode::Full) {
619 // Check support for marking unknown operations as dynamically legal.
620 target.markUnknownOpDynamicallyLegal([](Operation *op) {
621 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
622 });
623
624 (void)applyFullConversion(getOperation(), target, std::move(patterns));
625 return;
626 }
627
628 // Otherwise, handle an analysis conversion.
629 assert(mode == ConversionMode::Analysis);
630
631 // Analyze the convertible operations.
632 DenseSet<Operation *> legalizedOps;
633 if (failed(applyAnalysisConversion(getOperation(), target,
634 std::move(patterns), legalizedOps)))
635 return signalPassFailure();
636
637 // Emit remarks for each legalizable operation.
638 for (auto *op : legalizedOps)
639 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
640 }
641
642 /// The mode of conversion to use.
643 ConversionMode mode;
644 };
645 } // end anonymous namespace
646
647 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
648 legalizerConversionMode(
649 "test-legalize-mode",
650 llvm::cl::desc("The legalization mode to use with the test driver"),
651 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
652 llvm::cl::values(
653 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
654 "analysis", "Perform an analysis conversion"),
655 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
656 "Perform a full conversion"),
657 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
658 "partial", "Perform a partial conversion")));
659
660 //===----------------------------------------------------------------------===//
661 // ConversionPatternRewriter::getRemappedValue testing. This method is used
662 // to get the remapped value of an original value that was replaced using
663 // ConversionPatternRewriter.
664 namespace {
665 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
666 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
667 /// operand twice.
668 ///
669 /// Example:
670 /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
671 /// is replaced with:
672 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
673 struct OneVResOneVOperandOp1Converter
674 : public OpConversionPattern<OneVResOneVOperandOp1> {
675 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
676
677 LogicalResult
matchAndRewrite__anonac0207691311::OneVResOneVOperandOp1Converter678 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
679 ConversionPatternRewriter &rewriter) const override {
680 auto origOps = op.getOperands();
681 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
682 "One operand expected");
683 Value origOp = *origOps.begin();
684 SmallVector<Value, 2> remappedOperands;
685 // Replicate the remapped original operand twice. Note that we don't used
686 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
687 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
688 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
689
690 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
691 remappedOperands);
692 return success();
693 }
694 };
695
696 struct TestRemappedValue
697 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
runOnFunction__anonac0207691311::TestRemappedValue698 void runOnFunction() override {
699 mlir::OwningRewritePatternList patterns;
700 patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
701
702 mlir::ConversionTarget target(getContext());
703 target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
704 // We make OneVResOneVOperandOp1 legal only when it has more that one
705 // operand. This will trigger the conversion that will replace one-operand
706 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
707 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
708 [](Operation *op) -> bool {
709 return std::distance(op->operand_begin(), op->operand_end()) > 1;
710 });
711
712 if (failed(mlir::applyFullConversion(getFunction(), target,
713 std::move(patterns)))) {
714 signalPassFailure();
715 }
716 }
717 };
718 } // end anonymous namespace
719
720 //===----------------------------------------------------------------------===//
721 // Test patterns without a specific root operation kind
722 //===----------------------------------------------------------------------===//
723
724 namespace {
725 /// This pattern matches and removes any operation in the test dialect.
726 struct RemoveTestDialectOps : public RewritePattern {
RemoveTestDialectOps__anonac0207691511::RemoveTestDialectOps727 RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
728
matchAndRewrite__anonac0207691511::RemoveTestDialectOps729 LogicalResult matchAndRewrite(Operation *op,
730 PatternRewriter &rewriter) const override {
731 if (!isa<TestDialect>(op->getDialect()))
732 return failure();
733 rewriter.eraseOp(op);
734 return success();
735 }
736 };
737
738 struct TestUnknownRootOpDriver
739 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
runOnFunction__anonac0207691511::TestUnknownRootOpDriver740 void runOnFunction() override {
741 mlir::OwningRewritePatternList patterns;
742 patterns.insert<RemoveTestDialectOps>();
743
744 mlir::ConversionTarget target(getContext());
745 target.addIllegalDialect<TestDialect>();
746 if (failed(
747 applyPartialConversion(getFunction(), target, std::move(patterns))))
748 signalPassFailure();
749 }
750 };
751 } // end anonymous namespace
752
753 //===----------------------------------------------------------------------===//
754 // Test type conversions
755 //===----------------------------------------------------------------------===//
756
757 namespace {
758 struct TestTypeConversionProducer
759 : public OpConversionPattern<TestTypeProducerOp> {
760 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
761 LogicalResult
matchAndRewrite__anonac0207691611::TestTypeConversionProducer762 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands,
763 ConversionPatternRewriter &rewriter) const final {
764 Type resultType = op.getType();
765 if (resultType.isa<FloatType>())
766 resultType = rewriter.getF64Type();
767 else if (resultType.isInteger(16))
768 resultType = rewriter.getIntegerType(64);
769 else
770 return failure();
771
772 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
773 return success();
774 }
775 };
776
777 struct TestTypeConversionDriver
778 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
getDependentDialects__anonac0207691611::TestTypeConversionDriver779 void getDependentDialects(DialectRegistry ®istry) const override {
780 registry.insert<TestDialect>();
781 }
782
runOnOperation__anonac0207691611::TestTypeConversionDriver783 void runOnOperation() override {
784 // Initialize the type converter.
785 TypeConverter converter;
786
787 /// Add the legal set of type conversions.
788 converter.addConversion([](Type type) -> Type {
789 // Treat F64 as legal.
790 if (type.isF64())
791 return type;
792 // Allow converting BF16/F16/F32 to F64.
793 if (type.isBF16() || type.isF16() || type.isF32())
794 return FloatType::getF64(type.getContext());
795 // Otherwise, the type is illegal.
796 return nullptr;
797 });
798 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
799 // Drop all integer types.
800 return success();
801 });
802
803 /// Add the legal set of type materializations.
804 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
805 ValueRange inputs,
806 Location loc) -> Value {
807 // Allow casting from F64 back to F32.
808 if (!resultType.isF16() && inputs.size() == 1 &&
809 inputs[0].getType().isF64())
810 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
811 // Allow producing an i32 or i64 from nothing.
812 if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
813 inputs.empty())
814 return builder.create<TestTypeProducerOp>(loc, resultType);
815 // Allow producing an i64 from an integer.
816 if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
817 inputs[0].getType().isa<IntegerType>())
818 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
819 // Otherwise, fail.
820 return nullptr;
821 });
822
823 // Initialize the conversion target.
824 mlir::ConversionTarget target(getContext());
825 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
826 return op.getType().isF64() || op.getType().isInteger(64);
827 });
828 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
829 return converter.isSignatureLegal(op.getType()) &&
830 converter.isLegal(&op.getBody());
831 });
832 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
833 // Allow casts from F64 to F32.
834 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
835 });
836
837 // Initialize the set of rewrite patterns.
838 OwningRewritePatternList patterns;
839 patterns.insert<TestTypeConversionProducer>(converter, &getContext());
840 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
841 converter);
842
843 if (failed(applyPartialConversion(getOperation(), target,
844 std::move(patterns))))
845 signalPassFailure();
846 }
847 };
848 } // end anonymous namespace
849
850 //===----------------------------------------------------------------------===//
851 // Test Block Merging
852 //===----------------------------------------------------------------------===//
853
854 namespace {
855 /// A rewriter pattern that tests that blocks can be merged.
856 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
857 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
858
859 LogicalResult
matchAndRewrite__anonac0207691d11::TestMergeBlock860 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
861 ConversionPatternRewriter &rewriter) const final {
862 Block &firstBlock = op.body().front();
863 Operation *branchOp = firstBlock.getTerminator();
864 Block *secondBlock = &*(std::next(op.body().begin()));
865 auto succOperands = branchOp->getOperands();
866 SmallVector<Value, 2> replacements(succOperands);
867 rewriter.eraseOp(branchOp);
868 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
869 rewriter.updateRootInPlace(op, [] {});
870 return success();
871 }
872 };
873
874 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
875 struct TestUndoBlocksMerge : public ConversionPattern {
TestUndoBlocksMerge__anonac0207691d11::TestUndoBlocksMerge876 TestUndoBlocksMerge(MLIRContext *ctx)
877 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
878 LogicalResult
matchAndRewrite__anonac0207691d11::TestUndoBlocksMerge879 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
880 ConversionPatternRewriter &rewriter) const final {
881 Block &firstBlock = op->getRegion(0).front();
882 Operation *branchOp = firstBlock.getTerminator();
883 Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
884 rewriter.setInsertionPointToStart(secondBlock);
885 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
886 auto succOperands = branchOp->getOperands();
887 SmallVector<Value, 2> replacements(succOperands);
888 rewriter.eraseOp(branchOp);
889 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
890 rewriter.updateRootInPlace(op, [] {});
891 return success();
892 }
893 };
894
895 /// A rewrite mechanism to inline the body of the op into its parent, when both
896 /// ops can have a single block.
897 struct TestMergeSingleBlockOps
898 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
899 using OpConversionPattern<
900 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
901
902 LogicalResult
matchAndRewrite__anonac0207691d11::TestMergeSingleBlockOps903 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
904 ConversionPatternRewriter &rewriter) const final {
905 SingleBlockImplicitTerminatorOp parentOp =
906 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
907 if (!parentOp)
908 return failure();
909 Block &innerBlock = op.region().front();
910 TerminatorOp innerTerminator =
911 cast<TerminatorOp>(innerBlock.getTerminator());
912 rewriter.mergeBlockBefore(&innerBlock, op);
913 rewriter.eraseOp(innerTerminator);
914 rewriter.eraseOp(op);
915 rewriter.updateRootInPlace(op, [] {});
916 return success();
917 }
918 };
919
920 struct TestMergeBlocksPatternDriver
921 : public PassWrapper<TestMergeBlocksPatternDriver,
922 OperationPass<ModuleOp>> {
runOnOperation__anonac0207691d11::TestMergeBlocksPatternDriver923 void runOnOperation() override {
924 mlir::OwningRewritePatternList patterns;
925 MLIRContext *context = &getContext();
926 patterns
927 .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
928 context);
929 ConversionTarget target(*context);
930 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
931 TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
932 TestReturnOp>();
933 target.addIllegalOp<ILLegalOpF>();
934
935 /// Expect the op to have a single block after legalization.
936 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
937 [&](TestMergeBlocksOp op) -> bool {
938 return llvm::hasSingleElement(op.body());
939 });
940
941 /// Only allow `test.br` within test.merge_blocks op.
942 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
943 return op->getParentOfType<TestMergeBlocksOp>();
944 });
945
946 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
947 /// inlined.
948 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
949 [&](SingleBlockImplicitTerminatorOp op) -> bool {
950 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
951 });
952
953 DenseSet<Operation *> unlegalizedOps;
954 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
955 &unlegalizedOps);
956 for (auto *op : unlegalizedOps)
957 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
958 }
959 };
960 } // namespace
961
962 //===----------------------------------------------------------------------===//
963 // Test Selective Replacement
964 //===----------------------------------------------------------------------===//
965
966 namespace {
967 /// A rewrite mechanism to inline the body of the op into its parent, when both
968 /// ops can have a single block.
969 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
970 using OpRewritePattern<TestCastOp>::OpRewritePattern;
971
matchAndRewrite__anonac0207692411::TestSelectiveOpReplacementPattern972 LogicalResult matchAndRewrite(TestCastOp op,
973 PatternRewriter &rewriter) const final {
974 if (op.getNumOperands() != 2)
975 return failure();
976 OperandRange operands = op.getOperands();
977
978 // Replace non-terminator uses with the first operand.
979 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
980 return operand.getOwner()->isKnownTerminator();
981 });
982 // Replace everything else with the second operand if the operation isn't
983 // dead.
984 rewriter.replaceOp(op, op.getOperand(1));
985 return success();
986 }
987 };
988
989 struct TestSelectiveReplacementPatternDriver
990 : public PassWrapper<TestSelectiveReplacementPatternDriver,
991 OperationPass<>> {
runOnOperation__anonac0207692411::TestSelectiveReplacementPatternDriver992 void runOnOperation() override {
993 mlir::OwningRewritePatternList patterns;
994 MLIRContext *context = &getContext();
995 patterns.insert<TestSelectiveOpReplacementPattern>(context);
996 applyPatternsAndFoldGreedily(getOperation()->getRegions(),
997 std::move(patterns));
998 }
999 };
1000 } // namespace
1001
1002 //===----------------------------------------------------------------------===//
1003 // PassRegistration
1004 //===----------------------------------------------------------------------===//
1005
1006 namespace mlir {
1007 namespace test {
registerPatternsTestPass()1008 void registerPatternsTestPass() {
1009 PassRegistration<TestReturnTypeDriver>("test-return-type",
1010 "Run return type functions");
1011
1012 PassRegistration<TestDerivedAttributeDriver>("test-derived-attr",
1013 "Run test derived attributes");
1014
1015 PassRegistration<TestPatternDriver>("test-patterns",
1016 "Run test dialect patterns");
1017
1018 PassRegistration<TestLegalizePatternDriver>(
1019 "test-legalize-patterns", "Run test dialect legalization patterns", [] {
1020 return std::make_unique<TestLegalizePatternDriver>(
1021 legalizerConversionMode);
1022 });
1023
1024 PassRegistration<TestRemappedValue>(
1025 "test-remapped-value",
1026 "Test public remapped value mechanism in ConversionPatternRewriter");
1027
1028 PassRegistration<TestUnknownRootOpDriver>(
1029 "test-legalize-unknown-root-patterns",
1030 "Test public remapped value mechanism in ConversionPatternRewriter");
1031
1032 PassRegistration<TestTypeConversionDriver>(
1033 "test-legalize-type-conversion",
1034 "Test various type conversion functionalities in DialectConversion");
1035
1036 PassRegistration<TestMergeBlocksPatternDriver>{
1037 "test-merge-blocks",
1038 "Test Merging operation in ConversionPatternRewriter"};
1039 PassRegistration<TestSelectiveReplacementPatternDriver>{
1040 "test-pattern-selective-replacement",
1041 "Test selective replacement in the PatternRewriter"};
1042 }
1043 } // namespace test
1044 } // namespace mlir
1045