1 //===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
10 #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
11 
12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpDefinition.h"
15 
16 namespace mlir {
17 class FuncOp;
18 
19 namespace linalg {
20 
21 class LinalgOp;
22 
23 /// A very primitive alias analysis which just records for each view, either:
24 ///   1. The base buffer, or
25 ///   2. The block argument view
26 /// that it indexes into.
27 /// This does not perform inter-block or inter-procedural analysis and assumes
28 /// that different block argument views do not alias.
29 class Aliases {
30 public:
31   /// Returns true if v1 and v2 alias.
alias(Value v1,Value v2)32   bool alias(Value v1, Value v2) { return find(v1) == find(v2); }
33 
34 private:
35   /// Returns the base buffer or block argument into which the view `v` aliases.
36   /// This lazily records the new aliases discovered while walking back the
37   /// use-def chain.
38   Value find(Value v);
39 
40   DenseMap<Value, Value> aliases;
41 };
42 
43 /// Data structure for holding a dependence graph that operates on LinalgOp and
44 /// views as SSA values.
45 class LinalgDependenceGraph {
46 public:
47   enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
48   // TODO: OpOperand tracks dependencies on buffer operands. Tensor result will
49   // need an extension to use OpResult.
50   struct LinalgDependenceGraphElem {
51     using OpView = PointerUnion<OpOperand *, Value>;
52     // dependentOpView may be either:
53     //   1. src in the case of dependencesIntoGraphs.
54     //   2. dst in the case of dependencesFromDstGraphs.
55     OpView dependentOpView;
56     // View in the op that is used to index in the graph:
57     //   1. src in the case of dependencesFromDstGraphs.
58     //   2. dst in the case of dependencesIntoGraphs.
59     OpView indexingOpView;
60     // Type of the dependence.
61     DependenceType dependenceType;
62 
63     // Return the Operation that owns the operand or result represented in
64     // `opView`.
getOwnerLinalgDependenceGraphElem65     static Operation *getOwner(OpView opView) {
66       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
67         return operand->getOwner();
68       return opView.get<Value>().cast<OpResult>().getOwner();
69     }
70     // Return the operand or the result Value represented by the `opView`.
getValueLinalgDependenceGraphElem71     static Value getValue(OpView opView) {
72       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
73         return operand->get();
74       return opView.get<Value>();
75     }
76     // Return the indexing map of the operand/result in `opView` specified in
77     // the owning LinalgOp. If the owner is not a LinalgOp returns llvm::None.
getIndexingMapLinalgDependenceGraphElem78     static Optional<AffineMap> getIndexingMap(OpView opView) {
79       auto owner = dyn_cast<LinalgOp>(getOwner(opView));
80       if (!owner)
81         return llvm::None;
82       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
83         return owner.getTiedIndexingMap(operand);
84       return owner.getTiedIndexingMap(owner.getOutputOperand(
85           opView.get<Value>().cast<OpResult>().getResultNumber()));
86     }
87     // Return the operand number if the `opView` is an OpOperand *. Otherwise
88     // return llvm::None.
getOperandNumberLinalgDependenceGraphElem89     static Optional<unsigned> getOperandNumber(OpView opView) {
90       if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
91         return operand->getOperandNumber();
92       return llvm::None;
93     }
94     // Return the result number if the `opView` is an OpResult. Otherwise return
95     // llvm::None.
getResultNumberLinalgDependenceGraphElem96     static Optional<unsigned> getResultNumber(OpView opView) {
97       if (OpResult result = opView.dyn_cast<Value>().cast<OpResult>())
98         return result.getResultNumber();
99       return llvm::None;
100     }
101 
102     // Return the owner of the dependent OpView.
getDependentOpLinalgDependenceGraphElem103     Operation *getDependentOp() const { return getOwner(dependentOpView); }
104 
105     // Return the owner of the indexing OpView.
getIndexingOpLinalgDependenceGraphElem106     Operation *getIndexingOp() const { return getOwner(indexingOpView); }
107 
108     // Return the operand or result stored in the dependentOpView.
getDependentValueLinalgDependenceGraphElem109     Value getDependentValue() const { return getValue(dependentOpView); }
110 
111     // Return the operand or result stored in the indexingOpView.
getIndexingValueLinalgDependenceGraphElem112     Value getIndexingValue() const { return getValue(indexingOpView); }
113 
114     // If the dependent OpView is an operand, return operand number. Return
115     // llvm::None otherwise.
getDependentOpViewOperandNumLinalgDependenceGraphElem116     Optional<unsigned> getDependentOpViewOperandNum() const {
117       return getOperandNumber(dependentOpView);
118     }
119 
120     // If the indexing OpView is an operand, return operand number. Return
121     // llvm::None otherwise.
getIndexingOpViewOperandNumLinalgDependenceGraphElem122     Optional<unsigned> getIndexingOpViewOperandNum() const {
123       return getOperandNumber(indexingOpView);
124     }
125 
126     // If the dependent OpView is a result value, return the result
127     // number. Return llvm::None otherwise.
getDependentOpViewResultNumLinalgDependenceGraphElem128     Optional<unsigned> getDependentOpViewResultNum() const {
129       return getResultNumber(dependentOpView);
130     }
131 
132     // If the dependent OpView is a result value, return the result
133     // number. Return llvm::None otherwise.
getIndexingOpViewResultNumLinalgDependenceGraphElem134     Optional<unsigned> getIndexingOpViewResultNum() const {
135       return getResultNumber(indexingOpView);
136     }
137 
138     // Return the indexing map of the operand/result in the dependent OpView as
139     // specified in the owner of the OpView.
getDependentOpViewIndexingMapLinalgDependenceGraphElem140     Optional<AffineMap> getDependentOpViewIndexingMap() const {
141       return getIndexingMap(dependentOpView);
142     }
143 
144     // Return the indexing map of the operand/result in the indexing OpView as
145     // specified in the owner of the OpView.
getIndexingOpViewIndexingMapLinalgDependenceGraphElem146     Optional<AffineMap> getIndexingOpViewIndexingMap() const {
147       return getIndexingMap(indexingOpView);
148     }
149   };
150   using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
151   using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
152   using dependence_iterator = LinalgDependences::const_iterator;
153   using dependence_range = iterator_range<dependence_iterator>;
154 
155   static StringRef getDependenceTypeStr(DependenceType depType);
156 
157   // Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
158   static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
159   LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
160 
161   /// Returns the X such that op -> X is a dependence of type dt.
162   dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
163   dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const;
164 
165   /// Returns the X such that X -> op is a dependence of type dt.
166   dependence_range getDependencesInto(Operation *dst, DependenceType dt) const;
167   dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const;
168 
169   /// Returns the operations that are interleaved between `srcLinalgOp` and
170   /// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence
171   /// relation with `srcLinalgOp`, on any view.
172   /// Any such operation prevents reordering.
173   SmallVector<Operation *, 8>
174   findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const;
175 
176   /// Returns the operations that are interleaved between `srcLinalgOp` and
177   /// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`.
178   /// Dependences are restricted to views aliasing `view`.
179   SmallVector<Operation *, 8> findCoveringReads(LinalgOp srcLinalgOp,
180                                                 LinalgOp dstLinalgOp,
181                                                 Value view) const;
182 
183   /// Returns the operations that are interleaved between `srcLinalgOp` and
184   /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
185   /// Dependences are restricted to views aliasing `view`.
186   SmallVector<Operation *, 8> findCoveringWrites(LinalgOp srcLinalgOp,
187                                                  LinalgOp dstLinalgOp,
188                                                  Value view) const;
189 
190   /// Returns true if the two operations have the specified dependence from
191   /// `srcLinalgOp` to `dstLinalgOp`.
192   bool hasDependenceFrom(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
193                          ArrayRef<DependenceType> depTypes = {
194                              DependenceType::RAW, DependenceType::WAW}) const;
195 
196   /// Returns true if the `linalgOp` has dependences into it.
197   bool hasDependentOperationsInto(LinalgOp linalgOp,
198                                   ArrayRef<DependenceType> depTypes = {
199                                       DependenceType::RAW,
200                                       DependenceType::WAW}) const;
201 
202   /// Returns true if the `linalgOp` has dependences from it.
203   bool hasDependentOperationsFrom(LinalgOp linalgOp,
204                                   ArrayRef<DependenceType> depTypes = {
205                                       DependenceType::RAW,
206                                       DependenceType::WAW}) const;
207 
208   /// Returns true if the `linalgOp` has dependences into or from it.
209   bool hasDependentOperations(LinalgOp linalgOp,
210                               ArrayRef<DependenceType> depTypes = {
211                                   DependenceType::RAW,
212                                   DependenceType::WAW}) const;
213 
214   /// Returns all operations that have a dependence into `linalgOp` of types
215   /// listed in `depTypes`.
216   SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsInto(
217       LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
218                              DependenceType::RAW, DependenceType::WAW}) const;
219 
220   /// Returns all operations that have a dependence from `linalgOp` of types
221   /// listed in `depTypes`.
222   SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsFrom(
223       LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
224                              DependenceType::RAW, DependenceType::WAW}) const;
225 
226   /// Returns all dependent operations (into and from) given `operation`.
227   SmallVector<LinalgDependenceGraphElem, 2>
228   getDependentOperations(LinalgOp linalgOp,
229                          ArrayRef<DependenceType> depTypes = {
230                              DependenceType::RAW, DependenceType::WAW}) const;
231 
232   void print(raw_ostream &os) const;
233 
234   void dump() const;
235 
236 private:
237   // Keep dependences in both directions, this is not just a performance gain
238   // but it also reduces usage errors.
239   // Dependence information is stored as a map of:
240   //   (source operation -> LinalgDependenceGraphElem)
241   DependenceGraph dependencesFromGraphs[DependenceType::NumTypes];
242   // Reverse dependence information is stored as a map of:
243   //   (destination operation -> LinalgDependenceGraphElem)
244   DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes];
245 
246   /// Analyses the aliasing views between `src` and `dst` and inserts the proper
247   /// dependences in the graph.
248   void addDependencesBetween(LinalgOp src, LinalgOp dst);
249 
250   // Adds an new dependence unit in the proper graph.
251   // Uses std::pair to keep operations and view together and avoid usage errors
252   // related to src/dst and producer/consumer terminology in the context of
253   // dependences.
254   void addDependenceElem(DependenceType dt,
255                          LinalgDependenceGraphElem::OpView indexingOpView,
256                          LinalgDependenceGraphElem::OpView dependentOpView);
257 
258   /// Implementation detail for findCoveringxxx.
259   SmallVector<Operation *, 8>
260   findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
261                                         LinalgOp dstLinalgOp, Value view,
262                                         ArrayRef<DependenceType> types) const;
263 
264   Aliases &aliases;
265   SmallVector<LinalgOp, 8> linalgOps;
266   DenseMap<Operation *, unsigned> linalgOpPositions;
267 };
268 } // namespace linalg
269 } // namespace mlir
270 
271 #endif // MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
272