1 //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/LoopOps/LoopOps.h"
10 #include "mlir/Dialect/StandardOps/Ops.h"
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Transforms/LoopUtils.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "mlir/Transforms/RegionUtils.h"
15 #include "llvm/Support/Debug.h"
16
17 #define PASS_NAME "loop-coalescing"
18 #define DEBUG_TYPE PASS_NAME
19
20 using namespace mlir;
21
22 namespace {
23 class LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> {
24 public:
runOnFunction()25 void runOnFunction() override {
26 FuncOp func = getFunction();
27
28 func.walk([](loop::ForOp op) {
29 // Ignore nested loops.
30 if (op.getParentOfType<loop::ForOp>())
31 return;
32
33 SmallVector<loop::ForOp, 4> loops;
34 getPerfectlyNestedLoops(loops, op);
35 LLVM_DEBUG(llvm::dbgs()
36 << "found a perfect nest of depth " << loops.size() << '\n');
37
38 // Look for a band of loops that can be coalesced, i.e. perfectly nested
39 // loops with bounds defined above some loop.
40 // 1. For each loop, find above which parent loop its operands are
41 // defined.
42 SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
43 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
44 operandsDefinedAbove[i] = i;
45 for (unsigned j = 0; j < i; ++j) {
46 if (areValuesDefinedAbove(loops[i].getOperands(),
47 loops[j].region())) {
48 operandsDefinedAbove[i] = j;
49 break;
50 }
51 }
52 LLVM_DEBUG(llvm::dbgs()
53 << " bounds of loop " << i << " are known above depth "
54 << operandsDefinedAbove[i] << '\n');
55 }
56
57 // 2. Identify bands of loops such that the operands of all of them are
58 // defined above the first loop in the band. Traverse the nest bottom-up
59 // so that modifications don't invalidate the inner loops.
60 for (unsigned end = loops.size(); end > 0; --end) {
61 unsigned start = 0;
62 for (; start < end - 1; ++start) {
63 auto maxPos =
64 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
65 std::next(operandsDefinedAbove.begin(), end));
66 if (maxPos > start)
67 continue;
68
69 assert(maxPos == start &&
70 "expected loop bounds to be known at the start of the band");
71 LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start
72 << " to " << end << '\n');
73
74 auto band =
75 llvm::makeMutableArrayRef(loops.data() + start, end - start);
76 coalesceLoops(band);
77 break;
78 }
79 // If a band was found and transformed, keep looking at the loops above
80 // the outermost transformed loop.
81 if (start != end - 1)
82 end = start + 1;
83 }
84 });
85 }
86 };
87
88 } // namespace
89
createLoopCoalescingPass()90 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopCoalescingPass() {
91 return std::make_unique<LoopCoalescingPass>();
92 }
93
94 static PassRegistration<LoopCoalescingPass>
95 reg(PASS_NAME,
96 "coalesce nested loops with independent bounds into a single loop");
97