1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Quant/FakeQuantSupport.h"
11 #include "mlir/Dialect/Quant/Passes.h"
12 #include "mlir/Dialect/Quant/QuantOps.h"
13 #include "mlir/Dialect/Quant/UniformSupport.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 using namespace mlir;
18 using namespace mlir::quant;
19 
20 namespace {
21 struct ConvertSimulatedQuantPass
22     : public QuantConvertSimulatedQuantBase<ConvertSimulatedQuantPass> {
23   void runOnFunction() override;
24 };
25 
26 /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
27 template <typename ConcreteRewriteClass, typename FakeQuantOp>
28 class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
29 public:
30   using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
31 
FakeQuantRewrite(MLIRContext * ctx,bool * hadFailure)32   FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
33       : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
34 
matchAndRewrite(FakeQuantOp op,PatternRewriter & rewriter) const35   LogicalResult matchAndRewrite(FakeQuantOp op,
36                                 PatternRewriter &rewriter) const override {
37     // TODO: If this pattern comes up more frequently, consider adding core
38     // support for failable rewrites.
39     if (failableRewrite(op, rewriter)) {
40       *hadFailure = true;
41       return failure();
42     }
43 
44     return success();
45   }
46 
47 private:
48   bool *hadFailure;
49 
failableRewrite(FakeQuantOp op,PatternRewriter & rewriter) const50   bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
51     auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
52     if (!converter) {
53       return (op.emitError("unsupported quantized type conversion"), true);
54     }
55 
56     QuantizedType elementType =
57         static_cast<const ConcreteRewriteClass *>(this)
58             ->convertFakeQuantAttrsToType(op, converter.expressedType);
59 
60     if (!elementType) {
61       // Note that the fakeQuantAttrsToType will have emitted the error.
62       return true;
63     }
64 
65     Type quantizedType = converter.convert(elementType);
66     assert(quantizedType &&
67            "Converter accepted a type that it did not convert");
68 
69     // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
70     // this is a forced/hard-coded constraint.
71     auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
72                                                     op.inputs());
73     rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
74                                                   qbarrier.getResult());
75 
76     return false;
77   }
78 };
79 
80 class ConstFakeQuantRewrite
81     : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
82 public:
83   using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
84 
ConstFakeQuantRewrite(MLIRContext * ctx,bool * hadFailure)85   ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
86       : BaseRewrite(ctx, hadFailure) {}
87 
convertFakeQuantAttrsToType(ConstFakeQuant fqOp,Type expressedType) const88   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
89                                             Type expressedType) const {
90     return fakeQuantAttrsToType(
91         fqOp.getLoc(), fqOp.num_bits(), fqOp.min().convertToFloat(),
92         fqOp.max().convertToFloat(), fqOp.narrow_range(), expressedType,
93         fqOp.is_signed());
94   }
95 };
96 
97 class ConstFakeQuantPerAxisRewrite
98     : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
99                               ConstFakeQuantPerAxis> {
100 public:
101   using BaseRewrite =
102       FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
103 
ConstFakeQuantPerAxisRewrite(MLIRContext * ctx,bool * hadFailure)104   ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
105       : BaseRewrite(ctx, hadFailure) {}
106 
convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,Type expressedType) const107   QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
108                                             Type expressedType) const {
109     SmallVector<double, 4> min, max;
110     min.reserve(fqOp.min().size());
111     max.reserve(fqOp.max().size());
112     for (auto m : fqOp.min())
113       min.push_back(m.cast<FloatAttr>().getValueAsDouble());
114     for (auto m : fqOp.max())
115       max.push_back(m.cast<FloatAttr>().getValueAsDouble());
116 
117     return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits(), fqOp.axis(),
118                                 min, max, fqOp.narrow_range(), expressedType,
119                                 fqOp.is_signed());
120   }
121 };
122 
123 } // namespace
124 
runOnFunction()125 void ConvertSimulatedQuantPass::runOnFunction() {
126   bool hadFailure = false;
127   auto func = getFunction();
128   RewritePatternSet patterns(func.getContext());
129   auto ctx = func.getContext();
130   patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
131       ctx, &hadFailure);
132   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
133   if (hadFailure)
134     signalPassFailure();
135 }
136 
137 std::unique_ptr<OperationPass<FuncOp>>
createConvertSimulatedQuantPass()138 mlir::quant::createConvertSimulatedQuantPass() {
139   return std::make_unique<ConvertSimulatedQuantPass>();
140 }
141