1 //===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the primary pass for instantiating a CAG, running it to
10 // convergence on a module to determine eligible quantized type transforms, and
11 // applying those transforms to the IR.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/QuantOps/QuantOps.h"
16 #include "mlir/Dialect/QuantOps/QuantTypes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Quantizer/Configurations/FxpMathConfig.h"
19 #include "mlir/Quantizer/Support/Configuration.h"
20 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
21 #include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
22 #include "mlir/Quantizer/Transforms/Passes.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "llvm/Support/DOTGraphTraits.h"
25 #include "llvm/Support/GraphWriter.h"
26 #include "llvm/Support/raw_ostream.h"
27
28 using namespace mlir;
29 using namespace mlir::quantizer;
30 using namespace mlir::quant;
31
32 namespace llvm {
33
34 template <>
35 struct DOTGraphTraits<const CAGSlice *>
36 : public DOTGraphTraits<const CAGNode *> {
DOTGraphTraitsllvm::DOTGraphTraits37 DOTGraphTraits(bool isSimple = false)
38 : DOTGraphTraits<const CAGNode *>(isSimple) {}
39
getNodeLabelllvm::DOTGraphTraits40 std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
41 std::string s;
42 llvm::raw_string_ostream out(s);
43 node->printLabel(out);
44 return out.str();
45 }
46
getGraphPropertiesllvm::DOTGraphTraits47 static std::string getGraphProperties(const CAGSlice *) {
48 return "rankdir=LR;";
49 }
50
isNodeHiddenllvm::DOTGraphTraits51 static bool isNodeHidden(const CAGNode *node) {
52 // Filter constraint nodes with no incoming or outgoing connections.
53 // These orphans are often created as part of graph merging operations.
54 return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
55 }
56
getNodeAttributesllvm::DOTGraphTraits57 std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
58 switch (node->getKind()) {
59 default:
60 return std::string();
61 case CAGNode::Kind::OperandAnchor:
62 return "shape=record,color=yellow,style=filled";
63 case CAGNode::Kind::ResultAnchor:
64 return "shape=record,color=lightblue,style=filled";
65 case CAGNode::Kind::Constraint:
66 return "shape=record,style=dotted";
67 }
68 }
69 };
70
71 } // end namespace llvm
72
73 namespace {
74
75 class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
76 public:
77 InferQuantizedTypesPass() = default;
InferQuantizedTypesPass(SolverContext & solverContext,const TargetConfiguration & config)78 InferQuantizedTypesPass(SolverContext &solverContext,
79 const TargetConfiguration &config)
80 : explicitSolverContext(&solverContext), explicitConfig(&config) {}
81 void runOnModule() override;
82 void runWithConfig(SolverContext &solverContext,
83 const TargetConfiguration &config);
84
85 void transformOperandType(CAGOperandAnchor *anchor, Type newType);
86 void transformResultType(CAGResultAnchor *anchor, Type newType);
87
88 private:
89 SolverContext *explicitSolverContext = nullptr;
90 const TargetConfiguration *explicitConfig = nullptr;
91 };
92
93 } // end anonymous namespace
94
95 /// Maximum number of propagation rounds to run to converge the CAG before
96 /// signalling an error.
97 static const int kMaximumPropagationRounds = 1000;
98
validateTypeConversion(Type newType,Type origType,Operation * op)99 static LogicalResult validateTypeConversion(Type newType, Type origType,
100 Operation *op) {
101 if (!newType) {
102 return op->emitOpError() << "unsupported type conversion from " << newType;
103 }
104 return success();
105 }
106
runOnModule()107 void InferQuantizedTypesPass::runOnModule() {
108 if (explicitSolverContext && explicitConfig) {
109 // If explicitly constructed with a config and context.
110 runWithConfig(*explicitSolverContext, *explicitConfig);
111 return;
112 }
113
114 // For global pass registration, use defaults.
115 SolverContext solverContext(*getModule().getContext());
116 auto config = FxpMathTargetConfig::create(solverContext);
117 runWithConfig(solverContext, *config);
118 }
119
runWithConfig(SolverContext & solverContext,const TargetConfiguration & config)120 void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
121 const TargetConfiguration &config) {
122 CAGSlice cag(solverContext);
123 for (auto f : getModule().getOps<FuncOp>()) {
124 f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
125 }
126 config.finalizeAnchors(cag);
127
128 // Propagate.
129 int propRound;
130 for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
131 auto propCount = cag.propagate(config);
132 if (propCount == 0)
133 break;
134 }
135 if (propRound == 0) {
136 emitError(UnknownLoc::get(&getContext()),
137 "exceeded maximum number of solver iterations (infinite loop?)");
138 return;
139 }
140
141 // TODO: Only dump the GraphViz if a flag is set and move to a utility.
142 // GraphViz.
143 if (!solverContext.getDebugCAGDotPath().empty()) {
144 auto actFileName =
145 llvm::WriteGraph(const_cast<const CAGSlice *>(&cag), "CAG",
146 /*ShortNames=*/false,
147 /*Title=*/"CAG",
148 /*Filename=*/solverContext.getDebugCAGDotPath());
149 llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
150 }
151
152 // Start transforming the types in order of anchor type (results, then
153 // operands).
154 // Apply result types.
155 for (auto *node : cag) {
156 auto anchorNode = dyn_cast<CAGResultAnchor>(node);
157 if (!anchorNode)
158 continue;
159 if (Type newType = anchorNode->getTransformedType())
160 transformResultType(anchorNode, newType);
161 }
162
163 // Apply operand types.
164 for (auto *node : cag) {
165 auto anchorNode = dyn_cast<CAGOperandAnchor>(node);
166 if (!anchorNode)
167 continue;
168 if (Type newType = anchorNode->getTransformedType())
169 transformOperandType(anchorNode, newType);
170 }
171 }
172
transformOperandType(CAGOperandAnchor * anchor,Type newType)173 void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
174 Type newType) {
175 Value inputValue = anchor->getValue();
176 Operation *op = anchor->getOp();
177 OpBuilder b(op->getBlock(), Block::iterator(op));
178
179 SmallVector<Value, 1> removeValuesIfDead;
180
181 // Because we've already run the result transforms at this phase, it is
182 // very likely that inputValue points to a dcast op whose input matches
183 // our type. We detect that situation and route around just to save some
184 // bulk in the IR.
185 Value newTypedInputValue = inputValue;
186 auto inputDcastOp =
187 dyn_cast_or_null<DequantizeCastOp>(inputValue.getDefiningOp());
188 if (inputDcastOp && inputDcastOp.arg().getType() == newType) {
189 // Can just use the dcast's input value.
190 newTypedInputValue = inputDcastOp.arg();
191 removeValuesIfDead.push_back(inputDcastOp);
192 } else {
193 // Need to synthesize a qcast.
194 newTypedInputValue =
195 b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
196 }
197
198 switch (anchor->getTypeTransformRule()) {
199 case CAGAnchorNode::TypeTransformRule::Direct:
200 anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
201 break;
202
203 case CAGAnchorNode::TypeTransformRule::DirectStorage: {
204 Type storageType = QuantizedType::castToStorageType(newType);
205 if (failed(validateTypeConversion(storageType, newType, op)))
206 return;
207 anchor->getOp()->setOperand(
208 anchor->getOperandIdx(),
209 b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
210 break;
211 }
212
213 case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
214 // Leave the anchor as-is and just cast in/out after it.
215 anchor->getOp()->setOperand(
216 anchor->getOperandIdx(),
217 b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
218 newTypedInputValue));
219 break;
220 }
221
222 for (Value removeValueIfDead : removeValuesIfDead) {
223 if (removeValueIfDead.use_empty()) {
224 removeValueIfDead.getDefiningOp()->erase();
225 }
226 }
227 }
228
transformResultType(CAGResultAnchor * anchor,Type newType)229 void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
230 Type newType) {
231 Value origResultValue = anchor->getValue();
232 Operation *op = origResultValue.getDefiningOp();
233 OpBuilder b(op->getBlock(), ++Block::iterator(op));
234
235 Value replacedResultValue = nullptr;
236 Value newResultValue = nullptr;
237 switch (anchor->getTypeTransformRule()) {
238 case CAGAnchorNode::TypeTransformRule::Direct:
239 origResultValue.setType(newType);
240 replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
241 op->getLoc(), anchor->getOriginalType(), origResultValue);
242 break;
243
244 case CAGAnchorNode::TypeTransformRule::DirectStorage: {
245 Type storageType = QuantizedType::castToStorageType(newType);
246 if (failed(validateTypeConversion(storageType, newType, op)))
247 return;
248 origResultValue.setType(storageType);
249 replacedResultValue =
250 b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
251 newResultValue = b.create<DequantizeCastOp>(
252 op->getLoc(), anchor->getOriginalType(), replacedResultValue);
253 break;
254 }
255
256 case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
257 // Leave the anchor as-is and just cast in/out after it.
258 replacedResultValue =
259 b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
260 newResultValue = b.create<DequantizeCastOp>(
261 op->getLoc(), anchor->getOriginalType(), replacedResultValue);
262 break;
263 }
264
265 if (replacedResultValue) {
266 // Transform:
267 // origResultValue --> replaceResultValue -> newResultValue
268 // \-> [original uses]
269 // To:
270 // origResultValue -> replaceResultValue ->
271 // newResultValue -> [original uses]
272 // Note that replaceResultValue may equal newResultValue or there may
273 // be operands between the two.
274 origResultValue.replaceAllUsesWith(newResultValue);
275 replacedResultValue.getDefiningOp()->replaceUsesOfWith(newResultValue,
276 origResultValue);
277 }
278 }
279
280 std::unique_ptr<OpPassBase<ModuleOp>>
createInferQuantizedTypesPass(SolverContext & solverContext,const TargetConfiguration & config)281 mlir::quantizer::createInferQuantizedTypesPass(
282 SolverContext &solverContext, const TargetConfiguration &config) {
283 return std::make_unique<InferQuantizedTypesPass>(solverContext, config);
284 }
285
286 static PassRegistration<InferQuantizedTypesPass>
287 pass("quantizer-infer-quantized-types",
288 "Infers quantized types for a module");
289