1 //===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the primary pass for instantiating a CAG, running it to
10 // convergence on a module to determine eligible quantized type transforms, and
11 // applying those transforms to the IR.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/QuantOps/QuantOps.h"
16 #include "mlir/Dialect/QuantOps/QuantTypes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Quantizer/Configurations/FxpMathConfig.h"
19 #include "mlir/Quantizer/Support/Configuration.h"
20 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
21 #include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
22 #include "mlir/Quantizer/Transforms/Passes.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "llvm/Support/DOTGraphTraits.h"
25 #include "llvm/Support/GraphWriter.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace mlir;
29 using namespace mlir::quantizer;
30 using namespace mlir::quant;
31 
32 namespace llvm {
33 
34 template <>
35 struct DOTGraphTraits<const CAGSlice *>
36     : public DOTGraphTraits<const CAGNode *> {
DOTGraphTraitsllvm::DOTGraphTraits37   DOTGraphTraits(bool isSimple = false)
38       : DOTGraphTraits<const CAGNode *>(isSimple) {}
39 
getNodeLabelllvm::DOTGraphTraits40   std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
41     std::string s;
42     llvm::raw_string_ostream out(s);
43     node->printLabel(out);
44     return out.str();
45   }
46 
getGraphPropertiesllvm::DOTGraphTraits47   static std::string getGraphProperties(const CAGSlice *) {
48     return "rankdir=LR;";
49   }
50 
isNodeHiddenllvm::DOTGraphTraits51   static bool isNodeHidden(const CAGNode *node) {
52     // Filter constraint nodes with no incoming or outgoing connections.
53     // These orphans are often created as part of graph merging operations.
54     return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
55   }
56 
getNodeAttributesllvm::DOTGraphTraits57   std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
58     switch (node->getKind()) {
59     default:
60       return std::string();
61     case CAGNode::Kind::OperandAnchor:
62       return "shape=record,color=yellow,style=filled";
63     case CAGNode::Kind::ResultAnchor:
64       return "shape=record,color=lightblue,style=filled";
65     case CAGNode::Kind::Constraint:
66       return "shape=record,style=dotted";
67     }
68   }
69 };
70 
71 } // end namespace llvm
72 
73 namespace {
74 
75 class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
76 public:
77   InferQuantizedTypesPass() = default;
InferQuantizedTypesPass(SolverContext & solverContext,const TargetConfiguration & config)78   InferQuantizedTypesPass(SolverContext &solverContext,
79                           const TargetConfiguration &config)
80       : explicitSolverContext(&solverContext), explicitConfig(&config) {}
81   void runOnModule() override;
82   void runWithConfig(SolverContext &solverContext,
83                      const TargetConfiguration &config);
84 
85   void transformOperandType(CAGOperandAnchor *anchor, Type newType);
86   void transformResultType(CAGResultAnchor *anchor, Type newType);
87 
88 private:
89   SolverContext *explicitSolverContext = nullptr;
90   const TargetConfiguration *explicitConfig = nullptr;
91 };
92 
93 } // end anonymous namespace
94 
95 /// Maximum number of propagation rounds to run to converge the CAG before
96 /// signalling an error.
97 static const int kMaximumPropagationRounds = 1000;
98 
validateTypeConversion(Type newType,Type origType,Operation * op)99 static LogicalResult validateTypeConversion(Type newType, Type origType,
100                                             Operation *op) {
101   if (!newType) {
102     return op->emitOpError() << "unsupported type conversion from " << newType;
103   }
104   return success();
105 }
106 
runOnModule()107 void InferQuantizedTypesPass::runOnModule() {
108   if (explicitSolverContext && explicitConfig) {
109     // If explicitly constructed with a config and context.
110     runWithConfig(*explicitSolverContext, *explicitConfig);
111     return;
112   }
113 
114   // For global pass registration, use defaults.
115   SolverContext solverContext(*getModule().getContext());
116   auto config = FxpMathTargetConfig::create(solverContext);
117   runWithConfig(solverContext, *config);
118 }
119 
runWithConfig(SolverContext & solverContext,const TargetConfiguration & config)120 void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
121                                             const TargetConfiguration &config) {
122   CAGSlice cag(solverContext);
123   for (auto f : getModule().getOps<FuncOp>()) {
124     f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
125   }
126   config.finalizeAnchors(cag);
127 
128   // Propagate.
129   int propRound;
130   for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
131     auto propCount = cag.propagate(config);
132     if (propCount == 0)
133       break;
134   }
135   if (propRound == 0) {
136     emitError(UnknownLoc::get(&getContext()),
137               "exceeded maximum number of solver iterations (infinite loop?)");
138     return;
139   }
140 
141   // TODO: Only dump the GraphViz if a flag is set and move to a utility.
142   // GraphViz.
143   if (!solverContext.getDebugCAGDotPath().empty()) {
144     auto actFileName =
145         llvm::WriteGraph(const_cast<const CAGSlice *>(&cag), "CAG",
146                          /*ShortNames=*/false,
147                          /*Title=*/"CAG",
148                          /*Filename=*/solverContext.getDebugCAGDotPath());
149     llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
150   }
151 
152   // Start transforming the types in order of anchor type (results, then
153   // operands).
154   // Apply result types.
155   for (auto *node : cag) {
156     auto anchorNode = dyn_cast<CAGResultAnchor>(node);
157     if (!anchorNode)
158       continue;
159     if (Type newType = anchorNode->getTransformedType())
160       transformResultType(anchorNode, newType);
161   }
162 
163   // Apply operand types.
164   for (auto *node : cag) {
165     auto anchorNode = dyn_cast<CAGOperandAnchor>(node);
166     if (!anchorNode)
167       continue;
168     if (Type newType = anchorNode->getTransformedType())
169       transformOperandType(anchorNode, newType);
170   }
171 }
172 
transformOperandType(CAGOperandAnchor * anchor,Type newType)173 void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
174                                                    Type newType) {
175   Value inputValue = anchor->getValue();
176   Operation *op = anchor->getOp();
177   OpBuilder b(op->getBlock(), Block::iterator(op));
178 
179   SmallVector<Value, 1> removeValuesIfDead;
180 
181   // Because we've already run the result transforms at this phase, it is
182   // very likely that inputValue points to a dcast op whose input matches
183   // our type. We detect that situation and route around just to save some
184   // bulk in the IR.
185   Value newTypedInputValue = inputValue;
186   auto inputDcastOp =
187       dyn_cast_or_null<DequantizeCastOp>(inputValue.getDefiningOp());
188   if (inputDcastOp && inputDcastOp.arg().getType() == newType) {
189     // Can just use the dcast's input value.
190     newTypedInputValue = inputDcastOp.arg();
191     removeValuesIfDead.push_back(inputDcastOp);
192   } else {
193     // Need to synthesize a qcast.
194     newTypedInputValue =
195         b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
196   }
197 
198   switch (anchor->getTypeTransformRule()) {
199   case CAGAnchorNode::TypeTransformRule::Direct:
200     anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
201     break;
202 
203   case CAGAnchorNode::TypeTransformRule::DirectStorage: {
204     Type storageType = QuantizedType::castToStorageType(newType);
205     if (failed(validateTypeConversion(storageType, newType, op)))
206       return;
207     anchor->getOp()->setOperand(
208         anchor->getOperandIdx(),
209         b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
210     break;
211   }
212 
213   case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
214     // Leave the anchor as-is and just cast in/out after it.
215     anchor->getOp()->setOperand(
216         anchor->getOperandIdx(),
217         b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
218                                    newTypedInputValue));
219     break;
220   }
221 
222   for (Value removeValueIfDead : removeValuesIfDead) {
223     if (removeValueIfDead.use_empty()) {
224       removeValueIfDead.getDefiningOp()->erase();
225     }
226   }
227 }
228 
transformResultType(CAGResultAnchor * anchor,Type newType)229 void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
230                                                   Type newType) {
231   Value origResultValue = anchor->getValue();
232   Operation *op = origResultValue.getDefiningOp();
233   OpBuilder b(op->getBlock(), ++Block::iterator(op));
234 
235   Value replacedResultValue = nullptr;
236   Value newResultValue = nullptr;
237   switch (anchor->getTypeTransformRule()) {
238   case CAGAnchorNode::TypeTransformRule::Direct:
239     origResultValue.setType(newType);
240     replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
241         op->getLoc(), anchor->getOriginalType(), origResultValue);
242     break;
243 
244   case CAGAnchorNode::TypeTransformRule::DirectStorage: {
245     Type storageType = QuantizedType::castToStorageType(newType);
246     if (failed(validateTypeConversion(storageType, newType, op)))
247       return;
248     origResultValue.setType(storageType);
249     replacedResultValue =
250         b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
251     newResultValue = b.create<DequantizeCastOp>(
252         op->getLoc(), anchor->getOriginalType(), replacedResultValue);
253     break;
254   }
255 
256   case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
257     // Leave the anchor as-is and just cast in/out after it.
258     replacedResultValue =
259         b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
260     newResultValue = b.create<DequantizeCastOp>(
261         op->getLoc(), anchor->getOriginalType(), replacedResultValue);
262     break;
263   }
264 
265   if (replacedResultValue) {
266     // Transform:
267     //   origResultValue -->  replaceResultValue -> newResultValue
268     //                   \->  [original uses]
269     // To:
270     //   origResultValue -> replaceResultValue ->
271     //                      newResultValue -> [original uses]
272     // Note that replaceResultValue may equal newResultValue or there may
273     // be operands between the two.
274     origResultValue.replaceAllUsesWith(newResultValue);
275     replacedResultValue.getDefiningOp()->replaceUsesOfWith(newResultValue,
276                                                            origResultValue);
277   }
278 }
279 
280 std::unique_ptr<OpPassBase<ModuleOp>>
createInferQuantizedTypesPass(SolverContext & solverContext,const TargetConfiguration & config)281 mlir::quantizer::createInferQuantizedTypesPass(
282     SolverContext &solverContext, const TargetConfiguration &config) {
283   return std::make_unique<InferQuantizedTypesPass>(solverContext, config);
284 }
285 
286 static PassRegistration<InferQuantizedTypesPass>
287     pass("quantizer-infer-quantized-types",
288          "Infers quantized types for a module");
289