1 //===- DataFlowAnalysis.h - General DataFlow Analysis Utilities -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file has several utilities and algorithms that perform abstract dataflow
10 // analysis over the IR. These allow for users to hook into various analysis
11 // propagation algorithms without needing to reinvent the traversal over the
12 // different types of control structures present within MLIR, such as regions,
13 // the callgraph, etc. A few of the main entry points are detailed below:
14 //
15 // FowardDataFlowAnalysis:
16 //  This class provides support for defining dataflow algorithms that are
17 //  forward, sparse, pessimistic (except along unreached backedges) and
18 //  context-insensitive for the interprocedural aspects.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H
23 #define MLIR_ANALYSIS_DATAFLOWANALYSIS_H
24 
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/ControlFlowInterfaces.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/Optional.h"
29 #include "llvm/Support/Allocator.h"
30 
31 namespace mlir {
32 //===----------------------------------------------------------------------===//
33 // ChangeResult
34 //===----------------------------------------------------------------------===//
35 
36 /// A result type used to indicate if a change happened. Boolean operations on
37 /// ChangeResult behave as though `Change` is truthy.
38 enum class ChangeResult {
39   NoChange,
40   Change,
41 };
42 inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
43   return lhs == ChangeResult::Change ? lhs : rhs;
44 }
45 inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
46   lhs = lhs | rhs;
47   return lhs;
48 }
49 inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
50   return lhs == ChangeResult::NoChange ? lhs : rhs;
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // AbstractLatticeElement
55 //===----------------------------------------------------------------------===//
56 
57 namespace detail {
58 /// This class represents an abstract lattice. A lattice is what gets propagated
59 /// across the IR, and contains the information for a specific Value.
60 class AbstractLatticeElement {
61 public:
62   virtual ~AbstractLatticeElement();
63 
64   /// Returns true if the value of this lattice is uninitialized, meaning that
65   /// it hasn't yet been initialized.
66   virtual bool isUninitialized() const = 0;
67 
68   /// Join the information contained in 'rhs' into this lattice. Returns
69   /// if the value of the lattice changed.
70   virtual ChangeResult join(const AbstractLatticeElement &rhs) = 0;
71 
72   /// Mark the lattice element as having reached a pessimistic fixpoint. This
73   /// means that the lattice may potentially have conflicting value states, and
74   /// only the most conservative value should be relied on.
75   virtual ChangeResult markPessimisticFixpoint() = 0;
76 
77   /// Mark the lattice element as having reached an optimistic fixpoint. This
78   /// means that we optimistically assume the current value is the true state.
79   virtual void markOptimisticFixpoint() = 0;
80 
81   /// Returns true if the lattice has reached a fixpoint. A fixpoint is when the
82   /// information optimistically assumed to be true is the same as the
83   /// information known to be true.
84   virtual bool isAtFixpoint() const = 0;
85 };
86 } // namespace detail
87 
88 //===----------------------------------------------------------------------===//
89 // LatticeElement
90 //===----------------------------------------------------------------------===//
91 
92 /// This class represents a lattice holding a specific value of type `ValueT`.
93 /// Lattice values (`ValueT`) are required to adhere to the following:
94 ///   * static ValueT join(const ValueT &lhs, const ValueT &rhs);
95 ///     - This method conservatively joins the information held by `lhs`
96 ///       and `rhs` into a new value. This method is required to be monotonic.
97 ///   * static ValueT getPessimisticValueState(MLIRContext *context);
98 ///     - This method computes a pessimistic/conservative value state assuming
99 ///       no information about the state of the IR.
100 ///   * static ValueT getPessimisticValueState(Value value);
101 ///     - This method computes a pessimistic/conservative value state for
102 ///       `value` assuming only information present in the current IR.
103 ///   * bool operator==(const ValueT &rhs) const;
104 ///
105 template <typename ValueT>
106 class LatticeElement final : public detail::AbstractLatticeElement {
107 public:
108   LatticeElement() = delete;
LatticeElement(const ValueT & knownValue)109   LatticeElement(const ValueT &knownValue) : knownValue(knownValue) {}
110 
111   /// Return the value held by this lattice. This requires that the value is
112   /// initialized.
getValue()113   ValueT &getValue() {
114     assert(!isUninitialized() && "expected known lattice element");
115     return *optimisticValue;
116   }
getValue()117   const ValueT &getValue() const {
118     assert(!isUninitialized() && "expected known lattice element");
119     return *optimisticValue;
120   }
121 
122   /// Returns true if the value of this lattice hasn't yet been initialized.
isUninitialized()123   bool isUninitialized() const final { return !optimisticValue.hasValue(); }
124 
125   /// Join the information contained in the 'rhs' lattice into this
126   /// lattice. Returns if the state of the current lattice changed.
join(const detail::AbstractLatticeElement & rhs)127   ChangeResult join(const detail::AbstractLatticeElement &rhs) final {
128     const LatticeElement<ValueT> &rhsLattice =
129         static_cast<const LatticeElement<ValueT> &>(rhs);
130 
131     // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
132     if (isAtFixpoint() || rhsLattice.isUninitialized())
133       return ChangeResult::NoChange;
134 
135     // Join the rhs value into this lattice.
136     return join(rhsLattice.getValue());
137   }
138 
139   /// Join the information contained in the 'rhs' value into this
140   /// lattice. Returns if the state of the current lattice changed.
join(const ValueT & rhs)141   ChangeResult join(const ValueT &rhs) {
142     // If the current lattice is uninitialized, copy the rhs value.
143     if (isUninitialized()) {
144       optimisticValue = rhs;
145       return ChangeResult::Change;
146     }
147 
148     // Otherwise, join rhs with the current optimistic value.
149     ValueT newValue = ValueT::join(*optimisticValue, rhs);
150     assert(ValueT::join(newValue, *optimisticValue) == newValue &&
151            "expected `join` to be monotonic");
152     assert(ValueT::join(newValue, rhs) == newValue &&
153            "expected `join` to be monotonic");
154 
155     // Update the current optimistic value if something changed.
156     if (newValue == optimisticValue)
157       return ChangeResult::NoChange;
158 
159     optimisticValue = newValue;
160     return ChangeResult::Change;
161   }
162 
163   /// Mark the lattice element as having reached a pessimistic fixpoint. This
164   /// means that the lattice may potentially have conflicting value states, and
165   /// only the conservatively known value state should be relied on.
markPessimisticFixpoint()166   ChangeResult markPessimisticFixpoint() final {
167     if (isAtFixpoint())
168       return ChangeResult::NoChange;
169 
170     // For this fixed point, we take whatever we knew to be true and set that to
171     // our optimistic value.
172     optimisticValue = knownValue;
173     return ChangeResult::Change;
174   }
175 
176   /// Mark the lattice element as having reached an optimistic fixpoint. This
177   /// means that we optimistically assume the current value is the true state.
markOptimisticFixpoint()178   void markOptimisticFixpoint() final {
179     assert(!isUninitialized() && "expected an initialized value");
180     knownValue = *optimisticValue;
181   }
182 
183   /// Returns true if the lattice has reached a fixpoint. A fixpoint is when the
184   /// information optimistically assumed to be true is the same as the
185   /// information known to be true.
isAtFixpoint()186   bool isAtFixpoint() const final { return optimisticValue == knownValue; }
187 
188 private:
189   /// The value that is conservatively known to be true.
190   ValueT knownValue;
191   /// The currently computed value that is optimistically assumed to be true, or
192   /// None if the lattice element is uninitialized.
193   Optional<ValueT> optimisticValue;
194 };
195 
196 //===----------------------------------------------------------------------===//
197 // ForwardDataFlowAnalysisBase
198 //===----------------------------------------------------------------------===//
199 
200 namespace detail {
201 /// This class is the non-templated virtual base class for the
202 /// ForwardDataFlowAnalysis. This class provides opaque hooks to the main
203 /// algorithm.
204 class ForwardDataFlowAnalysisBase {
205 public:
206   virtual ~ForwardDataFlowAnalysisBase();
207 
208   /// Initialize and compute the analysis on operations rooted under the given
209   /// top-level operation. Note that the top-level operation is not visited.
210   void run(Operation *topLevelOp);
211 
212   /// Return the lattice element attached to the given value. If a lattice has
213   /// not been added for the given value, a new 'uninitialized' value is
214   /// inserted and returned.
215   AbstractLatticeElement &getLatticeElement(Value value);
216 
217   /// Return the lattice element attached to the given value, or nullptr if no
218   /// lattice for the value has yet been created.
219   AbstractLatticeElement *lookupLatticeElement(Value value);
220 
221   /// Visit the given operation, and join any necessary analysis state
222   /// into the lattices for the results and block arguments owned by this
223   /// operation using the provided set of operand lattice elements (all pointer
224   /// values are guaranteed to be non-null). Returns if any result or block
225   /// argument value lattices changed during the visit. The lattice for a result
226   /// or block argument value can be obtained and join'ed into by using
227   /// `getLatticeElement`.
228   virtual ChangeResult
229   visitOperation(Operation *op,
230                  ArrayRef<AbstractLatticeElement *> operands) = 0;
231 
232   /// Given a BranchOpInterface, and the current lattice elements that
233   /// correspond to the branch operands (all pointer values are guaranteed to be
234   /// non-null), try to compute a specific set of successors that would be
235   /// selected for the branch. Returns failure if not computable, or if all of
236   /// the successors would be chosen. If a subset of successors can be selected,
237   /// `successors` is populated.
238   virtual LogicalResult
239   getSuccessorsForOperands(BranchOpInterface branch,
240                            ArrayRef<AbstractLatticeElement *> operands,
241                            SmallVectorImpl<Block *> &successors) = 0;
242 
243   /// Given a RegionBranchOpInterface, and the current lattice elements that
244   /// correspond to the branch operands (all pointer values are guaranteed to be
245   /// non-null), compute a specific set of region successors that would be
246   /// selected.
247   virtual void
248   getSuccessorsForOperands(RegionBranchOpInterface branch,
249                            Optional<unsigned> sourceIndex,
250                            ArrayRef<AbstractLatticeElement *> operands,
251                            SmallVectorImpl<RegionSuccessor> &successors) = 0;
252 
253   /// Create a new uninitialized lattice element. An optional value is provided
254   /// which, if valid, should be used to initialize the known conservative state
255   /// of the lattice.
256   virtual AbstractLatticeElement *createLatticeElement(Value value = {}) = 0;
257 
258 private:
259   /// A map from SSA value to lattice element.
260   DenseMap<Value, AbstractLatticeElement *> latticeValues;
261 };
262 } // namespace detail
263 
264 //===----------------------------------------------------------------------===//
265 // ForwardDataFlowAnalysis
266 //===----------------------------------------------------------------------===//
267 
268 /// This class provides a general forward dataflow analysis driver
269 /// utilizing the lattice classes defined above, to enable the easy definition
270 /// of dataflow analysis algorithms. More specifically this driver is useful for
271 /// defining analyses that are forward, sparse, pessimistic (except along
272 /// unreached backedges) and context-insensitive for the interprocedural
273 /// aspects.
274 template <typename ValueT>
275 class ForwardDataFlowAnalysis : public detail::ForwardDataFlowAnalysisBase {
276 public:
ForwardDataFlowAnalysis(MLIRContext * context)277   ForwardDataFlowAnalysis(MLIRContext *context) : context(context) {}
278 
279   /// Return the MLIR context used when constructing this analysis.
getContext()280   MLIRContext *getContext() { return context; }
281 
282   /// Compute the analysis on operations rooted under the given top-level
283   /// operation. Note that the top-level operation is not visited.
run(Operation * topLevelOp)284   void run(Operation *topLevelOp) {
285     detail::ForwardDataFlowAnalysisBase::run(topLevelOp);
286   }
287 
288   /// Return the lattice element attached to the given value, or nullptr if no
289   /// lattice for the value has yet been created.
lookupLatticeElement(Value value)290   LatticeElement<ValueT> *lookupLatticeElement(Value value) {
291     return static_cast<LatticeElement<ValueT> *>(
292         detail::ForwardDataFlowAnalysisBase::lookupLatticeElement(value));
293   }
294 
295 protected:
296   /// Return the lattice element attached to the given value. If a lattice has
297   /// not been added for the given value, a new 'uninitialized' value is
298   /// inserted and returned.
getLatticeElement(Value value)299   LatticeElement<ValueT> &getLatticeElement(Value value) {
300     return static_cast<LatticeElement<ValueT> &>(
301         detail::ForwardDataFlowAnalysisBase::getLatticeElement(value));
302   }
303 
304   /// Mark all of the lattices for the given range of Values as having reached a
305   /// pessimistic fixpoint.
markAllPessimisticFixpoint(ValueRange values)306   ChangeResult markAllPessimisticFixpoint(ValueRange values) {
307     ChangeResult result = ChangeResult::NoChange;
308     for (Value value : values)
309       result |= getLatticeElement(value).markPessimisticFixpoint();
310     return result;
311   }
312 
313   /// Visit the given operation, and join any necessary analysis state
314   /// into the lattices for the results and block arguments owned by this
315   /// operation using the provided set of operand lattice elements (all pointer
316   /// values are guaranteed to be non-null). Returns if any result or block
317   /// argument value lattices changed during the visit. The lattice for a result
318   /// or block argument value can be obtained by using
319   /// `getLatticeElement`.
320   virtual ChangeResult
321   visitOperation(Operation *op,
322                  ArrayRef<LatticeElement<ValueT> *> operands) = 0;
323 
324   /// Given a BranchOpInterface, and the current lattice elements that
325   /// correspond to the branch operands (all pointer values are guaranteed to be
326   /// non-null), try to compute a specific set of successors that would be
327   /// selected for the branch. Returns failure if not computable, or if all of
328   /// the successors would be chosen. If a subset of successors can be selected,
329   /// `successors` is populated.
330   virtual LogicalResult
getSuccessorsForOperands(BranchOpInterface branch,ArrayRef<LatticeElement<ValueT> * > operands,SmallVectorImpl<Block * > & successors)331   getSuccessorsForOperands(BranchOpInterface branch,
332                            ArrayRef<LatticeElement<ValueT> *> operands,
333                            SmallVectorImpl<Block *> &successors) {
334     return failure();
335   }
336 
337   /// Given a RegionBranchOpInterface, and the current lattice elements that
338   /// correspond to the branch operands (all pointer values are guaranteed to be
339   /// non-null), compute a specific set of region successors that would be
340   /// selected.
341   virtual void
getSuccessorsForOperands(RegionBranchOpInterface branch,Optional<unsigned> sourceIndex,ArrayRef<LatticeElement<ValueT> * > operands,SmallVectorImpl<RegionSuccessor> & successors)342   getSuccessorsForOperands(RegionBranchOpInterface branch,
343                            Optional<unsigned> sourceIndex,
344                            ArrayRef<LatticeElement<ValueT> *> operands,
345                            SmallVectorImpl<RegionSuccessor> &successors) {
346     SmallVector<Attribute> constantOperands(operands.size());
347     branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
348   }
349 
350 private:
351   /// Type-erased wrappers that convert the abstract lattice operands to derived
352   /// lattices and invoke the virtual hooks operating on the derived lattices.
353   ChangeResult
visitOperation(Operation * op,ArrayRef<detail::AbstractLatticeElement * > operands)354   visitOperation(Operation *op,
355                  ArrayRef<detail::AbstractLatticeElement *> operands) final {
356     LatticeElement<ValueT> *const *derivedOperandBase =
357         reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
358     return visitOperation(
359         op, llvm::makeArrayRef(derivedOperandBase, operands.size()));
360   }
361   LogicalResult
getSuccessorsForOperands(BranchOpInterface branch,ArrayRef<detail::AbstractLatticeElement * > operands,SmallVectorImpl<Block * > & successors)362   getSuccessorsForOperands(BranchOpInterface branch,
363                            ArrayRef<detail::AbstractLatticeElement *> operands,
364                            SmallVectorImpl<Block *> &successors) final {
365     LatticeElement<ValueT> *const *derivedOperandBase =
366         reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
367     return getSuccessorsForOperands(
368         branch, llvm::makeArrayRef(derivedOperandBase, operands.size()),
369         successors);
370   }
371   void
getSuccessorsForOperands(RegionBranchOpInterface branch,Optional<unsigned> sourceIndex,ArrayRef<detail::AbstractLatticeElement * > operands,SmallVectorImpl<RegionSuccessor> & successors)372   getSuccessorsForOperands(RegionBranchOpInterface branch,
373                            Optional<unsigned> sourceIndex,
374                            ArrayRef<detail::AbstractLatticeElement *> operands,
375                            SmallVectorImpl<RegionSuccessor> &successors) final {
376     LatticeElement<ValueT> *const *derivedOperandBase =
377         reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
378     getSuccessorsForOperands(
379         branch, sourceIndex,
380         llvm::makeArrayRef(derivedOperandBase, operands.size()), successors);
381   }
382 
383   /// Create a new uninitialized lattice element. An optional value is provided,
384   /// which if valid, should be used to initialize the known conservative state
385   /// of the lattice.
createLatticeElement(Value value)386   detail::AbstractLatticeElement *createLatticeElement(Value value) final {
387     ValueT knownValue = value ? ValueT::getPessimisticValueState(value)
388                               : ValueT::getPessimisticValueState(context);
389     return new (allocator.Allocate()) LatticeElement<ValueT>(knownValue);
390   }
391 
392   /// An allocator used for new lattice elements.
393   llvm::SpecificBumpPtrAllocator<LatticeElement<ValueT>> allocator;
394 
395   /// The MLIRContext of this solver.
396   MLIRContext *context;
397 };
398 
399 } // end namespace mlir
400 
401 #endif // MLIR_ANALYSIS_DATAFLOWANALYSIS_H
402