1 //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LoopOps/LoopOps.h"
10 #include "mlir/Dialect/StandardOps/Ops.h"
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Transforms/LoopUtils.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "mlir/Transforms/RegionUtils.h"
15 #include "llvm/Support/Debug.h"
16 
17 #define PASS_NAME "loop-coalescing"
18 #define DEBUG_TYPE PASS_NAME
19 
20 using namespace mlir;
21 
22 namespace {
23 class LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> {
24 public:
runOnFunction()25   void runOnFunction() override {
26     FuncOp func = getFunction();
27 
28     func.walk([](loop::ForOp op) {
29       // Ignore nested loops.
30       if (op.getParentOfType<loop::ForOp>())
31         return;
32 
33       SmallVector<loop::ForOp, 4> loops;
34       getPerfectlyNestedLoops(loops, op);
35       LLVM_DEBUG(llvm::dbgs()
36                  << "found a perfect nest of depth " << loops.size() << '\n');
37 
38       // Look for a band of loops that can be coalesced, i.e. perfectly nested
39       // loops with bounds defined above some loop.
40       // 1. For each loop, find above which parent loop its operands are
41       // defined.
42       SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
43       for (unsigned i = 0, e = loops.size(); i < e; ++i) {
44         operandsDefinedAbove[i] = i;
45         for (unsigned j = 0; j < i; ++j) {
46           if (areValuesDefinedAbove(loops[i].getOperands(),
47                                     loops[j].region())) {
48             operandsDefinedAbove[i] = j;
49             break;
50           }
51         }
52         LLVM_DEBUG(llvm::dbgs()
53                    << "  bounds of loop " << i << " are known above depth "
54                    << operandsDefinedAbove[i] << '\n');
55       }
56 
57       // 2. Identify bands of loops such that the operands of all of them are
58       // defined above the first loop in the band.  Traverse the nest bottom-up
59       // so that modifications don't invalidate the inner loops.
60       for (unsigned end = loops.size(); end > 0; --end) {
61         unsigned start = 0;
62         for (; start < end - 1; ++start) {
63           auto maxPos =
64               *std::max_element(std::next(operandsDefinedAbove.begin(), start),
65                                 std::next(operandsDefinedAbove.begin(), end));
66           if (maxPos > start)
67             continue;
68 
69           assert(maxPos == start &&
70                  "expected loop bounds to be known at the start of the band");
71           LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
72                                   << " to " << end << '\n');
73 
74           auto band =
75               llvm::makeMutableArrayRef(loops.data() + start, end - start);
76           coalesceLoops(band);
77           break;
78         }
79         // If a band was found and transformed, keep looking at the loops above
80         // the outermost transformed loop.
81         if (start != end - 1)
82           end = start + 1;
83       }
84     });
85   }
86 };
87 
88 } // namespace
89 
createLoopCoalescingPass()90 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopCoalescingPass() {
91   return std::make_unique<LoopCoalescingPass>();
92 }
93 
94 static PassRegistration<LoopCoalescingPass>
95     reg(PASS_NAME,
96         "coalesce nested loops with independent bounds into a single loop");
97