1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19
20 #define DEBUG_TYPE "linalg-dependence-analysis"
21
22 using namespace mlir;
23 using namespace mlir::linalg;
24
25 using llvm::dbgs;
26
find(Value v)27 Value Aliases::find(Value v) {
28 if (v.isa<BlockArgument>())
29 return v;
30
31 auto it = aliases.find(v);
32 if (it != aliases.end()) {
33 assert(it->getSecond().getType().isa<MemRefType>() && "Memref expected");
34 return it->getSecond();
35 }
36
37 while (true) {
38 if (v.isa<BlockArgument>())
39 return v;
40
41 Operation *defOp = v.getDefiningOp();
42 if (!defOp)
43 return v;
44
45 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
46 // Collect all memory effects on `v`.
47 SmallVector<MemoryEffects::EffectInstance, 1> effects;
48 memEffect.getEffectsOnValue(v, effects);
49
50 // If we have the 'Allocate' memory effect on `v`, then `v` should be the
51 // original buffer.
52 if (llvm::any_of(
53 effects, [](const MemoryEffects::EffectInstance &instance) {
54 return isa<MemoryEffects::Allocate>(instance.getEffect());
55 }))
56 return v;
57 }
58
59 if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
60 auto it =
61 aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
62 return it.first->second;
63 }
64
65 llvm::errs() << "View alias analysis reduces to: " << v << "\n";
66 llvm_unreachable("unsupported view alias case");
67 }
68 }
69
getDependenceTypeStr(DependenceType depType)70 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
71 switch (depType) {
72 case LinalgDependenceGraph::DependenceType::RAW:
73 return "RAW";
74 case LinalgDependenceGraph::DependenceType::RAR:
75 return "RAR";
76 case LinalgDependenceGraph::DependenceType::WAR:
77 return "WAR";
78 case LinalgDependenceGraph::DependenceType::WAW:
79 return "WAW";
80 default:
81 break;
82 }
83 llvm_unreachable("Unexpected DependenceType");
84 }
85
86 LinalgDependenceGraph
buildDependenceGraph(Aliases & aliases,FuncOp f)87 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
88 SmallVector<Operation *, 8> linalgOps;
89 f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
90 return LinalgDependenceGraph(aliases, linalgOps);
91 }
92
LinalgDependenceGraph(Aliases & aliases,ArrayRef<Operation * > ops)93 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
94 ArrayRef<Operation *> ops)
95 : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
96 for (auto en : llvm::enumerate(linalgOps)) {
97 assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
98 linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
99 }
100 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
101 for (unsigned j = i + 1; j < e; ++j) {
102 addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
103 }
104 }
105 }
106
addDependenceElem(DependenceType dt,LinalgOpView indexingOpView,LinalgOpView dependentOpView)107 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
108 LinalgOpView indexingOpView,
109 LinalgOpView dependentOpView) {
110 LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
111 << *indexingOpView.op << " -> " << *dependentOpView.op);
112 dependencesFromGraphs[dt][indexingOpView.op].push_back(
113 LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
114 dependencesIntoGraphs[dt][dependentOpView.op].push_back(
115 LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
116 }
117
118 LinalgDependenceGraph::dependence_range
getDependencesFrom(LinalgOp src,LinalgDependenceGraph::DependenceType dt) const119 LinalgDependenceGraph::getDependencesFrom(
120 LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
121 return getDependencesFrom(src.getOperation(), dt);
122 }
123
124 LinalgDependenceGraph::dependence_range
getDependencesFrom(Operation * src,LinalgDependenceGraph::DependenceType dt) const125 LinalgDependenceGraph::getDependencesFrom(
126 Operation *src, LinalgDependenceGraph::DependenceType dt) const {
127 auto iter = dependencesFromGraphs[dt].find(src);
128 if (iter == dependencesFromGraphs[dt].end())
129 return llvm::make_range(nullptr, nullptr);
130 return llvm::make_range(iter->second.begin(), iter->second.end());
131 }
132
133 LinalgDependenceGraph::dependence_range
getDependencesInto(LinalgOp dst,LinalgDependenceGraph::DependenceType dt) const134 LinalgDependenceGraph::getDependencesInto(
135 LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
136 return getDependencesInto(dst.getOperation(), dt);
137 }
138
139 LinalgDependenceGraph::dependence_range
getDependencesInto(Operation * dst,LinalgDependenceGraph::DependenceType dt) const140 LinalgDependenceGraph::getDependencesInto(
141 Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
142 auto iter = dependencesIntoGraphs[dt].find(dst);
143 if (iter == dependencesIntoGraphs[dt].end())
144 return llvm::make_range(nullptr, nullptr);
145 return llvm::make_range(iter->second.begin(), iter->second.end());
146 }
147
addDependencesBetween(LinalgOp src,LinalgOp dst)148 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
149 assert(src.hasBufferSemantics() &&
150 "expected linalg op with buffer semantics");
151 assert(dst.hasBufferSemantics() &&
152 "expected linalg op with buffer semantics");
153 for (auto srcView : src.getOutputBuffers()) { // W
154 // RAW graph
155 for (auto dstView : dst.getInputs()) { // R
156 if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
157 addDependenceElem(DependenceType::RAW,
158 LinalgOpView{src.getOperation(), srcView},
159 LinalgOpView{dst.getOperation(), dstView});
160 }
161 }
162 // WAW graph
163 for (auto dstView : dst.getOutputBuffers()) { // W
164 if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
165 addDependenceElem(DependenceType::WAW,
166 LinalgOpView{src.getOperation(), srcView},
167 LinalgOpView{dst.getOperation(), dstView});
168 }
169 }
170 }
171 for (auto srcView : src.getInputs()) { // R
172 // RAR graph
173 for (auto dstView : dst.getInputs()) { // R
174 if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
175 addDependenceElem(DependenceType::RAR,
176 LinalgOpView{src.getOperation(), srcView},
177 LinalgOpView{dst.getOperation(), dstView});
178 }
179 }
180 // WAR graph
181 for (auto dstView : dst.getOutputBuffers()) { // W
182 if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
183 addDependenceElem(DependenceType::WAR,
184 LinalgOpView{src.getOperation(), srcView},
185 LinalgOpView{dst.getOperation(), dstView});
186 }
187 }
188 }
189 }
190
191 SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp) const192 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
193 LinalgOp dstLinalgOp) const {
194 return findOperationsWithCoveringDependences(
195 srcLinalgOp, dstLinalgOp, nullptr,
196 {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
197 }
198
findCoveringWrites(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const199 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
200 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
201 return findOperationsWithCoveringDependences(
202 srcLinalgOp, dstLinalgOp, view,
203 {DependenceType::WAW, DependenceType::WAR});
204 }
205
findCoveringReads(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const206 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
207 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
208 return findOperationsWithCoveringDependences(
209 srcLinalgOp, dstLinalgOp, view,
210 {DependenceType::RAR, DependenceType::RAW});
211 }
212
213 SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view,ArrayRef<DependenceType> types) const214 LinalgDependenceGraph::findOperationsWithCoveringDependences(
215 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
216 ArrayRef<DependenceType> types) const {
217 auto *src = srcLinalgOp.getOperation();
218 auto *dst = dstLinalgOp.getOperation();
219 auto srcPos = linalgOpPositions.lookup(src);
220 auto dstPos = linalgOpPositions.lookup(dst);
221 assert(srcPos < dstPos && "expected dst after src in IR traversal order");
222
223 SmallVector<Operation *, 8> res;
224 // Consider an intermediate interleaved `interim` op, look for any dependence
225 // to an aliasing view on a src -> op -> dst path.
226 // TODO: we are not considering paths yet, just interleaved positions.
227 for (auto dt : types) {
228 for (auto dependence : getDependencesFrom(src, dt)) {
229 auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op);
230 // Skip if not interleaved.
231 if (interimPos >= dstPos || interimPos <= srcPos)
232 continue;
233 if (view && !aliases.alias(view, dependence.indexingView))
234 continue;
235 auto *op = dependence.dependentOpView.op;
236 LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
237 << getDependenceTypeStr(dt) << ": " << *src << " -> "
238 << *op << " on " << dependence.indexingView);
239 res.push_back(op);
240 }
241 }
242 return res;
243 }
244