1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "linalg-dependence-analysis"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 using llvm::dbgs;
26 
find(Value v)27 Value Aliases::find(Value v) {
28   if (v.isa<BlockArgument>())
29     return v;
30 
31   auto it = aliases.find(v);
32   if (it != aliases.end()) {
33     assert(it->getSecond().getType().isa<MemRefType>() && "Memref expected");
34     return it->getSecond();
35   }
36 
37   while (true) {
38     if (v.isa<BlockArgument>())
39       return v;
40 
41     Operation *defOp = v.getDefiningOp();
42     if (!defOp)
43       return v;
44 
45     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
46       // Collect all memory effects on `v`.
47       SmallVector<MemoryEffects::EffectInstance, 1> effects;
48       memEffect.getEffectsOnValue(v, effects);
49 
50       // If we have the 'Allocate' memory effect on `v`, then `v` should be the
51       // original buffer.
52       if (llvm::any_of(
53               effects, [](const MemoryEffects::EffectInstance &instance) {
54                 return isa<MemoryEffects::Allocate>(instance.getEffect());
55               }))
56         return v;
57     }
58 
59     if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
60       auto it =
61           aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
62       return it.first->second;
63     }
64 
65     llvm::errs() << "View alias analysis reduces to: " << v << "\n";
66     llvm_unreachable("unsupported view alias case");
67   }
68 }
69 
getDependenceTypeStr(DependenceType depType)70 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
71   switch (depType) {
72   case LinalgDependenceGraph::DependenceType::RAW:
73     return "RAW";
74   case LinalgDependenceGraph::DependenceType::RAR:
75     return "RAR";
76   case LinalgDependenceGraph::DependenceType::WAR:
77     return "WAR";
78   case LinalgDependenceGraph::DependenceType::WAW:
79     return "WAW";
80   default:
81     break;
82   }
83   llvm_unreachable("Unexpected DependenceType");
84 }
85 
86 LinalgDependenceGraph
buildDependenceGraph(Aliases & aliases,FuncOp f)87 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
88   SmallVector<Operation *, 8> linalgOps;
89   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
90   return LinalgDependenceGraph(aliases, linalgOps);
91 }
92 
LinalgDependenceGraph(Aliases & aliases,ArrayRef<Operation * > ops)93 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
94                                              ArrayRef<Operation *> ops)
95     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
96   for (auto en : llvm::enumerate(linalgOps)) {
97     assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
98     linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
99   }
100   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
101     for (unsigned j = i + 1; j < e; ++j) {
102       addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
103     }
104   }
105 }
106 
addDependenceElem(DependenceType dt,LinalgOpView indexingOpView,LinalgOpView dependentOpView)107 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
108                                               LinalgOpView indexingOpView,
109                                               LinalgOpView dependentOpView) {
110   LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
111                     << *indexingOpView.op << " -> " << *dependentOpView.op);
112   dependencesFromGraphs[dt][indexingOpView.op].push_back(
113       LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
114   dependencesIntoGraphs[dt][dependentOpView.op].push_back(
115       LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
116 }
117 
118 LinalgDependenceGraph::dependence_range
getDependencesFrom(LinalgOp src,LinalgDependenceGraph::DependenceType dt) const119 LinalgDependenceGraph::getDependencesFrom(
120     LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
121   return getDependencesFrom(src.getOperation(), dt);
122 }
123 
124 LinalgDependenceGraph::dependence_range
getDependencesFrom(Operation * src,LinalgDependenceGraph::DependenceType dt) const125 LinalgDependenceGraph::getDependencesFrom(
126     Operation *src, LinalgDependenceGraph::DependenceType dt) const {
127   auto iter = dependencesFromGraphs[dt].find(src);
128   if (iter == dependencesFromGraphs[dt].end())
129     return llvm::make_range(nullptr, nullptr);
130   return llvm::make_range(iter->second.begin(), iter->second.end());
131 }
132 
133 LinalgDependenceGraph::dependence_range
getDependencesInto(LinalgOp dst,LinalgDependenceGraph::DependenceType dt) const134 LinalgDependenceGraph::getDependencesInto(
135     LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
136   return getDependencesInto(dst.getOperation(), dt);
137 }
138 
139 LinalgDependenceGraph::dependence_range
getDependencesInto(Operation * dst,LinalgDependenceGraph::DependenceType dt) const140 LinalgDependenceGraph::getDependencesInto(
141     Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
142   auto iter = dependencesIntoGraphs[dt].find(dst);
143   if (iter == dependencesIntoGraphs[dt].end())
144     return llvm::make_range(nullptr, nullptr);
145   return llvm::make_range(iter->second.begin(), iter->second.end());
146 }
147 
addDependencesBetween(LinalgOp src,LinalgOp dst)148 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
149   assert(src.hasBufferSemantics() &&
150          "expected linalg op with buffer semantics");
151   assert(dst.hasBufferSemantics() &&
152          "expected linalg op with buffer semantics");
153   for (auto srcView : src.getOutputBuffers()) { // W
154     // RAW graph
155     for (auto dstView : dst.getInputs()) {   // R
156       if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
157         addDependenceElem(DependenceType::RAW,
158                           LinalgOpView{src.getOperation(), srcView},
159                           LinalgOpView{dst.getOperation(), dstView});
160       }
161     }
162     // WAW graph
163     for (auto dstView : dst.getOutputBuffers()) { // W
164       if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
165         addDependenceElem(DependenceType::WAW,
166                           LinalgOpView{src.getOperation(), srcView},
167                           LinalgOpView{dst.getOperation(), dstView});
168       }
169     }
170   }
171   for (auto srcView : src.getInputs()) { // R
172     // RAR graph
173     for (auto dstView : dst.getInputs()) {   // R
174       if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
175         addDependenceElem(DependenceType::RAR,
176                           LinalgOpView{src.getOperation(), srcView},
177                           LinalgOpView{dst.getOperation(), dstView});
178       }
179     }
180     // WAR graph
181     for (auto dstView : dst.getOutputBuffers()) { // W
182       if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
183         addDependenceElem(DependenceType::WAR,
184                           LinalgOpView{src.getOperation(), srcView},
185                           LinalgOpView{dst.getOperation(), dstView});
186       }
187     }
188   }
189 }
190 
191 SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp) const192 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
193                                                LinalgOp dstLinalgOp) const {
194   return findOperationsWithCoveringDependences(
195       srcLinalgOp, dstLinalgOp, nullptr,
196       {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
197 }
198 
findCoveringWrites(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const199 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
200     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
201   return findOperationsWithCoveringDependences(
202       srcLinalgOp, dstLinalgOp, view,
203       {DependenceType::WAW, DependenceType::WAR});
204 }
205 
findCoveringReads(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const206 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
207     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
208   return findOperationsWithCoveringDependences(
209       srcLinalgOp, dstLinalgOp, view,
210       {DependenceType::RAR, DependenceType::RAW});
211 }
212 
213 SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view,ArrayRef<DependenceType> types) const214 LinalgDependenceGraph::findOperationsWithCoveringDependences(
215     LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
216     ArrayRef<DependenceType> types) const {
217   auto *src = srcLinalgOp.getOperation();
218   auto *dst = dstLinalgOp.getOperation();
219   auto srcPos = linalgOpPositions.lookup(src);
220   auto dstPos = linalgOpPositions.lookup(dst);
221   assert(srcPos < dstPos && "expected dst after src in IR traversal order");
222 
223   SmallVector<Operation *, 8> res;
224   // Consider an intermediate interleaved `interim` op, look for any dependence
225   // to an aliasing view on a src -> op -> dst path.
226   // TODO: we are not considering paths yet, just interleaved positions.
227   for (auto dt : types) {
228     for (auto dependence : getDependencesFrom(src, dt)) {
229       auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op);
230       // Skip if not interleaved.
231       if (interimPos >= dstPos || interimPos <= srcPos)
232         continue;
233       if (view && !aliases.alias(view, dependence.indexingView))
234         continue;
235       auto *op = dependence.dependentOpView.op;
236       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
237                         << getDependenceTypeStr(dt) << ": " << *src << " -> "
238                         << *op << " on " << dependence.indexingView);
239       res.push_back(op);
240     }
241   }
242   return res;
243 }
244