1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a test pass for the match reduction utility.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
printReductionResult(Operation * redRegionOp,unsigned numOutput,Value reducedValue,ArrayRef<Operation * > combinerOps)20 void printReductionResult(Operation *redRegionOp, unsigned numOutput,
21                           Value reducedValue,
22                           ArrayRef<Operation *> combinerOps) {
23   if (reducedValue) {
24     redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
25     redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
26     for (Operation *combOp : combinerOps)
27       redRegionOp->emitRemark("Combiner Op: ") << *combOp;
28 
29     return;
30   }
31 
32   redRegionOp->emitRemark("Reduction NOT found in output #")
33       << numOutput << "!";
34 }
35 
36 struct TestMatchReductionPass
37     : public PassWrapper<TestMatchReductionPass, FunctionPass> {
getArgument__anonf8a13a7e0111::TestMatchReductionPass38   StringRef getArgument() const final { return "test-match-reduction"; }
getDescription__anonf8a13a7e0111::TestMatchReductionPass39   StringRef getDescription() const final {
40     return "Test the match reduction utility.";
41   }
42 
runOnFunction__anonf8a13a7e0111::TestMatchReductionPass43   void runOnFunction() override {
44     FuncOp func = getFunction();
45     func->emitRemark("Testing function");
46 
47     func.walk<WalkOrder::PreOrder>([](Operation *op) {
48       if (isa<FuncOp>(op))
49         return;
50 
51       // Limit testing to ops with only one region.
52       if (op->getNumRegions() != 1)
53         return;
54 
55       Region &region = op->getRegion(0);
56       if (!region.hasOneBlock())
57         return;
58 
59       // We expect all the tested region ops to have 1 input by default. The
60       // remaining arguments are assumed to be outputs/reductions and there must
61       // be at least one.
62       // TODO: Extend it to support more generic cases.
63       Block &regionEntry = region.front();
64       auto args = regionEntry.getArguments();
65       if (args.size() < 2)
66         return;
67 
68       auto outputs = args.drop_front();
69       for (int i = 0, size = outputs.size(); i < size; ++i) {
70         SmallVector<Operation *, 4> combinerOps;
71         Value reducedValue = matchReduction(outputs, i, combinerOps);
72         printReductionResult(op, i, reducedValue, combinerOps);
73       }
74     });
75   }
76 };
77 
78 } // end anonymous namespace
79 
80 namespace mlir {
81 namespace test {
registerTestMatchReductionPass()82 void registerTestMatchReductionPass() {
83   PassRegistration<TestMatchReductionPass>();
84 }
85 } // namespace test
86 } // namespace mlir
87