1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // These rewriters lower from the Tosa to the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 #include <numeric> 26 27 using namespace mlir; 28 29 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 30 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 31 } 32 33 template <typename T> 34 static mlir::ConstantOp 35 createConstFromIntAttribute(Operation *op, std::string attrName, 36 Type requiredAttrType, OpBuilder &rewriter) { 37 auto castedN = static_cast<T>( 38 op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue()); 39 return rewriter.create<mlir::ConstantOp>( 40 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); 41 } 42 43 template <typename T> 44 static void getValuesFromIntArrayAttribute(ArrayAttr attr, 45 SmallVector<T> &arrayValues) { 46 for (Attribute val : attr.getValue()) { 47 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue()); 48 } 49 } 50 51 template <typename T, typename P> 52 static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, 53 mlir::ConstantOp max, P pred, 54 OpBuilder &rewriter) { 55 auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min); 56 auto minOrArg = 57 rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg); 58 auto largerThanMax = rewriter.create<T>(loc, pred, max, arg); 59 return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg); 60 } 61 62 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, 63 Attribute padAttr, OpBuilder &rewriter) { 64 // Input should be padded if necessary. 65 if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) 66 return input; 67 68 ShapedType inputTy = input.getType().cast<ShapedType>(); 69 Type inputETy = inputTy.getElementType(); 70 auto inputShape = inputTy.getShape(); 71 72 assert((inputShape.size() * 2) == pad.size()); 73 74 SmallVector<int64_t, 4> paddedShape; 75 SmallVector<OpFoldResult, 8> lowIndices; 76 SmallVector<OpFoldResult, 8> highIndices; 77 for (int i = 0, s = inputShape.size(); i < s; i++) { 78 auto lowPad = pad[i * 2]; 79 auto highPad = pad[i * 2 + 1]; 80 paddedShape.push_back(inputShape[i] + highPad + lowPad); 81 lowIndices.push_back(rewriter.getIndexAttr(lowPad)); 82 highIndices.push_back(rewriter.getIndexAttr(highPad)); 83 } 84 85 Value padValue = rewriter.create<ConstantOp>(loc, padAttr); 86 87 return linalg::PadTensorOp::createPadScalarOp( 88 RankedTensorType::get(paddedShape, inputETy), input, padValue, 89 lowIndices, highIndices, loc, rewriter) 90 .result(); 91 } 92 93 static Value 94 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, 95 ArrayRef<Type> resultTypes, 96 PatternRewriter &rewriter) { 97 Location loc = op->getLoc(); 98 auto elementTy = 99 op->getOperand(0).getType().cast<ShapedType>().getElementType(); 100 101 // tosa::AbsOp 102 if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>()) 103 return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args); 104 105 if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) { 106 auto zero = 107 rewriter.create<mlir::ConstantOp>(loc, rewriter.getZeroAttr(elementTy)); 108 auto cmp = 109 rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0], zero); 110 auto neg = rewriter.create<mlir::SubIOp>(loc, zero, args[0]); 111 return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg); 112 } 113 114 // tosa::AddOp 115 if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>()) 116 return rewriter.create<mlir::AddFOp>(loc, resultTypes, args); 117 118 if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>()) 119 return rewriter.create<mlir::AddIOp>(loc, resultTypes, args); 120 121 // tosa::SubOp 122 if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>()) 123 return rewriter.create<mlir::SubFOp>(loc, resultTypes, args); 124 125 if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>()) 126 return rewriter.create<mlir::SubIOp>(loc, resultTypes, args); 127 128 // tosa::MulOp 129 if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) { 130 if (dyn_cast<tosa::MulOp>(op).shift() != 0) { 131 (void)rewriter.notifyMatchFailure(op, 132 "Cannot have shift value for float"); 133 return nullptr; 134 } 135 return rewriter.create<mlir::MulFOp>(loc, resultTypes, args); 136 } 137 138 // tosa::DivOp 139 if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>()) 140 return rewriter.create<mlir::SignedDivIOp>(loc, resultTypes, args); 141 142 // tosa::ReciprocalOp 143 if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) { 144 auto one = 145 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); 146 return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]); 147 } 148 149 if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) { 150 Value a = args[0]; 151 Value b = args[1]; 152 auto shift = 153 op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue(); 154 if (shift > 0) { 155 auto shiftConst = 156 rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8); 157 if (!a.getType().isInteger(32)) 158 a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a); 159 160 if (!b.getType().isInteger(32)) 161 b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b); 162 163 auto result = rewriter.create<tosa::ApplyScaleOp>( 164 loc, rewriter.getI32Type(), a, b, shiftConst, 165 rewriter.getBoolAttr(false)); 166 167 if (elementTy.isInteger(32)) 168 return result; 169 170 return rewriter.create<TruncateIOp>(loc, elementTy, result); 171 } 172 173 int aWidth = a.getType().getIntOrFloatBitWidth(); 174 int bWidth = b.getType().getIntOrFloatBitWidth(); 175 int cWidth = resultTypes[0].getIntOrFloatBitWidth(); 176 177 if (aWidth < cWidth) 178 a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a); 179 if (bWidth < cWidth) 180 b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b); 181 182 return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b); 183 } 184 185 // tosa::NegateOp 186 if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>()) 187 return rewriter.create<mlir::NegFOp>(loc, resultTypes, args); 188 189 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() && 190 !cast<tosa::NegateOp>(op).quantization_info()) { 191 auto constant = 192 rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); 193 return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]); 194 } 195 196 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() && 197 cast<tosa::NegateOp>(op).quantization_info()) { 198 auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info(); 199 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); 200 int64_t inZp = 201 quantizationInfo.getValue().input_zp().getValue().getSExtValue(); 202 int64_t outZp = 203 quantizationInfo.getValue().output_zp().getValue().getSExtValue(); 204 205 // Compute the maximum value that can occur in the intermediate buffer. 206 int64_t zpAdd = inZp + outZp; 207 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + 208 std::abs(zpAdd) + 1; 209 210 // Convert that maximum value into the maximum bitwidth needed to represent 211 // it. We assume 48-bit numbers may be supported further in the pipeline. 212 int intermediateBitWidth = 64; 213 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { 214 intermediateBitWidth = 16; 215 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { 216 intermediateBitWidth = 32; 217 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { 218 intermediateBitWidth = 48; 219 } 220 221 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); 222 Value zpAddValue = rewriter.create<ConstantOp>( 223 loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); 224 225 // The negation can be applied by doing: 226 // outputValue = inZp + outZp - inputValue 227 auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]); 228 auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext); 229 230 // Clamp to the negation range. 231 auto min = rewriter.create<ConstantOp>( 232 loc, rewriter.getIntegerAttr( 233 intermediateType, 234 APInt::getSignedMinValue(inputBitWidth).getSExtValue())); 235 auto max = rewriter.create<ConstantOp>( 236 loc, rewriter.getIntegerAttr( 237 intermediateType, 238 APInt::getSignedMaxValue(inputBitWidth).getSExtValue())); 239 auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max, 240 CmpIPredicate::slt, rewriter); 241 242 // Truncate to the final value. 243 return rewriter.create<TruncateIOp>(loc, elementTy, clamp); 244 } 245 246 // tosa::BitwiseAndOp 247 if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>()) 248 return rewriter.create<mlir::AndOp>(loc, resultTypes, args); 249 250 // tosa::BitwiseOrOp 251 if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>()) 252 return rewriter.create<mlir::OrOp>(loc, resultTypes, args); 253 254 // tosa::BitwiseNotOp 255 if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) { 256 auto allOnesAttr = rewriter.getIntegerAttr( 257 elementTy, APInt::getAllOnesValue(elementTy.getIntOrFloatBitWidth())); 258 auto allOnes = rewriter.create<ConstantOp>(loc, allOnesAttr); 259 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], allOnes); 260 } 261 262 // tosa::BitwiseXOrOp 263 if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>()) 264 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args); 265 266 // tosa::LogicalLeftShiftOp 267 if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>()) 268 return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args); 269 270 // tosa::LogicalRightShiftOp 271 if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>()) 272 return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args); 273 274 // tosa::ArithmeticRightShiftOp 275 if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) { 276 auto result = 277 rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args); 278 auto round = op->getAttr("round").cast<BoolAttr>().getValue(); 279 if (!round) { 280 return result; 281 } 282 283 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); 284 auto one = 285 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); 286 auto zero = 287 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); 288 auto i1one = 289 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1)); 290 291 // Checking that input2 != 0 292 auto shiftValueGreaterThanZero = 293 rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero); 294 295 // Checking for the last bit of input1 to be 1 296 auto subtract = 297 rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one); 298 auto shifted = rewriter 299 .create<mlir::SignedShiftRightOp>(loc, resultTypes, 300 args[0], subtract) 301 ->getResults(); 302 auto truncated = 303 rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None); 304 auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one); 305 306 auto shouldRound = rewriter.create<mlir::AndOp>( 307 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); 308 auto extended = 309 rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound); 310 return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended); 311 } 312 313 // tosa::LogicalAnd 314 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1)) 315 return rewriter.create<mlir::AndOp>(loc, resultTypes, args); 316 317 // tosa::LogicalNot 318 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) { 319 auto one = rewriter.create<mlir::ConstantOp>( 320 loc, rewriter.getIntegerAttr(elementTy, 1)); 321 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], one); 322 } 323 324 // tosa::LogicalOr 325 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1)) 326 return rewriter.create<mlir::OrOp>(loc, resultTypes, args); 327 328 // tosa::LogicalXor 329 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1)) 330 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args); 331 332 // tosa::PowOp 333 if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>()) 334 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args); 335 336 // tosa::RsqrtOp 337 if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>()) 338 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args); 339 340 // tosa::LogOp 341 if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>()) 342 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args); 343 344 // tosa::ExpOp 345 if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>()) 346 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args); 347 348 // tosa::TanhOp 349 if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>()) 350 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args); 351 352 // tosa::GreaterOp 353 if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>()) 354 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0], 355 args[1]); 356 357 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger()) 358 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0], 359 args[1]); 360 361 // tosa::GreaterEqualOp 362 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>()) 363 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0], 364 args[1]); 365 366 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger()) 367 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0], 368 args[1]); 369 370 // tosa::EqualOp 371 if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>()) 372 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0], 373 args[1]); 374 375 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger()) 376 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], 377 args[1]); 378 379 // tosa::SelectOp 380 if (isa<tosa::SelectOp>(op)) { 381 elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType(); 382 if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>()) 383 return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]); 384 } 385 386 // tosa::MaximumOp 387 if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) { 388 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, 389 args[0], args[1]); 390 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 391 } 392 393 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) { 394 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, 395 args[0], args[1]); 396 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 397 } 398 399 // tosa::MinimumOp 400 if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) { 401 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, 402 args[0], args[1]); 403 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 404 } 405 406 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) { 407 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt, 408 args[0], args[1]); 409 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 410 } 411 412 // tosa::CeilOp 413 if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>()) 414 return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args); 415 416 // tosa::FloorOp 417 if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>()) 418 return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args); 419 420 // tosa::ClampOp 421 if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) { 422 auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy, 423 op->getAttr("min_fp")); 424 auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy, 425 op->getAttr("max_fp")); 426 return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT, 427 rewriter); 428 } 429 430 if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) { 431 auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy, 432 rewriter); 433 auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy, 434 rewriter); 435 return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt, 436 rewriter); 437 } 438 439 // tosa::ReluNOp 440 if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) { 441 auto zero = 442 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0)); 443 auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy, 444 op->getAttr("max_fp")); 445 return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT, 446 rewriter); 447 } 448 449 if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) { 450 auto zero = 451 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); 452 auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy, 453 rewriter); 454 return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt, 455 rewriter); 456 } 457 458 // tosa::SigmoidOp 459 if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) { 460 auto one = 461 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); 462 auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]); 463 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate); 464 auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one); 465 return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added); 466 } 467 468 // tosa::CastOp 469 if (isa<tosa::CastOp>(op)) { 470 Type srcTy = elementTy; 471 Type dstTy = resultTypes.front(); 472 bool bitExtend = 473 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); 474 475 if (srcTy == dstTy) 476 return args.front(); 477 478 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend) 479 return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None); 480 481 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend) 482 return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args, 483 mlir::None); 484 485 // 1-bit integers need to be treated as signless. 486 if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy)) 487 return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args, 488 mlir::None); 489 490 if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend) 491 return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args, 492 mlir::None); 493 494 // All other si-to-fp conversions should be handled by SIToFP. 495 if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy)) 496 return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args, 497 mlir::None); 498 499 // Casting to boolean, floats need to only be checked as not-equal to zero. 500 if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) { 501 Value zero = 502 rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0)); 503 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, 504 args.front(), zero); 505 } 506 507 if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) { 508 auto zero = 509 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); 510 auto half = 511 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f)); 512 513 auto intMin = rewriter.create<ConstantOp>( 514 loc, rewriter.getF32FloatAttr( 515 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) 516 .getSExtValue())); 517 518 auto intMax = rewriter.create<ConstantOp>( 519 loc, rewriter.getF32FloatAttr( 520 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) 521 .getSExtValue())); 522 523 auto added = rewriter.create<AddFOp>(loc, args[0], half); 524 auto subbed = rewriter.create<SubFOp>(loc, args[0], half); 525 auto negative = 526 rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero); 527 auto rounded = 528 rewriter.create<mlir::SelectOp>(loc, negative, subbed, added); 529 530 auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax, 531 CmpFPredicate::OLT, rewriter); 532 533 return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped); 534 } 535 536 // Casting to boolean, integers need to only be checked as not-equal to 537 // zero. 538 if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) { 539 Value zero = 540 rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth()); 541 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(), 542 zero); 543 } 544 545 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend) 546 return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args, 547 mlir::None); 548 549 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) { 550 auto intMin = rewriter.create<ConstantIntOp>( 551 loc, 552 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) 553 .getSExtValue(), 554 srcTy.getIntOrFloatBitWidth()); 555 556 auto intMax = rewriter.create<ConstantIntOp>( 557 loc, 558 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) 559 .getSExtValue(), 560 srcTy.getIntOrFloatBitWidth()); 561 562 auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax, 563 CmpIPredicate::slt, rewriter); 564 return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped); 565 } 566 } 567 568 (void)rewriter.notifyMatchFailure( 569 op, "unhandled op for linalg body calculation for elementwise op"); 570 return nullptr; 571 } 572 573 static LogicalResult 574 elementwiseMatchAndRewriteHelper(Operation *operation, 575 PatternRewriter &rewriter) { 576 auto loc = operation->getLoc(); 577 578 assert(operation->getNumResults() == 1 && 579 "All TOSA elementwise ops should only return a single result."); 580 581 auto results = operation->getResults(); 582 auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>(); 583 584 if (!resultTy) 585 return rewriter.notifyMatchFailure(operation, 586 "All results must be a shaped type"); 587 588 unsigned rank = resultTy.getRank(); 589 590 // Construct the indexing maps needed for linalg.generic ops. 591 SmallVector<Type> bodyArgTypes; 592 593 for (Value in : operation->getOperands()) 594 bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); 595 596 SmallVector<Type> opResultTypes; 597 SmallVector<Value> initTensors; 598 for (auto result : results) { 599 auto resultTy = result.getType().template cast<ShapedType>(); 600 if (!resultTy.hasStaticShape()) 601 return rewriter.notifyMatchFailure( 602 operation, 603 "tosa to linalg conversion expects statically shaped tensors"); 604 605 initTensors.push_back(rewriter.create<linalg::InitTensorOp>( 606 loc, ArrayRef<Value>({}), resultTy.getShape(), 607 resultTy.getElementType())); 608 opResultTypes.push_back(result.getType()); 609 } 610 611 auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( 612 initTensors, [](Value v) { return getElementTypeOrSelf(v); })); 613 614 SmallVector<Value, 2> operands; 615 SmallVector<AffineMap, 2> indexingMaps; 616 indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); 617 618 // Input indexing maps may be broadcasted. 619 for (Value operand : operation->getOperands()) { 620 ShapedType type = operand.getType().cast<ShapedType>(); 621 622 if (type.getShape() == resultTy.getShape()) { 623 operands.push_back(operand); 624 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); 625 continue; 626 } 627 628 SmallVector<int64_t, 5> newShape; 629 SmallVector<AffineExpr, 4> affineExprs; 630 newShape.reserve(type.getRank()); 631 for (auto it : llvm::enumerate(type.getShape())) { 632 if (it.value() == resultTy.getDimSize(it.index())) { 633 newShape.push_back(it.value()); 634 affineExprs.push_back( 635 mlir::getAffineDimExpr(it.index(), rewriter.getContext())); 636 } 637 } 638 639 if (newShape.size() != rank) { 640 operand = rewriter.create<tosa::ReshapeOp>( 641 loc, RankedTensorType::get(newShape, type.getElementType()), operand); 642 } 643 644 operands.push_back(operand); 645 indexingMaps.push_back(AffineMap::get( 646 /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs, 647 rewriter.getContext())); 648 } 649 650 indexingMaps.append(operation->getNumResults(), 651 rewriter.getMultiDimIdentityMap(rank)); 652 653 bool didEncounterError = false; 654 auto linalgOp = rewriter.create<linalg::GenericOp>( 655 loc, opResultTypes, operands, initTensors, indexingMaps, 656 getNParallelLoopsAttrs(rank), 657 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { 658 Value opResult = createLinalgBodyCalculationForElementwiseOp( 659 operation, blockArgs.take_front(operation->getNumOperands()), 660 bodyResultTypes, rewriter); 661 if (!opResult) { 662 didEncounterError = true; 663 return; 664 } 665 nestedBuilder.create<linalg::YieldOp>(loc, opResult); 666 }); 667 668 if (didEncounterError) 669 return failure(); 670 671 rewriter.replaceOp(operation, linalgOp->getResults()); 672 return success(); 673 } 674 675 // Returns the constant initial value for a given reduction operation. The 676 // attribute type varies depending on the element type required. 677 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, 678 PatternRewriter &rewriter) { 679 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) 680 return rewriter.getFloatAttr(elementTy, 0.0); 681 682 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) 683 return rewriter.getIntegerAttr(elementTy, 0); 684 685 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) 686 return rewriter.getFloatAttr(elementTy, 1.0); 687 688 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) 689 return rewriter.getIntegerAttr(elementTy, 1); 690 691 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) 692 return rewriter.getFloatAttr( 693 elementTy, APFloat::getLargest( 694 elementTy.cast<FloatType>().getFloatSemantics(), false)); 695 696 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) 697 return rewriter.getIntegerAttr( 698 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); 699 700 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) 701 return rewriter.getFloatAttr( 702 elementTy, APFloat::getLargest( 703 elementTy.cast<FloatType>().getFloatSemantics(), true)); 704 705 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) 706 return rewriter.getIntegerAttr( 707 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); 708 709 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) 710 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnesValue(1)); 711 712 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) 713 return rewriter.getIntegerAttr(elementTy, APInt::getNullValue(1)); 714 715 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>()) 716 return rewriter.getFloatAttr( 717 elementTy, APFloat::getLargest( 718 elementTy.cast<FloatType>().getFloatSemantics(), true)); 719 720 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>()) 721 return rewriter.getIntegerAttr( 722 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); 723 724 return {}; 725 } 726 727 // Creates the body calculation for a reduction. The operations vary depending 728 // on the input type. 729 static Value createLinalgBodyCalculationForReduceOp(Operation *op, 730 ValueRange args, 731 Type elementTy, 732 PatternRewriter &rewriter) { 733 Location loc = op->getLoc(); 734 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) { 735 return rewriter.create<AddFOp>(loc, args); 736 } 737 738 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) { 739 return rewriter.create<AddIOp>(loc, args); 740 } 741 742 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) { 743 return rewriter.create<MulFOp>(loc, args); 744 } 745 746 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) { 747 return rewriter.create<MulIOp>(loc, args); 748 } 749 750 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) { 751 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, 752 args[0], args[1]); 753 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 754 } 755 756 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) { 757 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt, 758 args[0], args[1]); 759 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 760 } 761 762 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) { 763 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, 764 args[0], args[1]); 765 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 766 } 767 768 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) { 769 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, 770 args[0], args[1]); 771 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); 772 } 773 774 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) 775 return rewriter.create<mlir::AndOp>(loc, args); 776 777 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) 778 return rewriter.create<mlir::OrOp>(loc, args); 779 780 return {}; 781 } 782 783 // Performs the match and rewrite for reduction operations. This includes 784 // declaring a correctly sized initial value, and the linalg.generic operation 785 // that reduces across the specified axis. 786 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, 787 PatternRewriter &rewriter) { 788 auto loc = op->getLoc(); 789 auto inputTy = op->getOperand(0).getType().template cast<ShapedType>(); 790 auto resultTy = op->getResult(0).getType().template cast<ShapedType>(); 791 auto elementTy = resultTy.getElementType(); 792 Value input = op->getOperand(0); 793 794 llvm::SmallVector<int64_t> reduceShape; 795 for (unsigned i = 0; i < inputTy.getRank(); i++) { 796 if (axis != i) 797 reduceShape.push_back(inputTy.getDimSize(i)); 798 } 799 800 Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType()); 801 802 // First fill the output buffer with the init value. 803 auto initTensor = 804 rewriter 805 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape, 806 resultTy.getElementType()) 807 .result(); 808 809 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); 810 if (!fillValueAttr) 811 return rewriter.notifyMatchFailure( 812 op, "No initial value found for reduction operation"); 813 814 auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr); 815 auto filledTensor = 816 rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result(); 817 818 SmallVector<AffineExpr, 2> srcExprs; 819 SmallVector<AffineExpr, 2> dstExprs; 820 SmallVector<StringRef, 4> iteratorTypes; 821 for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) { 822 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); 823 824 iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName() 825 : getParallelIteratorTypeName()); 826 if (axis != i) 827 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); 828 } 829 830 bool didEncounterError = false; 831 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs}); 832 auto linalgOp = rewriter.create<linalg::GenericOp>( 833 loc, reduceTy, input, filledTensor, maps, iteratorTypes, 834 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { 835 auto result = createLinalgBodyCalculationForReduceOp( 836 op, blockArgs, elementTy, rewriter); 837 if (result) 838 didEncounterError = true; 839 840 nestedBuilder.create<linalg::YieldOp>(loc, result); 841 }); 842 843 if (!didEncounterError) 844 return failure(); 845 846 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy, 847 linalgOp.getResults()); 848 return success(); 849 } 850 851 static LogicalResult 852 convolutionMatchAndRewriterHelper(Operation *op, 853 ConversionPatternRewriter &rewriter) { 854 Location loc = op->getLoc(); 855 Value input = op->getOperand(0); 856 Value weight = op->getOperand(1); 857 Value bias = op->getOperand(2); 858 859 ShapedType inputTy = input.getType().cast<ShapedType>(); 860 ShapedType weightTy = weight.getType().cast<ShapedType>(); 861 ShapedType biasTy = bias.getType().cast<ShapedType>(); 862 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 863 864 Type inputETy = inputTy.getElementType(); 865 Type resultETy = resultTy.getElementType(); 866 867 auto padAttr = op->getAttr("pad").cast<ArrayAttr>(); 868 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>(); 869 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>(); 870 871 bool isQuantized = op->hasAttr("quantization_info"); 872 IntegerAttr iZp; 873 IntegerAttr kZp; 874 if (isQuantized) { 875 auto quantizationInfo = 876 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>(); 877 iZp = rewriter.getI32IntegerAttr( 878 quantizationInfo.input_zp().getValue().getSExtValue()); 879 kZp = rewriter.getI32IntegerAttr( 880 quantizationInfo.weight_zp().getValue().getSExtValue()); 881 } 882 883 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 884 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 885 return rewriter.notifyMatchFailure(op, 886 "tosa.conv ops require static shapes"); 887 888 auto weightShape = weightTy.getShape(); 889 auto resultShape = resultTy.getShape(); 890 891 // Apply padding as necessary. 892 Attribute zeroAttr = rewriter.getZeroAttr(inputETy); 893 llvm::SmallVector<int64_t> pad; 894 pad.resize(2, 0); 895 getValuesFromIntArrayAttribute(padAttr, pad); 896 pad.resize(pad.size() + 2, 0); 897 898 input = applyPad(loc, input, pad, zeroAttr, rewriter); 899 900 // Broadcast the initial value to the output tensor before convolving. 901 SmallVector<AffineMap, 4> indexingMaps; 902 indexingMaps.push_back(AffineMap::get( 903 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, 904 {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); 905 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); 906 907 Value initTensor = rewriter.create<linalg::InitTensorOp>( 908 loc, resultTy.getShape(), resultTy.getElementType()); 909 910 Value biasBroadcast = 911 rewriter 912 .create<linalg::GenericOp>( 913 loc, resultTy, bias, initTensor, indexingMaps, 914 getNParallelLoopsAttrs(resultTy.getRank()), 915 [&](OpBuilder &nestedBuilder, Location nestedLoc, 916 ValueRange args) { 917 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 918 }) 919 .getResult(0); 920 921 // Extract the attributes for convolution. 922 llvm::SmallVector<int64_t> stride, dilation; 923 getValuesFromIntArrayAttribute(strideTosaAttr, stride); 924 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation); 925 926 // Create the convolution op. 927 auto strideAttr = DenseIntElementsAttr::get( 928 RankedTensorType::get({2}, rewriter.getI64Type()), stride); 929 auto dilationAttr = DenseIntElementsAttr::get( 930 RankedTensorType::get({2}, rewriter.getI64Type()), dilation); 931 932 if (isa<tosa::Conv2DOp>(op) && !isQuantized) { 933 rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>( 934 op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast}, 935 strideAttr, dilationAttr); 936 return success(); 937 } 938 939 if (isa<tosa::Conv2DOp>(op) && isQuantized) { 940 auto iZpVal = rewriter.create<ConstantOp>(loc, iZp); 941 auto kZpVal = rewriter.create<ConstantOp>(loc, kZp); 942 rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyQOp>( 943 op, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, 944 ValueRange{biasBroadcast}, strideAttr, dilationAttr); 945 return success(); 946 } 947 948 if (isa<tosa::DepthwiseConv2DOp>(op) && !isQuantized) { 949 ShapedType linalgConvTy = 950 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], 951 weightShape[2], weightShape[3]}, 952 resultETy); 953 954 Value biasReshape = 955 rewriter.create<tosa::ReshapeOp>(loc, linalgConvTy, biasBroadcast); 956 Value conv = rewriter 957 .create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>( 958 loc, linalgConvTy, ValueRange{input, weight}, 959 ValueRange{biasReshape}, dilationAttr, strideAttr) 960 .getResult(0); 961 962 Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv); 963 rewriter.replaceOp(op, reshape); 964 return success(); 965 } 966 967 return failure(); 968 } 969 970 namespace { 971 972 template <typename SrcOp> 973 class PointwiseConverter : public OpRewritePattern<SrcOp> { 974 public: 975 using OpRewritePattern<SrcOp>::OpRewritePattern; 976 977 LogicalResult matchAndRewrite(SrcOp op, 978 PatternRewriter &rewriter) const final { 979 return elementwiseMatchAndRewriteHelper(op, rewriter); 980 } 981 }; 982 983 template <typename T> 984 class ConvConverter : public OpConversionPattern<T> { 985 public: 986 using OpConversionPattern<T>::OpConversionPattern; 987 LogicalResult 988 matchAndRewrite(T op, ArrayRef<Value> args, 989 ConversionPatternRewriter &rewriter) const final { 990 return convolutionMatchAndRewriterHelper(op, rewriter); 991 } 992 }; 993 994 class TransposeConvConverter 995 : public OpConversionPattern<tosa::TransposeConv2DOp> { 996 public: 997 using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern; 998 LogicalResult 999 matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef<Value> args, 1000 ConversionPatternRewriter &rewriter) const final { 1001 Location loc = op->getLoc(); 1002 Value input = op->getOperand(0); 1003 Value weight = op->getOperand(1); 1004 Value bias = op->getOperand(2); 1005 1006 ShapedType inputTy = input.getType().cast<ShapedType>(); 1007 ShapedType weightTy = weight.getType().cast<ShapedType>(); 1008 ShapedType biasTy = bias.getType().cast<ShapedType>(); 1009 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 1010 1011 llvm::SmallVector<int64_t> pad; 1012 llvm::SmallVector<int64_t> stride; 1013 llvm::SmallVector<int64_t> dilation; 1014 1015 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 1016 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 1017 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation); 1018 1019 // We have not solved for stride / dilation yet. Dilation should be 1020 // straight forward but stride is more complicated. Linalg work is likely 1021 // required for efficient implementation. 1022 if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) 1023 return failure(); 1024 if (llvm::any_of(dilation, [](int64_t v) { return v != 1; })) 1025 return failure(); 1026 1027 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 1028 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 1029 return failure(); 1030 1031 int64_t inputHeight = inputTy.getDimSize(1); 1032 int64_t inputWidth = inputTy.getDimSize(2); 1033 int64_t kernelHeight = weightTy.getDimSize(1); 1034 int64_t kernelWidth = weightTy.getDimSize(2); 1035 int64_t outputHeight = resultTy.getDimSize(1); 1036 int64_t outputWidth = resultTy.getDimSize(2); 1037 1038 int64_t requiredInputHeight = outputHeight + kernelHeight - 1; 1039 int64_t requiredInputWidth = outputWidth + kernelWidth - 1; 1040 1041 llvm::SmallVector<int64_t> newPad(4, 0); 1042 newPad[0] = kernelHeight - 1 - pad[0]; 1043 newPad[2] = kernelWidth - 1 - pad[1]; 1044 1045 newPad[1] = requiredInputHeight - newPad[0] - inputHeight; 1046 newPad[3] = requiredInputWidth - newPad[2] - inputWidth; 1047 1048 auto reverse1 = rewriter.create<tosa::ReverseOp>( 1049 loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); 1050 auto reverse2 = rewriter.create<tosa::ReverseOp>( 1051 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); 1052 1053 Value conv2d; 1054 if (op.quantization_info().hasValue()) { 1055 conv2d = rewriter.create<tosa::Conv2DOp>( 1056 loc, resultTy, input, reverse2, bias, 1057 rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride), 1058 rewriter.getI64ArrayAttr(dilation), 1059 op.quantization_info().getValue()); 1060 } else { 1061 conv2d = rewriter.create<tosa::Conv2DOp>( 1062 loc, resultTy, input, reverse2, bias, 1063 rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride), 1064 rewriter.getI64ArrayAttr(dilation)); 1065 } 1066 1067 rewriter.replaceOp(op, conv2d); 1068 return success(); 1069 } 1070 }; 1071 1072 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> { 1073 public: 1074 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern; 1075 LogicalResult 1076 matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args, 1077 ConversionPatternRewriter &rewriter) const final { 1078 tosa::MatMulOp::Adaptor adaptor(args); 1079 1080 Location loc = op.getLoc(); 1081 1082 auto outputTy = op.getType().cast<ShapedType>(); 1083 auto outputElementTy = outputTy.getElementType(); 1084 auto zeroAttr = rewriter.getZeroAttr(outputElementTy); 1085 Value zero = rewriter.create<ConstantOp>(loc, zeroAttr); 1086 auto initTensor = rewriter.create<linalg::InitTensorOp>( 1087 loc, outputTy.getShape(), outputTy.getElementType()); 1088 Value zeroTensor = 1089 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0); 1090 if (!op.quantization_info()) { 1091 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>( 1092 op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, 1093 ValueRange{zeroTensor}); 1094 return success(); 1095 } 1096 1097 auto quantizationInfo = op.quantization_info().getValue(); 1098 auto aZp = rewriter.create<ConstantOp>( 1099 loc, rewriter.getI32IntegerAttr( 1100 quantizationInfo.a_zp().getValue().getSExtValue())); 1101 auto bZp = rewriter.create<ConstantOp>( 1102 loc, rewriter.getI32IntegerAttr( 1103 quantizationInfo.b_zp().getValue().getSExtValue())); 1104 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>( 1105 op, TypeRange{op.getType()}, 1106 ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); 1107 1108 return success(); 1109 } 1110 }; 1111 1112 class FullyConnectedConverter 1113 : public OpConversionPattern<tosa::FullyConnectedOp> { 1114 public: 1115 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern; 1116 LogicalResult 1117 matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args, 1118 ConversionPatternRewriter &rewriter) const final { 1119 Location loc = op.getLoc(); 1120 auto outputTy = op.getType().cast<ShapedType>(); 1121 auto input = op.input(); 1122 auto weight = op.weight(); 1123 auto bias = op.bias(); 1124 1125 auto weightTy = weight.getType().cast<ShapedType>(); 1126 auto weightShape = weightTy.getShape(); 1127 1128 // Creating maps for the output of MatMul and the bias 1129 SmallVector<AffineMap, 4> indexingMaps; 1130 1131 // Broadcast the bias. 1132 indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, 1133 {rewriter.getAffineDimExpr(1)}, 1134 rewriter.getContext())); 1135 1136 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); 1137 1138 auto initTensor = 1139 rewriter 1140 .create<linalg::InitTensorOp>(loc, outputTy.getShape(), 1141 outputTy.getElementType()) 1142 ->getResults(); 1143 1144 auto linalgOp = 1145 rewriter 1146 .create<linalg::GenericOp>( 1147 loc, outputTy, bias, initTensor, indexingMaps, 1148 getNParallelLoopsAttrs(outputTy.getRank()), 1149 [&](OpBuilder &nested_builder, Location nested_loc, 1150 ValueRange args) { 1151 nested_builder.create<linalg::YieldOp>(loc, *args.begin()); 1152 }) 1153 ->getResults(); 1154 1155 SmallVector<int64_t> permutation{1, 0}; 1156 auto permutationAttr = DenseIntElementsAttr::get( 1157 RankedTensorType::get({2}, rewriter.getI64Type()), permutation); 1158 Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr); 1159 1160 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]}; 1161 Type newWeightTy = 1162 RankedTensorType::get(newWeightShape, weightTy.getElementType()); 1163 1164 Value transposedWeight = rewriter.create<tosa::TransposeOp>( 1165 loc, newWeightTy, weight, permutationValue); 1166 1167 if (!op.quantization_info()) { 1168 rewriter.replaceOpWithNewOp<linalg::MatmulOp>( 1169 op, TypeRange{op.getType()}, ValueRange{input, transposedWeight}, 1170 linalgOp); 1171 return success(); 1172 } 1173 1174 auto quantizationInfo = op.quantization_info().getValue(); 1175 auto inputZp = rewriter.create<ConstantOp>( 1176 loc, rewriter.getI32IntegerAttr( 1177 quantizationInfo.input_zp().getValue().getSExtValue())); 1178 auto outputZp = rewriter.create<ConstantOp>( 1179 loc, rewriter.getI32IntegerAttr( 1180 quantizationInfo.weight_zp().getValue().getSExtValue())); 1181 rewriter.replaceOpWithNewOp<linalg::QuantizedMatmulOp>( 1182 op, TypeRange{op.getType()}, 1183 ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp); 1184 1185 return success(); 1186 } 1187 }; 1188 1189 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> { 1190 public: 1191 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; 1192 1193 LogicalResult 1194 matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args, 1195 ConversionPatternRewriter &rewriter) const final { 1196 typename tosa::ReshapeOp::Adaptor operands(args); 1197 1198 ShapedType operandTy = operands.input1().getType().cast<ShapedType>(); 1199 ShapedType resultTy = reshape.getType().template cast<ShapedType>(); 1200 1201 if (operandTy == resultTy) { 1202 rewriter.replaceOp(reshape, args[0]); 1203 return success(); 1204 } 1205 1206 if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) 1207 return failure(); 1208 1209 // Compute the reassociation maps for the linalg operation. 1210 ArrayRef<int64_t> expandedShape = 1211 (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape() 1212 : resultTy.getShape()); 1213 ArrayRef<int64_t> collapsedShape = 1214 (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape() 1215 : operandTy.getShape()); 1216 unsigned currSrcDim = 0, currDstDim = 0; 1217 SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size()); 1218 1219 // First scan all dimensions in the source shapes to see whether we have a 1220 // perfect case where consecutive dimensions in source are collapsed. For 1221 // such case we can just generate one single linalg.reshape. 1222 bool isCollapsingSource = true; 1223 while (currSrcDim < expandedShape.size() && 1224 currDstDim < collapsedShape.size()) { 1225 int64_t dstSize = collapsedShape[currDstDim]; 1226 int64_t srcSize = expandedShape[currSrcDim]; 1227 while (srcSize < dstSize && currSrcDim < expandedShape.size()) { 1228 reassociationMap[currDstDim].push_back( 1229 rewriter.getAffineDimExpr(currSrcDim++)); 1230 srcSize *= expandedShape[currSrcDim]; 1231 } 1232 if (srcSize == dstSize) { 1233 reassociationMap[currDstDim].push_back( 1234 rewriter.getAffineDimExpr(currSrcDim++)); 1235 // If the next dim in collapsedShape is not 1, treat subsequent dims in 1236 // expandedShape which are 1 to be collapsed. 1237 if (currDstDim == collapsedShape.size() - 1 || 1238 collapsedShape[currDstDim + 1] != 1) { 1239 while (currSrcDim < expandedShape.size() && 1240 expandedShape[currSrcDim] == 1) { 1241 reassociationMap[currDstDim].push_back( 1242 rewriter.getAffineDimExpr(currSrcDim++)); 1243 } 1244 } 1245 } else { 1246 isCollapsingSource = false; 1247 break; 1248 } 1249 currDstDim++; 1250 } 1251 1252 // Check if any remaining dimensions exist. If either is rank-0 we only 1253 // require the directly lowering. 1254 if (currSrcDim != expandedShape.size() || 1255 currDstDim != collapsedShape.size()) 1256 isCollapsingSource = collapsedShape.empty() || expandedShape.empty(); 1257 1258 // Otherwise, we need to first reduce all source dimensions into one and 1259 // then expand to the destination dimensions. 1260 if (!isCollapsingSource) { 1261 auto getIdentityExprs = [&rewriter](int n) { 1262 SmallVector<AffineExpr, 4> exprs; 1263 for (int i = 0; i < n; ++i) 1264 exprs.push_back(rewriter.getAffineDimExpr(i)); 1265 return exprs; 1266 }; 1267 Location loc = reshape.getLoc(); 1268 int64_t totalElems = 1269 std::accumulate(expandedShape.begin(), expandedShape.end(), 1, 1270 std::multiplies<int64_t>()); 1271 auto elemTy = operandTy.getElementType(); 1272 SmallVector<ReassociationExprs, 4> collapsingMap = { 1273 // Use operandTy here because we need to collapse all operands 1274 // dimensions. 1275 getIdentityExprs(operandTy.getShape().size())}; 1276 SmallVector<ReassociationExprs, 4> expandingMap = { 1277 // Use resultTy here because we need to expand to all result 1278 // dimensions. 1279 getIdentityExprs(resultTy.getShape().size())}; 1280 1281 auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); 1282 Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>( 1283 loc, collapsedTy, args[0], collapsingMap); 1284 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>( 1285 reshape, resultTy, collapsedOp, expandingMap); 1286 1287 return success(); 1288 } 1289 1290 if (resultTy.getRank() < args[0].getType().cast<ShapedType>().getRank()) 1291 rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>( 1292 reshape, resultTy, args[0], reassociationMap); 1293 else 1294 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>( 1295 reshape, resultTy, args[0], reassociationMap); 1296 1297 return success(); 1298 } 1299 }; 1300 1301 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> { 1302 public: 1303 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern; 1304 1305 LogicalResult matchAndRewrite(tosa::TransposeOp op, 1306 PatternRewriter &rewriter) const final { 1307 DenseIntElementsAttr perms; 1308 if (!matchPattern(op.perms(), m_Constant(&perms))) { 1309 return failure(); 1310 } 1311 1312 auto resultTy = op.getType().cast<ShapedType>(); 1313 if (!resultTy.hasStaticShape()) 1314 return failure(); 1315 1316 SmallVector<AffineExpr, 2> inputExprs; 1317 inputExprs.resize(resultTy.getRank()); 1318 for (auto permutation : llvm::enumerate(perms.getIntValues())) { 1319 inputExprs[permutation.value().getZExtValue()] = 1320 rewriter.getAffineDimExpr(permutation.index()); 1321 } 1322 1323 auto initTensor = rewriter.create<linalg::InitTensorOp>( 1324 op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(), 1325 resultTy.getElementType()); 1326 1327 SmallVector<AffineMap, 2> affineMaps = { 1328 AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, 1329 rewriter.getContext()), 1330 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 1331 1332 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 1333 op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, 1334 getNParallelLoopsAttrs(resultTy.getRank()), 1335 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 1336 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); 1337 }); 1338 return success(); 1339 } 1340 }; 1341 1342 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { 1343 public: 1344 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; 1345 1346 LogicalResult matchAndRewrite(tosa::RescaleOp op, 1347 PatternRewriter &rewriter) const final { 1348 auto loc = op.getLoc(); 1349 auto input = op.input(); 1350 auto inputTy = op.input().getType().cast<ShapedType>(); 1351 auto outputTy = op.output().getType().cast<ShapedType>(); 1352 unsigned rank = inputTy.getRank(); 1353 1354 // This is an illegal configuration. terminate and log an error 1355 if (op.double_round() && !op.scale32()) 1356 return rewriter.notifyMatchFailure( 1357 op, "tosa.rescale requires scale32 for double_round to be true"); 1358 1359 if (!outputTy.hasStaticShape()) 1360 return rewriter.notifyMatchFailure( 1361 op, "tosa to linalg conversion expects statically shaped tensors"); 1362 1363 // The shift and multiplier values. 1364 SmallVector<int32_t> multiplierValues; 1365 getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues); 1366 1367 SmallVector<int8_t> shiftValues; 1368 getValuesFromIntArrayAttribute(op.shift(), shiftValues); 1369 1370 // Double round only occurs if shift is greater than 31, check that this 1371 // is ever true. 1372 bool doubleRound = 1373 op.double_round() && 1374 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); 1375 1376 SmallVector<AffineMap> indexingMaps = { 1377 rewriter.getMultiDimIdentityMap(rank)}; 1378 SmallVector<Value, 4> genericInputs = {input}; 1379 1380 // If we are rescaling per-channel then we need to store the multiplier 1381 // values in a buffer. 1382 Value multiplierConstant; 1383 int64_t multiplierArg = 0; 1384 if (multiplierValues.size() == 1) { 1385 multiplierConstant = rewriter.create<ConstantOp>( 1386 loc, rewriter.getI32IntegerAttr(multiplierValues.front())); 1387 } else { 1388 SmallVector<AffineExpr, 2> multiplierExprs{ 1389 rewriter.getAffineDimExpr(rank - 1)}; 1390 auto multiplierType = 1391 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, 1392 rewriter.getI32Type()); 1393 genericInputs.push_back(rewriter.create<ConstantOp>( 1394 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); 1395 1396 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, 1397 /*symbolCount=*/0, multiplierExprs, 1398 rewriter.getContext())); 1399 1400 multiplierArg = indexingMaps.size() - 1; 1401 } 1402 1403 // If we are rescaling per-channel then we need to store the shift 1404 // values in a buffer. 1405 Value shiftConstant; 1406 int64_t shiftArg = 0; 1407 if (shiftValues.size() == 1) { 1408 shiftConstant = rewriter.create<ConstantOp>( 1409 loc, rewriter.getI8IntegerAttr(shiftValues.front())); 1410 } else { 1411 SmallVector<AffineExpr, 2> shiftExprs = { 1412 rewriter.getAffineDimExpr(rank - 1)}; 1413 auto shiftType = 1414 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, 1415 rewriter.getIntegerType(8)); 1416 genericInputs.push_back(rewriter.create<ConstantOp>( 1417 loc, DenseIntElementsAttr::get(shiftType, shiftValues))); 1418 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, 1419 /*symbolCount=*/0, shiftExprs, 1420 rewriter.getContext())); 1421 shiftArg = indexingMaps.size() - 1; 1422 } 1423 1424 // Indexing maps for output values. 1425 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); 1426 1427 // Construct the indexing maps needed for linalg.generic ops. 1428 Value initTensor = rewriter.create<linalg::InitTensorOp>( 1429 loc, ArrayRef<Value>({}), outputTy.getShape(), 1430 outputTy.getElementType()); 1431 1432 auto linalgOp = rewriter.create<linalg::GenericOp>( 1433 loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, 1434 getNParallelLoopsAttrs(rank), 1435 [&](OpBuilder &nestedBuilder, Location nestedLoc, 1436 ValueRange blockArgs) { 1437 Value value = blockArgs[0]; 1438 1439 // For now we do all of our math in 64-bit. This is not optimal but 1440 // should be correct for now, consider computing correct bit depth 1441 // later. 1442 int32_t inBitwidth = 1443 value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32; 1444 1445 auto inputZp = createConstFromIntAttribute<int32_t>( 1446 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), 1447 nestedBuilder); 1448 auto outputZp = createConstFromIntAttribute<int32_t>( 1449 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); 1450 1451 Value multiplier = multiplierConstant ? multiplierConstant 1452 : blockArgs[multiplierArg]; 1453 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; 1454 1455 if (value.getType().getIntOrFloatBitWidth() < 32) { 1456 value = nestedBuilder.create<SignExtendIOp>( 1457 nestedLoc, nestedBuilder.getI32Type(), value); 1458 } 1459 1460 value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp); 1461 1462 value = nestedBuilder.create<tosa::ApplyScaleOp>( 1463 loc, nestedBuilder.getI32Type(), value, multiplier, shift, 1464 nestedBuilder.getBoolAttr(doubleRound)); 1465 1466 // Move to the new zero-point. 1467 value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp); 1468 1469 // Saturate to the output size. 1470 IntegerType outIntType = 1471 blockArgs.back().getType().cast<IntegerType>(); 1472 unsigned outBitWidth = outIntType.getWidth(); 1473 auto intMin = nestedBuilder.create<ConstantOp>( 1474 loc, nestedBuilder.getIntegerAttr( 1475 nestedBuilder.getI32Type(), 1476 APInt::getSignedMinValue(outBitWidth).getSExtValue())); 1477 auto intMax = nestedBuilder.create<ConstantOp>( 1478 loc, nestedBuilder.getIntegerAttr( 1479 nestedBuilder.getI32Type(), 1480 APInt::getSignedMaxValue(outBitWidth).getSExtValue())); 1481 1482 value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax, 1483 CmpIPredicate::slt, nestedBuilder); 1484 1485 if (outIntType.getWidth() < 32) { 1486 value = 1487 nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value); 1488 } 1489 1490 nestedBuilder.create<linalg::YieldOp>(loc, value); 1491 }); 1492 1493 rewriter.replaceOp(op, linalgOp->getResults()); 1494 return success(); 1495 } 1496 }; 1497 1498 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> { 1499 public: 1500 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; 1501 1502 LogicalResult matchAndRewrite(tosa::ResizeOp op, 1503 PatternRewriter &rewriter) const final { 1504 Location loc = op.getLoc(); 1505 auto input = op.input(); 1506 auto inputTy = input.getType().cast<ShapedType>(); 1507 auto resultTy = op.getType().cast<ShapedType>(); 1508 auto resultElementTy = resultTy.getElementType(); 1509 1510 auto imageH = inputTy.getShape()[1]; 1511 auto imageW = inputTy.getShape()[2]; 1512 1513 if (!resultTy.hasStaticShape()) 1514 return failure(); 1515 if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR") 1516 return failure(); 1517 1518 auto initTensor = 1519 rewriter 1520 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{}, 1521 resultTy.getShape(), resultElementTy) 1522 .result(); 1523 1524 SmallVector<AffineMap, 2> affineMaps = { 1525 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 1526 1527 auto genericOp = rewriter.create<linalg::GenericOp>( 1528 loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps, 1529 getNParallelLoopsAttrs(resultTy.getRank())); 1530 rewriter.replaceOp(op, genericOp.getResult(0)); 1531 1532 { 1533 OpBuilder::InsertionGuard regionGuard(rewriter); 1534 rewriter.createBlock(&genericOp.region(), genericOp.region().end(), 1535 TypeRange({resultElementTy})); 1536 Value batch = rewriter.create<linalg::IndexOp>(loc, 0); 1537 Value y = rewriter.create<linalg::IndexOp>(loc, 1); 1538 Value x = rewriter.create<linalg::IndexOp>(loc, 2); 1539 Value channel = rewriter.create<linalg::IndexOp>(loc, 3); 1540 1541 auto hwMin = 1542 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0)); 1543 auto hMax = rewriter.create<ConstantOp>( 1544 loc, rewriter.getI32IntegerAttr(imageH - 1)); 1545 auto wMax = rewriter.create<ConstantOp>( 1546 loc, rewriter.getI32IntegerAttr(imageW - 1)); 1547 1548 Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), y); 1549 Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), x); 1550 1551 int32_t shift = op.shift(); 1552 bool floatingPointMode = shift == 0; 1553 1554 Value yStride, xStride, yOffset, xOffset; 1555 if (floatingPointMode) { 1556 yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]); 1557 xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]); 1558 yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]); 1559 xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]); 1560 } else { 1561 SmallVector<int32_t> stride, offset; 1562 getValuesFromIntArrayAttribute(op.stride(), stride); 1563 getValuesFromIntArrayAttribute(op.offset(), offset); 1564 1565 yStride = rewriter.create<ConstantOp>( 1566 loc, rewriter.getI32IntegerAttr(stride[0])); 1567 xStride = rewriter.create<ConstantOp>( 1568 loc, rewriter.getI32IntegerAttr(stride[1])); 1569 yOffset = rewriter.create<ConstantOp>( 1570 loc, rewriter.getI32IntegerAttr(offset[0])); 1571 xOffset = rewriter.create<ConstantOp>( 1572 loc, rewriter.getI32IntegerAttr(offset[1])); 1573 } 1574 1575 // Compute the the integer index and partial offset. 1576 // x = x * stride + offset; 1577 // ix = floor(x) 1578 // dx = x - ix 1579 Value ix, iy, dx, dy; 1580 if (floatingPointMode) { 1581 Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY); 1582 Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX); 1583 1584 y = rewriter.create<MulFOp>(loc, y, yStride); 1585 x = rewriter.create<MulFOp>(loc, x, xStride); 1586 1587 y = rewriter.create<AddFOp>(loc, y, yOffset); 1588 x = rewriter.create<AddFOp>(loc, x, xOffset); 1589 1590 iy = rewriter.create<FloorFOp>(loc, y); 1591 ix = rewriter.create<FloorFOp>(loc, x); 1592 1593 dy = rewriter.create<SubFOp>(loc, y, iy); 1594 dx = rewriter.create<SubFOp>(loc, x, ix); 1595 1596 iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy); 1597 ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix); 1598 } else { 1599 Value shiftVal = 1600 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift)); 1601 1602 Value y = rewriter.create<MulIOp>(loc, inY, yStride); 1603 Value x = rewriter.create<MulIOp>(loc, inX, xStride); 1604 1605 y = rewriter.create<AddIOp>(loc, y, yOffset); 1606 x = rewriter.create<AddIOp>(loc, x, xOffset); 1607 1608 iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal); 1609 ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal); 1610 1611 Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal); 1612 Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal); 1613 1614 dy = rewriter.create<SubIOp>(loc, y, yTrunc); 1615 dx = rewriter.create<SubIOp>(loc, x, xTrunc); 1616 } 1617 1618 if (op.mode() == "NEAREST_NEIGHBOR") { 1619 Value yPred, xPred; 1620 // Round the index position towards the closest pixel location. 1621 if (floatingPointMode) { 1622 auto halfVal = 1623 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f)); 1624 yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy, 1625 halfVal); 1626 xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx, 1627 halfVal); 1628 } else { 1629 auto halfVal = rewriter.create<ConstantOp>( 1630 loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); 1631 yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy, 1632 halfVal); 1633 xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx, 1634 halfVal); 1635 } 1636 1637 auto zeroVal = 1638 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0)); 1639 auto oneVal = 1640 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1)); 1641 1642 auto yOffset = 1643 rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal); 1644 auto xOffset = 1645 rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal); 1646 1647 iy = rewriter.create<AddIOp>(loc, iy, yOffset); 1648 ix = rewriter.create<AddIOp>(loc, ix, xOffset); 1649 1650 // Clamp the to be within the bounds of the input image. 1651 1652 iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt, 1653 rewriter); 1654 ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt, 1655 rewriter); 1656 1657 // Read the value from the input array. 1658 iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy); 1659 ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix); 1660 1661 Value result = rewriter.create<tensor::ExtractOp>( 1662 loc, input, ValueRange{batch, iy, ix, channel}); 1663 1664 rewriter.create<linalg::YieldOp>(loc, result); 1665 1666 return success(); 1667 } 1668 1669 if (op.mode() == "BILINEAR") { 1670 Value y0 = iy; 1671 Value x0 = ix; 1672 1673 auto oneVal = 1674 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1)); 1675 Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal); 1676 Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal); 1677 1678 y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt, 1679 rewriter); 1680 y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt, 1681 rewriter); 1682 1683 x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt, 1684 rewriter); 1685 x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt, 1686 rewriter); 1687 1688 y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0); 1689 y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1); 1690 x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0); 1691 x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1); 1692 1693 Value y0x0 = rewriter.create<tensor::ExtractOp>( 1694 loc, input, ValueRange{batch, y0, x0, channel}); 1695 Value y0x1 = rewriter.create<tensor::ExtractOp>( 1696 loc, input, ValueRange{batch, y0, x1, channel}); 1697 Value y1x0 = rewriter.create<tensor::ExtractOp>( 1698 loc, input, ValueRange{batch, y1, x0, channel}); 1699 Value y1x1 = rewriter.create<tensor::ExtractOp>( 1700 loc, input, ValueRange{batch, y1, x1, channel}); 1701 1702 if (floatingPointMode) { 1703 auto oneVal = 1704 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f)); 1705 Value rightPart = dx; 1706 Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx); 1707 1708 y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart); 1709 y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart); 1710 Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1); 1711 1712 y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart); 1713 y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart); 1714 Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1); 1715 1716 Value bottomPart = dy; 1717 Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy); 1718 topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart); 1719 bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart); 1720 Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc); 1721 1722 rewriter.create<linalg::YieldOp>(loc, result); 1723 return success(); 1724 } else { 1725 y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0); 1726 y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1); 1727 y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0); 1728 y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1); 1729 1730 if (resultElementTy.getIntOrFloatBitWidth() > 32) { 1731 dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx); 1732 dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy); 1733 } 1734 1735 auto unitVal = rewriter.create<ConstantOp>( 1736 loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); 1737 Value rightPart = dx; 1738 Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx); 1739 1740 y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart); 1741 y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart); 1742 Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1); 1743 1744 y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart); 1745 y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart); 1746 Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1); 1747 1748 Value bottomPart = dy; 1749 Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy); 1750 topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart); 1751 bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart); 1752 Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc); 1753 1754 rewriter.create<linalg::YieldOp>(loc, result); 1755 return success(); 1756 } 1757 } 1758 1759 return failure(); 1760 } 1761 1762 return success(); 1763 } 1764 }; 1765 1766 // At the codegen level any identity operations should be removed. Any cases 1767 // where identity is load-bearing (e.g. cross device computation) should be 1768 // handled before lowering to codegen. 1769 template <typename SrcOp> 1770 class IdentityNConverter : public OpRewritePattern<SrcOp> { 1771 public: 1772 using OpRewritePattern<SrcOp>::OpRewritePattern; 1773 1774 LogicalResult matchAndRewrite(SrcOp op, 1775 PatternRewriter &rewriter) const final { 1776 rewriter.replaceOp(op, op.getOperation()->getOperands()); 1777 return success(); 1778 } 1779 }; 1780 1781 template <typename SrcOp> 1782 class ReduceConverter : public OpRewritePattern<SrcOp> { 1783 public: 1784 using OpRewritePattern<SrcOp>::OpRewritePattern; 1785 1786 LogicalResult matchAndRewrite(SrcOp reduceOp, 1787 PatternRewriter &rewriter) const final { 1788 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter); 1789 } 1790 }; 1791 1792 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { 1793 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; 1794 1795 LogicalResult 1796 matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args, 1797 ConversionPatternRewriter &rewriter) const override { 1798 auto resultType = op.getType().dyn_cast<RankedTensorType>(); 1799 if (!resultType || !resultType.hasStaticShape()) { 1800 return rewriter.notifyMatchFailure(op, 1801 "expected static shaped tensor type"); 1802 } 1803 1804 Location loc = op.getLoc(); 1805 int axis = op.axis(); 1806 Value axisValue = 1807 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis)); 1808 int rank = resultType.getRank(); 1809 SmallVector<Value, 3> offsets, sizes, strides; 1810 sizes.reserve(rank); 1811 strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1)); 1812 offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0)); 1813 1814 for (int i = 0; i < rank; ++i) { 1815 sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i)); 1816 } 1817 1818 Value resultDimSize = sizes[axis]; 1819 for (auto arg : args.drop_front()) { 1820 auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue); 1821 resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size); 1822 } 1823 sizes[axis] = resultDimSize; 1824 1825 Value init = rewriter.create<linalg::InitTensorOp>( 1826 loc, resultType.getShape(), resultType.getElementType()); 1827 1828 Value zeroVal = rewriter.create<ConstantOp>( 1829 loc, rewriter.getZeroAttr(resultType.getElementType())); 1830 Value result = 1831 rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0); 1832 1833 for (auto arg : args) { 1834 sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue); 1835 result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets, 1836 sizes, strides); 1837 offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]); 1838 } 1839 rewriter.replaceOp(op, result); 1840 return success(); 1841 } 1842 }; 1843 1844 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> { 1845 public: 1846 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern; 1847 1848 LogicalResult matchAndRewrite(tosa::ReverseOp op, 1849 PatternRewriter &rewriter) const final { 1850 auto loc = op.getLoc(); 1851 Value input = op.input(); 1852 auto inputTy = input.getType().template cast<ShapedType>(); 1853 auto resultTy = op.getType().template cast<ShapedType>(); 1854 auto rank = resultTy.getRank(); 1855 auto axis = op.axis(); 1856 1857 if (!inputTy.hasStaticShape()) 1858 return rewriter.notifyMatchFailure( 1859 op, "No initial value found for reduction operation"); 1860 1861 // First fill the output buffer with the init value. 1862 auto initTensor = rewriter 1863 .create<linalg::InitTensorOp>( 1864 loc, ArrayRef<Value>({}), inputTy.getShape(), 1865 inputTy.getElementType()) 1866 .result(); 1867 1868 SmallVector<AffineExpr, 2> inputExprs; 1869 inputExprs.resize(resultTy.getRank()); 1870 1871 for (int i = 0; i < rank; i++) 1872 inputExprs[i] = rewriter.getAffineDimExpr(i); 1873 1874 inputExprs[axis] = 1875 rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) - 1876 inputExprs[axis]; 1877 1878 SmallVector<AffineMap, 2> affineMaps = { 1879 AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, 1880 rewriter.getContext()), 1881 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 1882 1883 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 1884 op, resultTy, op.input(), ValueRange{initTensor}, affineMaps, 1885 getNParallelLoopsAttrs(resultTy.getRank()), 1886 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 1887 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); 1888 }); 1889 return success(); 1890 } 1891 }; 1892 1893 // This converter translate a tile operation to a reshape, broadcast, reshape. 1894 // The first reshape minimally expands each tiled dimension to include a 1895 // proceding size-1 dim. This dim is then broadcasted to the appropriate 1896 // multiple. 1897 struct TileConverter : public OpConversionPattern<tosa::TileOp> { 1898 using OpConversionPattern<tosa::TileOp>::OpConversionPattern; 1899 1900 LogicalResult 1901 matchAndRewrite(tosa::TileOp op, ArrayRef<Value> args, 1902 ConversionPatternRewriter &rewriter) const override { 1903 auto loc = op.getLoc(); 1904 auto input = op.input1(); 1905 auto inputTy = input.getType().cast<ShapedType>(); 1906 auto inputShape = inputTy.getShape(); 1907 auto resultTy = op.getType().cast<ShapedType>(); 1908 auto elementTy = inputTy.getElementType(); 1909 int64_t rank = inputTy.getRank(); 1910 1911 if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) 1912 return failure(); 1913 1914 SmallVector<int64_t> multiples; 1915 getValuesFromIntArrayAttribute(op.multiples(), multiples); 1916 1917 // Broadcast the newly added dimensions to their appropriate multiple. 1918 SmallVector<int64_t, 2> genericShape; 1919 for (int i = 0; i < rank; i++) { 1920 genericShape.push_back(multiples[i]); 1921 genericShape.push_back(inputShape[i]); 1922 } 1923 1924 auto initTensor = rewriter.create<linalg::InitTensorOp>( 1925 op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy); 1926 1927 // We needs to map the input shape to the non-broadcasted dimensions. 1928 SmallVector<AffineExpr, 4> dimExprs; 1929 dimExprs.reserve(rank); 1930 for (unsigned i = 0; i < rank; ++i) 1931 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); 1932 1933 auto readAffineMap = 1934 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, 1935 rewriter.getContext()); 1936 1937 SmallVector<AffineMap, 2> affineMaps = { 1938 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; 1939 1940 auto genericOp = rewriter.create<linalg::GenericOp>( 1941 loc, RankedTensorType::get(genericShape, elementTy), input, 1942 ValueRange{initTensor}, affineMaps, 1943 getNParallelLoopsAttrs(genericShape.size()), 1944 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 1945 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); 1946 }); 1947 1948 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( 1949 op, resultTy, genericOp.getResult(0), 1950 rewriter.getI64ArrayAttr(resultTy.getShape())); 1951 return success(); 1952 } 1953 }; 1954 1955 class PadConverter : public OpRewritePattern<tosa::PadOp> { 1956 public: 1957 using OpRewritePattern<tosa::PadOp>::OpRewritePattern; 1958 1959 LogicalResult matchAndRewrite(tosa::PadOp padOp, 1960 PatternRewriter &rewriter) const final { 1961 auto loc = padOp.getLoc(); 1962 auto input = padOp.input1(); 1963 auto padding = padOp.padding(); 1964 1965 ShapedType inputTy = input.getType().cast<ShapedType>(); 1966 ShapedType paddingTy = padding.getType().cast<ShapedType>(); 1967 Type elementTy = inputTy.getElementType(); 1968 int64_t rank = inputTy.getRank(); 1969 1970 if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { 1971 return rewriter.notifyMatchFailure( 1972 padOp, 1973 "Pad converter requires static shaped input / padding values."); 1974 } 1975 1976 Attribute constantAttr; 1977 if (elementTy.isa<FloatType>()) 1978 constantAttr = rewriter.getFloatAttr(elementTy, 0.0); 1979 else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) 1980 constantAttr = rewriter.getIntegerAttr(elementTy, 0); 1981 else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) { 1982 auto value = padOp.quantization_info().getValue().input_zp().getValue(); 1983 constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); 1984 } 1985 1986 if (!constantAttr) { 1987 return rewriter.notifyMatchFailure( 1988 padOp, 1989 "tosa.pad to linalg lowering encountered an unknown element type"); 1990 } 1991 1992 Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0)); 1993 Value highIndex = 1994 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1)); 1995 1996 SmallVector<OpFoldResult, 3> lowValues; 1997 SmallVector<OpFoldResult, 3> highValues; 1998 1999 lowValues.reserve(rank); 2000 highValues.reserve(rank); 2001 2002 for (int i = 0; i < rank; i++) { 2003 Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i); 2004 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>( 2005 loc, padding, ValueRange({inputIndex, lowIndex})); 2006 Value highVal = rewriter.createOrFold<tensor::ExtractOp>( 2007 loc, padding, ValueRange({inputIndex, highIndex})); 2008 2009 lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(), 2010 lowVal); 2011 highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(), 2012 highVal); 2013 2014 lowValues.push_back(lowVal); 2015 highValues.push_back(highVal); 2016 } 2017 2018 Value constant = rewriter.create<ConstantOp>(loc, constantAttr); 2019 2020 auto newPadOp = linalg::PadTensorOp::createPadScalarOp( 2021 padOp.getType(), input, constant, lowValues, highValues, loc, rewriter); 2022 2023 rewriter.replaceOp(padOp, newPadOp.getResult()); 2024 return success(); 2025 } 2026 }; 2027 2028 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic 2029 // op, producing two output buffers. 2030 // 2031 // The first output buffer contains the index of the found maximum value. It is 2032 // initialized to 0 and is resulting integer type. 2033 // 2034 // The second output buffer contains the maximum value found. It is initialized 2035 // to the minimum representable value of the input element type. After being 2036 // populated by indexed_generic, this buffer is disgarded as only the index is 2037 // requested. 2038 // 2039 // The indexed_generic op updates both the maximum value and index if the 2040 // current value exceeds the running max. 2041 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> { 2042 public: 2043 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern; 2044 2045 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, 2046 PatternRewriter &rewriter) const final { 2047 auto loc = argmaxOp.getLoc(); 2048 Value input = argmaxOp.input(); 2049 auto inputTy = input.getType().cast<ShapedType>(); 2050 auto resultTy = argmaxOp.output().getType().cast<ShapedType>(); 2051 auto inElementTy = inputTy.getElementType(); 2052 auto outElementTy = resultTy.getElementType(); 2053 int axis = argmaxOp.axis(); 2054 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); 2055 2056 if (!inputTy.hasStaticShape()) 2057 return rewriter.notifyMatchFailure( 2058 argmaxOp, 2059 "tosa.arg_max to linalg.* requires statically shaped input"); 2060 2061 if (!outElementTy.isa<IntegerType>()) 2062 return rewriter.notifyMatchFailure( 2063 argmaxOp, 2064 "tosa.arg_max to linalg.* requires integer-like result type"); 2065 2066 // First fill the output buffer for the index. 2067 auto initTensorIdx = 2068 rewriter 2069 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), 2070 resultTy.getShape(), outElementTy) 2071 .result(); 2072 auto fillValueIdx = rewriter.create<ConstantOp>( 2073 loc, rewriter.getIntegerAttr(outElementTy, 0)); 2074 auto filledTensorIdx = 2075 rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx) 2076 .result(); 2077 2078 // Second fill the output buffer for the running max. 2079 auto initTensorMax = 2080 rewriter 2081 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), 2082 resultTy.getShape(), inElementTy) 2083 .result(); 2084 auto fillValueMaxAttr = 2085 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); 2086 2087 if (!fillValueMaxAttr) 2088 return rewriter.notifyMatchFailure( 2089 argmaxOp, "unsupported tosa.argmax element type"); 2090 2091 auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr); 2092 auto filledTensorMax = 2093 rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax) 2094 .result(); 2095 2096 // We need to reduce along the arg-max axis, with parallel operations along 2097 // the rest. 2098 SmallVector<StringRef, 4> iteratorTypes; 2099 iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName()); 2100 iteratorTypes[axis] = getReductionIteratorTypeName(); 2101 2102 SmallVector<AffineExpr, 2> srcExprs; 2103 SmallVector<AffineExpr, 2> dstExprs; 2104 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { 2105 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); 2106 if (axis != i) 2107 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); 2108 } 2109 2110 bool didEncounterError = false; 2111 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}); 2112 auto linalgOp = rewriter.create<linalg::GenericOp>( 2113 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input, 2114 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, 2115 [&](OpBuilder &nestedBuilder, Location nestedLoc, 2116 ValueRange blockArgs) { 2117 auto newValue = blockArgs[0]; 2118 auto oldIndex = blockArgs[1]; 2119 auto oldValue = blockArgs[2]; 2120 2121 Value newIndex = rewriter.create<IndexCastOp>( 2122 nestedLoc, oldIndex.getType(), 2123 rewriter.create<linalg::IndexOp>(loc, axis)); 2124 2125 Value predicate; 2126 if (inElementTy.isa<FloatType>()) { 2127 predicate = rewriter.create<mlir::CmpFOp>( 2128 nestedLoc, CmpFPredicate::OGT, newValue, oldValue); 2129 } else if (inElementTy.isa<IntegerType>()) { 2130 predicate = rewriter.create<mlir::CmpIOp>( 2131 nestedLoc, CmpIPredicate::sgt, newValue, oldValue); 2132 } else { 2133 didEncounterError = true; 2134 return; 2135 } 2136 2137 auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate, 2138 newValue, oldValue); 2139 auto resultIndex = rewriter.create<mlir::SelectOp>( 2140 nestedLoc, predicate, newIndex, oldIndex); 2141 nestedBuilder.create<linalg::YieldOp>( 2142 nestedLoc, ValueRange({resultIndex, resultMax})); 2143 }); 2144 2145 if (didEncounterError) 2146 return rewriter.notifyMatchFailure( 2147 argmaxOp, "unsupported tosa.argmax element type"); 2148 2149 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); 2150 return success(); 2151 } 2152 }; 2153 2154 class GatherConverter : public OpConversionPattern<tosa::GatherOp> { 2155 public: 2156 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern; 2157 LogicalResult 2158 matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args, 2159 ConversionPatternRewriter &rewriter) const final { 2160 auto input = args[0]; 2161 auto indices = args[1]; 2162 2163 auto inputTy = input.getType().cast<ShapedType>(); 2164 auto indicesTy = indices.getType().cast<ShapedType>(); 2165 auto resultTy = op.getType().cast<ShapedType>(); 2166 2167 if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape()) 2168 return rewriter.notifyMatchFailure( 2169 op, "require input type to have static shape"); 2170 2171 auto resultElementTy = resultTy.getElementType(); 2172 2173 auto loc = op.getLoc(); 2174 2175 auto initTensor = 2176 rewriter 2177 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{}, 2178 resultTy.getShape(), resultElementTy) 2179 .result(); 2180 2181 SmallVector<AffineMap, 2> affineMaps = { 2182 AffineMap::get( 2183 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, 2184 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, 2185 rewriter.getContext()), 2186 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 2187 2188 auto genericOp = rewriter.create<linalg::GenericOp>( 2189 loc, ArrayRef<Type>({resultTy}), ValueRange{indices}, 2190 ValueRange{initTensor}, affineMaps, 2191 getNParallelLoopsAttrs(resultTy.getRank()), 2192 [&](OpBuilder &b, Location loc, ValueRange args) { 2193 auto indexValue = args[0]; 2194 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0); 2195 Value index1 = rewriter.create<IndexCastOp>( 2196 loc, rewriter.getIndexType(), indexValue); 2197 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2); 2198 Value extract = rewriter.create<tensor::ExtractOp>( 2199 loc, input, ValueRange{index0, index1, index2}); 2200 rewriter.create<linalg::YieldOp>(loc, extract); 2201 }); 2202 rewriter.replaceOp(op, genericOp.getResult(0)); 2203 return success(); 2204 } 2205 }; 2206 2207 // Lowerings the TableOp to a series of gathers and numerica operations. This 2208 // includes interpolation between the high/low values. For the I8 varient, this 2209 // simplifies to a single gather operation. 2210 class TableConverter : public OpRewritePattern<tosa::TableOp> { 2211 public: 2212 using OpRewritePattern<tosa::TableOp>::OpRewritePattern; 2213 2214 LogicalResult matchAndRewrite(tosa::TableOp op, 2215 PatternRewriter &rewriter) const final { 2216 auto loc = op.getLoc(); 2217 Value input = op.input(); 2218 Value table = op.table(); 2219 auto inputTy = input.getType().cast<ShapedType>(); 2220 auto tableTy = table.getType().cast<ShapedType>(); 2221 auto resultTy = op.getType().cast<ShapedType>(); 2222 2223 if (!inputTy.hasStaticShape()) 2224 return rewriter.notifyMatchFailure( 2225 op, "require input type to have static shape"); 2226 2227 auto inputElementTy = inputTy.getElementType(); 2228 auto tableElementTy = tableTy.getElementType(); 2229 auto resultElementTy = resultTy.getElementType(); 2230 2231 auto initTensor = 2232 rewriter 2233 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{}, 2234 resultTy.getShape(), resultElementTy) 2235 .result(); 2236 2237 SmallVector<AffineMap, 2> affineMaps = { 2238 rewriter.getMultiDimIdentityMap(resultTy.getRank()), 2239 rewriter.getMultiDimIdentityMap(resultTy.getRank())}; 2240 2241 auto genericOp = rewriter.create<linalg::GenericOp>( 2242 loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps, 2243 getNParallelLoopsAttrs(resultTy.getRank())); 2244 rewriter.replaceOp(op, genericOp.getResult(0)); 2245 2246 { 2247 OpBuilder::InsertionGuard regionGuard(rewriter); 2248 Block *block = 2249 rewriter.createBlock(&genericOp.region(), genericOp.region().end(), 2250 TypeRange({inputElementTy, resultElementTy})); 2251 2252 auto inputValue = block->getArgument(0); 2253 rewriter.setInsertionPointToStart(block); 2254 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && 2255 resultElementTy.isInteger(8)) { 2256 Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), 2257 inputValue); 2258 Value extract = 2259 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); 2260 rewriter.create<linalg::YieldOp>(loc, extract); 2261 return success(); 2262 } 2263 2264 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && 2265 resultElementTy.isInteger(32)) { 2266 Value extend = rewriter.create<SignExtendIOp>( 2267 loc, rewriter.getI32Type(), inputValue); 2268 2269 auto offset = 2270 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768)); 2271 auto seven = 2272 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7)); 2273 auto one = 2274 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1)); 2275 auto b1111111 = 2276 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127)); 2277 2278 // Compute the index and fractional part from the input value: 2279 // value = value + 32768 2280 // index = value >> 7; 2281 // fraction = 0x01111111 & value 2282 auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset); 2283 Value index = 2284 rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven); 2285 Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111); 2286 2287 // Extract the base and next values from the table. 2288 // base = (int32_t) table[index]; 2289 // next = (int32_t) table[index + 1]; 2290 Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one); 2291 2292 index = 2293 rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index); 2294 indexPlusOne = rewriter.create<IndexCastOp>( 2295 loc, rewriter.getIndexType(), indexPlusOne); 2296 2297 Value base = 2298 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); 2299 Value next = rewriter.create<tensor::ExtractOp>( 2300 loc, table, ValueRange{indexPlusOne}); 2301 2302 base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base); 2303 next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next); 2304 2305 // Use the fractional part to interpolate between the input values: 2306 // result = (base << 7) + (next - base) * fraction 2307 Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven); 2308 Value diff = rewriter.create<SubIOp>(loc, next, base); 2309 Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction); 2310 Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled); 2311 2312 rewriter.create<linalg::YieldOp>(loc, result); 2313 2314 return success(); 2315 } 2316 } 2317 2318 return rewriter.notifyMatchFailure( 2319 op, "unable to create body for tosa.table op"); 2320 } 2321 }; 2322 2323 template <typename SrcOp> 2324 class Pool2dConverter : public OpRewritePattern<SrcOp> { 2325 public: 2326 using OpRewritePattern<SrcOp>::OpRewritePattern; 2327 2328 LogicalResult matchAndRewrite(SrcOp op, 2329 PatternRewriter &rewriter) const final { 2330 Location loc = op.getLoc(); 2331 Value input = op.input(); 2332 ShapedType inputTy = input.getType().cast<ShapedType>(); 2333 Type inElementTy = inputTy.getElementType(); 2334 2335 ShapedType resultTy = op.getType().template cast<ShapedType>(); 2336 Type outElementTy = inputTy.getElementType(); 2337 2338 if (!inputTy.hasStaticShape()) 2339 return failure(); 2340 2341 // Determine what the initial value needs to be for the max pool op. 2342 Attribute initialAttr; 2343 if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32()) 2344 initialAttr = rewriter.getFloatAttr( 2345 outElementTy, 2346 APFloat::getLargest( 2347 outElementTy.cast<FloatType>().getFloatSemantics(), true)); 2348 2349 if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>()) 2350 initialAttr = rewriter.getIntegerAttr( 2351 outElementTy, 2352 APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth())); 2353 2354 if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>()) 2355 initialAttr = rewriter.getZeroAttr(outElementTy); 2356 2357 if (!initialAttr) 2358 return rewriter.notifyMatchFailure( 2359 op, "Unsupported initial value for tosa.maxpool_2d op"); 2360 2361 // Apply padding as necessary. 2362 llvm::SmallVector<int64_t> pad; 2363 pad.resize(2, 0); 2364 getValuesFromIntArrayAttribute(op.pad(), pad); 2365 pad.resize(pad.size() + 2, 0); 2366 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); 2367 2368 Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr); 2369 2370 SmallVector<int64_t> kernel, stride; 2371 getValuesFromIntArrayAttribute(op.kernel(), kernel); 2372 getValuesFromIntArrayAttribute(op.stride(), stride); 2373 2374 Attribute strideAttr = rewriter.getI64VectorAttr(stride); 2375 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); 2376 2377 // Create the linalg op that performs pooling. 2378 Value initTensor = rewriter.create<linalg::InitTensorOp>( 2379 loc, resultTy.getShape(), resultTy.getElementType()); 2380 2381 Value filledInitTensor = 2382 rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result(); 2383 2384 Value fakeWindowDims = 2385 rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy); 2386 2387 if (isa<tosa::MaxPool2dOp>(op)) { 2388 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>( 2389 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, 2390 filledInitTensor, strideAttr, dilationAttr); 2391 return success(); 2392 } 2393 2394 if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) { 2395 Value poolingOp = rewriter 2396 .create<linalg::PoolingNhwcSumOp>( 2397 loc, ArrayRef<Type>{resultTy}, 2398 ValueRange{paddedInput, fakeWindowDims}, 2399 filledInitTensor, strideAttr, dilationAttr) 2400 .getResult(0); 2401 auto poolingOpTy = poolingOp.getType().cast<ShapedType>(); 2402 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); 2403 auto genericOp = rewriter.create<linalg::GenericOp>( 2404 loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp}, 2405 ArrayRef<AffineMap>({affineMap}), 2406 getNParallelLoopsAttrs(resultTy.getRank()), 2407 [&](OpBuilder &b, Location loc, ValueRange args) { 2408 auto zero = rewriter.create<ConstantIndexOp>(loc, 0); 2409 auto one = rewriter.create<ConstantIndexOp>(loc, 1); 2410 auto iH = rewriter.create<ConstantIndexOp>( 2411 loc, poolingOpTy.getDimSize(1) - 1); 2412 auto iW = rewriter.create<ConstantIndexOp>( 2413 loc, poolingOpTy.getDimSize(2) - 1); 2414 2415 // Compute the indices from either end. 2416 auto y0 = rewriter.create<linalg::IndexOp>(loc, 1); 2417 auto x0 = rewriter.create<linalg::IndexOp>(loc, 2); 2418 auto y1 = rewriter.create<SubIOp>(loc, iH, y0); 2419 auto x1 = rewriter.create<SubIOp>(loc, iW, x0); 2420 2421 // Determines what the portion of valid input is covered by the 2422 // kernel. 2423 auto padFn = [&](Value v, Value x, int64_t pad) -> Value { 2424 if (pad == 0) 2425 return v; 2426 2427 auto padVal = rewriter.create<ConstantIndexOp>(loc, pad); 2428 Value dx = rewriter.create<SubIOp>(loc, x, padVal); 2429 2430 Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt, 2431 dx, zero); 2432 Value offset = 2433 rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero); 2434 return rewriter.create<mlir::AddIOp>(loc, v, offset) 2435 ->getResult(0); 2436 }; 2437 2438 // Compute the vertical component of coverage. 2439 auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]); 2440 auto kH1 = padFn(kH0, y0, pad[2]); 2441 auto kH2 = padFn(kH1, y1, pad[3]); 2442 auto kHCmp = 2443 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one); 2444 auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2); 2445 2446 // compute teh horizontal component of coverage. 2447 auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]); 2448 auto kW1 = padFn(kW0, x0, pad[4]); 2449 auto kW2 = padFn(kW1, x1, pad[5]); 2450 auto kWCmp = 2451 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one); 2452 auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2); 2453 2454 // Compute the total number of elements and normalize. 2455 Value count = rewriter.create<MulIOp>(loc, kH3, kW3); 2456 auto countI = rewriter.create<mlir::IndexCastOp>( 2457 loc, rewriter.getI32Type(), count); 2458 auto countF = 2459 rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI); 2460 2461 auto div = 2462 rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0); 2463 2464 rewriter.create<linalg::YieldOp>(loc, div); 2465 }); 2466 2467 rewriter.replaceOp(op, genericOp.getResult(0)); 2468 return success(); 2469 } 2470 2471 return failure(); 2472 } 2473 }; 2474 2475 } // namespace 2476 2477 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( 2478 RewritePatternSet *patterns) { 2479 patterns->add< 2480 // clang-format off 2481 PointwiseConverter<tosa::AddOp>, 2482 PointwiseConverter<tosa::SubOp>, 2483 PointwiseConverter<tosa::MulOp>, 2484 PointwiseConverter<tosa::DivOp>, 2485 PointwiseConverter<tosa::NegateOp>, 2486 PointwiseConverter<tosa::PowOp>, 2487 PointwiseConverter<tosa::ReciprocalOp>, 2488 PointwiseConverter<tosa::RsqrtOp>, 2489 PointwiseConverter<tosa::LogOp>, 2490 PointwiseConverter<tosa::ExpOp>, 2491 PointwiseConverter<tosa::AbsOp>, 2492 PointwiseConverter<tosa::TanhOp>, 2493 PointwiseConverter<tosa::BitwiseAndOp>, 2494 PointwiseConverter<tosa::BitwiseOrOp>, 2495 PointwiseConverter<tosa::BitwiseNotOp>, 2496 PointwiseConverter<tosa::BitwiseXorOp>, 2497 PointwiseConverter<tosa::LogicalAndOp>, 2498 PointwiseConverter<tosa::LogicalNotOp>, 2499 PointwiseConverter<tosa::LogicalOrOp>, 2500 PointwiseConverter<tosa::LogicalXorOp>, 2501 PointwiseConverter<tosa::CastOp>, 2502 PointwiseConverter<tosa::LogicalLeftShiftOp>, 2503 PointwiseConverter<tosa::LogicalRightShiftOp>, 2504 PointwiseConverter<tosa::ArithmeticRightShiftOp>, 2505 PointwiseConverter<tosa::SelectOp>, 2506 PointwiseConverter<tosa::GreaterOp>, 2507 PointwiseConverter<tosa::GreaterEqualOp>, 2508 PointwiseConverter<tosa::EqualOp>, 2509 PointwiseConverter<tosa::MaximumOp>, 2510 PointwiseConverter<tosa::MinimumOp>, 2511 PointwiseConverter<tosa::CeilOp>, 2512 PointwiseConverter<tosa::FloorOp>, 2513 PointwiseConverter<tosa::ClampOp>, 2514 PointwiseConverter<tosa::ReluNOp>, 2515 PointwiseConverter<tosa::SigmoidOp>, 2516 IdentityNConverter<tosa::IdentityOp>, 2517 ReduceConverter<tosa::ReduceAllOp>, 2518 ReduceConverter<tosa::ReduceAnyOp>, 2519 ReduceConverter<tosa::ReduceMinOp>, 2520 ReduceConverter<tosa::ReduceMaxOp>, 2521 ReduceConverter<tosa::ReduceSumOp>, 2522 ReduceConverter<tosa::ReduceProdOp>, 2523 ArgMaxConverter, 2524 ConcatConverter, 2525 ConvConverter<tosa::Conv2DOp>, 2526 ConvConverter<tosa::DepthwiseConv2DOp>, 2527 TransposeConvConverter, 2528 GatherConverter, 2529 PadConverter, 2530 ReshapeConverter, 2531 RescaleConverter, 2532 ResizeConverter, 2533 ReverseConverter, 2534 TableConverter, 2535 TileConverter, 2536 TransposeConverter, 2537 MatMulConverter, 2538 Pool2dConverter<tosa::AvgPool2dOp>, 2539 Pool2dConverter<tosa::MaxPool2dOp>, 2540 FullyConnectedConverter>(patterns->getContext()); 2541 // clang-format on 2542 } 2543