1 //===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides graph-based data structures for representing anchors
10 // and constraints between them.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
15 #define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
16 
17 #include <utility>
18 #include <vector>
19 
20 #include "mlir/IR/Function.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Module.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Quantizer/Support/Metadata.h"
26 #include "llvm/ADT/DenseMap.h"
27 
28 namespace mlir {
29 namespace quantizer {
30 
31 class CAGNode;
32 class CAGSlice;
33 class TargetConfiguration;
34 
35 /// A node in the Constraint Analysis Graph.
36 /// Nodes are either anchors (representing results and operands) or constraints.
37 /// Anchor nodes are connected to other anchor nodes via constraints.
38 /// Nodes exist within graph slices, which are typically analyses attached to
39 /// the function or module. Slices can contain other slices, which mirrors
40 /// the nesting of analyses.
41 ///
42 /// Nodes have directed relationships which propagate successor-ward when dirty.
43 /// Relationships can be bi-directional, in which case, the constraint's
44 /// propagation mechanism must ensure convergence.
45 class CAGNode {
46 public:
47   enum class Kind {
48     /// Anchors.
49     Anchor,
50     OperandAnchor,
51     ResultAnchor,
52     LastAnchor = ResultAnchor,
53 
54     /// Constraints.
55     Constraint,
56     SolveUniformConstraint,
57     UniformPropagateExplicitScale,
58     LastConstraint = UniformPropagateExplicitScale,
59   };
60 
61   // Vector and iterator over nodes.
62   using node_vector = SmallVector<CAGNode *, 1>;
63   using iterator = node_vector::iterator;
64   using const_iterator = node_vector::const_iterator;
65 
66   virtual ~CAGNode() = default;
67 
getKind()68   Kind getKind() const { return kind; }
69 
70   /// Unique id of the node within the slice.
getNodeId()71   int getNodeId() const { return nodeId; }
72 
73   /// Whether the node is dirty, requiring one or more calls to propagate().
isDirty()74   bool isDirty() const { return dirty; }
markDirty()75   void markDirty() { dirty = true; }
clearDirty()76   void clearDirty() { dirty = false; }
77 
78   /// Iterator over this node's children (outgoing) nodes.
begin()79   const_iterator begin() const { return outgoing.begin(); }
end()80   const_iterator end() const { return outgoing.end(); }
begin()81   iterator begin() { return outgoing.begin(); }
end()82   iterator end() { return outgoing.end(); }
83 
84   /// Iterator over this parents (incoming) nodes.
incoming_begin()85   const_iterator incoming_begin() const { return incoming.begin(); }
incoming_end()86   const_iterator incoming_end() const { return incoming.end(); }
incoming_begin()87   iterator incoming_begin() { return incoming.begin(); }
incoming_end()88   iterator incoming_end() { return incoming.end(); }
89 
propagate(SolverContext & solverContext,const TargetConfiguration & config)90   virtual void propagate(SolverContext &solverContext,
91                          const TargetConfiguration &config) {}
92 
93   /// Prints the node label, suitable for one-line display.
94   virtual void printLabel(raw_ostream &os) const;
95 
findChildrenOfKind(SmallVectorImpl<T * > & found)96   template <typename T> void findChildrenOfKind(SmallVectorImpl<T *> &found) {
97     for (CAGNode *child : *this) {
98       T *ofKind = dyn_cast<T>(child);
99       if (ofKind) {
100         found.push_back(ofKind);
101       }
102     }
103   }
104 
105   /// Replaces this node by rerouting any parent nodes to have otherNode
106   /// as a child.
107   void replaceIncoming(CAGNode *otherNode);
108 
109   /// Adds an outgoing connection to this node (and corresponding back
110   /// incoming connection).
111   void addOutgoing(CAGNode *toNode);
112 
113   /// Whether this node is an orphan (has no incoming or outgoing connections).
isOrphan()114   bool isOrphan() const { return incoming.empty() && outgoing.empty(); }
115 
116 protected:
CAGNode(Kind kind)117   CAGNode(Kind kind) : kind(kind) {}
118 
119 private:
120   Kind kind;
121   int nodeId = -1;
122   node_vector outgoing;
123   node_vector incoming;
124   bool dirty = false;
125 
126   friend class CAGSlice;
127 };
128 
129 /// Anchor nodes represent points in the source IR where we may choose to
130 /// introduce a type transition. These include operands, results, arguments
131 /// returns, etc.
132 class CAGAnchorNode : public CAGNode {
133 public:
134   enum class TypeTransformRule {
135     /// The owning op directly supports all transformed types. In practice,
136     /// this means that the op supports QuantizedType for this anchor.
137     Direct,
138 
139     /// The type of this anchor should be set to the QuantizedType storage
140     /// type. This will only be valid if constraints are such that all
141     /// inputs/outputs converge to the same storage type (i.e. coupled).
142     DirectStorage,
143 
144     /// The anchor must only be typed based on the expressed type. This is
145     /// used for ops that do not natively support quantization, and suitable
146     /// casts will be inserted.
147     ExpressedOnly,
148   };
149 
150   /// Metadata for solving uniform quantization params.
getUniformMetadata()151   CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; }
getUniformMetadata()152   const CAGUniformMetadata &getUniformMetadata() const {
153     return uniformMetadata;
154   }
155 
156   virtual Operation *getOp() const = 0;
157   virtual Value getValue() const = 0;
158 
classof(const CAGNode * n)159   static bool classof(const CAGNode *n) {
160     return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor;
161   }
162 
163   void propagate(SolverContext &solverContext,
164                  const TargetConfiguration &config) override;
165 
166   void printLabel(raw_ostream &os) const override;
167 
168   /// Given the anchor metadata and resolved solutions, chooses the most
169   /// salient and returns an appropriate type to represent it.
170   Type getTransformedType();
171 
getTypeTransformRule()172   TypeTransformRule getTypeTransformRule() const { return typeTransformRule; }
173 
setTypeTransformRule(TypeTransformRule r)174   void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; }
175 
176   /// Gets the Type that was defined for this anchor at the time of
177   /// construction.
getOriginalType()178   Type getOriginalType() const { return originalType; }
179 
180 protected:
CAGAnchorNode(Kind kind,Type originalType)181   CAGAnchorNode(Kind kind, Type originalType)
182       : CAGNode(kind), originalType(originalType) {}
183 
184 private:
185   CAGUniformMetadata uniformMetadata;
186   Type originalType;
187   TypeTransformRule typeTransformRule = TypeTransformRule::Direct;
188 };
189 
190 /// An anchor tied to a specific operand.
191 /// Since operand anchors can be rewritten so that the operand refers to
192 /// a new result, they are maintained by reference (to the op and index).
193 class CAGOperandAnchor : public CAGAnchorNode {
194 public:
195   CAGOperandAnchor(Operation *op, unsigned operandIdx);
196 
getOp()197   Operation *getOp() const final { return op; }
getOperandIdx()198   unsigned getOperandIdx() const { return operandIdx; }
199 
classof(const CAGNode * n)200   static bool classof(const CAGNode *n) {
201     return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor;
202   }
203 
getValue()204   Value getValue() const final { return op->getOperand(operandIdx); }
205 
206   void printLabel(raw_ostream &os) const override;
207 
208 private:
209   Operation *op;
210   unsigned operandIdx;
211 };
212 
213 /// An anchor tied to a specific result.
214 /// Since a result is already anchored to its defining op, result anchors refer
215 /// directly to the underlying Value.
216 class CAGResultAnchor : public CAGAnchorNode {
217 public:
218   CAGResultAnchor(Operation *op, unsigned resultIdx);
219 
classof(const CAGNode * n)220   static bool classof(const CAGNode *n) {
221     return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor;
222   }
223 
getOp()224   Operation *getOp() const final { return resultValue.getDefiningOp(); }
getValue()225   Value getValue() const final { return resultValue; }
226 
227   void printLabel(raw_ostream &os) const override;
228 
229 private:
230   Value resultValue;
231 };
232 
233 /// Base class for constraint nodes.
234 class CAGConstraintNode : public CAGNode {
235 public:
CAGConstraintNode(Kind kind)236   CAGConstraintNode(Kind kind) : CAGNode(kind) {}
237 
classof(const CAGNode * n)238   static bool classof(const CAGNode *n) {
239     return n->getKind() >= Kind::Constraint &&
240            n->getKind() <= Kind::LastConstraint;
241   }
242 };
243 
244 /// A slice of a CAG (which may be the whole graph).
245 class CAGSlice {
246 public:
247   CAGSlice(SolverContext &context);
248   ~CAGSlice();
249 
250   using node_vector = std::vector<CAGNode *>;
251   using iterator = node_vector::iterator;
252   using const_iterator = node_vector::const_iterator;
253 
begin()254   iterator begin() { return allNodes.begin(); }
end()255   iterator end() { return allNodes.end(); }
begin()256   const_iterator begin() const { return allNodes.begin(); }
end()257   const_iterator end() const { return allNodes.end(); }
258 
259   /// Gets an operand anchor node.
260   CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx);
261 
262   /// Gets a result anchor node.
263   CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx);
264 
265   /// Adds a relation constraint with incoming 'from' anchors and outgoing 'to'
266   /// anchors.
267   template <typename T, typename... Args>
addUniqueConstraint(ArrayRef<CAGAnchorNode * > anchors,Args...args)268   T *addUniqueConstraint(ArrayRef<CAGAnchorNode *> anchors, Args... args) {
269     static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
270                   "T must be a CAGConstraingNode");
271     T *constraintNode = addNode(std::make_unique<T>(args...));
272     for (auto *anchor : anchors)
273       anchor->addOutgoing(constraintNode);
274     return constraintNode;
275   }
276 
277   /// Adds a unidirectional constraint from a node to an array of target nodes.
278   template <typename T, typename... Args>
addUnidirectionalConstraint(CAGAnchorNode * fromAnchor,ArrayRef<CAGAnchorNode * > toAnchors,Args...args)279   T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor,
280                                  ArrayRef<CAGAnchorNode *> toAnchors,
281                                  Args... args) {
282     static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
283                   "T must be a CAGConstraingNode");
284     T *constraintNode = addNode(std::make_unique<T>(args...));
285     fromAnchor->addOutgoing(constraintNode);
286     for (auto *toAnchor : toAnchors) {
287       constraintNode->addOutgoing(toAnchor);
288     }
289     return constraintNode;
290   }
291 
292   template <typename T>
addClusteredConstraint(ArrayRef<CAGAnchorNode * > anchors)293   T *addClusteredConstraint(ArrayRef<CAGAnchorNode *> anchors) {
294     static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
295                   "T must be a CAGConstraingNode");
296     SmallVector<T *, 8> cluster;
297     for (auto *anchor : anchors) {
298       anchor->findChildrenOfKind<T>(cluster);
299     }
300 
301     T *constraintNode;
302     if (cluster.empty()) {
303       // Create new.
304       constraintNode = addNode(std::make_unique<T>());
305     } else {
306       // Merge existing.
307       constraintNode = cluster[0];
308       for (size_t i = 1, e = cluster.size(); i < e; ++i) {
309         cluster[i]->replaceIncoming(constraintNode);
310       }
311     }
312     for (auto *anchor : anchors) {
313       anchor->addOutgoing(constraintNode);
314     }
315     return constraintNode;
316   }
317 
318   /// Enumerates all implied connections in the slice.
319   /// An implied connection is any two nodes that physically refer to the
320   /// same value in the IR, such as result->operand.
321   /// Typically this will be modeled with some kind of strong or weak
322   /// identity constraint such that types propagate.
323   /// This is usually called when the slice has been fully constructed in
324   /// order to add final constraints.
325   /// It is legal for the callback to modify the graph by adding constraints.
326   void enumerateImpliedConnections(
327       std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback);
328 
329   /// Performs one round of propagation, returning the number of nodes
330   /// propagates. If returns > 0, then additional propagate() rounds are
331   /// required.
332   unsigned propagate(const TargetConfiguration &config);
333 
334 private:
335   /// Adds a node to the graph.
336   /// The node should be a subclass of TransformNode.
337   /// Returns the raw pointer to the node.
338   template <typename T>
addNode(std::unique_ptr<T> node)339   T *addNode(std::unique_ptr<T> node) {
340     node->nodeId = allNodes.size();
341     T *unownedNode = node.release();
342     allNodes.push_back(unownedNode);
343     return unownedNode;
344   }
345 
346   SolverContext &context;
347   std::vector<CAGNode *> allNodes;
348   DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *> operandAnchors;
349   DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *> resultAnchors;
350 };
351 
352 inline raw_ostream &operator<<(raw_ostream &os, const CAGNode &node) {
353   node.printLabel(os);
354   return os;
355 }
356 
357 } // namespace quantizer
358 } // namespace mlir
359 
360 #endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H
361