1 //===- LowerUniformRealMath.cpp  ------------------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "UniformKernelUtils.h"
10 
11 #include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
12 #include "mlir/Dialect/FxpMathOps/Passes.h"
13 #include "mlir/Dialect/StandardOps/Ops.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::fxpmath;
20 using namespace mlir::fxpmath::detail;
21 using namespace mlir::quant;
22 
23 namespace {
24 
25 struct LowerUniformRealMathPass
26     : public FunctionPass<LowerUniformRealMathPass> {
27   void runOnFunction() override;
28 };
29 
30 struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
31   void runOnFunction() override;
32 };
33 
34 } // end anonymous namespace
35 
36 //===----------------------------------------------------------------------===//
37 // Dequantize
38 //===----------------------------------------------------------------------===//
39 
emitUniformPerLayerDequantize(Location loc,Value input,UniformQuantizedType elementType,PatternRewriter & rewriter)40 static Value emitUniformPerLayerDequantize(Location loc, Value input,
41                                            UniformQuantizedType elementType,
42                                            PatternRewriter &rewriter) {
43   // Pre-conditions.
44   if (!elementType.isSigned()) {
45     // TODO: Support unsigned storage type.
46     emitWarning(loc, "unimplemented: dequantize signed uniform");
47     return nullptr;
48   }
49 
50   Type storageType = elementType.castToStorageType(input.getType());
51   Type realType = elementType.castToExpressedType(input.getType());
52   Type intermediateType =
53       castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
54   assert(storageType && "cannot cast to storage type");
55   assert(realType && "cannot cast to expressed type");
56 
57   // Cast to storage type.
58   input = rewriter.create<StorageCastOp>(loc, storageType, input);
59 
60   // Promote to intermediate type.
61   input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
62 
63   // Apply zero-point offset.
64   if (elementType.getZeroPoint() != 0) {
65     Value negZeroPointConst = rewriter.create<ConstantOp>(
66         loc, broadcastScalarConstIntValue(intermediateType,
67                                           -elementType.getZeroPoint()));
68     input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
69   }
70 
71   // Convert to float.
72   input = rewriter.create<ConvertISToFOp>(loc, realType, input);
73 
74   // Mul by scale.
75   Value scaleConst = rewriter.create<ConstantOp>(
76       loc, broadcastScalarConstFloatValue(realType,
77                                           APFloat(elementType.getScale())));
78   return rewriter.create<MulFOp>(loc, input, scaleConst);
79 }
80 
81 static Value
emitUniformPerAxisDequantize(Location loc,Value input,UniformQuantizedPerAxisType elementType,PatternRewriter & rewriter)82 emitUniformPerAxisDequantize(Location loc, Value input,
83                              UniformQuantizedPerAxisType elementType,
84                              PatternRewriter &rewriter) {
85   // TODO: Support per-axis dequantize.
86   rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
87       << "unimplemented: per-axis uniform dequantization";
88   return nullptr;
89 }
90 
emitDequantize(Location loc,Value input,PatternRewriter & rewriter)91 static Value emitDequantize(Location loc, Value input,
92                             PatternRewriter &rewriter) {
93   Type inputType = input.getType();
94   QuantizedType qElementType =
95       QuantizedType::getQuantizedElementType(inputType);
96   if (auto uperLayerElementType =
97           qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
98     return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
99                                          rewriter);
100   } else if (auto uperAxisElementType =
101                  qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
102     return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
103                                         rewriter);
104   } else {
105     return nullptr;
106   }
107 }
108 
109 namespace {
110 
111 struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
112   using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
113 
matchAndRewrite__anona4a7cf450211::UniformDequantizePattern114   PatternMatchResult matchAndRewrite(DequantizeCastOp op,
115                                      PatternRewriter &rewriter) const override {
116     Type inputType = op.arg().getType();
117     Type outputType = op.getResult().getType();
118 
119     QuantizedType inputElementType =
120         QuantizedType::getQuantizedElementType(inputType);
121     Type expressedOutputType = inputElementType.castToExpressedType(inputType);
122     if (expressedOutputType != outputType) {
123       // Not a valid uniform cast.
124       return matchFailure();
125     }
126 
127     Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
128     if (!dequantizedValue) {
129       return matchFailure();
130     }
131 
132     rewriter.replaceOp(op, dequantizedValue);
133     return matchSuccess();
134   }
135 };
136 
137 } // end anonymous namespace
138 
139 //===----------------------------------------------------------------------===//
140 // Elementwise add
141 //===----------------------------------------------------------------------===//
142 
143 static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo & info,PatternRewriter & rewriter)144 tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
145                                       PatternRewriter &rewriter) {
146   if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
147       info.rhsType != info.resultType) {
148     return failure();
149   }
150 
151   // Choose a byte aligned intermediate width big enough to perform the
152   // calculation without overflow.
153   // TODO: This should probably be made just big enough to avoid overflow and
154   // leave the downstream tooling to decide how to align that to machine
155   // word sizes.
156   unsigned intermediateWidth =
157       info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
158   IntegerType intermediateElementType =
159       IntegerType::get(intermediateWidth, rewriter.getContext());
160   Type intermediateType =
161       castElementType(info.resultStorageType, intermediateElementType);
162 
163   // Cast operands to storage type.
164   Value lhsValue = rewriter
165                        .create<StorageCastOp>(info.op->getLoc(),
166                                               info.lhsStorageType, info.lhs)
167                        .getResult();
168   Value rhsValue = rewriter
169                        .create<StorageCastOp>(info.op->getLoc(),
170                                               info.rhsStorageType, info.rhs)
171                        .getResult();
172 
173   // Cast to the intermediate sized type.
174   lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
175                                           lhsValue);
176   rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
177                                           rhsValue);
178 
179   // Add.
180   Value resultValue =
181       rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
182 
183   // Zero point offset adjustment.
184   // result = (lhs - zp) + (rhs - zp) + zp
185   // zpOffset = -zp
186   int zpOffset = -1 * info.resultType.getZeroPoint();
187   if (zpOffset != 0) {
188     Value zpOffsetConst = rewriter.create<ConstantOp>(
189         info.op->getLoc(),
190         broadcastScalarConstIntValue(intermediateType, zpOffset));
191     resultValue =
192         rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
193   }
194 
195   // Clamp.
196   auto clampMinMax = info.getClampMinMax(intermediateElementType);
197   resultValue = rewriter.create<ClampISOp>(
198       info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
199 
200   // Convert back to original type.
201   resultValue = rewriter.create<ConvertISOp>(
202       info.op->getLoc(), info.resultStorageType, resultValue);
203 
204   // Cast back for new result.
205   rewriter.replaceOpWithNewOp<StorageCastOp>(
206       info.op, info.getQuantizedResultType(), resultValue);
207 
208   return success();
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // Elementwise mul
213 //===----------------------------------------------------------------------===//
214 
215 static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo & info,PatternRewriter & rewriter)216 tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
217                             PatternRewriter &rewriter) {
218   if (!info.resultType.isSigned()) {
219     return failure();
220   }
221 
222   double outputMultiplierReal = info.lhsType.getScale() *
223                                 info.rhsType.getScale() /
224                                 info.resultType.getScale();
225   if (outputMultiplierReal > 1.0) {
226     info.op->emitWarning(
227         "unimplemented: cannot multiply with multiplier > 1.0");
228     return failure();
229   }
230 
231   // TODO: Choose an appropriate intermediate width for muls > 8 bits to
232   // avoid overflow.
233   unsigned intermediateWidth = 32;
234   IntegerType intermediateElementType =
235       IntegerType::get(intermediateWidth, rewriter.getContext());
236   Type intermediateType =
237       castElementType(info.resultStorageType, intermediateElementType);
238 
239   // Cast operands to storage type.
240   Value lhsValue = rewriter
241                        .create<StorageCastOp>(info.op->getLoc(),
242                                               info.lhsStorageType, info.lhs)
243                        .getResult();
244   Value rhsValue = rewriter
245                        .create<StorageCastOp>(info.op->getLoc(),
246                                               info.rhsStorageType, info.rhs)
247                        .getResult();
248 
249   // Cast to the intermediate sized type.
250   lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
251                                           lhsValue);
252   rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
253                                           rhsValue);
254 
255   // Apply argument zeroPoints.
256   if (info.lhsType.getZeroPoint() != 0) {
257     Value zpOffsetConst = rewriter.create<ConstantOp>(
258         info.op->getLoc(), broadcastScalarConstIntValue(
259                                intermediateType, -info.lhsType.getZeroPoint()));
260     lhsValue =
261         rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
262   }
263 
264   if (info.rhsType.getZeroPoint() != 0) {
265     Value zpOffsetConst = rewriter.create<ConstantOp>(
266         info.op->getLoc(), broadcastScalarConstIntValue(
267                                intermediateType, -info.rhsType.getZeroPoint()));
268     rhsValue =
269         rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
270   }
271 
272   // Mul.
273   Value resultValue =
274       rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
275 
276   // Scale output.
277   QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
278   resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
279       info.op->getLoc(), resultValue,
280       IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
281   resultValue = rewriter.create<RoundingDivideByPotISOp>(
282       info.op->getLoc(), resultValue,
283       IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
284 
285   // Zero point offset adjustment.
286   if (info.resultType.getZeroPoint() != 0) {
287     Value zpOffsetConst = rewriter.create<ConstantOp>(
288         info.op->getLoc(),
289         broadcastScalarConstIntValue(intermediateType,
290                                      info.resultType.getZeroPoint()));
291     resultValue =
292         rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
293   }
294 
295   // Clamp.
296   auto clampMinMax = info.getClampMinMax(intermediateElementType);
297   resultValue = rewriter.create<ClampISOp>(
298       info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
299 
300   // Convert back to original type.
301   resultValue = rewriter.create<ConvertISOp>(
302       info.op->getLoc(), info.resultStorageType, resultValue);
303 
304   // Cast back for new result.
305   rewriter.replaceOpWithNewOp<StorageCastOp>(
306       info.op, info.getQuantizedResultType(), resultValue);
307 
308   return success();
309 }
310 
311 namespace {
312 
313 struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
314   using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
315 
matchAndRewrite__anona4a7cf450311::UniformRealAddEwPattern316   PatternMatchResult matchAndRewrite(RealAddEwOp op,
317                                      PatternRewriter &rewriter) const override {
318     const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
319                                    op.clamp_max());
320     if (!info.isValid()) {
321       return matchFailure();
322     }
323 
324     // Try all of the permutations we support.
325     if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
326       return matchSuccess();
327     }
328 
329     return matchFailure();
330   }
331 };
332 
333 struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
334   using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
335 
matchAndRewrite__anona4a7cf450311::UniformRealMulEwPattern336   PatternMatchResult matchAndRewrite(RealMulEwOp op,
337                                      PatternRewriter &rewriter) const override {
338     const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
339                                    op.clamp_max());
340     if (!info.isValid()) {
341       return matchFailure();
342     }
343 
344     // Try all of the permutations we support.
345     if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
346       return matchSuccess();
347     }
348 
349     return matchFailure();
350   }
351 };
352 
353 } // end anonymous namespace
354 
355 //===----------------------------------------------------------------------===//
356 // LowerUniformRealMath pass
357 //===----------------------------------------------------------------------===//
358 
runOnFunction()359 void LowerUniformRealMathPass::runOnFunction() {
360   auto fn = getFunction();
361   OwningRewritePatternList patterns;
362   auto *context = &getContext();
363   patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
364   applyPatternsGreedily(fn, patterns);
365 }
366 
createLowerUniformRealMathPass()367 OpPassBase<FuncOp> *mlir::fxpmath::createLowerUniformRealMathPass() {
368   return new LowerUniformRealMathPass();
369 }
370 
371 static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
372     "fxpmath-lower-uniform-real-math",
373     "Lowers uniform-quantized real math ops to integer arithmetic.");
374 
375 //===----------------------------------------------------------------------===//
376 // LowerUniformCasts pass
377 //===----------------------------------------------------------------------===//
378 
runOnFunction()379 void LowerUniformCastsPass::runOnFunction() {
380   auto fn = getFunction();
381   OwningRewritePatternList patterns;
382   auto *context = &getContext();
383   patterns.insert<UniformDequantizePattern>(context);
384   applyPatternsGreedily(fn, patterns);
385 }
386 
createLowerUniformCastsPass()387 OpPassBase<FuncOp> *mlir::fxpmath::createLowerUniformCastsPass() {
388   return new LowerUniformCastsPass();
389 }
390 
391 static PassRegistration<LowerUniformCastsPass>
392     lowerUniformCastsPass("fxpmath-lower-uniform-casts",
393                           "Lowers uniform-quantized casts.");
394