1 //===- NumberOfExecutions.cpp - Number of executions analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of the number of executions analysis.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/NumberOfExecutions.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/RegionKindInterface.h"
17 #include "mlir/Interfaces/ControlFlowInterfaces.h"
18
19 #include "llvm/ADT/FunctionExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/Support/raw_ostream.h"
22
23 #define DEBUG_TYPE "number-of-executions-analysis"
24
25 using namespace mlir;
26
27 //===----------------------------------------------------------------------===//
28 // NumberOfExecutions
29 //===----------------------------------------------------------------------===//
30
31 /// Computes blocks number of executions information for the given region.
computeRegionBlockNumberOfExecutions(Region & region,DenseMap<Block *,BlockNumberOfExecutionsInfo> & blockInfo)32 static void computeRegionBlockNumberOfExecutions(
33 Region ®ion, DenseMap<Block *, BlockNumberOfExecutionsInfo> &blockInfo) {
34 Operation *parentOp = region.getParentOp();
35 int regionId = region.getRegionNumber();
36
37 auto regionKindInterface = dyn_cast<RegionKindInterface>(parentOp);
38 bool isGraphRegion =
39 regionKindInterface &&
40 regionKindInterface.getRegionKind(regionId) == RegionKind::Graph;
41
42 // CFG analysis does not make sense for Graph regions, set the number of
43 // executions for all blocks as unknown.
44 if (isGraphRegion) {
45 for (Block &block : region)
46 blockInfo.insert({&block, {&block, None, None}});
47 return;
48 }
49
50 // Number of region invocations for all attached regions.
51 SmallVector<int64_t, 4> numRegionsInvocations;
52
53 // Query RegionBranchOpInterface interface if it is available.
54 if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)) {
55 SmallVector<Attribute, 4> operands(parentOp->getNumOperands());
56 for (auto operandIt : llvm::enumerate(parentOp->getOperands()))
57 matchPattern(operandIt.value(), m_Constant(&operands[operandIt.index()]));
58
59 regionInterface.getNumRegionInvocations(operands, numRegionsInvocations);
60 }
61
62 // Number of region invocations *each time* parent operation is invoked.
63 Optional<int64_t> numRegionInvocations;
64
65 if (!numRegionsInvocations.empty() &&
66 numRegionsInvocations[regionId] != kUnknownNumRegionInvocations) {
67 numRegionInvocations = numRegionsInvocations[regionId];
68 }
69
70 // DFS traversal looking for loops in the CFG.
71 llvm::SmallSet<Block *, 4> loopStart;
72
73 llvm::unique_function<void(Block *, llvm::SmallSet<Block *, 4> &)> dfs =
74 [&](Block *block, llvm::SmallSet<Block *, 4> &visited) {
75 // Found a loop in the CFG that starts at the `block`.
76 if (visited.contains(block)) {
77 loopStart.insert(block);
78 return;
79 }
80
81 // Continue DFS traversal.
82 visited.insert(block);
83 for (Block *successor : block->getSuccessors())
84 dfs(successor, visited);
85 visited.erase(block);
86 };
87
88 llvm::SmallSet<Block *, 4> visited;
89 dfs(®ion.front(), visited);
90
91 // Start from the entry block and follow only blocks with single succesor.
92 Block *block = ®ion.front();
93 while (block && !loopStart.contains(block)) {
94 // Block will be executed exactly once.
95 blockInfo.insert(
96 {block, BlockNumberOfExecutionsInfo(block, numRegionInvocations,
97 /*numberOfBlockExecutions=*/1)});
98
99 // We reached the exit block or block with multiple successors.
100 if (block->getNumSuccessors() != 1)
101 break;
102
103 // Continue traversal.
104 block = block->getSuccessor(0);
105 }
106
107 // For all blocks that we did not visit set the executions number to unknown.
108 for (Block &block : region)
109 if (blockInfo.count(&block) == 0)
110 blockInfo.insert({&block, BlockNumberOfExecutionsInfo(
111 &block, numRegionInvocations,
112 /*numberOfBlockExecutions=*/None)});
113 }
114
115 /// Creates a new NumberOfExecutions analysis that computes how many times a
116 /// block within a region is executed for all associated regions.
NumberOfExecutions(Operation * op)117 NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
118 operation->walk([&](Region *region) {
119 computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
120 });
121 }
122
123 Optional<int64_t>
getNumberOfExecutions(Operation * op,Region * perEntryOfThisRegion) const124 NumberOfExecutions::getNumberOfExecutions(Operation *op,
125 Region *perEntryOfThisRegion) const {
126 // Assuming that all operations complete in a finite amount of time (do not
127 // abort and do not go into the infinite loop), the number of operation
128 // executions is equal to the number of block executions that contains the
129 // operation.
130 return getNumberOfExecutions(op->getBlock(), perEntryOfThisRegion);
131 }
132
133 Optional<int64_t>
getNumberOfExecutions(Block * block,Region * perEntryOfThisRegion) const134 NumberOfExecutions::getNumberOfExecutions(Block *block,
135 Region *perEntryOfThisRegion) const {
136 // Return None if the given `block` does not lie inside the
137 // `perEntryOfThisRegion` region.
138 if (!perEntryOfThisRegion->findAncestorBlockInRegion(*block))
139 return None;
140
141 // Find the block information for the given `block.
142 auto blockIt = blockNumbersOfExecution.find(block);
143 if (blockIt == blockNumbersOfExecution.end())
144 return None;
145 const auto &blockInfo = blockIt->getSecond();
146
147 // Override the number of region invocations with `1` if the
148 // `perEntryOfThisRegion` region owns the block.
149 auto getNumberOfExecutions = [&](const BlockNumberOfExecutionsInfo &info) {
150 if (info.getBlock()->getParent() == perEntryOfThisRegion)
151 return info.getNumberOfExecutions(/*numberOfRegionInvocations=*/1);
152 return info.getNumberOfExecutions();
153 };
154
155 // Immediately return None if we do not know the block number of executions.
156 auto blockExecutions = getNumberOfExecutions(blockInfo);
157 if (!blockExecutions.hasValue())
158 return None;
159
160 // Follow parent operations until we reach the operations that owns the
161 // `perEntryOfThisRegion`.
162 int64_t numberOfExecutions = *blockExecutions;
163 Operation *parentOp = block->getParentOp();
164
165 while (parentOp != perEntryOfThisRegion->getParentOp()) {
166 // Find how many times will be executed the block that owns the parent
167 // operation.
168 Block *parentBlock = parentOp->getBlock();
169
170 auto parentBlockIt = blockNumbersOfExecution.find(parentBlock);
171 if (parentBlockIt == blockNumbersOfExecution.end())
172 return None;
173 const auto &parentBlockInfo = parentBlockIt->getSecond();
174 auto parentBlockExecutions = getNumberOfExecutions(parentBlockInfo);
175
176 // We stumbled upon an operation with unknown number of executions.
177 if (!parentBlockExecutions.hasValue())
178 return None;
179
180 // Number of block executions is a product of all parent blocks executions.
181 numberOfExecutions *= *parentBlockExecutions;
182 parentOp = parentOp->getParentOp();
183
184 assert(parentOp != nullptr);
185 }
186
187 return numberOfExecutions;
188 }
189
printBlockExecutions(raw_ostream & os,Region * perEntryOfThisRegion) const190 void NumberOfExecutions::printBlockExecutions(
191 raw_ostream &os, Region *perEntryOfThisRegion) const {
192 unsigned blockId = 0;
193
194 operation->walk([&](Block *block) {
195 llvm::errs() << "Block: " << blockId++ << "\n";
196 llvm::errs() << "Number of executions: ";
197 if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
198 llvm::errs() << *n << "\n";
199 else
200 llvm::errs() << "<unknown>\n";
201 });
202 }
203
printOperationExecutions(raw_ostream & os,Region * perEntryOfThisRegion) const204 void NumberOfExecutions::printOperationExecutions(
205 raw_ostream &os, Region *perEntryOfThisRegion) const {
206 operation->walk([&](Block *block) {
207 block->walk([&](Operation *operation) {
208 // Skip the operation that was used to build the analysis.
209 if (operation == this->operation)
210 return;
211
212 llvm::errs() << "Operation: " << operation->getName() << "\n";
213 llvm::errs() << "Number of executions: ";
214 if (auto n = getNumberOfExecutions(operation, perEntryOfThisRegion))
215 llvm::errs() << *n << "\n";
216 else
217 llvm::errs() << "<unknown>\n";
218 });
219 });
220 }
221
222 //===----------------------------------------------------------------------===//
223 // BlockNumberOfExecutionsInfo
224 //===----------------------------------------------------------------------===//
225
BlockNumberOfExecutionsInfo(Block * block,Optional<int64_t> numberOfRegionInvocations,Optional<int64_t> numberOfBlockExecutions)226 BlockNumberOfExecutionsInfo::BlockNumberOfExecutionsInfo(
227 Block *block, Optional<int64_t> numberOfRegionInvocations,
228 Optional<int64_t> numberOfBlockExecutions)
229 : block(block), numberOfRegionInvocations(numberOfRegionInvocations),
230 numberOfBlockExecutions(numberOfBlockExecutions) {}
231
getNumberOfExecutions() const232 Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions() const {
233 if (numberOfRegionInvocations && numberOfBlockExecutions)
234 return *numberOfRegionInvocations * *numberOfBlockExecutions;
235 return None;
236 }
237
getNumberOfExecutions(int64_t numberOfRegionInvocations) const238 Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions(
239 int64_t numberOfRegionInvocations) const {
240 if (numberOfBlockExecutions)
241 return numberOfRegionInvocations * *numberOfBlockExecutions;
242 return None;
243 }
244