//===- LowerUniformRealMath.cpp ------------------------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "UniformKernelUtils.h" #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/FxpMathOps/Passes.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::fxpmath; using namespace mlir::fxpmath::detail; using namespace mlir::quant; namespace { struct LowerUniformRealMathPass : public FunctionPass { void runOnFunction() override; }; struct LowerUniformCastsPass : public FunctionPass { void runOnFunction() override; }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Dequantize //===----------------------------------------------------------------------===// static Value emitUniformPerLayerDequantize(Location loc, Value input, UniformQuantizedType elementType, PatternRewriter &rewriter) { // Pre-conditions. if (!elementType.isSigned()) { // TODO: Support unsigned storage type. emitWarning(loc, "unimplemented: dequantize signed uniform"); return nullptr; } Type storageType = elementType.castToStorageType(input.getType()); Type realType = elementType.castToExpressedType(input.getType()); Type intermediateType = castElementType(storageType, IntegerType::get(32, rewriter.getContext())); assert(storageType && "cannot cast to storage type"); assert(realType && "cannot cast to expressed type"); // Cast to storage type. input = rewriter.create(loc, storageType, input); // Promote to intermediate type. input = rewriter.create(loc, intermediateType, input); // Apply zero-point offset. if (elementType.getZeroPoint() != 0) { Value negZeroPointConst = rewriter.create( loc, broadcastScalarConstIntValue(intermediateType, -elementType.getZeroPoint())); input = rewriter.create(loc, input, negZeroPointConst); } // Convert to float. input = rewriter.create(loc, realType, input); // Mul by scale. Value scaleConst = rewriter.create( loc, broadcastScalarConstFloatValue(realType, APFloat(elementType.getScale()))); return rewriter.create(loc, input, scaleConst); } static Value emitUniformPerAxisDequantize(Location loc, Value input, UniformQuantizedPerAxisType elementType, PatternRewriter &rewriter) { // TODO: Support per-axis dequantize. rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning) << "unimplemented: per-axis uniform dequantization"; return nullptr; } static Value emitDequantize(Location loc, Value input, PatternRewriter &rewriter) { Type inputType = input.getType(); QuantizedType qElementType = QuantizedType::getQuantizedElementType(inputType); if (auto uperLayerElementType = qElementType.dyn_cast_or_null()) { return emitUniformPerLayerDequantize(loc, input, uperLayerElementType, rewriter); } else if (auto uperAxisElementType = qElementType.dyn_cast_or_null()) { return emitUniformPerAxisDequantize(loc, input, uperAxisElementType, rewriter); } else { return nullptr; } } namespace { struct UniformDequantizePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(DequantizeCastOp op, PatternRewriter &rewriter) const override { Type inputType = op.arg().getType(); Type outputType = op.getResult().getType(); QuantizedType inputElementType = QuantizedType::getQuantizedElementType(inputType); Type expressedOutputType = inputElementType.castToExpressedType(inputType); if (expressedOutputType != outputType) { // Not a valid uniform cast. return matchFailure(); } Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } rewriter.replaceOp(op, dequantizedValue); return matchSuccess(); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Elementwise add //===----------------------------------------------------------------------===// static LogicalResult tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, PatternRewriter &rewriter) { if (!info.resultType.isSigned() || info.lhsType != info.resultType || info.rhsType != info.resultType) { return failure(); } // Choose a byte aligned intermediate width big enough to perform the // calculation without overflow. // TODO: This should probably be made just big enough to avoid overflow and // leave the downstream tooling to decide how to align that to machine // word sizes. unsigned intermediateWidth = info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32; IntegerType intermediateElementType = IntegerType::get(intermediateWidth, rewriter.getContext()); Type intermediateType = castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. Value lhsValue = rewriter .create(info.op->getLoc(), info.lhsStorageType, info.lhs) .getResult(); Value rhsValue = rewriter .create(info.op->getLoc(), info.rhsStorageType, info.rhs) .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create(info.op->getLoc(), intermediateType, lhsValue); rhsValue = rewriter.create(info.op->getLoc(), intermediateType, rhsValue); // Add. Value resultValue = rewriter.create(info.op->getLoc(), lhsValue, rhsValue); // Zero point offset adjustment. // result = (lhs - zp) + (rhs - zp) + zp // zpOffset = -zp int zpOffset = -1 * info.resultType.getZeroPoint(); if (zpOffset != 0) { Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, zpOffset)); resultValue = rewriter.create(info.op->getLoc(), resultValue, zpOffsetConst); } // Clamp. auto clampMinMax = info.getClampMinMax(intermediateElementType); resultValue = rewriter.create( info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); // Convert back to original type. resultValue = rewriter.create( info.op->getLoc(), info.resultStorageType, resultValue); // Cast back for new result. rewriter.replaceOpWithNewOp( info.op, info.getQuantizedResultType(), resultValue); return success(); } //===----------------------------------------------------------------------===// // Elementwise mul //===----------------------------------------------------------------------===// static LogicalResult tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, PatternRewriter &rewriter) { if (!info.resultType.isSigned()) { return failure(); } double outputMultiplierReal = info.lhsType.getScale() * info.rhsType.getScale() / info.resultType.getScale(); if (outputMultiplierReal > 1.0) { info.op->emitWarning( "unimplemented: cannot multiply with multiplier > 1.0"); return failure(); } // TODO: Choose an appropriate intermediate width for muls > 8 bits to // avoid overflow. unsigned intermediateWidth = 32; IntegerType intermediateElementType = IntegerType::get(intermediateWidth, rewriter.getContext()); Type intermediateType = castElementType(info.resultStorageType, intermediateElementType); // Cast operands to storage type. Value lhsValue = rewriter .create(info.op->getLoc(), info.lhsStorageType, info.lhs) .getResult(); Value rhsValue = rewriter .create(info.op->getLoc(), info.rhsStorageType, info.rhs) .getResult(); // Cast to the intermediate sized type. lhsValue = rewriter.create(info.op->getLoc(), intermediateType, lhsValue); rhsValue = rewriter.create(info.op->getLoc(), intermediateType, rhsValue); // Apply argument zeroPoints. if (info.lhsType.getZeroPoint() != 0) { Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.lhsType.getZeroPoint())); lhsValue = rewriter.create(info.op->getLoc(), lhsValue, zpOffsetConst); } if (info.rhsType.getZeroPoint() != 0) { Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue( intermediateType, -info.rhsType.getZeroPoint())); rhsValue = rewriter.create(info.op->getLoc(), rhsValue, zpOffsetConst); } // Mul. Value resultValue = rewriter.create(info.op->getLoc(), lhsValue, rhsValue); // Scale output. QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal); resultValue = rewriter.create( info.op->getLoc(), resultValue, IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier)); resultValue = rewriter.create( info.op->getLoc(), resultValue, IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent)); // Zero point offset adjustment. if (info.resultType.getZeroPoint() != 0) { Value zpOffsetConst = rewriter.create( info.op->getLoc(), broadcastScalarConstIntValue(intermediateType, info.resultType.getZeroPoint())); resultValue = rewriter.create(info.op->getLoc(), resultValue, zpOffsetConst); } // Clamp. auto clampMinMax = info.getClampMinMax(intermediateElementType); resultValue = rewriter.create( info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); // Convert back to original type. resultValue = rewriter.create( info.op->getLoc(), info.resultStorageType, resultValue); // Cast back for new result. rewriter.replaceOpWithNewOp( info.op, info.getQuantizedResultType(), resultValue); return success(); } namespace { struct UniformRealAddEwPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(RealAddEwOp op, PatternRewriter &rewriter) const override { const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), op.clamp_max()); if (!info.isValid()) { return matchFailure(); } // Try all of the permutations we support. if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) { return matchSuccess(); } return matchFailure(); } }; struct UniformRealMulEwPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(RealMulEwOp op, PatternRewriter &rewriter) const override { const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), op.clamp_max()); if (!info.isValid()) { return matchFailure(); } // Try all of the permutations we support. if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) { return matchSuccess(); } return matchFailure(); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // LowerUniformRealMath pass //===----------------------------------------------------------------------===// void LowerUniformRealMathPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.insert(context); applyPatternsGreedily(fn, patterns); } OpPassBase *mlir::fxpmath::createLowerUniformRealMathPass() { return new LowerUniformRealMathPass(); } static PassRegistration lowerUniformRealMathPass( "fxpmath-lower-uniform-real-math", "Lowers uniform-quantized real math ops to integer arithmetic."); //===----------------------------------------------------------------------===// // LowerUniformCasts pass //===----------------------------------------------------------------------===// void LowerUniformCastsPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.insert(context); applyPatternsGreedily(fn, patterns); } OpPassBase *mlir::fxpmath::createLowerUniformCastsPass() { return new LowerUniformCastsPass(); } static PassRegistration lowerUniformCastsPass("fxpmath-lower-uniform-casts", "Lowers uniform-quantized casts.");