1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/AffineOps/AffineOps.h"
15 #include "mlir/Dialect/LoopOps/LoopOps.h"
16 #include "mlir/IR/Function.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/Functional.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Support/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22
23 ///
24 /// Implements Analysis functions specific to slicing in Function.
25 ///
26
27 using namespace mlir;
28
29 using llvm::SetVector;
30
getForwardSliceImpl(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)31 static void getForwardSliceImpl(Operation *op,
32 SetVector<Operation *> *forwardSlice,
33 TransitiveFilter filter) {
34 if (!op) {
35 return;
36 }
37
38 // Evaluate whether we should keep this use.
39 // This is useful in particular to implement scoping; i.e. return the
40 // transitive forwardSlice in the current scope.
41 if (!filter(op)) {
42 return;
43 }
44
45 if (auto forOp = dyn_cast<AffineForOp>(op)) {
46 for (auto *ownerInst : forOp.getInductionVar().getUsers())
47 if (forwardSlice->count(ownerInst) == 0)
48 getForwardSliceImpl(ownerInst, forwardSlice, filter);
49 } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
50 for (auto *ownerInst : forOp.getInductionVar().getUsers())
51 if (forwardSlice->count(ownerInst) == 0)
52 getForwardSliceImpl(ownerInst, forwardSlice, filter);
53 } else {
54 assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
55 assert(op->getNumResults() <= 1 && "unexpected multiple results");
56 if (op->getNumResults() > 0) {
57 for (auto *ownerInst : op->getResult(0).getUsers())
58 if (forwardSlice->count(ownerInst) == 0)
59 getForwardSliceImpl(ownerInst, forwardSlice, filter);
60 }
61 }
62
63 forwardSlice->insert(op);
64 }
65
getForwardSlice(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)66 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
67 TransitiveFilter filter) {
68 getForwardSliceImpl(op, forwardSlice, filter);
69 // Don't insert the top level operation, we just queried on it and don't
70 // want it in the results.
71 forwardSlice->remove(op);
72
73 // Reverse to get back the actual topological order.
74 // std::reverse does not work out of the box on SetVector and I want an
75 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
76 std::vector<Operation *> v(forwardSlice->takeVector());
77 forwardSlice->insert(v.rbegin(), v.rend());
78 }
79
getBackwardSliceImpl(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)80 static void getBackwardSliceImpl(Operation *op,
81 SetVector<Operation *> *backwardSlice,
82 TransitiveFilter filter) {
83 if (!op)
84 return;
85
86 assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
87 isa<loop::ForOp>(op)) &&
88 "unexpected generic op with regions");
89
90 // Evaluate whether we should keep this def.
91 // This is useful in particular to implement scoping; i.e. return the
92 // transitive forwardSlice in the current scope.
93 if (!filter(op)) {
94 return;
95 }
96
97 for (auto en : llvm::enumerate(op->getOperands())) {
98 auto operand = en.value();
99 if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
100 if (auto affIv = getForInductionVarOwner(operand)) {
101 auto *affOp = affIv.getOperation();
102 if (backwardSlice->count(affOp) == 0)
103 getBackwardSliceImpl(affOp, backwardSlice, filter);
104 } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
105 auto *loopOp = loopIv.getOperation();
106 if (backwardSlice->count(loopOp) == 0)
107 getBackwardSliceImpl(loopOp, backwardSlice, filter);
108 } else if (blockArg.getOwner() !=
109 &op->getParentOfType<FuncOp>().getBody().front()) {
110 op->emitError("unsupported CF for operand ") << en.index();
111 llvm_unreachable("Unsupported control flow");
112 }
113 continue;
114 }
115 auto *op = operand.getDefiningOp();
116 if (backwardSlice->count(op) == 0) {
117 getBackwardSliceImpl(op, backwardSlice, filter);
118 }
119 }
120
121 backwardSlice->insert(op);
122 }
123
getBackwardSlice(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)124 void mlir::getBackwardSlice(Operation *op,
125 SetVector<Operation *> *backwardSlice,
126 TransitiveFilter filter) {
127 getBackwardSliceImpl(op, backwardSlice, filter);
128
129 // Don't insert the top level operation, we just queried on it and don't
130 // want it in the results.
131 backwardSlice->remove(op);
132 }
133
getSlice(Operation * op,TransitiveFilter backwardFilter,TransitiveFilter forwardFilter)134 SetVector<Operation *> mlir::getSlice(Operation *op,
135 TransitiveFilter backwardFilter,
136 TransitiveFilter forwardFilter) {
137 SetVector<Operation *> slice;
138 slice.insert(op);
139
140 unsigned currentIndex = 0;
141 SetVector<Operation *> backwardSlice;
142 SetVector<Operation *> forwardSlice;
143 while (currentIndex != slice.size()) {
144 auto *currentInst = (slice)[currentIndex];
145 // Compute and insert the backwardSlice starting from currentInst.
146 backwardSlice.clear();
147 getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
148 slice.insert(backwardSlice.begin(), backwardSlice.end());
149
150 // Compute and insert the forwardSlice starting from currentInst.
151 forwardSlice.clear();
152 getForwardSlice(currentInst, &forwardSlice, forwardFilter);
153 slice.insert(forwardSlice.begin(), forwardSlice.end());
154 ++currentIndex;
155 }
156 return topologicalSort(slice);
157 }
158
159 namespace {
160 /// DFS post-order implementation that maintains a global count to work across
161 /// multiple invocations, to help implement topological sort on multi-root DAGs.
162 /// We traverse all operations but only record the ones that appear in
163 /// `toSort` for the final result.
164 struct DFSState {
DFSState__anon63f9c0190111::DFSState165 DFSState(const SetVector<Operation *> &set)
166 : toSort(set), topologicalCounts(), seen() {}
167 const SetVector<Operation *> &toSort;
168 SmallVector<Operation *, 16> topologicalCounts;
169 DenseSet<Operation *> seen;
170 };
171 } // namespace
172
DFSPostorder(Operation * current,DFSState * state)173 static void DFSPostorder(Operation *current, DFSState *state) {
174 assert(current->getNumResults() <= 1 && "NYI: multi-result");
175 if (current->getNumResults() > 0) {
176 for (auto &u : current->getResult(0).getUses()) {
177 auto *op = u.getOwner();
178 DFSPostorder(op, state);
179 }
180 }
181 bool inserted;
182 using IterTy = decltype(state->seen.begin());
183 IterTy iter;
184 std::tie(iter, inserted) = state->seen.insert(current);
185 if (inserted) {
186 if (state->toSort.count(current) > 0) {
187 state->topologicalCounts.push_back(current);
188 }
189 }
190 }
191
192 SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)193 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
194 if (toSort.empty()) {
195 return toSort;
196 }
197
198 // Run from each root with global count and `seen` set.
199 DFSState state(toSort);
200 for (auto *s : toSort) {
201 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
202 DFSPostorder(s, &state);
203 }
204
205 // Reorder and return.
206 SetVector<Operation *> res;
207 for (auto it = state.topologicalCounts.rbegin(),
208 eit = state.topologicalCounts.rend();
209 it != eit; ++it) {
210 res.insert(*it);
211 }
212 return res;
213 }
214