1 //===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
10 
11 #include "mlir/IR/MLIRContext.h"
12 #include "mlir/Quantizer/Support/Configuration.h"
13 #include "llvm/Support/raw_ostream.h"
14 
15 using namespace mlir;
16 using namespace mlir::quantizer;
17 
replaceIncoming(CAGNode * otherNode)18 void CAGNode::replaceIncoming(CAGNode *otherNode) {
19   if (this == otherNode)
20     return;
21   for (CAGNode *parentNode : incoming) {
22     for (CAGNode *&it : parentNode->outgoing) {
23       if (it == this) {
24         it = otherNode;
25         otherNode->incoming.push_back(parentNode);
26       }
27     }
28   }
29   incoming.clear();
30 }
31 
addOutgoing(CAGNode * toNode)32 void CAGNode::addOutgoing(CAGNode *toNode) {
33   if (!llvm::is_contained(outgoing, toNode)) {
34     outgoing.push_back(toNode);
35     toNode->incoming.push_back(this);
36   }
37 }
38 
CAGOperandAnchor(Operation * op,unsigned operandIdx)39 CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx)
40     : CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx).getType()),
41       op(op), operandIdx(operandIdx) {}
42 
CAGResultAnchor(Operation * op,unsigned resultIdx)43 CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx)
44     : CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx).getType()),
45       resultValue(op->getResult(resultIdx)) {}
46 
CAGSlice(SolverContext & context)47 CAGSlice::CAGSlice(SolverContext &context) : context(context) {}
~CAGSlice()48 CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); }
49 
getOperandAnchor(Operation * op,unsigned operandIdx)50 CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op,
51                                              unsigned operandIdx) {
52   assert(operandIdx < op->getNumOperands() && "illegal operand index");
53 
54   // Dedup.
55   auto key = std::make_pair(op, operandIdx);
56   auto foundIt = operandAnchors.find(key);
57   if (foundIt != operandAnchors.end()) {
58     return foundIt->second;
59   }
60 
61   // Create.
62   auto anchor = std::make_unique<CAGOperandAnchor>(op, operandIdx);
63   auto *unowned = anchor.release();
64   unowned->nodeId = allNodes.size();
65   allNodes.push_back(unowned);
66   operandAnchors.insert(std::make_pair(key, unowned));
67   return unowned;
68 }
69 
getResultAnchor(Operation * op,unsigned resultIdx)70 CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) {
71   assert(resultIdx < op->getNumResults() && "illegal result index");
72 
73   // Dedup.
74   auto key = std::make_pair(op, resultIdx);
75   auto foundIt = resultAnchors.find(key);
76   if (foundIt != resultAnchors.end()) {
77     return foundIt->second;
78   }
79 
80   // Create.
81   auto anchor = std::make_unique<CAGResultAnchor>(op, resultIdx);
82   auto *unowned = anchor.release();
83   unowned->nodeId = allNodes.size();
84   allNodes.push_back(unowned);
85   resultAnchors.insert(std::make_pair(key, unowned));
86   return unowned;
87 }
88 
enumerateImpliedConnections(std::function<void (CAGAnchorNode * from,CAGAnchorNode * to)> callback)89 void CAGSlice::enumerateImpliedConnections(
90     std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback) {
91   // Discover peer identity pairs (i.e. implied edges from Result->Operand and
92   // Arg->Call). Use an intermediate vector so that the callback can modify.
93   std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs;
94   for (auto &resultAnchorPair : resultAnchors) {
95     CAGResultAnchor *resultAnchor = resultAnchorPair.second;
96     Value resultValue = resultAnchor->getValue();
97     for (auto &use : resultValue.getUses()) {
98       Operation *operandOp = use.getOwner();
99       unsigned operandIdx = use.getOperandNumber();
100       auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx));
101       if (foundIt != operandAnchors.end()) {
102         impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second));
103       }
104     }
105   }
106 
107   // Callback for each pair.
108   for (auto &impliedPair : impliedPairs) {
109     callback(impliedPair.first, impliedPair.second);
110   }
111 }
112 
propagate(const TargetConfiguration & config)113 unsigned CAGSlice::propagate(const TargetConfiguration &config) {
114   std::vector<CAGNode *> dirtyNodes;
115   dirtyNodes.reserve(allNodes.size());
116   // Note that because iteration happens in nodeId order, there is no need
117   // to sort in order to make deterministic. If the selection method changes,
118   // a sort should be explicitly done.
119   for (CAGNode *child : *this) {
120     if (child->isDirty()) {
121       dirtyNodes.push_back(child);
122     }
123   }
124 
125   if (dirtyNodes.empty()) {
126     return 0;
127   }
128   for (auto dirtyNode : dirtyNodes) {
129     dirtyNode->clearDirty();
130     dirtyNode->propagate(context, config);
131   }
132 
133   return dirtyNodes.size();
134 }
135 
propagate(SolverContext & solverContext,const TargetConfiguration & config)136 void CAGAnchorNode::propagate(SolverContext &solverContext,
137                               const TargetConfiguration &config) {
138   for (CAGNode *child : *this) {
139     child->markDirty();
140   }
141 }
142 
getTransformedType()143 Type CAGAnchorNode::getTransformedType() {
144   if (!getUniformMetadata().selectedType) {
145     return nullptr;
146   }
147   return getUniformMetadata().selectedType.castFromExpressedType(
148       getOriginalType());
149 }
150 
printLabel(raw_ostream & os) const151 void CAGNode::printLabel(raw_ostream &os) const {
152   os << "Node<" << static_cast<const void *>(this) << ">";
153 }
154 
printLabel(raw_ostream & os) const155 void CAGAnchorNode::printLabel(raw_ostream &os) const {
156   getUniformMetadata().printSummary(os);
157 }
158 
printLabel(raw_ostream & os) const159 void CAGOperandAnchor::printLabel(raw_ostream &os) const {
160   os << "Operand<";
161   op->getName().print(os);
162   os << "," << operandIdx;
163   os << ">";
164   CAGAnchorNode::printLabel(os);
165 }
166 
printLabel(raw_ostream & os) const167 void CAGResultAnchor::printLabel(raw_ostream &os) const {
168   os << "Result<";
169   getOp()->getName().print(os);
170   os << ">";
171   CAGAnchorNode::printLabel(os);
172 }
173