1 //===- LowerUniformRealMath.cpp ------------------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "UniformKernelUtils.h"
10
11 #include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
12 #include "mlir/Dialect/FxpMathOps/Passes.h"
13 #include "mlir/Dialect/StandardOps/Ops.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17
18 using namespace mlir;
19 using namespace mlir::fxpmath;
20 using namespace mlir::fxpmath::detail;
21 using namespace mlir::quant;
22
23 namespace {
24
25 struct LowerUniformRealMathPass
26 : public FunctionPass<LowerUniformRealMathPass> {
27 void runOnFunction() override;
28 };
29
30 struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
31 void runOnFunction() override;
32 };
33
34 } // end anonymous namespace
35
36 //===----------------------------------------------------------------------===//
37 // Dequantize
38 //===----------------------------------------------------------------------===//
39
emitUniformPerLayerDequantize(Location loc,Value input,UniformQuantizedType elementType,PatternRewriter & rewriter)40 static Value emitUniformPerLayerDequantize(Location loc, Value input,
41 UniformQuantizedType elementType,
42 PatternRewriter &rewriter) {
43 // Pre-conditions.
44 if (!elementType.isSigned()) {
45 // TODO: Support unsigned storage type.
46 emitWarning(loc, "unimplemented: dequantize signed uniform");
47 return nullptr;
48 }
49
50 Type storageType = elementType.castToStorageType(input.getType());
51 Type realType = elementType.castToExpressedType(input.getType());
52 Type intermediateType =
53 castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
54 assert(storageType && "cannot cast to storage type");
55 assert(realType && "cannot cast to expressed type");
56
57 // Cast to storage type.
58 input = rewriter.create<StorageCastOp>(loc, storageType, input);
59
60 // Promote to intermediate type.
61 input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
62
63 // Apply zero-point offset.
64 if (elementType.getZeroPoint() != 0) {
65 Value negZeroPointConst = rewriter.create<ConstantOp>(
66 loc, broadcastScalarConstIntValue(intermediateType,
67 -elementType.getZeroPoint()));
68 input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
69 }
70
71 // Convert to float.
72 input = rewriter.create<ConvertISToFOp>(loc, realType, input);
73
74 // Mul by scale.
75 Value scaleConst = rewriter.create<ConstantOp>(
76 loc, broadcastScalarConstFloatValue(realType,
77 APFloat(elementType.getScale())));
78 return rewriter.create<MulFOp>(loc, input, scaleConst);
79 }
80
81 static Value
emitUniformPerAxisDequantize(Location loc,Value input,UniformQuantizedPerAxisType elementType,PatternRewriter & rewriter)82 emitUniformPerAxisDequantize(Location loc, Value input,
83 UniformQuantizedPerAxisType elementType,
84 PatternRewriter &rewriter) {
85 // TODO: Support per-axis dequantize.
86 rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
87 << "unimplemented: per-axis uniform dequantization";
88 return nullptr;
89 }
90
emitDequantize(Location loc,Value input,PatternRewriter & rewriter)91 static Value emitDequantize(Location loc, Value input,
92 PatternRewriter &rewriter) {
93 Type inputType = input.getType();
94 QuantizedType qElementType =
95 QuantizedType::getQuantizedElementType(inputType);
96 if (auto uperLayerElementType =
97 qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
98 return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
99 rewriter);
100 } else if (auto uperAxisElementType =
101 qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
102 return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
103 rewriter);
104 } else {
105 return nullptr;
106 }
107 }
108
109 namespace {
110
111 struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
112 using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
113
matchAndRewrite__anona4a7cf450211::UniformDequantizePattern114 PatternMatchResult matchAndRewrite(DequantizeCastOp op,
115 PatternRewriter &rewriter) const override {
116 Type inputType = op.arg().getType();
117 Type outputType = op.getResult().getType();
118
119 QuantizedType inputElementType =
120 QuantizedType::getQuantizedElementType(inputType);
121 Type expressedOutputType = inputElementType.castToExpressedType(inputType);
122 if (expressedOutputType != outputType) {
123 // Not a valid uniform cast.
124 return matchFailure();
125 }
126
127 Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
128 if (!dequantizedValue) {
129 return matchFailure();
130 }
131
132 rewriter.replaceOp(op, dequantizedValue);
133 return matchSuccess();
134 }
135 };
136
137 } // end anonymous namespace
138
139 //===----------------------------------------------------------------------===//
140 // Elementwise add
141 //===----------------------------------------------------------------------===//
142
143 static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo & info,PatternRewriter & rewriter)144 tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
145 PatternRewriter &rewriter) {
146 if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
147 info.rhsType != info.resultType) {
148 return failure();
149 }
150
151 // Choose a byte aligned intermediate width big enough to perform the
152 // calculation without overflow.
153 // TODO: This should probably be made just big enough to avoid overflow and
154 // leave the downstream tooling to decide how to align that to machine
155 // word sizes.
156 unsigned intermediateWidth =
157 info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
158 IntegerType intermediateElementType =
159 IntegerType::get(intermediateWidth, rewriter.getContext());
160 Type intermediateType =
161 castElementType(info.resultStorageType, intermediateElementType);
162
163 // Cast operands to storage type.
164 Value lhsValue = rewriter
165 .create<StorageCastOp>(info.op->getLoc(),
166 info.lhsStorageType, info.lhs)
167 .getResult();
168 Value rhsValue = rewriter
169 .create<StorageCastOp>(info.op->getLoc(),
170 info.rhsStorageType, info.rhs)
171 .getResult();
172
173 // Cast to the intermediate sized type.
174 lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
175 lhsValue);
176 rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
177 rhsValue);
178
179 // Add.
180 Value resultValue =
181 rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
182
183 // Zero point offset adjustment.
184 // result = (lhs - zp) + (rhs - zp) + zp
185 // zpOffset = -zp
186 int zpOffset = -1 * info.resultType.getZeroPoint();
187 if (zpOffset != 0) {
188 Value zpOffsetConst = rewriter.create<ConstantOp>(
189 info.op->getLoc(),
190 broadcastScalarConstIntValue(intermediateType, zpOffset));
191 resultValue =
192 rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
193 }
194
195 // Clamp.
196 auto clampMinMax = info.getClampMinMax(intermediateElementType);
197 resultValue = rewriter.create<ClampISOp>(
198 info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
199
200 // Convert back to original type.
201 resultValue = rewriter.create<ConvertISOp>(
202 info.op->getLoc(), info.resultStorageType, resultValue);
203
204 // Cast back for new result.
205 rewriter.replaceOpWithNewOp<StorageCastOp>(
206 info.op, info.getQuantizedResultType(), resultValue);
207
208 return success();
209 }
210
211 //===----------------------------------------------------------------------===//
212 // Elementwise mul
213 //===----------------------------------------------------------------------===//
214
215 static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo & info,PatternRewriter & rewriter)216 tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
217 PatternRewriter &rewriter) {
218 if (!info.resultType.isSigned()) {
219 return failure();
220 }
221
222 double outputMultiplierReal = info.lhsType.getScale() *
223 info.rhsType.getScale() /
224 info.resultType.getScale();
225 if (outputMultiplierReal > 1.0) {
226 info.op->emitWarning(
227 "unimplemented: cannot multiply with multiplier > 1.0");
228 return failure();
229 }
230
231 // TODO: Choose an appropriate intermediate width for muls > 8 bits to
232 // avoid overflow.
233 unsigned intermediateWidth = 32;
234 IntegerType intermediateElementType =
235 IntegerType::get(intermediateWidth, rewriter.getContext());
236 Type intermediateType =
237 castElementType(info.resultStorageType, intermediateElementType);
238
239 // Cast operands to storage type.
240 Value lhsValue = rewriter
241 .create<StorageCastOp>(info.op->getLoc(),
242 info.lhsStorageType, info.lhs)
243 .getResult();
244 Value rhsValue = rewriter
245 .create<StorageCastOp>(info.op->getLoc(),
246 info.rhsStorageType, info.rhs)
247 .getResult();
248
249 // Cast to the intermediate sized type.
250 lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
251 lhsValue);
252 rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
253 rhsValue);
254
255 // Apply argument zeroPoints.
256 if (info.lhsType.getZeroPoint() != 0) {
257 Value zpOffsetConst = rewriter.create<ConstantOp>(
258 info.op->getLoc(), broadcastScalarConstIntValue(
259 intermediateType, -info.lhsType.getZeroPoint()));
260 lhsValue =
261 rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
262 }
263
264 if (info.rhsType.getZeroPoint() != 0) {
265 Value zpOffsetConst = rewriter.create<ConstantOp>(
266 info.op->getLoc(), broadcastScalarConstIntValue(
267 intermediateType, -info.rhsType.getZeroPoint()));
268 rhsValue =
269 rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
270 }
271
272 // Mul.
273 Value resultValue =
274 rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
275
276 // Scale output.
277 QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
278 resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
279 info.op->getLoc(), resultValue,
280 IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
281 resultValue = rewriter.create<RoundingDivideByPotISOp>(
282 info.op->getLoc(), resultValue,
283 IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
284
285 // Zero point offset adjustment.
286 if (info.resultType.getZeroPoint() != 0) {
287 Value zpOffsetConst = rewriter.create<ConstantOp>(
288 info.op->getLoc(),
289 broadcastScalarConstIntValue(intermediateType,
290 info.resultType.getZeroPoint()));
291 resultValue =
292 rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
293 }
294
295 // Clamp.
296 auto clampMinMax = info.getClampMinMax(intermediateElementType);
297 resultValue = rewriter.create<ClampISOp>(
298 info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
299
300 // Convert back to original type.
301 resultValue = rewriter.create<ConvertISOp>(
302 info.op->getLoc(), info.resultStorageType, resultValue);
303
304 // Cast back for new result.
305 rewriter.replaceOpWithNewOp<StorageCastOp>(
306 info.op, info.getQuantizedResultType(), resultValue);
307
308 return success();
309 }
310
311 namespace {
312
313 struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
314 using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
315
matchAndRewrite__anona4a7cf450311::UniformRealAddEwPattern316 PatternMatchResult matchAndRewrite(RealAddEwOp op,
317 PatternRewriter &rewriter) const override {
318 const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
319 op.clamp_max());
320 if (!info.isValid()) {
321 return matchFailure();
322 }
323
324 // Try all of the permutations we support.
325 if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
326 return matchSuccess();
327 }
328
329 return matchFailure();
330 }
331 };
332
333 struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
334 using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
335
matchAndRewrite__anona4a7cf450311::UniformRealMulEwPattern336 PatternMatchResult matchAndRewrite(RealMulEwOp op,
337 PatternRewriter &rewriter) const override {
338 const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
339 op.clamp_max());
340 if (!info.isValid()) {
341 return matchFailure();
342 }
343
344 // Try all of the permutations we support.
345 if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
346 return matchSuccess();
347 }
348
349 return matchFailure();
350 }
351 };
352
353 } // end anonymous namespace
354
355 //===----------------------------------------------------------------------===//
356 // LowerUniformRealMath pass
357 //===----------------------------------------------------------------------===//
358
runOnFunction()359 void LowerUniformRealMathPass::runOnFunction() {
360 auto fn = getFunction();
361 OwningRewritePatternList patterns;
362 auto *context = &getContext();
363 patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
364 applyPatternsGreedily(fn, patterns);
365 }
366
createLowerUniformRealMathPass()367 OpPassBase<FuncOp> *mlir::fxpmath::createLowerUniformRealMathPass() {
368 return new LowerUniformRealMathPass();
369 }
370
371 static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
372 "fxpmath-lower-uniform-real-math",
373 "Lowers uniform-quantized real math ops to integer arithmetic.");
374
375 //===----------------------------------------------------------------------===//
376 // LowerUniformCasts pass
377 //===----------------------------------------------------------------------===//
378
runOnFunction()379 void LowerUniformCastsPass::runOnFunction() {
380 auto fn = getFunction();
381 OwningRewritePatternList patterns;
382 auto *context = &getContext();
383 patterns.insert<UniformDequantizePattern>(context);
384 applyPatternsGreedily(fn, patterns);
385 }
386
createLowerUniformCastsPass()387 OpPassBase<FuncOp> *mlir::fxpmath::createLowerUniformCastsPass() {
388 return new LowerUniformCastsPass();
389 }
390
391 static PassRegistration<LowerUniformCastsPass>
392 lowerUniformCastsPass("fxpmath-lower-uniform-casts",
393 "Lowers uniform-quantized casts.");
394