1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19 using namespace mlir;
20 using namespace mlir::test;
21
22 // Native function for testing NativeCodeCall
chooseOperand(Value input1,Value input2,BoolAttr choice)23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
24 return choice.getValue() ? input1 : input2;
25 }
26
createOpI(PatternRewriter & rewriter,Location loc,Value input)27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
28 rewriter.create<OpI>(loc, input);
29 }
30
handleNoResultOp(PatternRewriter & rewriter,OpSymbolBindingNoResult op)31 static void handleNoResultOp(PatternRewriter &rewriter,
32 OpSymbolBindingNoResult op) {
33 // Turn the no result op to a one-result op.
34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
35 op.operand());
36 }
37
getFirstI32Result(Operation * op,Value & value)38 static bool getFirstI32Result(Operation *op, Value &value) {
39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
40 return false;
41 value = op->getResult(0);
42 return true;
43 }
44
bindNativeCodeCallResult(Value value)45 static Value bindNativeCodeCallResult(Value value) { return value; }
46
bindMultipleNativeCodeCallResult(Value input1,Value input2)47 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
48 Value input2) {
49 return SmallVector<Value, 2>({input2, input1});
50 }
51
52 // Test that natives calls are only called once during rewrites.
53 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
54 // This let us check the number of times OpM_Test was called by inspecting
55 // the returned value in the MLIR output.
56 static int64_t opMIncreasingValue = 314159265;
OpMTest(PatternRewriter & rewriter,Value val)57 static Attribute OpMTest(PatternRewriter &rewriter, Value val) {
58 int64_t i = opMIncreasingValue++;
59 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
60 }
61
62 namespace {
63 #include "TestPatterns.inc"
64 } // end anonymous namespace
65
66 //===----------------------------------------------------------------------===//
67 // Test Reduce Pattern Interface
68 //===----------------------------------------------------------------------===//
69
populateTestReductionPatterns(RewritePatternSet & patterns)70 void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) {
71 populateWithGenerated(patterns);
72 }
73
74 //===----------------------------------------------------------------------===//
75 // Canonicalizer Driver.
76 //===----------------------------------------------------------------------===//
77
78 namespace {
79 struct FoldingPattern : public RewritePattern {
80 public:
FoldingPattern__anon04a6666b0211::FoldingPattern81 FoldingPattern(MLIRContext *context)
82 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
83 /*benefit=*/1, context) {}
84
matchAndRewrite__anon04a6666b0211::FoldingPattern85 LogicalResult matchAndRewrite(Operation *op,
86 PatternRewriter &rewriter) const override {
87 // Exercise OperationFolder API for a single-result operation that is folded
88 // upon construction. The operation being created through the folder has an
89 // in-place folder, and it should be still present in the output.
90 // Furthermore, the folder should not crash when attempting to recover the
91 // (unchanged) operation result.
92 OperationFolder folder(op->getContext());
93 Value result = folder.create<TestOpInPlaceFold>(
94 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
95 rewriter.getI32IntegerAttr(0));
96 assert(result);
97 rewriter.replaceOp(op, result);
98 return success();
99 }
100 };
101
102 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
getArgument__anon04a6666b0211::TestPatternDriver103 StringRef getArgument() const final { return "test-patterns"; }
getDescription__anon04a6666b0211::TestPatternDriver104 StringRef getDescription() const final { return "Run test dialect patterns"; }
runOnFunction__anon04a6666b0211::TestPatternDriver105 void runOnFunction() override {
106 mlir::RewritePatternSet patterns(&getContext());
107 populateWithGenerated(patterns);
108
109 // Verify named pattern is generated with expected name.
110 patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext());
111
112 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
113 }
114 };
115 } // end anonymous namespace
116
117 //===----------------------------------------------------------------------===//
118 // ReturnType Driver.
119 //===----------------------------------------------------------------------===//
120
121 namespace {
122 // Generate ops for each instance where the type can be successfully inferred.
123 template <typename OpTy>
invokeCreateWithInferredReturnType(Operation * op)124 static void invokeCreateWithInferredReturnType(Operation *op) {
125 auto *context = op->getContext();
126 auto fop = op->getParentOfType<FuncOp>();
127 auto location = UnknownLoc::get(context);
128 OpBuilder b(op);
129 b.setInsertionPointAfter(op);
130
131 // Use permutations of 2 args as operands.
132 assert(fop.getNumArguments() >= 2);
133 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
134 for (int j = 0; j < e; ++j) {
135 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
136 SmallVector<Type, 2> inferredReturnTypes;
137 if (succeeded(OpTy::inferReturnTypes(
138 context, llvm::None, values, op->getAttrDictionary(),
139 op->getRegions(), inferredReturnTypes))) {
140 OperationState state(location, OpTy::getOperationName());
141 // TODO: Expand to regions.
142 OpTy::build(b, state, values, op->getAttrs());
143 (void)b.createOperation(state);
144 }
145 }
146 }
147 }
148
reifyReturnShape(Operation * op)149 static void reifyReturnShape(Operation *op) {
150 OpBuilder b(op);
151
152 // Use permutations of 2 args as operands.
153 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
154 SmallVector<Value, 2> shapes;
155 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
156 !llvm::hasSingleElement(shapes))
157 return;
158 for (auto it : llvm::enumerate(shapes)) {
159 op->emitRemark() << "value " << it.index() << ": "
160 << it.value().getDefiningOp();
161 }
162 }
163
164 struct TestReturnTypeDriver
165 : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
getDependentDialects__anon04a6666b0311::TestReturnTypeDriver166 void getDependentDialects(DialectRegistry ®istry) const override {
167 registry.insert<tensor::TensorDialect>();
168 }
getArgument__anon04a6666b0311::TestReturnTypeDriver169 StringRef getArgument() const final { return "test-return-type"; }
getDescription__anon04a6666b0311::TestReturnTypeDriver170 StringRef getDescription() const final { return "Run return type functions"; }
171
runOnFunction__anon04a6666b0311::TestReturnTypeDriver172 void runOnFunction() override {
173 if (getFunction().getName() == "testCreateFunctions") {
174 std::vector<Operation *> ops;
175 // Collect ops to avoid triggering on inserted ops.
176 for (auto &op : getFunction().getBody().front())
177 ops.push_back(&op);
178 // Generate test patterns for each, but skip terminator.
179 for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
180 // Test create method of each of the Op classes below. The resultant
181 // output would be in reverse order underneath `op` from which
182 // the attributes and regions are used.
183 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
184 invokeCreateWithInferredReturnType<
185 OpWithShapedTypeInferTypeInterfaceOp>(op);
186 };
187 return;
188 }
189 if (getFunction().getName() == "testReifyFunctions") {
190 std::vector<Operation *> ops;
191 // Collect ops to avoid triggering on inserted ops.
192 for (auto &op : getFunction().getBody().front())
193 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
194 ops.push_back(&op);
195 // Generate test patterns for each, but skip terminator.
196 for (auto *op : ops)
197 reifyReturnShape(op);
198 }
199 }
200 };
201 } // end anonymous namespace
202
203 namespace {
204 struct TestDerivedAttributeDriver
205 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
getArgument__anon04a6666b0411::TestDerivedAttributeDriver206 StringRef getArgument() const final { return "test-derived-attr"; }
getDescription__anon04a6666b0411::TestDerivedAttributeDriver207 StringRef getDescription() const final {
208 return "Run test derived attributes";
209 }
210 void runOnFunction() override;
211 };
212 } // end anonymous namespace
213
runOnFunction()214 void TestDerivedAttributeDriver::runOnFunction() {
215 getFunction().walk([](DerivedAttributeOpInterface dOp) {
216 auto dAttr = dOp.materializeDerivedAttributes();
217 if (!dAttr)
218 return;
219 for (auto d : dAttr)
220 dOp.emitRemark() << d.first << " = " << d.second;
221 });
222 }
223
224 //===----------------------------------------------------------------------===//
225 // Legalization Driver.
226 //===----------------------------------------------------------------------===//
227
228 namespace {
229 //===----------------------------------------------------------------------===//
230 // Region-Block Rewrite Testing
231
232 /// This pattern is a simple pattern that inlines the first region of a given
233 /// operation into the parent region.
234 struct TestRegionRewriteBlockMovement : public ConversionPattern {
TestRegionRewriteBlockMovement__anon04a6666b0611::TestRegionRewriteBlockMovement235 TestRegionRewriteBlockMovement(MLIRContext *ctx)
236 : ConversionPattern("test.region", 1, ctx) {}
237
238 LogicalResult
matchAndRewrite__anon04a6666b0611::TestRegionRewriteBlockMovement239 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
240 ConversionPatternRewriter &rewriter) const final {
241 // Inline this region into the parent region.
242 auto &parentRegion = *op->getParentRegion();
243 auto &opRegion = op->getRegion(0);
244 if (op->getAttr("legalizer.should_clone"))
245 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
246 else
247 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
248
249 if (op->getAttr("legalizer.erase_old_blocks")) {
250 while (!opRegion.empty())
251 rewriter.eraseBlock(&opRegion.front());
252 }
253
254 // Drop this operation.
255 rewriter.eraseOp(op);
256 return success();
257 }
258 };
259 /// This pattern is a simple pattern that generates a region containing an
260 /// illegal operation.
261 struct TestRegionRewriteUndo : public RewritePattern {
TestRegionRewriteUndo__anon04a6666b0611::TestRegionRewriteUndo262 TestRegionRewriteUndo(MLIRContext *ctx)
263 : RewritePattern("test.region_builder", 1, ctx) {}
264
matchAndRewrite__anon04a6666b0611::TestRegionRewriteUndo265 LogicalResult matchAndRewrite(Operation *op,
266 PatternRewriter &rewriter) const final {
267 // Create the region operation with an entry block containing arguments.
268 OperationState newRegion(op->getLoc(), "test.region");
269 newRegion.addRegion();
270 auto *regionOp = rewriter.createOperation(newRegion);
271 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0));
272 entryBlock->addArgument(rewriter.getIntegerType(64));
273
274 // Add an explicitly illegal operation to ensure the conversion fails.
275 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
276 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
277
278 // Drop this operation.
279 rewriter.eraseOp(op);
280 return success();
281 }
282 };
283 /// A simple pattern that creates a block at the end of the parent region of the
284 /// matched operation.
285 struct TestCreateBlock : public RewritePattern {
TestCreateBlock__anon04a6666b0611::TestCreateBlock286 TestCreateBlock(MLIRContext *ctx)
287 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
288
matchAndRewrite__anon04a6666b0611::TestCreateBlock289 LogicalResult matchAndRewrite(Operation *op,
290 PatternRewriter &rewriter) const final {
291 Region ®ion = *op->getParentRegion();
292 Type i32Type = rewriter.getIntegerType(32);
293 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type});
294 rewriter.create<TerminatorOp>(op->getLoc());
295 rewriter.replaceOp(op, {});
296 return success();
297 }
298 };
299
300 /// A simple pattern that creates a block containing an invalid operation in
301 /// order to trigger the block creation undo mechanism.
302 struct TestCreateIllegalBlock : public RewritePattern {
TestCreateIllegalBlock__anon04a6666b0611::TestCreateIllegalBlock303 TestCreateIllegalBlock(MLIRContext *ctx)
304 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
305
matchAndRewrite__anon04a6666b0611::TestCreateIllegalBlock306 LogicalResult matchAndRewrite(Operation *op,
307 PatternRewriter &rewriter) const final {
308 Region ®ion = *op->getParentRegion();
309 Type i32Type = rewriter.getIntegerType(32);
310 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type});
311 // Create an illegal op to ensure the conversion fails.
312 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
313 rewriter.create<TerminatorOp>(op->getLoc());
314 rewriter.replaceOp(op, {});
315 return success();
316 }
317 };
318
319 /// A simple pattern that tests the undo mechanism when replacing the uses of a
320 /// block argument.
321 struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace__anon04a6666b0611::TestUndoBlockArgReplace322 TestUndoBlockArgReplace(MLIRContext *ctx)
323 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
324
325 LogicalResult
matchAndRewrite__anon04a6666b0611::TestUndoBlockArgReplace326 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
327 ConversionPatternRewriter &rewriter) const final {
328 auto illegalOp =
329 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
330 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
331 illegalOp);
332 rewriter.updateRootInPlace(op, [] {});
333 return success();
334 }
335 };
336
337 /// A rewrite pattern that tests the undo mechanism when erasing a block.
338 struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase__anon04a6666b0611::TestUndoBlockErase339 TestUndoBlockErase(MLIRContext *ctx)
340 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
341
342 LogicalResult
matchAndRewrite__anon04a6666b0611::TestUndoBlockErase343 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
344 ConversionPatternRewriter &rewriter) const final {
345 Block *secondBlock = &*std::next(op->getRegion(0).begin());
346 rewriter.setInsertionPointToStart(secondBlock);
347 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
348 rewriter.eraseBlock(secondBlock);
349 rewriter.updateRootInPlace(op, [] {});
350 return success();
351 }
352 };
353
354 //===----------------------------------------------------------------------===//
355 // Type-Conversion Rewrite Testing
356
357 /// This patterns erases a region operation that has had a type conversion.
358 struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion__anon04a6666b0611::TestDropOpSignatureConversion359 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
360 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
361 LogicalResult
matchAndRewrite__anon04a6666b0611::TestDropOpSignatureConversion362 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
363 ConversionPatternRewriter &rewriter) const override {
364 Region ®ion = op->getRegion(0);
365 Block *entry = ®ion.front();
366
367 // Convert the original entry arguments.
368 TypeConverter &converter = *getTypeConverter();
369 TypeConverter::SignatureConversion result(entry->getNumArguments());
370 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
371 result)) ||
372 failed(rewriter.convertRegionTypes(®ion, converter, &result)))
373 return failure();
374
375 // Convert the region signature and just drop the operation.
376 rewriter.eraseOp(op);
377 return success();
378 }
379 };
380 /// This pattern simply updates the operands of the given operation.
381 struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp__anon04a6666b0611::TestPassthroughInvalidOp382 TestPassthroughInvalidOp(MLIRContext *ctx)
383 : ConversionPattern("test.invalid", 1, ctx) {}
384 LogicalResult
matchAndRewrite__anon04a6666b0611::TestPassthroughInvalidOp385 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
386 ConversionPatternRewriter &rewriter) const final {
387 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
388 llvm::None);
389 return success();
390 }
391 };
392 /// This pattern handles the case of a split return value.
393 struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType__anon04a6666b0611::TestSplitReturnType394 TestSplitReturnType(MLIRContext *ctx)
395 : ConversionPattern("test.return", 1, ctx) {}
396 LogicalResult
matchAndRewrite__anon04a6666b0611::TestSplitReturnType397 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
398 ConversionPatternRewriter &rewriter) const final {
399 // Check for a return of F32.
400 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
401 return failure();
402
403 // Check if the first operation is a cast operation, if it is we use the
404 // results directly.
405 auto *defOp = operands[0].getDefiningOp();
406 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
407 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
408 return success();
409 }
410
411 // Otherwise, fail to match.
412 return failure();
413 }
414 };
415
416 //===----------------------------------------------------------------------===//
417 // Multi-Level Type-Conversion Rewrite Testing
418 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32__anon04a6666b0611::TestChangeProducerTypeI32ToF32419 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
420 : ConversionPattern("test.type_producer", 1, ctx) {}
421 LogicalResult
matchAndRewrite__anon04a6666b0611::TestChangeProducerTypeI32ToF32422 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
423 ConversionPatternRewriter &rewriter) const final {
424 // If the type is I32, change the type to F32.
425 if (!Type(*op->result_type_begin()).isSignlessInteger(32))
426 return failure();
427 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
428 return success();
429 }
430 };
431 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64__anon04a6666b0611::TestChangeProducerTypeF32ToF64432 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
433 : ConversionPattern("test.type_producer", 1, ctx) {}
434 LogicalResult
matchAndRewrite__anon04a6666b0611::TestChangeProducerTypeF32ToF64435 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
436 ConversionPatternRewriter &rewriter) const final {
437 // If the type is F32, change the type to F64.
438 if (!Type(*op->result_type_begin()).isF32())
439 return rewriter.notifyMatchFailure(op, "expected single f32 operand");
440 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
441 return success();
442 }
443 };
444 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid__anon04a6666b0611::TestChangeProducerTypeF32ToInvalid445 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
446 : ConversionPattern("test.type_producer", 10, ctx) {}
447 LogicalResult
matchAndRewrite__anon04a6666b0611::TestChangeProducerTypeF32ToInvalid448 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
449 ConversionPatternRewriter &rewriter) const final {
450 // Always convert to B16, even though it is not a legal type. This tests
451 // that values are unmapped correctly.
452 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
453 return success();
454 }
455 };
456 struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType__anon04a6666b0611::TestUpdateConsumerType457 TestUpdateConsumerType(MLIRContext *ctx)
458 : ConversionPattern("test.type_consumer", 1, ctx) {}
459 LogicalResult
matchAndRewrite__anon04a6666b0611::TestUpdateConsumerType460 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
461 ConversionPatternRewriter &rewriter) const final {
462 // Verify that the incoming operand has been successfully remapped to F64.
463 if (!operands[0].getType().isF64())
464 return failure();
465 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
466 return success();
467 }
468 };
469
470 //===----------------------------------------------------------------------===//
471 // Non-Root Replacement Rewrite Testing
472 /// This pattern generates an invalid operation, but replaces it before the
473 /// pattern is finished. This checks that we don't need to legalize the
474 /// temporary op.
475 struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement__anon04a6666b0611::TestNonRootReplacement476 TestNonRootReplacement(MLIRContext *ctx)
477 : RewritePattern("test.replace_non_root", 1, ctx) {}
478
matchAndRewrite__anon04a6666b0611::TestNonRootReplacement479 LogicalResult matchAndRewrite(Operation *op,
480 PatternRewriter &rewriter) const final {
481 auto resultType = *op->result_type_begin();
482 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
483 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
484
485 rewriter.replaceOp(illegalOp, {legalOp});
486 rewriter.replaceOp(op, {illegalOp});
487 return success();
488 }
489 };
490
491 //===----------------------------------------------------------------------===//
492 // Recursive Rewrite Testing
493 /// This pattern is applied to the same operation multiple times, but has a
494 /// bounded recursion.
495 struct TestBoundedRecursiveRewrite
496 : public OpRewritePattern<TestRecursiveRewriteOp> {
497 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
498
initialize__anon04a6666b0611::TestBoundedRecursiveRewrite499 void initialize() {
500 // The conversion target handles bounding the recursion of this pattern.
501 setHasBoundedRewriteRecursion();
502 }
503
matchAndRewrite__anon04a6666b0611::TestBoundedRecursiveRewrite504 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
505 PatternRewriter &rewriter) const final {
506 // Decrement the depth of the op in-place.
507 rewriter.updateRootInPlace(op, [&] {
508 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
509 });
510 return success();
511 }
512 };
513
514 struct TestNestedOpCreationUndoRewrite
515 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
516 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
517
matchAndRewrite__anon04a6666b0611::TestNestedOpCreationUndoRewrite518 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
519 PatternRewriter &rewriter) const final {
520 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
521 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
522 return success();
523 };
524 };
525
526 // This pattern matches `test.blackhole` and delete this op and its producer.
527 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
528 using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
529
matchAndRewrite__anon04a6666b0611::TestReplaceEraseOp530 LogicalResult matchAndRewrite(BlackHoleOp op,
531 PatternRewriter &rewriter) const final {
532 Operation *producer = op.getOperand().getDefiningOp();
533 // Always erase the user before the producer, the framework should handle
534 // this correctly.
535 rewriter.eraseOp(op);
536 rewriter.eraseOp(producer);
537 return success();
538 };
539 };
540 } // namespace
541
542 namespace {
543 struct TestTypeConverter : public TypeConverter {
544 using TypeConverter::TypeConverter;
TestTypeConverter__anon04a6666b0a11::TestTypeConverter545 TestTypeConverter() {
546 addConversion(convertType);
547 addArgumentMaterialization(materializeCast);
548 addSourceMaterialization(materializeCast);
549
550 /// Materialize the cast for one-to-one conversion from i64 to f64.
551 const auto materializeOneToOneCast =
552 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs,
553 Location loc) -> Optional<Value> {
554 if (resultType.getWidth() == 42 && inputs.size() == 1)
555 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
556 return llvm::None;
557 };
558 addArgumentMaterialization(materializeOneToOneCast);
559 }
560
convertType__anon04a6666b0a11::TestTypeConverter561 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
562 // Drop I16 types.
563 if (t.isSignlessInteger(16))
564 return success();
565
566 // Convert I64 to F64.
567 if (t.isSignlessInteger(64)) {
568 results.push_back(FloatType::getF64(t.getContext()));
569 return success();
570 }
571
572 // Convert I42 to I43.
573 if (t.isInteger(42)) {
574 results.push_back(IntegerType::get(t.getContext(), 43));
575 return success();
576 }
577
578 // Split F32 into F16,F16.
579 if (t.isF32()) {
580 results.assign(2, FloatType::getF16(t.getContext()));
581 return success();
582 }
583
584 // Otherwise, convert the type directly.
585 results.push_back(t);
586 return success();
587 }
588
589 /// Hook for materializing a conversion. This is necessary because we generate
590 /// 1->N type mappings.
materializeCast__anon04a6666b0a11::TestTypeConverter591 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
592 ValueRange inputs, Location loc) {
593 if (inputs.size() == 1)
594 return inputs[0];
595 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
596 }
597 };
598
599 struct TestLegalizePatternDriver
600 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
getArgument__anon04a6666b0a11::TestLegalizePatternDriver601 StringRef getArgument() const final { return "test-legalize-patterns"; }
getDescription__anon04a6666b0a11::TestLegalizePatternDriver602 StringRef getDescription() const final {
603 return "Run test dialect legalization patterns";
604 }
605 /// The mode of conversion to use with the driver.
606 enum class ConversionMode { Analysis, Full, Partial };
607
TestLegalizePatternDriver__anon04a6666b0a11::TestLegalizePatternDriver608 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
609
runOnOperation__anon04a6666b0a11::TestLegalizePatternDriver610 void runOnOperation() override {
611 TestTypeConverter converter;
612 mlir::RewritePatternSet patterns(&getContext());
613 populateWithGenerated(patterns);
614 patterns
615 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
616 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
617 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
618 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
619 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
620 TestNonRootReplacement, TestBoundedRecursiveRewrite,
621 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>(
622 &getContext());
623 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
624 mlir::populateFuncOpTypeConversionPattern(patterns, converter);
625 mlir::populateCallOpTypeConversionPattern(patterns, converter);
626
627 // Define the conversion target used for the test.
628 ConversionTarget target(getContext());
629 target.addLegalOp<ModuleOp>();
630 target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
631 TerminatorOp>();
632 target
633 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
634 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
635 // Don't allow F32 operands.
636 return llvm::none_of(op.getOperandTypes(),
637 [](Type type) { return type.isF32(); });
638 });
639 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
640 return converter.isSignatureLegal(op.getType()) &&
641 converter.isLegal(&op.getBody());
642 });
643
644 // Expect the type_producer/type_consumer operations to only operate on f64.
645 target.addDynamicallyLegalOp<TestTypeProducerOp>(
646 [](TestTypeProducerOp op) { return op.getType().isF64(); });
647 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
648 return op.getOperand().getType().isF64();
649 });
650
651 // Check support for marking certain operations as recursively legal.
652 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
653 return static_cast<bool>(
654 op->getAttrOfType<UnitAttr>("test.recursively_legal"));
655 });
656
657 // Mark the bound recursion operation as dynamically legal.
658 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
659 [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
660
661 // Handle a partial conversion.
662 if (mode == ConversionMode::Partial) {
663 DenseSet<Operation *> unlegalizedOps;
664 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
665 &unlegalizedOps);
666 // Emit remarks for each legalizable operation.
667 for (auto *op : unlegalizedOps)
668 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
669 return;
670 }
671
672 // Handle a full conversion.
673 if (mode == ConversionMode::Full) {
674 // Check support for marking unknown operations as dynamically legal.
675 target.markUnknownOpDynamicallyLegal([](Operation *op) {
676 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
677 });
678
679 (void)applyFullConversion(getOperation(), target, std::move(patterns));
680 return;
681 }
682
683 // Otherwise, handle an analysis conversion.
684 assert(mode == ConversionMode::Analysis);
685
686 // Analyze the convertible operations.
687 DenseSet<Operation *> legalizedOps;
688 if (failed(applyAnalysisConversion(getOperation(), target,
689 std::move(patterns), legalizedOps)))
690 return signalPassFailure();
691
692 // Emit remarks for each legalizable operation.
693 for (auto *op : legalizedOps)
694 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
695 }
696
697 /// The mode of conversion to use.
698 ConversionMode mode;
699 };
700 } // end anonymous namespace
701
702 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
703 legalizerConversionMode(
704 "test-legalize-mode",
705 llvm::cl::desc("The legalization mode to use with the test driver"),
706 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
707 llvm::cl::values(
708 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
709 "analysis", "Perform an analysis conversion"),
710 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
711 "Perform a full conversion"),
712 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
713 "partial", "Perform a partial conversion")));
714
715 //===----------------------------------------------------------------------===//
716 // ConversionPatternRewriter::getRemappedValue testing. This method is used
717 // to get the remapped value of an original value that was replaced using
718 // ConversionPatternRewriter.
719 namespace {
720 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
721 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
722 /// operand twice.
723 ///
724 /// Example:
725 /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
726 /// is replaced with:
727 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
728 struct OneVResOneVOperandOp1Converter
729 : public OpConversionPattern<OneVResOneVOperandOp1> {
730 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
731
732 LogicalResult
matchAndRewrite__anon04a6666b1411::OneVResOneVOperandOp1Converter733 matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
734 ConversionPatternRewriter &rewriter) const override {
735 auto origOps = op.getOperands();
736 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
737 "One operand expected");
738 Value origOp = *origOps.begin();
739 SmallVector<Value, 2> remappedOperands;
740 // Replicate the remapped original operand twice. Note that we don't used
741 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
742 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
743 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
744
745 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
746 remappedOperands);
747 return success();
748 }
749 };
750
751 struct TestRemappedValue
752 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
getArgument__anon04a6666b1411::TestRemappedValue753 StringRef getArgument() const final { return "test-remapped-value"; }
getDescription__anon04a6666b1411::TestRemappedValue754 StringRef getDescription() const final {
755 return "Test public remapped value mechanism in ConversionPatternRewriter";
756 }
runOnFunction__anon04a6666b1411::TestRemappedValue757 void runOnFunction() override {
758 mlir::RewritePatternSet patterns(&getContext());
759 patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
760
761 mlir::ConversionTarget target(getContext());
762 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
763 // We make OneVResOneVOperandOp1 legal only when it has more that one
764 // operand. This will trigger the conversion that will replace one-operand
765 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
766 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
767 [](Operation *op) -> bool {
768 return std::distance(op->operand_begin(), op->operand_end()) > 1;
769 });
770
771 if (failed(mlir::applyFullConversion(getFunction(), target,
772 std::move(patterns)))) {
773 signalPassFailure();
774 }
775 }
776 };
777 } // end anonymous namespace
778
779 //===----------------------------------------------------------------------===//
780 // Test patterns without a specific root operation kind
781 //===----------------------------------------------------------------------===//
782
783 namespace {
784 /// This pattern matches and removes any operation in the test dialect.
785 struct RemoveTestDialectOps : public RewritePattern {
RemoveTestDialectOps__anon04a6666b1611::RemoveTestDialectOps786 RemoveTestDialectOps(MLIRContext *context)
787 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
788
matchAndRewrite__anon04a6666b1611::RemoveTestDialectOps789 LogicalResult matchAndRewrite(Operation *op,
790 PatternRewriter &rewriter) const override {
791 if (!isa<TestDialect>(op->getDialect()))
792 return failure();
793 rewriter.eraseOp(op);
794 return success();
795 }
796 };
797
798 struct TestUnknownRootOpDriver
799 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
getArgument__anon04a6666b1611::TestUnknownRootOpDriver800 StringRef getArgument() const final {
801 return "test-legalize-unknown-root-patterns";
802 }
getDescription__anon04a6666b1611::TestUnknownRootOpDriver803 StringRef getDescription() const final {
804 return "Test public remapped value mechanism in ConversionPatternRewriter";
805 }
runOnFunction__anon04a6666b1611::TestUnknownRootOpDriver806 void runOnFunction() override {
807 mlir::RewritePatternSet patterns(&getContext());
808 patterns.add<RemoveTestDialectOps>(&getContext());
809
810 mlir::ConversionTarget target(getContext());
811 target.addIllegalDialect<TestDialect>();
812 if (failed(
813 applyPartialConversion(getFunction(), target, std::move(patterns))))
814 signalPassFailure();
815 }
816 };
817 } // end anonymous namespace
818
819 //===----------------------------------------------------------------------===//
820 // Test type conversions
821 //===----------------------------------------------------------------------===//
822
823 namespace {
824 struct TestTypeConversionProducer
825 : public OpConversionPattern<TestTypeProducerOp> {
826 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
827 LogicalResult
matchAndRewrite__anon04a6666b1711::TestTypeConversionProducer828 matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands,
829 ConversionPatternRewriter &rewriter) const final {
830 Type resultType = op.getType();
831 if (resultType.isa<FloatType>())
832 resultType = rewriter.getF64Type();
833 else if (resultType.isInteger(16))
834 resultType = rewriter.getIntegerType(64);
835 else
836 return failure();
837
838 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
839 return success();
840 }
841 };
842
843 /// Call signature conversion and then fail the rewrite to trigger the undo
844 /// mechanism.
845 struct TestSignatureConversionUndo
846 : public OpConversionPattern<TestSignatureConversionUndoOp> {
847 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
848
849 LogicalResult
matchAndRewrite__anon04a6666b1711::TestSignatureConversionUndo850 matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands,
851 ConversionPatternRewriter &rewriter) const final {
852 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
853 return failure();
854 }
855 };
856
857 /// Just forward the operands to the root op. This is essentially a no-op
858 /// pattern that is used to trigger target materialization.
859 struct TestTypeConsumerForward
860 : public OpConversionPattern<TestTypeConsumerOp> {
861 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
862
863 LogicalResult
matchAndRewrite__anon04a6666b1711::TestTypeConsumerForward864 matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands,
865 ConversionPatternRewriter &rewriter) const final {
866 rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); });
867 return success();
868 }
869 };
870
871 struct TestTypeConversionAnotherProducer
872 : public OpRewritePattern<TestAnotherTypeProducerOp> {
873 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
874
matchAndRewrite__anon04a6666b1711::TestTypeConversionAnotherProducer875 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
876 PatternRewriter &rewriter) const final {
877 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
878 return success();
879 }
880 };
881
882 struct TestTypeConversionDriver
883 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
getDependentDialects__anon04a6666b1711::TestTypeConversionDriver884 void getDependentDialects(DialectRegistry ®istry) const override {
885 registry.insert<TestDialect>();
886 }
getArgument__anon04a6666b1711::TestTypeConversionDriver887 StringRef getArgument() const final {
888 return "test-legalize-type-conversion";
889 }
getDescription__anon04a6666b1711::TestTypeConversionDriver890 StringRef getDescription() const final {
891 return "Test various type conversion functionalities in DialectConversion";
892 }
893
runOnOperation__anon04a6666b1711::TestTypeConversionDriver894 void runOnOperation() override {
895 // Initialize the type converter.
896 TypeConverter converter;
897
898 /// Add the legal set of type conversions.
899 converter.addConversion([](Type type) -> Type {
900 // Treat F64 as legal.
901 if (type.isF64())
902 return type;
903 // Allow converting BF16/F16/F32 to F64.
904 if (type.isBF16() || type.isF16() || type.isF32())
905 return FloatType::getF64(type.getContext());
906 // Otherwise, the type is illegal.
907 return nullptr;
908 });
909 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
910 // Drop all integer types.
911 return success();
912 });
913
914 /// Add the legal set of type materializations.
915 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
916 ValueRange inputs,
917 Location loc) -> Value {
918 // Allow casting from F64 back to F32.
919 if (!resultType.isF16() && inputs.size() == 1 &&
920 inputs[0].getType().isF64())
921 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
922 // Allow producing an i32 or i64 from nothing.
923 if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
924 inputs.empty())
925 return builder.create<TestTypeProducerOp>(loc, resultType);
926 // Allow producing an i64 from an integer.
927 if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
928 inputs[0].getType().isa<IntegerType>())
929 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
930 // Otherwise, fail.
931 return nullptr;
932 });
933
934 // Initialize the conversion target.
935 mlir::ConversionTarget target(getContext());
936 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
937 return op.getType().isF64() || op.getType().isInteger(64);
938 });
939 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
940 return converter.isSignatureLegal(op.getType()) &&
941 converter.isLegal(&op.getBody());
942 });
943 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
944 // Allow casts from F64 to F32.
945 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
946 });
947
948 // Initialize the set of rewrite patterns.
949 RewritePatternSet patterns(&getContext());
950 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
951 TestSignatureConversionUndo>(converter, &getContext());
952 patterns.add<TestTypeConversionAnotherProducer>(&getContext());
953 mlir::populateFuncOpTypeConversionPattern(patterns, converter);
954
955 if (failed(applyPartialConversion(getOperation(), target,
956 std::move(patterns))))
957 signalPassFailure();
958 }
959 };
960 } // end anonymous namespace
961
962 //===----------------------------------------------------------------------===//
963 // Test Block Merging
964 //===----------------------------------------------------------------------===//
965
966 namespace {
967 /// A rewriter pattern that tests that blocks can be merged.
968 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
969 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
970
971 LogicalResult
matchAndRewrite__anon04a6666b1f11::TestMergeBlock972 matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
973 ConversionPatternRewriter &rewriter) const final {
974 Block &firstBlock = op.body().front();
975 Operation *branchOp = firstBlock.getTerminator();
976 Block *secondBlock = &*(std::next(op.body().begin()));
977 auto succOperands = branchOp->getOperands();
978 SmallVector<Value, 2> replacements(succOperands);
979 rewriter.eraseOp(branchOp);
980 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
981 rewriter.updateRootInPlace(op, [] {});
982 return success();
983 }
984 };
985
986 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
987 struct TestUndoBlocksMerge : public ConversionPattern {
TestUndoBlocksMerge__anon04a6666b1f11::TestUndoBlocksMerge988 TestUndoBlocksMerge(MLIRContext *ctx)
989 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
990 LogicalResult
matchAndRewrite__anon04a6666b1f11::TestUndoBlocksMerge991 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
992 ConversionPatternRewriter &rewriter) const final {
993 Block &firstBlock = op->getRegion(0).front();
994 Operation *branchOp = firstBlock.getTerminator();
995 Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
996 rewriter.setInsertionPointToStart(secondBlock);
997 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
998 auto succOperands = branchOp->getOperands();
999 SmallVector<Value, 2> replacements(succOperands);
1000 rewriter.eraseOp(branchOp);
1001 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1002 rewriter.updateRootInPlace(op, [] {});
1003 return success();
1004 }
1005 };
1006
1007 /// A rewrite mechanism to inline the body of the op into its parent, when both
1008 /// ops can have a single block.
1009 struct TestMergeSingleBlockOps
1010 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1011 using OpConversionPattern<
1012 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1013
1014 LogicalResult
matchAndRewrite__anon04a6666b1f11::TestMergeSingleBlockOps1015 matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
1016 ConversionPatternRewriter &rewriter) const final {
1017 SingleBlockImplicitTerminatorOp parentOp =
1018 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1019 if (!parentOp)
1020 return failure();
1021 Block &innerBlock = op.region().front();
1022 TerminatorOp innerTerminator =
1023 cast<TerminatorOp>(innerBlock.getTerminator());
1024 rewriter.mergeBlockBefore(&innerBlock, op);
1025 rewriter.eraseOp(innerTerminator);
1026 rewriter.eraseOp(op);
1027 rewriter.updateRootInPlace(op, [] {});
1028 return success();
1029 }
1030 };
1031
1032 struct TestMergeBlocksPatternDriver
1033 : public PassWrapper<TestMergeBlocksPatternDriver,
1034 OperationPass<ModuleOp>> {
getArgument__anon04a6666b1f11::TestMergeBlocksPatternDriver1035 StringRef getArgument() const final { return "test-merge-blocks"; }
getDescription__anon04a6666b1f11::TestMergeBlocksPatternDriver1036 StringRef getDescription() const final {
1037 return "Test Merging operation in ConversionPatternRewriter";
1038 }
runOnOperation__anon04a6666b1f11::TestMergeBlocksPatternDriver1039 void runOnOperation() override {
1040 MLIRContext *context = &getContext();
1041 mlir::RewritePatternSet patterns(context);
1042 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1043 context);
1044 ConversionTarget target(*context);
1045 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1046 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1047 target.addIllegalOp<ILLegalOpF>();
1048
1049 /// Expect the op to have a single block after legalization.
1050 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1051 [&](TestMergeBlocksOp op) -> bool {
1052 return llvm::hasSingleElement(op.body());
1053 });
1054
1055 /// Only allow `test.br` within test.merge_blocks op.
1056 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1057 return op->getParentOfType<TestMergeBlocksOp>();
1058 });
1059
1060 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1061 /// inlined.
1062 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1063 [&](SingleBlockImplicitTerminatorOp op) -> bool {
1064 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1065 });
1066
1067 DenseSet<Operation *> unlegalizedOps;
1068 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1069 &unlegalizedOps);
1070 for (auto *op : unlegalizedOps)
1071 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1072 }
1073 };
1074 } // namespace
1075
1076 //===----------------------------------------------------------------------===//
1077 // Test Selective Replacement
1078 //===----------------------------------------------------------------------===//
1079
1080 namespace {
1081 /// A rewrite mechanism to inline the body of the op into its parent, when both
1082 /// ops can have a single block.
1083 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1084 using OpRewritePattern<TestCastOp>::OpRewritePattern;
1085
matchAndRewrite__anon04a6666b2611::TestSelectiveOpReplacementPattern1086 LogicalResult matchAndRewrite(TestCastOp op,
1087 PatternRewriter &rewriter) const final {
1088 if (op.getNumOperands() != 2)
1089 return failure();
1090 OperandRange operands = op.getOperands();
1091
1092 // Replace non-terminator uses with the first operand.
1093 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
1094 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1095 });
1096 // Replace everything else with the second operand if the operation isn't
1097 // dead.
1098 rewriter.replaceOp(op, op.getOperand(1));
1099 return success();
1100 }
1101 };
1102
1103 struct TestSelectiveReplacementPatternDriver
1104 : public PassWrapper<TestSelectiveReplacementPatternDriver,
1105 OperationPass<>> {
getArgument__anon04a6666b2611::TestSelectiveReplacementPatternDriver1106 StringRef getArgument() const final {
1107 return "test-pattern-selective-replacement";
1108 }
getDescription__anon04a6666b2611::TestSelectiveReplacementPatternDriver1109 StringRef getDescription() const final {
1110 return "Test selective replacement in the PatternRewriter";
1111 }
runOnOperation__anon04a6666b2611::TestSelectiveReplacementPatternDriver1112 void runOnOperation() override {
1113 MLIRContext *context = &getContext();
1114 mlir::RewritePatternSet patterns(context);
1115 patterns.add<TestSelectiveOpReplacementPattern>(context);
1116 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
1117 std::move(patterns));
1118 }
1119 };
1120 } // namespace
1121
1122 //===----------------------------------------------------------------------===//
1123 // PassRegistration
1124 //===----------------------------------------------------------------------===//
1125
1126 namespace mlir {
1127 namespace test {
registerPatternsTestPass()1128 void registerPatternsTestPass() {
1129 PassRegistration<TestReturnTypeDriver>();
1130
1131 PassRegistration<TestDerivedAttributeDriver>();
1132
1133 PassRegistration<TestPatternDriver>();
1134
1135 PassRegistration<TestLegalizePatternDriver>([] {
1136 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
1137 });
1138
1139 PassRegistration<TestRemappedValue>();
1140
1141 PassRegistration<TestUnknownRootOpDriver>();
1142
1143 PassRegistration<TestTypeConversionDriver>();
1144
1145 PassRegistration<TestMergeBlocksPatternDriver>();
1146 PassRegistration<TestSelectiveReplacementPatternDriver>();
1147 }
1148 } // namespace test
1149 } // namespace mlir
1150