1 //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/Transforms/LoopUtils.h"
12 #include "mlir/Transforms/Passes.h"
13 #include "mlir/Transforms/RegionUtils.h"
14 #include "llvm/Support/Debug.h"
15 
16 #define PASS_NAME "loop-coalescing"
17 #define DEBUG_TYPE PASS_NAME
18 
19 using namespace mlir;
20 
21 namespace {
22 struct LoopCoalescingPass : public LoopCoalescingBase<LoopCoalescingPass> {
runOnFunction__anon8324cc070111::LoopCoalescingPass23   void runOnFunction() override {
24     FuncOp func = getFunction();
25 
26     func.walk([](scf::ForOp op) {
27       // Ignore nested loops.
28       if (op.getParentOfType<scf::ForOp>())
29         return;
30 
31       SmallVector<scf::ForOp, 4> loops;
32       getPerfectlyNestedLoops(loops, op);
33       LLVM_DEBUG(llvm::dbgs()
34                  << "found a perfect nest of depth " << loops.size() << '\n');
35 
36       // Look for a band of loops that can be coalesced, i.e. perfectly nested
37       // loops with bounds defined above some loop.
38       // 1. For each loop, find above which parent loop its operands are
39       // defined.
40       SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
41       for (unsigned i = 0, e = loops.size(); i < e; ++i) {
42         operandsDefinedAbove[i] = i;
43         for (unsigned j = 0; j < i; ++j) {
44           if (areValuesDefinedAbove(loops[i].getOperands(),
45                                     loops[j].region())) {
46             operandsDefinedAbove[i] = j;
47             break;
48           }
49         }
50         LLVM_DEBUG(llvm::dbgs()
51                    << "  bounds of loop " << i << " are known above depth "
52                    << operandsDefinedAbove[i] << '\n');
53       }
54 
55       // 2. Identify bands of loops such that the operands of all of them are
56       // defined above the first loop in the band.  Traverse the nest bottom-up
57       // so that modifications don't invalidate the inner loops.
58       for (unsigned end = loops.size(); end > 0; --end) {
59         unsigned start = 0;
60         for (; start < end - 1; ++start) {
61           auto maxPos =
62               *std::max_element(std::next(operandsDefinedAbove.begin(), start),
63                                 std::next(operandsDefinedAbove.begin(), end));
64           if (maxPos > start)
65             continue;
66 
67           assert(maxPos == start &&
68                  "expected loop bounds to be known at the start of the band");
69           LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
70                                   << " to " << end << '\n');
71 
72           auto band =
73               llvm::makeMutableArrayRef(loops.data() + start, end - start);
74           coalesceLoops(band);
75           break;
76         }
77         // If a band was found and transformed, keep looking at the loops above
78         // the outermost transformed loop.
79         if (start != end - 1)
80           end = start + 1;
81       }
82     });
83   }
84 };
85 
86 } // namespace
87 
createLoopCoalescingPass()88 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopCoalescingPass() {
89   return std::make_unique<LoopCoalescingPass>();
90 }
91