1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a test pass for the match reduction utility.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Pass/Pass.h"
15
16 using namespace mlir;
17
18 namespace {
19
printReductionResult(Operation * redRegionOp,unsigned numOutput,Value reducedValue,ArrayRef<Operation * > combinerOps)20 void printReductionResult(Operation *redRegionOp, unsigned numOutput,
21 Value reducedValue,
22 ArrayRef<Operation *> combinerOps) {
23 if (reducedValue) {
24 redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
25 redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
26 for (Operation *combOp : combinerOps)
27 redRegionOp->emitRemark("Combiner Op: ") << *combOp;
28
29 return;
30 }
31
32 redRegionOp->emitRemark("Reduction NOT found in output #")
33 << numOutput << "!";
34 }
35
36 struct TestMatchReductionPass
37 : public PassWrapper<TestMatchReductionPass, FunctionPass> {
getArgument__anonf8a13a7e0111::TestMatchReductionPass38 StringRef getArgument() const final { return "test-match-reduction"; }
getDescription__anonf8a13a7e0111::TestMatchReductionPass39 StringRef getDescription() const final {
40 return "Test the match reduction utility.";
41 }
42
runOnFunction__anonf8a13a7e0111::TestMatchReductionPass43 void runOnFunction() override {
44 FuncOp func = getFunction();
45 func->emitRemark("Testing function");
46
47 func.walk<WalkOrder::PreOrder>([](Operation *op) {
48 if (isa<FuncOp>(op))
49 return;
50
51 // Limit testing to ops with only one region.
52 if (op->getNumRegions() != 1)
53 return;
54
55 Region ®ion = op->getRegion(0);
56 if (!region.hasOneBlock())
57 return;
58
59 // We expect all the tested region ops to have 1 input by default. The
60 // remaining arguments are assumed to be outputs/reductions and there must
61 // be at least one.
62 // TODO: Extend it to support more generic cases.
63 Block ®ionEntry = region.front();
64 auto args = regionEntry.getArguments();
65 if (args.size() < 2)
66 return;
67
68 auto outputs = args.drop_front();
69 for (int i = 0, size = outputs.size(); i < size; ++i) {
70 SmallVector<Operation *, 4> combinerOps;
71 Value reducedValue = matchReduction(outputs, i, combinerOps);
72 printReductionResult(op, i, reducedValue, combinerOps);
73 }
74 });
75 }
76 };
77
78 } // end anonymous namespace
79
80 namespace mlir {
81 namespace test {
registerTestMatchReductionPass()82 void registerTestMatchReductionPass() {
83 PassRegistration<TestMatchReductionPass>();
84 }
85 } // namespace test
86 } // namespace mlir
87