//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include using namespace mlir; static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } template static mlir::ConstantOp createConstFromIntAttribute(Operation *op, std::string attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( op->getAttr(attrName).cast().getValue().getSExtValue()); return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } template static void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector &arrayValues) { for (Attribute val : attr.getValue()) { arrayValues.push_back(val.cast().getValue().getSExtValue()); } } template static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, mlir::ConstantOp max, P pred, OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, pred, arg, min); auto minOrArg = rewriter.create(loc, smallerThanMin, min, arg); auto largerThanMax = rewriter.create(loc, pred, max, arg); return rewriter.create(loc, largerThanMax, max, minOrArg); } static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, Attribute padAttr, OpBuilder &rewriter) { // Input should be padded if necessary. if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; ShapedType inputTy = input.getType().cast(); Type inputETy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); assert((inputShape.size() * 2) == pad.size()); SmallVector paddedShape; SmallVector lowIndices; SmallVector highIndices; for (int i = 0, s = inputShape.size(); i < s; i++) { auto lowPad = pad[i * 2]; auto highPad = pad[i * 2 + 1]; paddedShape.push_back(inputShape[i] + highPad + lowPad); lowIndices.push_back(rewriter.getIndexAttr(lowPad)); highIndices.push_back(rewriter.getIndexAttr(highPad)); } Value padValue = rewriter.create(loc, padAttr); return linalg::PadTensorOp::createPadScalarOp( RankedTensorType::get(paddedShape, inputETy), input, padValue, lowIndices, highIndices, loc, rewriter) .result(); } static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef resultTypes, PatternRewriter &rewriter) { Location loc = op->getLoc(); auto elementTy = op->getOperand(0).getType().cast().getElementType(); // tosa::AbsOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) { auto zero = rewriter.create(loc, rewriter.getZeroAttr(elementTy)); auto cmp = rewriter.create(loc, CmpIPredicate::sgt, args[0], zero); auto neg = rewriter.create(loc, zero, args[0]); return rewriter.create(loc, cmp, args[0], neg); } // tosa::AddOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::SubOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::MulOp if (isa(op) && elementTy.isa()) { if (dyn_cast(op).shift() != 0) { (void)rewriter.notifyMatchFailure(op, "Cannot have shift value for float"); return nullptr; } return rewriter.create(loc, resultTypes, args); } // tosa::DivOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && elementTy.isa()) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); return rewriter.create(loc, resultTypes, one, args[0]); } if (isa(op) && elementTy.isa()) { Value a = args[0]; Value b = args[1]; auto shift = op->getAttr("shift").cast().getValue().getSExtValue(); if (shift > 0) { auto shiftConst = rewriter.create(loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) a = rewriter.create(loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) b = rewriter.create(loc, rewriter.getI32Type(), b); auto result = rewriter.create( loc, rewriter.getI32Type(), a, b, shiftConst, rewriter.getBoolAttr(false)); if (elementTy.isInteger(32)) return result; return rewriter.create(loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); int bWidth = b.getType().getIntOrFloatBitWidth(); int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) a = rewriter.create(loc, resultTypes[0], a); if (bWidth < cWidth) b = rewriter.create(loc, resultTypes[0], b); return rewriter.create(loc, resultTypes, a, b); } // tosa::NegateOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa() && !cast(op).quantization_info()) { auto constant = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); return rewriter.create(loc, resultTypes, constant, args[0]); } if (isa(op) && elementTy.isa() && cast(op).quantization_info()) { auto quantizationInfo = cast(op).quantization_info(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); int64_t inZp = quantizationInfo.getValue().input_zp().getValue().getSExtValue(); int64_t outZp = quantizationInfo.getValue().output_zp().getValue().getSExtValue(); // Compute the maximum value that can occur in the intermediate buffer. int64_t zpAdd = inZp + outZp; int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + std::abs(zpAdd) + 1; // Convert that maximum value into the maximum bitwidth needed to represent // it. We assume 48-bit numbers may be supported further in the pipeline. int intermediateBitWidth = 64; if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { intermediateBitWidth = 16; } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { intermediateBitWidth = 32; } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { intermediateBitWidth = 48; } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); Value zpAddValue = rewriter.create( loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue auto ext = rewriter.create(loc, intermediateType, args[0]); auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. auto min = rewriter.create( loc, rewriter.getIntegerAttr( intermediateType, APInt::getSignedMinValue(inputBitWidth).getSExtValue())); auto max = rewriter.create( loc, rewriter.getIntegerAttr( intermediateType, APInt::getSignedMaxValue(inputBitWidth).getSExtValue())); auto clamp = clampHelper(loc, sub, min, max, CmpIPredicate::slt, rewriter); // Truncate to the final value. return rewriter.create(loc, elementTy, clamp); } // tosa::BitwiseAndOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && elementTy.isa()) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnesValue(elementTy.getIntOrFloatBitWidth())); auto allOnes = rewriter.create(loc, allOnesAttr); return rewriter.create(loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && elementTy.isa()) { auto result = rewriter.create(loc, resultTypes, args); auto round = op->getAttr("round").cast().getValue(); if (!round) { return result; } Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = rewriter.create(loc, IntegerAttr::get(elementTy, 1)); auto zero = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto i1one = rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 auto shiftValueGreaterThanZero = rewriter.create(loc, CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = rewriter.create(loc, resultTypes, args[1], one); auto shifted = rewriter .create(loc, resultTypes, args[0], subtract) ->getResults(); auto truncated = rewriter.create(loc, i1Ty, shifted, mlir::None); auto isInputOdd = rewriter.create(loc, i1Ty, truncated, i1one); auto shouldRound = rewriter.create( loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = rewriter.create(loc, resultTypes, shouldRound); return rewriter.create(loc, resultTypes, result, extended); } // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { auto one = rewriter.create( loc, rewriter.getIntegerAttr(elementTy, 1)); return rewriter.create(loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); // tosa::PowOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::RsqrtOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::LogOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::ExpOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::TanhOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::GreaterOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, CmpFPredicate::OGT, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return rewriter.create(loc, CmpIPredicate::sgt, args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, CmpFPredicate::OGE, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return rewriter.create(loc, CmpIPredicate::sge, args[0], args[1]); // tosa::EqualOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, CmpFPredicate::OEQ, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return rewriter.create(loc, CmpIPredicate::eq, args[0], args[1]); // tosa::SelectOp if (isa(op)) { elementTy = op->getOperand(1).getType().cast().getElementType(); if (elementTy.isa() || elementTy.isa()) return rewriter.create(loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { auto predicate = rewriter.create(loc, CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { auto predicate = rewriter.create(loc, CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::CeilOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::FloorOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); // tosa::ClampOp if (isa(op) && elementTy.isa()) { auto min = rewriter.create(loc, elementTy, op->getAttr("min_fp")); auto max = rewriter.create(loc, elementTy, op->getAttr("max_fp")); return clampHelper(loc, args[0], min, max, CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { auto min = createConstFromIntAttribute(op, "min_int", elementTy, rewriter); auto max = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); return clampHelper(loc, args[0], min, max, CmpIPredicate::slt, rewriter); } // tosa::ReluNOp if (isa(op) && elementTy.isa()) { auto zero = rewriter.create(loc, FloatAttr::get(elementTy, 0)); auto n = rewriter.create(loc, elementTy, op->getAttr("max_fp")); return clampHelper(loc, args[0], zero, n, CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { auto zero = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto n = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); return clampHelper(loc, args[0], zero, n, CmpIPredicate::slt, rewriter); } // tosa::SigmoidOp if (isa(op) && elementTy.isa()) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); auto negate = rewriter.create(loc, resultTypes, args[0]); auto exp = rewriter.create(loc, resultTypes, negate); auto added = rewriter.create(loc, resultTypes, exp, one); return rewriter.create(loc, resultTypes, one, added); } // tosa::CastOp if (isa(op)) { Type srcTy = elementTy; Type dstTy = resultTypes.front(); bool bitExtend = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); if (srcTy == dstTy) return args.front(); if (srcTy.isa() && dstTy.isa() && bitExtend) return rewriter.create(loc, resultTypes, args, mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) return rewriter.create(loc, resultTypes, args, mlir::None); // 1-bit integers need to be treated as signless. if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy)) return rewriter.create(loc, resultTypes, args, mlir::None); if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) return rewriter.create(loc, resultTypes, args, mlir::None); // All other si-to-fp conversions should be handled by SIToFP. if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy)) return rewriter.create(loc, resultTypes, args, mlir::None); // Casting to boolean, floats need to only be checked as not-equal to zero. if (srcTy.isa() && dstTy.isInteger(1)) { Value zero = rewriter.create(loc, rewriter.getFloatAttr(srcTy, 0.0)); return rewriter.create(loc, CmpFPredicate::UNE, args.front(), zero); } if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) { auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto half = rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); auto intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); auto intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); auto added = rewriter.create(loc, args[0], half); auto subbed = rewriter.create(loc, args[0], half); auto negative = rewriter.create(loc, CmpFPredicate::OLT, args[0], zero); auto rounded = rewriter.create(loc, negative, subbed, added); auto clamped = clampHelper(loc, rounded, intMin, intMax, CmpFPredicate::OLT, rewriter); return rewriter.create(loc, dstTy, clamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (srcTy.isa() && dstTy.isInteger(1)) { Value zero = rewriter.create(loc, 0, srcTy.getIntOrFloatBitWidth()); return rewriter.create(loc, CmpIPredicate::ne, args.front(), zero); } if (srcTy.isa() && dstTy.isa() && bitExtend) return rewriter.create(loc, resultTypes, args, mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) { auto intMin = rewriter.create( loc, APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); auto intMax = rewriter.create( loc, APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); auto clamped = clampHelper(loc, args[0], intMin, intMax, CmpIPredicate::slt, rewriter); return rewriter.create(loc, dstTy, clamped); } } (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; } static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter) { auto loc = operation->getLoc(); assert(operation->getNumResults() == 1 && "All TOSA elementwise ops should only return a single result."); auto results = operation->getResults(); auto resultTy = operation->getResult(0).getType().dyn_cast(); if (!resultTy) return rewriter.notifyMatchFailure(operation, "All results must be a shaped type"); unsigned rank = resultTy.getRank(); // Construct the indexing maps needed for linalg.generic ops. SmallVector bodyArgTypes; for (Value in : operation->getOperands()) bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); SmallVector opResultTypes; SmallVector initTensors; for (auto result : results) { auto resultTy = result.getType().template cast(); if (!resultTy.hasStaticShape()) return rewriter.notifyMatchFailure( operation, "tosa to linalg conversion expects statically shaped tensors"); initTensors.push_back(rewriter.create( loc, ArrayRef({}), resultTy.getShape(), resultTy.getElementType())); opResultTypes.push_back(result.getType()); } auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( initTensors, [](Value v) { return getElementTypeOrSelf(v); })); SmallVector operands; SmallVector indexingMaps; indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); // Input indexing maps may be broadcasted. for (Value operand : operation->getOperands()) { ShapedType type = operand.getType().cast(); if (type.getShape() == resultTy.getShape()) { operands.push_back(operand); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); continue; } SmallVector newShape; SmallVector affineExprs; newShape.reserve(type.getRank()); for (auto it : llvm::enumerate(type.getShape())) { if (it.value() == resultTy.getDimSize(it.index())) { newShape.push_back(it.value()); affineExprs.push_back( mlir::getAffineDimExpr(it.index(), rewriter.getContext())); } } if (newShape.size() != rank) { operand = rewriter.create( loc, RankedTensorType::get(newShape, type.getElementType()), operand); } operands.push_back(operand); indexingMaps.push_back(AffineMap::get( /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs, rewriter.getContext())); } indexingMaps.append(operation->getNumResults(), rewriter.getMultiDimIdentityMap(rank)); bool didEncounterError = false; auto linalgOp = rewriter.create( loc, opResultTypes, operands, initTensors, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), bodyResultTypes, rewriter); if (!opResult) { didEncounterError = true; return; } nestedBuilder.create(loc, opResult); }); if (didEncounterError) return failure(); rewriter.replaceOp(operation, linalgOp->getResults()); return success(); } // Returns the constant initial value for a given reduction operation. The // attribute type varies depending on the element type required. static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { if (isa(op) && elementTy.isa()) return rewriter.getFloatAttr(elementTy, 0.0); if (isa(op) && elementTy.isa()) return rewriter.getIntegerAttr(elementTy, 0); if (isa(op) && elementTy.isa()) return rewriter.getFloatAttr(elementTy, 1.0); if (isa(op) && elementTy.isa()) return rewriter.getIntegerAttr(elementTy, 1); if (isa(op) && elementTy.isa()) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( elementTy.cast().getFloatSemantics(), false)); if (isa(op) && elementTy.isa()) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); if (isa(op) && elementTy.isa()) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( elementTy.cast().getFloatSemantics(), true)); if (isa(op) && elementTy.isa()) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getAllOnesValue(1)); if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getNullValue(1)); if (isa(op) && elementTy.isa()) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( elementTy.cast().getFloatSemantics(), true)); if (isa(op) && elementTy.isa()) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); return {}; } // Creates the body calculation for a reduction. The operations vary depending // on the input type. static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && elementTy.isa()) { return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, args); if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, args); return {}; } // Performs the match and rewrite for reduction operations. This includes // declaring a correctly sized initial value, and the linalg.generic operation // that reduces across the specified axis. static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); auto inputTy = op->getOperand(0).getType().template cast(); auto resultTy = op->getResult(0).getType().template cast(); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); llvm::SmallVector reduceShape; for (unsigned i = 0; i < inputTy.getRank(); i++) { if (axis != i) reduceShape.push_back(inputTy.getDimSize(i)); } Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType()); // First fill the output buffer with the init value. auto initTensor = rewriter .create(loc, ArrayRef({}), reduceShape, resultTy.getElementType()) .result(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter.create(loc, fillValue, initTensor).result(); SmallVector srcExprs; SmallVector dstExprs; SmallVector iteratorTypes; for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) { srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName() : getParallelIteratorTypeName()); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs}); auto linalgOp = rewriter.create( loc, reduceTy, input, filledTensor, maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { auto result = createLinalgBodyCalculationForReduceOp( op, blockArgs, elementTy, rewriter); if (result) didEncounterError = true; nestedBuilder.create(loc, result); }); if (!didEncounterError) return failure(); rewriter.replaceOpWithNewOp(op, resultTy, linalgOp.getResults()); return success(); } static LogicalResult convolutionMatchAndRewriterHelper(Operation *op, ConversionPatternRewriter &rewriter) { Location loc = op->getLoc(); Value input = op->getOperand(0); Value weight = op->getOperand(1); Value bias = op->getOperand(2); ShapedType inputTy = input.getType().cast(); ShapedType weightTy = weight.getType().cast(); ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); auto padAttr = op->getAttr("pad").cast(); auto strideTosaAttr = op->getAttr("stride").cast(); auto dilationTosaAttr = op->getAttr("dilation").cast(); bool isQuantized = op->hasAttr("quantization_info"); IntegerAttr iZp; IntegerAttr kZp; if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); iZp = rewriter.getI32IntegerAttr( quantizationInfo.input_zp().getValue().getSExtValue()); kZp = rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue()); } if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) return rewriter.notifyMatchFailure(op, "tosa.conv ops require static shapes"); auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); // Apply padding as necessary. Attribute zeroAttr = rewriter.getZeroAttr(inputETy); llvm::SmallVector pad; pad.resize(2, 0); getValuesFromIntArrayAttribute(padAttr, pad); pad.resize(pad.size() + 2, 0); input = applyPad(loc, input, pad, zeroAttr, rewriter); // Broadcast the initial value to the output tensor before convolving. SmallVector indexingMaps; indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); Value initTensor = rewriter.create( loc, resultTy.getShape(), resultTy.getElementType()); Value biasBroadcast = rewriter .create( loc, resultTy, bias, initTensor, indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(nestedLoc, args[0]); }) .getResult(0); // Extract the attributes for convolution. llvm::SmallVector stride, dilation; getValuesFromIntArrayAttribute(strideTosaAttr, stride); getValuesFromIntArrayAttribute(dilationTosaAttr, dilation); // Create the convolution op. auto strideAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), stride); auto dilationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), dilation); if (isa(op) && !isQuantized) { rewriter.replaceOpWithNewOp( op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast}, strideAttr, dilationAttr); return success(); } if (isa(op) && isQuantized) { auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); rewriter.replaceOpWithNewOp( op, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{biasBroadcast}, strideAttr, dilationAttr); return success(); } if (isa(op) && !isQuantized) { ShapedType linalgConvTy = RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], weightShape[2], weightShape[3]}, resultETy); Value biasReshape = rewriter.create(loc, linalgConvTy, biasBroadcast); Value conv = rewriter .create( loc, linalgConvTy, ValueRange{input, weight}, ValueRange{biasReshape}, dilationAttr, strideAttr) .getResult(0); Value reshape = rewriter.create(loc, resultTy, conv); rewriter.replaceOp(op, reshape); return success(); } return failure(); } namespace { template class PointwiseConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { return elementwiseMatchAndRewriteHelper(op, rewriter); } }; template class ConvConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(T op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { return convolutionMatchAndRewriterHelper(op, rewriter); } }; class TransposeConvConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); Value weight = op->getOperand(1); Value bias = op->getOperand(2); ShapedType inputTy = input.getType().cast(); ShapedType weightTy = weight.getType().cast(); ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); llvm::SmallVector pad; llvm::SmallVector stride; llvm::SmallVector dilation; getValuesFromIntArrayAttribute(op.out_pad().cast(), pad); getValuesFromIntArrayAttribute(op.stride().cast(), stride); getValuesFromIntArrayAttribute(op.dilation().cast(), dilation); // We have not solved for stride / dilation yet. Dilation should be // straight forward but stride is more complicated. Linalg work is likely // required for efficient implementation. if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) return failure(); if (llvm::any_of(dilation, [](int64_t v) { return v != 1; })) return failure(); if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) return failure(); int64_t inputHeight = inputTy.getDimSize(1); int64_t inputWidth = inputTy.getDimSize(2); int64_t kernelHeight = weightTy.getDimSize(1); int64_t kernelWidth = weightTy.getDimSize(2); int64_t outputHeight = resultTy.getDimSize(1); int64_t outputWidth = resultTy.getDimSize(2); int64_t requiredInputHeight = outputHeight + kernelHeight - 1; int64_t requiredInputWidth = outputWidth + kernelWidth - 1; llvm::SmallVector newPad(4, 0); newPad[0] = kernelHeight - 1 - pad[0]; newPad[2] = kernelWidth - 1 - pad[1]; newPad[1] = requiredInputHeight - newPad[0] - inputHeight; newPad[3] = requiredInputWidth - newPad[2] - inputWidth; auto reverse1 = rewriter.create( loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); auto reverse2 = rewriter.create( loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); Value conv2d; if (op.quantization_info().hasValue()) { conv2d = rewriter.create( loc, resultTy, input, reverse2, bias, rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride), rewriter.getI64ArrayAttr(dilation), op.quantization_info().getValue()); } else { conv2d = rewriter.create( loc, resultTy, input, reverse2, bias, rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride), rewriter.getI64ArrayAttr(dilation)); } rewriter.replaceOp(op, conv2d); return success(); } }; class MatMulConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::MatMulOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { tosa::MatMulOp::Adaptor adaptor(args); Location loc = op.getLoc(); auto outputTy = op.getType().cast(); auto outputElementTy = outputTy.getElementType(); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); Value zero = rewriter.create(loc, zeroAttr); auto initTensor = rewriter.create( loc, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); if (!op.quantization_info()) { rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, ValueRange{zeroTensor}); return success(); } auto quantizationInfo = op.quantization_info().getValue(); auto aZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.a_zp().getValue().getSExtValue())); auto bZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.b_zp().getValue().getSExtValue())); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); return success(); } }; class FullyConnectedConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); auto outputTy = op.getType().cast(); auto input = op.input(); auto weight = op.weight(); auto bias = op.bias(); auto weightTy = weight.getType().cast(); auto weightShape = weightTy.getShape(); // Creating maps for the output of MatMul and the bias SmallVector indexingMaps; // Broadcast the bias. indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)}, rewriter.getContext())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); auto initTensor = rewriter .create(loc, outputTy.getShape(), outputTy.getElementType()) ->getResults(); auto linalgOp = rewriter .create( loc, outputTy, bias, initTensor, indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nested_builder, Location nested_loc, ValueRange args) { nested_builder.create(loc, *args.begin()); }) ->getResults(); SmallVector permutation{1, 0}; auto permutationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), permutation); Value permutationValue = rewriter.create(loc, permutationAttr); SmallVector newWeightShape{weightShape[1], weightShape[0]}; Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); Value transposedWeight = rewriter.create( loc, newWeightTy, weight, permutationValue); if (!op.quantization_info()) { rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{input, transposedWeight}, linalgOp); return success(); } auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.input_zp().getValue().getSExtValue())); auto outputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue())); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp); return success(); } }; class ReshapeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef args, ConversionPatternRewriter &rewriter) const final { typename tosa::ReshapeOp::Adaptor operands(args); ShapedType operandTy = operands.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); if (operandTy == resultTy) { rewriter.replaceOp(reshape, args[0]); return success(); } if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) return failure(); // Compute the reassociation maps for the linalg operation. ArrayRef expandedShape = (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape() : resultTy.getShape()); ArrayRef collapsedShape = (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape() : operandTy.getShape()); unsigned currSrcDim = 0, currDstDim = 0; SmallVector reassociationMap(collapsedShape.size()); // First scan all dimensions in the source shapes to see whether we have a // perfect case where consecutive dimensions in source are collapsed. For // such case we can just generate one single linalg.reshape. bool isCollapsingSource = true; while (currSrcDim < expandedShape.size() && currDstDim < collapsedShape.size()) { int64_t dstSize = collapsedShape[currDstDim]; int64_t srcSize = expandedShape[currSrcDim]; while (srcSize < dstSize && currSrcDim < expandedShape.size()) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); srcSize *= expandedShape[currSrcDim]; } if (srcSize == dstSize) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); // If the next dim in collapsedShape is not 1, treat subsequent dims in // expandedShape which are 1 to be collapsed. if (currDstDim == collapsedShape.size() - 1 || collapsedShape[currDstDim + 1] != 1) { while (currSrcDim < expandedShape.size() && expandedShape[currSrcDim] == 1) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); } } } else { isCollapsingSource = false; break; } currDstDim++; } // Check if any remaining dimensions exist. If either is rank-0 we only // require the directly lowering. if (currSrcDim != expandedShape.size() || currDstDim != collapsedShape.size()) isCollapsingSource = collapsedShape.empty() || expandedShape.empty(); // Otherwise, we need to first reduce all source dimensions into one and // then expand to the destination dimensions. if (!isCollapsingSource) { auto getIdentityExprs = [&rewriter](int n) { SmallVector exprs; for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); return exprs; }; Location loc = reshape.getLoc(); int64_t totalElems = std::accumulate(expandedShape.begin(), expandedShape.end(), 1, std::multiplies()); auto elemTy = operandTy.getElementType(); SmallVector collapsingMap = { // Use operandTy here because we need to collapse all operands // dimensions. getIdentityExprs(operandTy.getShape().size())}; SmallVector expandingMap = { // Use resultTy here because we need to expand to all result // dimensions. getIdentityExprs(resultTy.getShape().size())}; auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); Value collapsedOp = rewriter.create( loc, collapsedTy, args[0], collapsingMap); rewriter.replaceOpWithNewOp( reshape, resultTy, collapsedOp, expandingMap); return success(); } if (resultTy.getRank() < args[0].getType().cast().getRank()) rewriter.replaceOpWithNewOp( reshape, resultTy, args[0], reassociationMap); else rewriter.replaceOpWithNewOp( reshape, resultTy, args[0], reassociationMap); return success(); } }; class TransposeConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const final { DenseIntElementsAttr perms; if (!matchPattern(op.perms(), m_Constant(&perms))) { return failure(); } auto resultTy = op.getType().cast(); if (!resultTy.hasStaticShape()) return failure(); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); for (auto permutation : llvm::enumerate(perms.getIntValues())) { inputExprs[permutation.value().getZExtValue()] = rewriter.getAffineDimExpr(permutation.index()); } auto initTensor = rewriter.create( op.getLoc(), ArrayRef({}), resultTy.getShape(), resultTy.getElementType()); SmallVector affineMaps = { AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(op.getLoc(), *args.begin()); }); return success(); } }; class RescaleConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::RescaleOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.input(); auto inputTy = op.input().getType().cast(); auto outputTy = op.output().getType().cast(); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error if (op.double_round() && !op.scale32()) return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); if (!outputTy.hasStaticShape()) return rewriter.notifyMatchFailure( op, "tosa to linalg conversion expects statically shaped tensors"); // The shift and multiplier values. SmallVector multiplierValues; getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues); SmallVector shiftValues; getValuesFromIntArrayAttribute(op.shift(), shiftValues); // Double round only occurs if shift is greater than 31, check that this // is ever true. bool doubleRound = op.double_round() && llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(rank)}; SmallVector genericInputs = {input}; // If we are rescaling per-channel then we need to store the multiplier // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { multiplierConstant = rewriter.create( loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ rewriter.getAffineDimExpr(rank - 1)}; auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, multiplierExprs, rewriter.getContext())); multiplierArg = indexingMaps.size() - 1; } // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { shiftConstant = rewriter.create( loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { rewriter.getAffineDimExpr(rank - 1)}; auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, rewriter.getContext())); shiftArg = indexingMaps.size() - 1; } // Indexing maps for output values. indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. Value initTensor = rewriter.create( loc, ArrayRef({}), outputTy.getShape(), outputTy.getElementType()); auto linalgOp = rewriter.create( loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value value = blockArgs[0]; // For now we do all of our math in 64-bit. This is not optimal but // should be correct for now, consider computing correct bit depth // later. int32_t inBitwidth = value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32; auto inputZp = createConstFromIntAttribute( op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), nestedBuilder); auto outputZp = createConstFromIntAttribute( op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (value.getType().getIntOrFloatBitWidth() < 32) { value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } value = nestedBuilder.create(nestedLoc, value, inputZp); value = nestedBuilder.create( loc, nestedBuilder.getI32Type(), value, multiplier, shift, nestedBuilder.getBoolAttr(doubleRound)); // Move to the new zero-point. value = nestedBuilder.create(nestedLoc, value, outputZp); // Saturate to the output size. IntegerType outIntType = blockArgs.back().getType().cast(); unsigned outBitWidth = outIntType.getWidth(); auto intMin = nestedBuilder.create( loc, nestedBuilder.getIntegerAttr( nestedBuilder.getI32Type(), APInt::getSignedMinValue(outBitWidth).getSExtValue())); auto intMax = nestedBuilder.create( loc, nestedBuilder.getIntegerAttr( nestedBuilder.getI32Type(), APInt::getSignedMaxValue(outBitWidth).getSExtValue())); value = clampHelper(nestedLoc, value, intMin, intMax, CmpIPredicate::slt, nestedBuilder); if (outIntType.getWidth() < 32) { value = nestedBuilder.create(nestedLoc, outIntType, value); } nestedBuilder.create(loc, value); }); rewriter.replaceOp(op, linalgOp->getResults()); return success(); } }; class ResizeConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); auto input = op.input(); auto inputTy = input.getType().cast(); auto resultTy = op.getType().cast(); auto resultElementTy = resultTy.getElementType(); auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; if (!resultTy.hasStaticShape()) return failure(); if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR") return failure(); auto initTensor = rewriter .create(loc, ArrayRef{}, resultTy.getShape(), resultElementTy) .result(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); { OpBuilder::InsertionGuard regionGuard(rewriter); rewriter.createBlock(&genericOp.region(), genericOp.region().end(), TypeRange({resultElementTy})); Value batch = rewriter.create(loc, 0); Value y = rewriter.create(loc, 1); Value x = rewriter.create(loc, 2); Value channel = rewriter.create(loc, 3); auto hwMin = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); auto hMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageH - 1)); auto wMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageW - 1)); Value inY = rewriter.create(loc, rewriter.getI32Type(), y); Value inX = rewriter.create(loc, rewriter.getI32Type(), x); int32_t shift = op.shift(); bool floatingPointMode = shift == 0; Value yStride, xStride, yOffset, xOffset; if (floatingPointMode) { yStride = rewriter.create(loc, op.stride_fp()[0]); xStride = rewriter.create(loc, op.stride_fp()[1]); yOffset = rewriter.create(loc, op.offset_fp()[0]); xOffset = rewriter.create(loc, op.offset_fp()[1]); } else { SmallVector stride, offset; getValuesFromIntArrayAttribute(op.stride(), stride); getValuesFromIntArrayAttribute(op.offset(), offset); yStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[0])); xStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[1])); yOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[0])); xOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[1])); } // Compute the the integer index and partial offset. // x = x * stride + offset; // ix = floor(x) // dx = x - ix Value ix, iy, dx, dy; if (floatingPointMode) { Value y = rewriter.create(loc, rewriter.getF32Type(), inY); Value x = rewriter.create(loc, rewriter.getF32Type(), inX); y = rewriter.create(loc, y, yStride); x = rewriter.create(loc, x, xStride); y = rewriter.create(loc, y, yOffset); x = rewriter.create(loc, x, xOffset); iy = rewriter.create(loc, y); ix = rewriter.create(loc, x); dy = rewriter.create(loc, y, iy); dx = rewriter.create(loc, x, ix); iy = rewriter.create(loc, rewriter.getI32Type(), iy); ix = rewriter.create(loc, rewriter.getI32Type(), ix); } else { Value shiftVal = rewriter.create(loc, rewriter.getI32IntegerAttr(shift)); Value y = rewriter.create(loc, inY, yStride); Value x = rewriter.create(loc, inX, xStride); y = rewriter.create(loc, y, yOffset); x = rewriter.create(loc, x, xOffset); iy = rewriter.create(loc, y, shiftVal); ix = rewriter.create(loc, x, shiftVal); Value yTrunc = rewriter.create(loc, iy, shiftVal); Value xTrunc = rewriter.create(loc, ix, shiftVal); dy = rewriter.create(loc, y, yTrunc); dx = rewriter.create(loc, x, xTrunc); } if (op.mode() == "NEAREST_NEIGHBOR") { Value yPred, xPred; // Round the index position towards the closest pixel location. if (floatingPointMode) { auto halfVal = rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); yPred = rewriter.create(loc, CmpFPredicate::OGE, dy, halfVal); xPred = rewriter.create(loc, CmpFPredicate::OGE, dx, halfVal); } else { auto halfVal = rewriter.create( loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); yPred = rewriter.create(loc, CmpIPredicate::sge, dy, halfVal); xPred = rewriter.create(loc, CmpIPredicate::sge, dx, halfVal); } auto zeroVal = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); auto oneVal = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); auto yOffset = rewriter.create(loc, yPred, oneVal, zeroVal); auto xOffset = rewriter.create(loc, xPred, oneVal, zeroVal); iy = rewriter.create(loc, iy, yOffset); ix = rewriter.create(loc, ix, xOffset); // Clamp the to be within the bounds of the input image. iy = clampHelper(loc, iy, hwMin, hMax, CmpIPredicate::slt, rewriter); ix = clampHelper(loc, ix, hwMin, wMax, CmpIPredicate::slt, rewriter); // Read the value from the input array. iy = rewriter.create(loc, rewriter.getIndexType(), iy); ix = rewriter.create(loc, rewriter.getIndexType(), ix); Value result = rewriter.create( loc, input, ValueRange{batch, iy, ix, channel}); rewriter.create(loc, result); return success(); } if (op.mode() == "BILINEAR") { Value y0 = iy; Value x0 = ix; auto oneVal = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); Value y1 = rewriter.create(loc, y0, oneVal); Value x1 = rewriter.create(loc, x0, oneVal); y0 = clampHelper(loc, y0, hwMin, hMax, CmpIPredicate::slt, rewriter); y1 = clampHelper(loc, y1, hwMin, hMax, CmpIPredicate::slt, rewriter); x0 = clampHelper(loc, x0, hwMin, wMax, CmpIPredicate::slt, rewriter); x1 = clampHelper(loc, x1, hwMin, wMax, CmpIPredicate::slt, rewriter); y0 = rewriter.create(loc, rewriter.getIndexType(), y0); y1 = rewriter.create(loc, rewriter.getIndexType(), y1); x0 = rewriter.create(loc, rewriter.getIndexType(), x0); x1 = rewriter.create(loc, rewriter.getIndexType(), x1); Value y0x0 = rewriter.create( loc, input, ValueRange{batch, y0, x0, channel}); Value y0x1 = rewriter.create( loc, input, ValueRange{batch, y0, x1, channel}); Value y1x0 = rewriter.create( loc, input, ValueRange{batch, y1, x0, channel}); Value y1x1 = rewriter.create( loc, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { auto oneVal = rewriter.create(loc, rewriter.getF32FloatAttr(1.f)); Value rightPart = dx; Value leftPart = rewriter.create(loc, oneVal, dx); y0x0 = rewriter.create(loc, y0x0, leftPart); y0x1 = rewriter.create(loc, y0x1, rightPart); Value topAcc = rewriter.create(loc, y0x0, y0x1); y1x0 = rewriter.create(loc, y1x0, leftPart); y1x1 = rewriter.create(loc, y1x1, rightPart); Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; Value topPart = rewriter.create(loc, oneVal, dy); topAcc = rewriter.create(loc, topAcc, topPart); bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); } else { y0x0 = rewriter.create(loc, resultElementTy, y0x0); y0x1 = rewriter.create(loc, resultElementTy, y0x1); y1x0 = rewriter.create(loc, resultElementTy, y1x0); y1x1 = rewriter.create(loc, resultElementTy, y1x1); if (resultElementTy.getIntOrFloatBitWidth() > 32) { dx = rewriter.create(loc, resultElementTy, dx); dy = rewriter.create(loc, resultElementTy, dy); } auto unitVal = rewriter.create( loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); Value rightPart = dx; Value leftPart = rewriter.create(loc, unitVal, dx); y0x0 = rewriter.create(loc, y0x0, leftPart); y0x1 = rewriter.create(loc, y0x1, rightPart); Value topAcc = rewriter.create(loc, y0x0, y0x1); y1x0 = rewriter.create(loc, y1x0, leftPart); y1x1 = rewriter.create(loc, y1x1, rightPart); Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; Value topPart = rewriter.create(loc, unitVal, dy); topAcc = rewriter.create(loc, topAcc, topPart); bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); } } return failure(); } return success(); } }; // At the codegen level any identity operations should be removed. Any cases // where identity is load-bearing (e.g. cross device computation) should be // handled before lowering to codegen. template class IdentityNConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { rewriter.replaceOp(op, op.getOperation()->getOperands()); return success(); } }; template class ReduceConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp reduceOp, PatternRewriter &rewriter) const final { return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter); } }; struct ConcatConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::ConcatOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const override { auto resultType = op.getType().dyn_cast(); if (!resultType || !resultType.hasStaticShape()) { return rewriter.notifyMatchFailure(op, "expected static shaped tensor type"); } Location loc = op.getLoc(); int axis = op.axis(); Value axisValue = rewriter.create(loc, rewriter.getIndexAttr(axis)); int rank = resultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); strides.resize(rank, rewriter.create(loc, 1)); offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { sizes.push_back(rewriter.create(loc, args[0], i)); } Value resultDimSize = sizes[axis]; for (auto arg : args.drop_front()) { auto size = rewriter.create(loc, arg, axisValue); resultDimSize = rewriter.create(loc, resultDimSize, size); } sizes[axis] = resultDimSize; Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = rewriter.create(loc, zeroVal, init).getResult(0); for (auto arg : args) { sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, sizes, strides); offsets[axis] = rewriter.create(loc, offsets[axis], sizes[axis]); } rewriter.replaceOp(op, result); return success(); } }; class ReverseConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ReverseOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.input(); auto inputTy = input.getType().template cast(); auto resultTy = op.getType().template cast(); auto rank = resultTy.getRank(); auto axis = op.axis(); if (!inputTy.hasStaticShape()) return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); // First fill the output buffer with the init value. auto initTensor = rewriter .create( loc, ArrayRef({}), inputTy.getShape(), inputTy.getElementType()) .result(); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); for (int i = 0; i < rank; i++) inputExprs[i] = rewriter.getAffineDimExpr(i); inputExprs[axis] = rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) - inputExprs[axis]; SmallVector affineMaps = { AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( op, resultTy, op.input(), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(op.getLoc(), *args.begin()); }); return success(); } }; // This converter translate a tile operation to a reshape, broadcast, reshape. // The first reshape minimally expands each tiled dimension to include a // proceding size-1 dim. This dim is then broadcasted to the appropriate // multiple. struct TileConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::TileOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.input1(); auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); auto resultTy = op.getType().cast(); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) return failure(); SmallVector multiples; getValuesFromIntArrayAttribute(op.multiples(), multiples); // Broadcast the newly added dimensions to their appropriate multiple. SmallVector genericShape; for (int i = 0; i < rank; i++) { genericShape.push_back(multiples[i]); genericShape.push_back(inputShape[i]); } auto initTensor = rewriter.create( op.getLoc(), ArrayRef({}), genericShape, elementTy); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector dimExprs; dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); auto readAffineMap = AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, rewriter.getContext()); SmallVector affineMaps = { readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; auto genericOp = rewriter.create( loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(op.getLoc(), *args.begin()); }); rewriter.replaceOpWithNewOp( op, resultTy, genericOp.getResult(0), rewriter.getI64ArrayAttr(resultTy.getShape())); return success(); } }; class PadConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::PadOp padOp, PatternRewriter &rewriter) const final { auto loc = padOp.getLoc(); auto input = padOp.input1(); auto padding = padOp.padding(); ShapedType inputTy = input.getType().cast(); ShapedType paddingTy = padding.getType().cast(); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { return rewriter.notifyMatchFailure( padOp, "Pad converter requires static shaped input / padding values."); } Attribute constantAttr; if (elementTy.isa()) constantAttr = rewriter.getFloatAttr(elementTy, 0.0); else if (elementTy.isa() && !padOp.quantization_info()) constantAttr = rewriter.getIntegerAttr(elementTy, 0); else if (elementTy.isa() && padOp.quantization_info()) { auto value = padOp.quantization_info().getValue().input_zp().getValue(); constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); } if (!constantAttr) { return rewriter.notifyMatchFailure( padOp, "tosa.pad to linalg lowering encountered an unknown element type"); } Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector lowValues; SmallVector highValues; lowValues.reserve(rank); highValues.reserve(rank); for (int i = 0; i < rank; i++) { Value inputIndex = rewriter.createOrFold(loc, i); Value lowVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, lowIndex})); Value highVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, highIndex})); lowVal = rewriter.createOrFold(loc, rewriter.getIndexType(), lowVal); highVal = rewriter.createOrFold(loc, rewriter.getIndexType(), highVal); lowValues.push_back(lowVal); highValues.push_back(highVal); } Value constant = rewriter.create(loc, constantAttr); auto newPadOp = linalg::PadTensorOp::createPadScalarOp( padOp.getType(), input, constant, lowValues, highValues, loc, rewriter); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); } }; // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic // op, producing two output buffers. // // The first output buffer contains the index of the found maximum value. It is // initialized to 0 and is resulting integer type. // // The second output buffer contains the maximum value found. It is initialized // to the minimum representable value of the input element type. After being // populated by indexed_generic, this buffer is disgarded as only the index is // requested. // // The indexed_generic op updates both the maximum value and index if the // current value exceeds the running max. class ArgMaxConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.input(); auto inputTy = input.getType().cast(); auto resultTy = argmaxOp.output().getType().cast(); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.axis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); if (!inputTy.hasStaticShape()) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires statically shaped input"); if (!outElementTy.isa()) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); // First fill the output buffer for the index. auto initTensorIdx = rewriter .create(loc, ArrayRef({}), resultTy.getShape(), outElementTy) .result(); auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter.create(loc, fillValueIdx, initTensorIdx) .result(); // Second fill the output buffer for the running max. auto initTensorMax = rewriter .create(loc, ArrayRef({}), resultTy.getShape(), inElementTy) .result(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); if (!fillValueMaxAttr) return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); auto filledTensorMax = rewriter.create(loc, fillValueMax, initTensorMax) .result(); // We need to reduce along the arg-max axis, with parallel operations along // the rest. SmallVector iteratorTypes; iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName()); iteratorTypes[axis] = getReductionIteratorTypeName(); SmallVector srcExprs; SmallVector dstExprs; for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}); auto linalgOp = rewriter.create( loc, ArrayRef({resultTy, resultMaxTy}), input, ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { auto newValue = blockArgs[0]; auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; Value newIndex = rewriter.create( nestedLoc, oldIndex.getType(), rewriter.create(loc, axis)); Value predicate; if (inElementTy.isa()) { predicate = rewriter.create( nestedLoc, CmpFPredicate::OGT, newValue, oldValue); } else if (inElementTy.isa()) { predicate = rewriter.create( nestedLoc, CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; } auto resultMax = rewriter.create(nestedLoc, predicate, newValue, oldValue); auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( nestedLoc, ValueRange({resultIndex, resultMax})); }); if (didEncounterError) return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); return success(); } }; class GatherConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::GatherOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { auto input = args[0]; auto indices = args[1]; auto inputTy = input.getType().cast(); auto indicesTy = indices.getType().cast(); auto resultTy = op.getType().cast(); if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape()) return rewriter.notifyMatchFailure( op, "require input type to have static shape"); auto resultElementTy = resultTy.getElementType(); auto loc = op.getLoc(); auto initTensor = rewriter .create(loc, ArrayRef{}, resultTy.getShape(), resultElementTy) .result(); SmallVector affineMaps = { AffineMap::get( /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{indices}, ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; auto index0 = rewriter.create(loc, 0); Value index1 = rewriter.create( loc, rewriter.getIndexType(), indexValue); auto index2 = rewriter.create(loc, 2); Value extract = rewriter.create( loc, input, ValueRange{index0, index1, index2}); rewriter.create(loc, extract); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } }; // Lowerings the TableOp to a series of gathers and numerica operations. This // includes interpolation between the high/low values. For the I8 varient, this // simplifies to a single gather operation. class TableConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TableOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.input(); Value table = op.table(); auto inputTy = input.getType().cast(); auto tableTy = table.getType().cast(); auto resultTy = op.getType().cast(); if (!inputTy.hasStaticShape()) return rewriter.notifyMatchFailure( op, "require input type to have static shape"); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); auto resultElementTy = resultTy.getElementType(); auto initTensor = rewriter .create(loc, ArrayRef{}, resultTy.getShape(), resultElementTy) .result(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); { OpBuilder::InsertionGuard regionGuard(rewriter); Block *block = rewriter.createBlock(&genericOp.region(), genericOp.region().end(), TypeRange({inputElementTy, resultElementTy})); auto inputValue = block->getArgument(0); rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { Value index = rewriter.create(loc, rewriter.getIndexType(), inputValue); Value extract = rewriter.create(loc, table, ValueRange{index}); rewriter.create(loc, extract); return success(); } if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { Value extend = rewriter.create( loc, rewriter.getI32Type(), inputValue); auto offset = rewriter.create(loc, rewriter.getI32IntegerAttr(32768)); auto seven = rewriter.create(loc, rewriter.getI32IntegerAttr(7)); auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); auto b1111111 = rewriter.create(loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value auto extendAdd = rewriter.create(loc, extend, offset); Value index = rewriter.create(loc, extendAdd, seven); Value fraction = rewriter.create(loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; Value indexPlusOne = rewriter.create(loc, index, one); index = rewriter.create(loc, rewriter.getIndexType(), index); indexPlusOne = rewriter.create( loc, rewriter.getIndexType(), indexPlusOne); Value base = rewriter.create(loc, table, ValueRange{index}); Value next = rewriter.create( loc, table, ValueRange{indexPlusOne}); base = rewriter.create(loc, rewriter.getI32Type(), base); next = rewriter.create(loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction Value baseScaled = rewriter.create(loc, base, seven); Value diff = rewriter.create(loc, next, base); Value diffScaled = rewriter.create(loc, diff, fraction); Value result = rewriter.create(loc, baseScaled, diffScaled); rewriter.create(loc, result); return success(); } } return rewriter.notifyMatchFailure( op, "unable to create body for tosa.table op"); } }; template class Pool2dConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.input(); ShapedType inputTy = input.getType().cast(); Type inElementTy = inputTy.getElementType(); ShapedType resultTy = op.getType().template cast(); Type outElementTy = inputTy.getElementType(); if (!inputTy.hasStaticShape()) return failure(); // Determine what the initial value needs to be for the max pool op. Attribute initialAttr; if (isa(op) && outElementTy.isF32()) initialAttr = rewriter.getFloatAttr( outElementTy, APFloat::getLargest( outElementTy.cast().getFloatSemantics(), true)); if (isa(op) && outElementTy.isa()) initialAttr = rewriter.getIntegerAttr( outElementTy, APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth())); if (isa(op) && outElementTy.isa()) initialAttr = rewriter.getZeroAttr(outElementTy); if (!initialAttr) return rewriter.notifyMatchFailure( op, "Unsupported initial value for tosa.maxpool_2d op"); // Apply padding as necessary. llvm::SmallVector pad; pad.resize(2, 0); getValuesFromIntArrayAttribute(op.pad(), pad); pad.resize(pad.size() + 2, 0); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); getValuesFromIntArrayAttribute(op.stride(), stride); Attribute strideAttr = rewriter.getI64VectorAttr(stride); Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. Value initTensor = rewriter.create( loc, resultTy.getShape(), resultTy.getElementType()); Value filledInitTensor = rewriter.create(loc, initialValue, initTensor).result(); Value fakeWindowDims = rewriter.create(loc, kernel, outElementTy); if (isa(op)) { rewriter.replaceOpWithNewOp( op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, filledInitTensor, strideAttr, dilationAttr); return success(); } if (isa(op) && inElementTy.isF32()) { Value poolingOp = rewriter .create( loc, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, filledInitTensor, strideAttr, dilationAttr) .getResult(0); auto poolingOpTy = poolingOp.getType().cast(); auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{}, ValueRange{poolingOp}, ArrayRef({affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto zero = rewriter.create(loc, 0); auto one = rewriter.create(loc, 1); auto iH = rewriter.create( loc, poolingOpTy.getDimSize(1) - 1); auto iW = rewriter.create( loc, poolingOpTy.getDimSize(2) - 1); // Compute the indices from either end. auto y0 = rewriter.create(loc, 1); auto x0 = rewriter.create(loc, 2); auto y1 = rewriter.create(loc, iH, y0); auto x1 = rewriter.create(loc, iW, x0); // Determines what the portion of valid input is covered by the // kernel. auto padFn = [&](Value v, Value x, int64_t pad) -> Value { if (pad == 0) return v; auto padVal = rewriter.create(loc, pad); Value dx = rewriter.create(loc, x, padVal); Value cmp = rewriter.create(loc, CmpIPredicate::slt, dx, zero); Value offset = rewriter.create(loc, cmp, dx, zero); return rewriter.create(loc, v, offset) ->getResult(0); }; // Compute the vertical component of coverage. auto kH0 = rewriter.create(loc, kernel[0]); auto kH1 = padFn(kH0, y0, pad[2]); auto kH2 = padFn(kH1, y1, pad[3]); auto kHCmp = rewriter.create(loc, CmpIPredicate::slt, kH2, one); auto kH3 = rewriter.create(loc, kHCmp, one, kH2); // compute teh horizontal component of coverage. auto kW0 = rewriter.create(loc, kernel[1]); auto kW1 = padFn(kW0, x0, pad[4]); auto kW2 = padFn(kW1, x1, pad[5]); auto kWCmp = rewriter.create(loc, CmpIPredicate::slt, kW2, one); auto kW3 = rewriter.create(loc, kWCmp, one, kW2); // Compute the total number of elements and normalize. Value count = rewriter.create(loc, kH3, kW3); auto countI = rewriter.create( loc, rewriter.getI32Type(), count); auto countF = rewriter.create(loc, inElementTy, countI); auto div = rewriter.create(loc, args[0], countF)->getResult(0); rewriter.create(loc, div); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } return failure(); } }; } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( RewritePatternSet *patterns) { patterns->add< // clang-format off PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ArgMaxConverter, ConcatConverter, ConvConverter, ConvConverter, TransposeConvConverter, GatherConverter, PadConverter, ReshapeConverter, RescaleConverter, ResizeConverter, ReverseConverter, TableConverter, TileConverter, TransposeConverter, MatMulConverter, Pool2dConverter, Pool2dConverter, FullyConnectedConverter>(patterns->getContext()); // clang-format on }