1 //===- TosaTestPasses.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Test passes to exercise TOSA helper functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 #define PASS_NAME "tosa-test-quant-utils"
25 
26 using namespace mlir;
27 using namespace mlir::tosa;
28 
29 // This transformation converts quantized uint8 to quantized int8. The
30 // construction of the new type invokes buildQTypeFromMinMax. Extracted from
31 // TOSA legalization infrastructure.
32 struct ConvertTosaNegateOp : public RewritePattern {
ConvertTosaNegateOpConvertTosaNegateOp33   explicit ConvertTosaNegateOp(MLIRContext *context)
34       : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
35   LogicalResult matchAndRewrite(Operation *op,
36                                 PatternRewriter &rewriter) const override;
37 };
38 
39 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const40 ConvertTosaNegateOp::matchAndRewrite(Operation *op,
41                                      PatternRewriter &rewriter) const {
42 
43   auto tosaNegateOp = cast<tosa::NegateOp>(op);
44 
45   auto inputType =
46       tosaNegateOp.input1().getType().dyn_cast<mlir::RankedTensorType>();
47   // skip if input is not ranked tensor type
48   if (!inputType)
49     return failure();
50 
51   // skip if it's not ranked tensor type.
52   auto outputType =
53       tosaNegateOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
54   if (!outputType)
55     return failure();
56 
57   // skip if output is not per-tensor quantized type.
58   auto outputElementType =
59       outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
60   if (!outputElementType)
61     return failure();
62 
63   // skip if output is not uint8.
64   if (outputElementType.isSigned() ||
65       outputElementType.getStorageTypeIntegralWidth() != 8)
66     return failure();
67 
68   double typeRangeMin = double(outputElementType.getStorageTypeMin() -
69                                outputElementType.getZeroPoint()) *
70                         outputElementType.getScale();
71   double typeRangeMax = double(outputElementType.getStorageTypeMax() -
72                                outputElementType.getZeroPoint()) *
73                         outputElementType.getScale();
74   bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false;
75 
76   auto dstQConstType = RankedTensorType::get(
77       outputType.getShape(),
78       buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
79                            rewriter.getF64FloatAttr(typeRangeMin),
80                            rewriter.getF64FloatAttr(typeRangeMax),
81                            rewriter.getI32IntegerAttr(
82                                outputElementType.getStorageTypeIntegralWidth()),
83                            0, true /* signed */,
84                            rewriter.getBoolAttr(narrow_range)));
85 
86   ElementsAttr inputElems;
87   if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems)))
88     return failure();
89 
90   auto newConstOp =
91       rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
92   auto newNegateOp = rewriter.create<tosa::NegateOp>(
93       op->getLoc(), dstQConstType, newConstOp.getResult());
94 
95   rewriter.replaceOp(op, {newNegateOp.getResult()});
96   return success();
97 }
98 
99 // This transformation modifies the quantized output of a test conv2d input and
100 // appends a TOSA rescale after it. The rescale op requires the invocation of
101 // computeMultiplierAndShift. From TOSA legalization infrastructure.
102 struct ConvertTosaConv2DOp : public RewritePattern {
ConvertTosaConv2DOpConvertTosaConv2DOp103   explicit ConvertTosaConv2DOp(MLIRContext *context)
104       : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
105   LogicalResult matchAndRewrite(Operation *op,
106                                 PatternRewriter &rewriter) const override;
107 };
108 
109 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const110 ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
111                                      PatternRewriter &rewriter) const {
112 
113   auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
114 
115   auto inputType =
116       tosaConv2DOp.input().getType().dyn_cast<mlir::RankedTensorType>();
117 
118   // skip if input is not ranked tensor type
119   if (!inputType)
120     return failure();
121 
122   auto weightType =
123       tosaConv2DOp.weight().getType().dyn_cast<mlir::RankedTensorType>();
124 
125   // skip if wt is not ranked tensor type
126   if (!weightType)
127     return failure();
128 
129   // skip if it's not ranked tensor type.
130   auto outputType =
131       tosaConv2DOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
132   if (!outputType)
133     return failure();
134 
135   auto inputQType =
136       inputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
137   auto weightQType =
138       weightType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
139   auto outputQType =
140       outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
141 
142   // Works on quantized type only.
143   if (!(inputQType && weightQType && outputQType))
144     return failure();
145 
146   auto newTosaConv2DOpType =
147       RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
148 
149   auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
150       op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(),
151       tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(),
152       tosaConv2DOp.stride(), tosaConv2DOp.dilation());
153 
154   // Create rescale to quantized type
155   double inputScale = inputQType.getScale();
156   double weightScale = weightQType.getScale();
157   double outputScale = outputQType.getScale();
158   int64_t outputZp = outputQType.getZeroPoint();
159 
160   double opTensorScale = (inputScale * weightScale) / outputScale;
161 
162   int32_t multiplier;
163   int32_t shift;
164 
165   // Obtain the quantized scale = multiplier and shift.
166   computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);
167 
168   auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
169       op->getLoc(), outputType, newTosaConv2DOp.getResult(),
170       rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
171       rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
172       rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
173       rewriter.getBoolAttr(false));
174 
175   rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
176   return success();
177 }
178 
179 namespace {
180 
181 struct TosaTestQuantUtilAPI
182     : public PassWrapper<TosaTestQuantUtilAPI, FunctionPass> {
getArgument__anonc0dc2cf70111::TosaTestQuantUtilAPI183   StringRef getArgument() const final { return PASS_NAME; }
getDescription__anonc0dc2cf70111::TosaTestQuantUtilAPI184   StringRef getDescription() const final {
185     return "TOSA Test: Exercise the APIs in QuantUtils.cpp.";
186   }
187   void runOnFunction() override;
188 };
189 
runOnFunction()190 void TosaTestQuantUtilAPI::runOnFunction() {
191   auto *ctx = &getContext();
192   RewritePatternSet patterns(ctx);
193   auto func = getFunction();
194 
195   patterns.add<ConvertTosaNegateOp>(ctx);
196   patterns.add<ConvertTosaConv2DOp>(ctx);
197   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
198 }
199 
200 } // anonymous namespace
201 
202 namespace mlir {
registerTosaTestQuantUtilAPIPass()203 void registerTosaTestQuantUtilAPIPass() {
204   PassRegistration<TosaTestQuantUtilAPI>();
205 }
206 } // namespace mlir
207