1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/AffineAnalysis.h"
14 #include "mlir/Analysis/AffineStructures.h"
15 #include "mlir/Analysis/Passes.h"
16 #include "mlir/Analysis/Utils.h"
17 #include "mlir/Dialect/AffineOps/AffineOps.h"
18 #include "mlir/Dialect/StandardOps/Ops.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/LoopFusionUtils.h"
22 #include "mlir/Transforms/Passes.h"
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27
28 #define DEBUG_TYPE "test-loop-fusion"
29
30 using namespace mlir;
31
32 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
33
34 static llvm::cl::opt<bool> clTestDependenceCheck(
35 "test-loop-fusion-dependence-check",
36 llvm::cl::desc("Enable testing of loop fusion dependence check"),
37 llvm::cl::cat(clOptionsCategory));
38
39 static llvm::cl::opt<bool> clTestSliceComputation(
40 "test-loop-fusion-slice-computation",
41 llvm::cl::desc("Enable testing of loop fusion slice computation"),
42 llvm::cl::cat(clOptionsCategory));
43
44 namespace {
45
46 struct TestLoopFusion : public FunctionPass<TestLoopFusion> {
47 void runOnFunction() override;
48 };
49
50 } // end anonymous namespace
51
createTestLoopFusionPass()52 std::unique_ptr<OpPassBase<FuncOp>> mlir::createTestLoopFusionPass() {
53 return std::make_unique<TestLoopFusion>();
54 }
55
56 // Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'.
57 static void
gatherLoops(Block * block,unsigned currLoopDepth,DenseMap<unsigned,SmallVector<AffineForOp,2>> & depthToLoops)58 gatherLoops(Block *block, unsigned currLoopDepth,
59 DenseMap<unsigned, SmallVector<AffineForOp, 2>> &depthToLoops) {
60 auto &loopsAtDepth = depthToLoops[currLoopDepth];
61 for (auto &op : *block) {
62 if (auto forOp = dyn_cast<AffineForOp>(op)) {
63 loopsAtDepth.push_back(forOp);
64 gatherLoops(forOp.getBody(), currLoopDepth + 1, depthToLoops);
65 }
66 }
67 }
68
69 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
70 // in range ['loopDepth' + 1, 'maxLoopDepth'].
71 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
testDependenceCheck(SmallVector<AffineForOp,2> & loops,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)72 static void testDependenceCheck(SmallVector<AffineForOp, 2> &loops, unsigned i,
73 unsigned j, unsigned loopDepth,
74 unsigned maxLoopDepth) {
75 AffineForOp srcForOp = loops[i];
76 AffineForOp dstForOp = loops[j];
77 mlir::ComputationSliceState sliceUnion;
78 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
79 FusionResult result =
80 mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
81 if (result.value == FusionResult::FailBlockDependence) {
82 srcForOp.getOperation()->emitRemark("block-level dependence preventing"
83 " fusion of loop nest ")
84 << i << " into loop nest " << j << " at depth " << loopDepth;
85 }
86 }
87 }
88
89 // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)90 static unsigned getBlockIndex(Operation &op) {
91 unsigned index = 0;
92 for (auto &opX : *op.getBlock()) {
93 if (&op == &opX)
94 break;
95 ++index;
96 }
97 return index;
98 }
99
100 // Returns a string representation of 'sliceUnion'.
getSliceStr(const mlir::ComputationSliceState & sliceUnion)101 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
102 std::string result;
103 llvm::raw_string_ostream os(result);
104 // Slice insertion point format [loop-depth, operation-block-index]
105 unsigned ipd = getNestingDepth(*sliceUnion.insertPoint);
106 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
107 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
108 << ")";
109 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
110 os << " loop bounds: ";
111 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
112 os << '[';
113 sliceUnion.lbs[k].print(os);
114 os << ", ";
115 sliceUnion.ubs[k].print(os);
116 os << "] ";
117 }
118 return os.str();
119 }
120
121 // Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
122 // in range ['loopDepth' + 1, 'maxLoopDepth'].
123 // Emits a string representation of the slice union as a remark on 'loops[j]'.
testSliceComputation(SmallVector<AffineForOp,2> & loops,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)124 static void testSliceComputation(SmallVector<AffineForOp, 2> &loops, unsigned i,
125 unsigned j, unsigned loopDepth,
126 unsigned maxLoopDepth) {
127 AffineForOp forOpA = loops[i];
128 AffineForOp forOpB = loops[j];
129 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
130 mlir::ComputationSliceState sliceUnion;
131 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
132 if (result.value == FusionResult::Success) {
133 forOpB.getOperation()->emitRemark("slice (")
134 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
135 << " : " << getSliceStr(sliceUnion) << ")";
136 }
137 }
138 }
139
runOnFunction()140 void TestLoopFusion::runOnFunction() {
141 // Gather all AffineForOps by loop depth.
142 DenseMap<unsigned, SmallVector<AffineForOp, 2>> depthToLoops;
143 for (auto &block : getFunction()) {
144 gatherLoops(&block, /*currLoopDepth=*/0, depthToLoops);
145 }
146
147 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
148 for (auto &depthAndLoops : depthToLoops) {
149 unsigned loopDepth = depthAndLoops.first;
150 auto &loops = depthAndLoops.second;
151 unsigned numLoops = loops.size();
152 for (unsigned j = 0; j < numLoops; ++j) {
153 for (unsigned k = 0; k < numLoops; ++k) {
154 if (j == k)
155 continue;
156 if (clTestDependenceCheck)
157 testDependenceCheck(loops, j, k, loopDepth, depthToLoops.size());
158 if (clTestSliceComputation)
159 testSliceComputation(loops, j, k, loopDepth, depthToLoops.size());
160 }
161 }
162 }
163 }
164
165 static PassRegistration<TestLoopFusion>
166 pass("test-loop-fusion", "Tests loop fusion utility functions.");
167