1 //===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides graph-based data structures for representing anchors 10 // and constraints between them. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H 15 #define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H 16 17 #include <utility> 18 #include <vector> 19 20 #include "mlir/IR/Function.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Module.h" 23 #include "mlir/IR/Operation.h" 24 #include "mlir/IR/Types.h" 25 #include "mlir/Quantizer/Support/Metadata.h" 26 #include "llvm/ADT/DenseMap.h" 27 28 namespace mlir { 29 namespace quantizer { 30 31 class CAGNode; 32 class CAGSlice; 33 class TargetConfiguration; 34 35 /// A node in the Constraint Analysis Graph. 36 /// Nodes are either anchors (representing results and operands) or constraints. 37 /// Anchor nodes are connected to other anchor nodes via constraints. 38 /// Nodes exist within graph slices, which are typically analyses attached to 39 /// the function or module. Slices can contain other slices, which mirrors 40 /// the nesting of analyses. 41 /// 42 /// Nodes have directed relationships which propagate successor-ward when dirty. 43 /// Relationships can be bi-directional, in which case, the constraint's 44 /// propagation mechanism must ensure convergence. 45 class CAGNode { 46 public: 47 enum class Kind { 48 /// Anchors. 49 Anchor, 50 OperandAnchor, 51 ResultAnchor, 52 LastAnchor = ResultAnchor, 53 54 /// Constraints. 55 Constraint, 56 SolveUniformConstraint, 57 UniformPropagateExplicitScale, 58 LastConstraint = UniformPropagateExplicitScale, 59 }; 60 61 // Vector and iterator over nodes. 62 using node_vector = SmallVector<CAGNode *, 1>; 63 using iterator = node_vector::iterator; 64 using const_iterator = node_vector::const_iterator; 65 66 virtual ~CAGNode() = default; 67 getKind()68 Kind getKind() const { return kind; } 69 70 /// Unique id of the node within the slice. getNodeId()71 int getNodeId() const { return nodeId; } 72 73 /// Whether the node is dirty, requiring one or more calls to propagate(). isDirty()74 bool isDirty() const { return dirty; } markDirty()75 void markDirty() { dirty = true; } clearDirty()76 void clearDirty() { dirty = false; } 77 78 /// Iterator over this node's children (outgoing) nodes. begin()79 const_iterator begin() const { return outgoing.begin(); } end()80 const_iterator end() const { return outgoing.end(); } begin()81 iterator begin() { return outgoing.begin(); } end()82 iterator end() { return outgoing.end(); } 83 84 /// Iterator over this parents (incoming) nodes. incoming_begin()85 const_iterator incoming_begin() const { return incoming.begin(); } incoming_end()86 const_iterator incoming_end() const { return incoming.end(); } incoming_begin()87 iterator incoming_begin() { return incoming.begin(); } incoming_end()88 iterator incoming_end() { return incoming.end(); } 89 propagate(SolverContext & solverContext,const TargetConfiguration & config)90 virtual void propagate(SolverContext &solverContext, 91 const TargetConfiguration &config) {} 92 93 /// Prints the node label, suitable for one-line display. 94 virtual void printLabel(raw_ostream &os) const; 95 findChildrenOfKind(SmallVectorImpl<T * > & found)96 template <typename T> void findChildrenOfKind(SmallVectorImpl<T *> &found) { 97 for (CAGNode *child : *this) { 98 T *ofKind = dyn_cast<T>(child); 99 if (ofKind) { 100 found.push_back(ofKind); 101 } 102 } 103 } 104 105 /// Replaces this node by rerouting any parent nodes to have otherNode 106 /// as a child. 107 void replaceIncoming(CAGNode *otherNode); 108 109 /// Adds an outgoing connection to this node (and corresponding back 110 /// incoming connection). 111 void addOutgoing(CAGNode *toNode); 112 113 /// Whether this node is an orphan (has no incoming or outgoing connections). isOrphan()114 bool isOrphan() const { return incoming.empty() && outgoing.empty(); } 115 116 protected: CAGNode(Kind kind)117 CAGNode(Kind kind) : kind(kind) {} 118 119 private: 120 Kind kind; 121 int nodeId = -1; 122 node_vector outgoing; 123 node_vector incoming; 124 bool dirty = false; 125 126 friend class CAGSlice; 127 }; 128 129 /// Anchor nodes represent points in the source IR where we may choose to 130 /// introduce a type transition. These include operands, results, arguments 131 /// returns, etc. 132 class CAGAnchorNode : public CAGNode { 133 public: 134 enum class TypeTransformRule { 135 /// The owning op directly supports all transformed types. In practice, 136 /// this means that the op supports QuantizedType for this anchor. 137 Direct, 138 139 /// The type of this anchor should be set to the QuantizedType storage 140 /// type. This will only be valid if constraints are such that all 141 /// inputs/outputs converge to the same storage type (i.e. coupled). 142 DirectStorage, 143 144 /// The anchor must only be typed based on the expressed type. This is 145 /// used for ops that do not natively support quantization, and suitable 146 /// casts will be inserted. 147 ExpressedOnly, 148 }; 149 150 /// Metadata for solving uniform quantization params. getUniformMetadata()151 CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; } getUniformMetadata()152 const CAGUniformMetadata &getUniformMetadata() const { 153 return uniformMetadata; 154 } 155 156 virtual Operation *getOp() const = 0; 157 virtual Value getValue() const = 0; 158 classof(const CAGNode * n)159 static bool classof(const CAGNode *n) { 160 return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor; 161 } 162 163 void propagate(SolverContext &solverContext, 164 const TargetConfiguration &config) override; 165 166 void printLabel(raw_ostream &os) const override; 167 168 /// Given the anchor metadata and resolved solutions, chooses the most 169 /// salient and returns an appropriate type to represent it. 170 Type getTransformedType(); 171 getTypeTransformRule()172 TypeTransformRule getTypeTransformRule() const { return typeTransformRule; } 173 setTypeTransformRule(TypeTransformRule r)174 void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; } 175 176 /// Gets the Type that was defined for this anchor at the time of 177 /// construction. getOriginalType()178 Type getOriginalType() const { return originalType; } 179 180 protected: CAGAnchorNode(Kind kind,Type originalType)181 CAGAnchorNode(Kind kind, Type originalType) 182 : CAGNode(kind), originalType(originalType) {} 183 184 private: 185 CAGUniformMetadata uniformMetadata; 186 Type originalType; 187 TypeTransformRule typeTransformRule = TypeTransformRule::Direct; 188 }; 189 190 /// An anchor tied to a specific operand. 191 /// Since operand anchors can be rewritten so that the operand refers to 192 /// a new result, they are maintained by reference (to the op and index). 193 class CAGOperandAnchor : public CAGAnchorNode { 194 public: 195 CAGOperandAnchor(Operation *op, unsigned operandIdx); 196 getOp()197 Operation *getOp() const final { return op; } getOperandIdx()198 unsigned getOperandIdx() const { return operandIdx; } 199 classof(const CAGNode * n)200 static bool classof(const CAGNode *n) { 201 return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor; 202 } 203 getValue()204 Value getValue() const final { return op->getOperand(operandIdx); } 205 206 void printLabel(raw_ostream &os) const override; 207 208 private: 209 Operation *op; 210 unsigned operandIdx; 211 }; 212 213 /// An anchor tied to a specific result. 214 /// Since a result is already anchored to its defining op, result anchors refer 215 /// directly to the underlying Value. 216 class CAGResultAnchor : public CAGAnchorNode { 217 public: 218 CAGResultAnchor(Operation *op, unsigned resultIdx); 219 classof(const CAGNode * n)220 static bool classof(const CAGNode *n) { 221 return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor; 222 } 223 getOp()224 Operation *getOp() const final { return resultValue.getDefiningOp(); } getValue()225 Value getValue() const final { return resultValue; } 226 227 void printLabel(raw_ostream &os) const override; 228 229 private: 230 Value resultValue; 231 }; 232 233 /// Base class for constraint nodes. 234 class CAGConstraintNode : public CAGNode { 235 public: CAGConstraintNode(Kind kind)236 CAGConstraintNode(Kind kind) : CAGNode(kind) {} 237 classof(const CAGNode * n)238 static bool classof(const CAGNode *n) { 239 return n->getKind() >= Kind::Constraint && 240 n->getKind() <= Kind::LastConstraint; 241 } 242 }; 243 244 /// A slice of a CAG (which may be the whole graph). 245 class CAGSlice { 246 public: 247 CAGSlice(SolverContext &context); 248 ~CAGSlice(); 249 250 using node_vector = std::vector<CAGNode *>; 251 using iterator = node_vector::iterator; 252 using const_iterator = node_vector::const_iterator; 253 begin()254 iterator begin() { return allNodes.begin(); } end()255 iterator end() { return allNodes.end(); } begin()256 const_iterator begin() const { return allNodes.begin(); } end()257 const_iterator end() const { return allNodes.end(); } 258 259 /// Gets an operand anchor node. 260 CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx); 261 262 /// Gets a result anchor node. 263 CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx); 264 265 /// Adds a relation constraint with incoming 'from' anchors and outgoing 'to' 266 /// anchors. 267 template <typename T, typename... Args> addUniqueConstraint(ArrayRef<CAGAnchorNode * > anchors,Args...args)268 T *addUniqueConstraint(ArrayRef<CAGAnchorNode *> anchors, Args... args) { 269 static_assert(std::is_convertible<T *, CAGConstraintNode *>(), 270 "T must be a CAGConstraingNode"); 271 T *constraintNode = addNode(std::make_unique<T>(args...)); 272 for (auto *anchor : anchors) 273 anchor->addOutgoing(constraintNode); 274 return constraintNode; 275 } 276 277 /// Adds a unidirectional constraint from a node to an array of target nodes. 278 template <typename T, typename... Args> addUnidirectionalConstraint(CAGAnchorNode * fromAnchor,ArrayRef<CAGAnchorNode * > toAnchors,Args...args)279 T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor, 280 ArrayRef<CAGAnchorNode *> toAnchors, 281 Args... args) { 282 static_assert(std::is_convertible<T *, CAGConstraintNode *>(), 283 "T must be a CAGConstraingNode"); 284 T *constraintNode = addNode(std::make_unique<T>(args...)); 285 fromAnchor->addOutgoing(constraintNode); 286 for (auto *toAnchor : toAnchors) { 287 constraintNode->addOutgoing(toAnchor); 288 } 289 return constraintNode; 290 } 291 292 template <typename T> addClusteredConstraint(ArrayRef<CAGAnchorNode * > anchors)293 T *addClusteredConstraint(ArrayRef<CAGAnchorNode *> anchors) { 294 static_assert(std::is_convertible<T *, CAGConstraintNode *>(), 295 "T must be a CAGConstraingNode"); 296 SmallVector<T *, 8> cluster; 297 for (auto *anchor : anchors) { 298 anchor->findChildrenOfKind<T>(cluster); 299 } 300 301 T *constraintNode; 302 if (cluster.empty()) { 303 // Create new. 304 constraintNode = addNode(std::make_unique<T>()); 305 } else { 306 // Merge existing. 307 constraintNode = cluster[0]; 308 for (size_t i = 1, e = cluster.size(); i < e; ++i) { 309 cluster[i]->replaceIncoming(constraintNode); 310 } 311 } 312 for (auto *anchor : anchors) { 313 anchor->addOutgoing(constraintNode); 314 } 315 return constraintNode; 316 } 317 318 /// Enumerates all implied connections in the slice. 319 /// An implied connection is any two nodes that physically refer to the 320 /// same value in the IR, such as result->operand. 321 /// Typically this will be modeled with some kind of strong or weak 322 /// identity constraint such that types propagate. 323 /// This is usually called when the slice has been fully constructed in 324 /// order to add final constraints. 325 /// It is legal for the callback to modify the graph by adding constraints. 326 void enumerateImpliedConnections( 327 std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback); 328 329 /// Performs one round of propagation, returning the number of nodes 330 /// propagates. If returns > 0, then additional propagate() rounds are 331 /// required. 332 unsigned propagate(const TargetConfiguration &config); 333 334 private: 335 /// Adds a node to the graph. 336 /// The node should be a subclass of TransformNode. 337 /// Returns the raw pointer to the node. 338 template <typename T> addNode(std::unique_ptr<T> node)339 T *addNode(std::unique_ptr<T> node) { 340 node->nodeId = allNodes.size(); 341 T *unownedNode = node.release(); 342 allNodes.push_back(unownedNode); 343 return unownedNode; 344 } 345 346 SolverContext &context; 347 std::vector<CAGNode *> allNodes; 348 DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *> operandAnchors; 349 DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *> resultAnchors; 350 }; 351 352 inline raw_ostream &operator<<(raw_ostream &os, const CAGNode &node) { 353 node.printLabel(os); 354 return os; 355 } 356 357 } // namespace quantizer 358 } // namespace mlir 359 360 #endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H 361