1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25 #include <numeric>
26
27 using namespace mlir;
28
getNParallelLoopsAttrs(unsigned nParallelLoops)29 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
30 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
31 }
32
33 template <typename T>
34 static mlir::ConstantOp
createConstFromIntAttribute(Operation * op,std::string attrName,Type requiredAttrType,OpBuilder & rewriter)35 createConstFromIntAttribute(Operation *op, std::string attrName,
36 Type requiredAttrType, OpBuilder &rewriter) {
37 auto castedN = static_cast<T>(
38 op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
39 return rewriter.create<mlir::ConstantOp>(
40 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
41 }
42
43 template <typename T>
getValuesFromIntArrayAttribute(ArrayAttr attr,SmallVector<T> & arrayValues)44 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
45 SmallVector<T> &arrayValues) {
46 for (Attribute val : attr.getValue()) {
47 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
48 }
49 }
50
51 template <typename T, typename P>
clampHelper(Location loc,Value arg,mlir::ConstantOp min,mlir::ConstantOp max,P pred,OpBuilder & rewriter)52 static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
53 mlir::ConstantOp max, P pred,
54 OpBuilder &rewriter) {
55 auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
56 auto minOrArg =
57 rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
58 auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
59 return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
60 }
61
applyPad(Location loc,Value input,ArrayRef<int64_t> pad,Attribute padAttr,OpBuilder & rewriter)62 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
63 Attribute padAttr, OpBuilder &rewriter) {
64 // Input should be padded if necessary.
65 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
66 return input;
67
68 ShapedType inputTy = input.getType().cast<ShapedType>();
69 Type inputETy = inputTy.getElementType();
70 auto inputShape = inputTy.getShape();
71
72 assert((inputShape.size() * 2) == pad.size());
73
74 SmallVector<int64_t, 4> paddedShape;
75 SmallVector<OpFoldResult, 8> lowIndices;
76 SmallVector<OpFoldResult, 8> highIndices;
77 for (int i = 0, s = inputShape.size(); i < s; i++) {
78 auto lowPad = pad[i * 2];
79 auto highPad = pad[i * 2 + 1];
80 paddedShape.push_back(inputShape[i] + highPad + lowPad);
81 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
82 highIndices.push_back(rewriter.getIndexAttr(highPad));
83 }
84
85 Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
86
87 return linalg::PadTensorOp::createPadScalarOp(
88 RankedTensorType::get(paddedShape, inputETy), input, padValue,
89 lowIndices, highIndices, loc, rewriter)
90 .result();
91 }
92
93 static Value
createLinalgBodyCalculationForElementwiseOp(Operation * op,ValueRange args,ArrayRef<Type> resultTypes,PatternRewriter & rewriter)94 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
95 ArrayRef<Type> resultTypes,
96 PatternRewriter &rewriter) {
97 Location loc = op->getLoc();
98 auto elementTy =
99 op->getOperand(0).getType().cast<ShapedType>().getElementType();
100
101 // tosa::AbsOp
102 if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
103 return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
104
105 if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
106 auto zero =
107 rewriter.create<mlir::ConstantOp>(loc, rewriter.getZeroAttr(elementTy));
108 auto cmp =
109 rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0], zero);
110 auto neg = rewriter.create<mlir::SubIOp>(loc, zero, args[0]);
111 return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg);
112 }
113
114 // tosa::AddOp
115 if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
116 return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
117
118 if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
119 return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
120
121 // tosa::SubOp
122 if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
123 return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
124
125 if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
126 return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
127
128 // tosa::MulOp
129 if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
130 if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
131 (void)rewriter.notifyMatchFailure(op,
132 "Cannot have shift value for float");
133 return nullptr;
134 }
135 return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
136 }
137
138 // tosa::DivOp
139 if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
140 return rewriter.create<mlir::SignedDivIOp>(loc, resultTypes, args);
141
142 // tosa::ReciprocalOp
143 if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
144 auto one =
145 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
146 return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
147 }
148
149 if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
150 Value a = args[0];
151 Value b = args[1];
152 auto shift =
153 op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
154 if (shift > 0) {
155 auto shiftConst =
156 rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
157 if (!a.getType().isInteger(32))
158 a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
159
160 if (!b.getType().isInteger(32))
161 b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
162
163 auto result = rewriter.create<tosa::ApplyScaleOp>(
164 loc, rewriter.getI32Type(), a, b, shiftConst,
165 rewriter.getBoolAttr(false));
166
167 if (elementTy.isInteger(32))
168 return result;
169
170 return rewriter.create<TruncateIOp>(loc, elementTy, result);
171 }
172
173 int aWidth = a.getType().getIntOrFloatBitWidth();
174 int bWidth = b.getType().getIntOrFloatBitWidth();
175 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
176
177 if (aWidth < cWidth)
178 a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
179 if (bWidth < cWidth)
180 b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
181
182 return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
183 }
184
185 // tosa::NegateOp
186 if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
187 return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
188
189 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
190 !cast<tosa::NegateOp>(op).quantization_info()) {
191 auto constant =
192 rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
193 return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
194 }
195
196 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
197 cast<tosa::NegateOp>(op).quantization_info()) {
198 auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
199 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
200 int64_t inZp =
201 quantizationInfo.getValue().input_zp().getValue().getSExtValue();
202 int64_t outZp =
203 quantizationInfo.getValue().output_zp().getValue().getSExtValue();
204
205 // Compute the maximum value that can occur in the intermediate buffer.
206 int64_t zpAdd = inZp + outZp;
207 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
208 std::abs(zpAdd) + 1;
209
210 // Convert that maximum value into the maximum bitwidth needed to represent
211 // it. We assume 48-bit numbers may be supported further in the pipeline.
212 int intermediateBitWidth = 64;
213 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
214 intermediateBitWidth = 16;
215 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
216 intermediateBitWidth = 32;
217 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
218 intermediateBitWidth = 48;
219 }
220
221 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
222 Value zpAddValue = rewriter.create<ConstantOp>(
223 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
224
225 // The negation can be applied by doing:
226 // outputValue = inZp + outZp - inputValue
227 auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
228 auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
229
230 // Clamp to the negation range.
231 auto min = rewriter.create<ConstantOp>(
232 loc, rewriter.getIntegerAttr(
233 intermediateType,
234 APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
235 auto max = rewriter.create<ConstantOp>(
236 loc, rewriter.getIntegerAttr(
237 intermediateType,
238 APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
239 auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
240 CmpIPredicate::slt, rewriter);
241
242 // Truncate to the final value.
243 return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
244 }
245
246 // tosa::BitwiseAndOp
247 if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
248 return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
249
250 // tosa::BitwiseOrOp
251 if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
252 return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
253
254 // tosa::BitwiseNotOp
255 if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
256 auto allOnesAttr = rewriter.getIntegerAttr(
257 elementTy, APInt::getAllOnesValue(elementTy.getIntOrFloatBitWidth()));
258 auto allOnes = rewriter.create<ConstantOp>(loc, allOnesAttr);
259 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], allOnes);
260 }
261
262 // tosa::BitwiseXOrOp
263 if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
264 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
265
266 // tosa::LogicalLeftShiftOp
267 if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
268 return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
269
270 // tosa::LogicalRightShiftOp
271 if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
272 return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
273
274 // tosa::ArithmeticRightShiftOp
275 if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
276 auto result =
277 rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
278 auto round = op->getAttr("round").cast<BoolAttr>().getValue();
279 if (!round) {
280 return result;
281 }
282
283 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
284 auto one =
285 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
286 auto zero =
287 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
288 auto i1one =
289 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
290
291 // Checking that input2 != 0
292 auto shiftValueGreaterThanZero =
293 rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
294
295 // Checking for the last bit of input1 to be 1
296 auto subtract =
297 rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
298 auto shifted = rewriter
299 .create<mlir::SignedShiftRightOp>(loc, resultTypes,
300 args[0], subtract)
301 ->getResults();
302 auto truncated =
303 rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
304 auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
305
306 auto shouldRound = rewriter.create<mlir::AndOp>(
307 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
308 auto extended =
309 rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
310 return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
311 }
312
313 // tosa::LogicalAnd
314 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
315 return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
316
317 // tosa::LogicalNot
318 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
319 auto one = rewriter.create<mlir::ConstantOp>(
320 loc, rewriter.getIntegerAttr(elementTy, 1));
321 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], one);
322 }
323
324 // tosa::LogicalOr
325 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
326 return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
327
328 // tosa::LogicalXor
329 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
330 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
331
332 // tosa::PowOp
333 if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
334 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
335
336 // tosa::RsqrtOp
337 if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
338 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
339
340 // tosa::LogOp
341 if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
342 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
343
344 // tosa::ExpOp
345 if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
346 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
347
348 // tosa::TanhOp
349 if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
350 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
351
352 // tosa::GreaterOp
353 if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
354 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
355 args[1]);
356
357 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
358 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
359 args[1]);
360
361 // tosa::GreaterEqualOp
362 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
363 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
364 args[1]);
365
366 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
367 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
368 args[1]);
369
370 // tosa::EqualOp
371 if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
372 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
373 args[1]);
374
375 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
376 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
377 args[1]);
378
379 // tosa::SelectOp
380 if (isa<tosa::SelectOp>(op)) {
381 elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
382 if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
383 return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
384 }
385
386 // tosa::MaximumOp
387 if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
388 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
389 args[0], args[1]);
390 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
391 }
392
393 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
394 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
395 args[0], args[1]);
396 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
397 }
398
399 // tosa::MinimumOp
400 if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
401 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
402 args[0], args[1]);
403 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
404 }
405
406 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
407 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
408 args[0], args[1]);
409 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
410 }
411
412 // tosa::CeilOp
413 if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
414 return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
415
416 // tosa::FloorOp
417 if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
418 return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
419
420 // tosa::ClampOp
421 if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
422 auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
423 op->getAttr("min_fp"));
424 auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
425 op->getAttr("max_fp"));
426 return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
427 rewriter);
428 }
429
430 if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
431 auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
432 rewriter);
433 auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
434 rewriter);
435 return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
436 rewriter);
437 }
438
439 // tosa::ReluNOp
440 if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
441 auto zero =
442 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
443 auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
444 op->getAttr("max_fp"));
445 return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
446 rewriter);
447 }
448
449 if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
450 auto zero =
451 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
452 auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
453 rewriter);
454 return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
455 rewriter);
456 }
457
458 // tosa::SigmoidOp
459 if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
460 auto one =
461 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
462 auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
463 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
464 auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
465 return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
466 }
467
468 // tosa::CastOp
469 if (isa<tosa::CastOp>(op)) {
470 Type srcTy = elementTy;
471 Type dstTy = resultTypes.front();
472 bool bitExtend =
473 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
474
475 if (srcTy == dstTy)
476 return args.front();
477
478 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
479 return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
480
481 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
482 return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
483 mlir::None);
484
485 // 1-bit integers need to be treated as signless.
486 if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
487 return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
488 mlir::None);
489
490 if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
491 return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
492 mlir::None);
493
494 // All other si-to-fp conversions should be handled by SIToFP.
495 if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
496 return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
497 mlir::None);
498
499 // Casting to boolean, floats need to only be checked as not-equal to zero.
500 if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
501 Value zero =
502 rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
503 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
504 args.front(), zero);
505 }
506
507 if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
508 auto zero =
509 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
510 auto half =
511 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
512
513 auto intMin = rewriter.create<ConstantOp>(
514 loc, rewriter.getF32FloatAttr(
515 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
516 .getSExtValue()));
517
518 auto intMax = rewriter.create<ConstantOp>(
519 loc, rewriter.getF32FloatAttr(
520 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
521 .getSExtValue()));
522
523 auto added = rewriter.create<AddFOp>(loc, args[0], half);
524 auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
525 auto negative =
526 rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
527 auto rounded =
528 rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
529
530 auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
531 CmpFPredicate::OLT, rewriter);
532
533 return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
534 }
535
536 // Casting to boolean, integers need to only be checked as not-equal to
537 // zero.
538 if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
539 Value zero =
540 rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
541 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
542 zero);
543 }
544
545 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
546 return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
547 mlir::None);
548
549 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
550 auto intMin = rewriter.create<ConstantIntOp>(
551 loc,
552 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
553 .getSExtValue(),
554 srcTy.getIntOrFloatBitWidth());
555
556 auto intMax = rewriter.create<ConstantIntOp>(
557 loc,
558 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
559 .getSExtValue(),
560 srcTy.getIntOrFloatBitWidth());
561
562 auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
563 CmpIPredicate::slt, rewriter);
564 return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
565 }
566 }
567
568 (void)rewriter.notifyMatchFailure(
569 op, "unhandled op for linalg body calculation for elementwise op");
570 return nullptr;
571 }
572
573 static LogicalResult
elementwiseMatchAndRewriteHelper(Operation * operation,PatternRewriter & rewriter)574 elementwiseMatchAndRewriteHelper(Operation *operation,
575 PatternRewriter &rewriter) {
576 auto loc = operation->getLoc();
577
578 assert(operation->getNumResults() == 1 &&
579 "All TOSA elementwise ops should only return a single result.");
580
581 auto results = operation->getResults();
582 auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
583
584 if (!resultTy)
585 return rewriter.notifyMatchFailure(operation,
586 "All results must be a shaped type");
587
588 unsigned rank = resultTy.getRank();
589
590 // Construct the indexing maps needed for linalg.generic ops.
591 SmallVector<Type> bodyArgTypes;
592
593 for (Value in : operation->getOperands())
594 bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
595
596 SmallVector<Type> opResultTypes;
597 SmallVector<Value> initTensors;
598 for (auto result : results) {
599 auto resultTy = result.getType().template cast<ShapedType>();
600 if (!resultTy.hasStaticShape())
601 return rewriter.notifyMatchFailure(
602 operation,
603 "tosa to linalg conversion expects statically shaped tensors");
604
605 initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
606 loc, ArrayRef<Value>({}), resultTy.getShape(),
607 resultTy.getElementType()));
608 opResultTypes.push_back(result.getType());
609 }
610
611 auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
612 initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
613
614 SmallVector<Value, 2> operands;
615 SmallVector<AffineMap, 2> indexingMaps;
616 indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
617
618 // Input indexing maps may be broadcasted.
619 for (Value operand : operation->getOperands()) {
620 ShapedType type = operand.getType().cast<ShapedType>();
621
622 if (type.getShape() == resultTy.getShape()) {
623 operands.push_back(operand);
624 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
625 continue;
626 }
627
628 SmallVector<int64_t, 5> newShape;
629 SmallVector<AffineExpr, 4> affineExprs;
630 newShape.reserve(type.getRank());
631 for (auto it : llvm::enumerate(type.getShape())) {
632 if (it.value() == resultTy.getDimSize(it.index())) {
633 newShape.push_back(it.value());
634 affineExprs.push_back(
635 mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
636 }
637 }
638
639 if (newShape.size() != rank) {
640 operand = rewriter.create<tosa::ReshapeOp>(
641 loc, RankedTensorType::get(newShape, type.getElementType()), operand);
642 }
643
644 operands.push_back(operand);
645 indexingMaps.push_back(AffineMap::get(
646 /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
647 rewriter.getContext()));
648 }
649
650 indexingMaps.append(operation->getNumResults(),
651 rewriter.getMultiDimIdentityMap(rank));
652
653 bool didEncounterError = false;
654 auto linalgOp = rewriter.create<linalg::GenericOp>(
655 loc, opResultTypes, operands, initTensors, indexingMaps,
656 getNParallelLoopsAttrs(rank),
657 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
658 Value opResult = createLinalgBodyCalculationForElementwiseOp(
659 operation, blockArgs.take_front(operation->getNumOperands()),
660 bodyResultTypes, rewriter);
661 if (!opResult) {
662 didEncounterError = true;
663 return;
664 }
665 nestedBuilder.create<linalg::YieldOp>(loc, opResult);
666 });
667
668 if (didEncounterError)
669 return failure();
670
671 rewriter.replaceOp(operation, linalgOp->getResults());
672 return success();
673 }
674
675 // Returns the constant initial value for a given reduction operation. The
676 // attribute type varies depending on the element type required.
createInitialValueForReduceOp(Operation * op,Type elementTy,PatternRewriter & rewriter)677 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
678 PatternRewriter &rewriter) {
679 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
680 return rewriter.getFloatAttr(elementTy, 0.0);
681
682 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
683 return rewriter.getIntegerAttr(elementTy, 0);
684
685 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
686 return rewriter.getFloatAttr(elementTy, 1.0);
687
688 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
689 return rewriter.getIntegerAttr(elementTy, 1);
690
691 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
692 return rewriter.getFloatAttr(
693 elementTy, APFloat::getLargest(
694 elementTy.cast<FloatType>().getFloatSemantics(), false));
695
696 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
697 return rewriter.getIntegerAttr(
698 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
699
700 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
701 return rewriter.getFloatAttr(
702 elementTy, APFloat::getLargest(
703 elementTy.cast<FloatType>().getFloatSemantics(), true));
704
705 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
706 return rewriter.getIntegerAttr(
707 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
708
709 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
710 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnesValue(1));
711
712 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
713 return rewriter.getIntegerAttr(elementTy, APInt::getNullValue(1));
714
715 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
716 return rewriter.getFloatAttr(
717 elementTy, APFloat::getLargest(
718 elementTy.cast<FloatType>().getFloatSemantics(), true));
719
720 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
721 return rewriter.getIntegerAttr(
722 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
723
724 return {};
725 }
726
727 // Creates the body calculation for a reduction. The operations vary depending
728 // on the input type.
createLinalgBodyCalculationForReduceOp(Operation * op,ValueRange args,Type elementTy,PatternRewriter & rewriter)729 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
730 ValueRange args,
731 Type elementTy,
732 PatternRewriter &rewriter) {
733 Location loc = op->getLoc();
734 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
735 return rewriter.create<AddFOp>(loc, args);
736 }
737
738 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
739 return rewriter.create<AddIOp>(loc, args);
740 }
741
742 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
743 return rewriter.create<MulFOp>(loc, args);
744 }
745
746 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
747 return rewriter.create<MulIOp>(loc, args);
748 }
749
750 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
751 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
752 args[0], args[1]);
753 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
754 }
755
756 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
757 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
758 args[0], args[1]);
759 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
760 }
761
762 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
763 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
764 args[0], args[1]);
765 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
766 }
767
768 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
769 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
770 args[0], args[1]);
771 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
772 }
773
774 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
775 return rewriter.create<mlir::AndOp>(loc, args);
776
777 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
778 return rewriter.create<mlir::OrOp>(loc, args);
779
780 return {};
781 }
782
783 // Performs the match and rewrite for reduction operations. This includes
784 // declaring a correctly sized initial value, and the linalg.generic operation
785 // that reduces across the specified axis.
reduceMatchAndRewriteHelper(Operation * op,uint64_t axis,PatternRewriter & rewriter)786 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
787 PatternRewriter &rewriter) {
788 auto loc = op->getLoc();
789 auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
790 auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
791 auto elementTy = resultTy.getElementType();
792 Value input = op->getOperand(0);
793
794 llvm::SmallVector<int64_t> reduceShape;
795 for (unsigned i = 0; i < inputTy.getRank(); i++) {
796 if (axis != i)
797 reduceShape.push_back(inputTy.getDimSize(i));
798 }
799
800 Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
801
802 // First fill the output buffer with the init value.
803 auto initTensor =
804 rewriter
805 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
806 resultTy.getElementType())
807 .result();
808
809 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
810 if (!fillValueAttr)
811 return rewriter.notifyMatchFailure(
812 op, "No initial value found for reduction operation");
813
814 auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
815 auto filledTensor =
816 rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
817
818 SmallVector<AffineExpr, 2> srcExprs;
819 SmallVector<AffineExpr, 2> dstExprs;
820 SmallVector<StringRef, 4> iteratorTypes;
821 for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
822 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
823
824 iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
825 : getParallelIteratorTypeName());
826 if (axis != i)
827 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
828 }
829
830 bool didEncounterError = false;
831 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
832 auto linalgOp = rewriter.create<linalg::GenericOp>(
833 loc, reduceTy, input, filledTensor, maps, iteratorTypes,
834 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
835 auto result = createLinalgBodyCalculationForReduceOp(
836 op, blockArgs, elementTy, rewriter);
837 if (result)
838 didEncounterError = true;
839
840 nestedBuilder.create<linalg::YieldOp>(loc, result);
841 });
842
843 if (!didEncounterError)
844 return failure();
845
846 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
847 linalgOp.getResults());
848 return success();
849 }
850
851 static LogicalResult
convolutionMatchAndRewriterHelper(Operation * op,ConversionPatternRewriter & rewriter)852 convolutionMatchAndRewriterHelper(Operation *op,
853 ConversionPatternRewriter &rewriter) {
854 Location loc = op->getLoc();
855 Value input = op->getOperand(0);
856 Value weight = op->getOperand(1);
857 Value bias = op->getOperand(2);
858
859 ShapedType inputTy = input.getType().cast<ShapedType>();
860 ShapedType weightTy = weight.getType().cast<ShapedType>();
861 ShapedType biasTy = bias.getType().cast<ShapedType>();
862 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
863
864 Type inputETy = inputTy.getElementType();
865 Type resultETy = resultTy.getElementType();
866
867 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
868 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
869 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
870
871 bool isQuantized = op->hasAttr("quantization_info");
872 IntegerAttr iZp;
873 IntegerAttr kZp;
874 if (isQuantized) {
875 auto quantizationInfo =
876 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
877 iZp = rewriter.getI32IntegerAttr(
878 quantizationInfo.input_zp().getValue().getSExtValue());
879 kZp = rewriter.getI32IntegerAttr(
880 quantizationInfo.weight_zp().getValue().getSExtValue());
881 }
882
883 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
884 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
885 return rewriter.notifyMatchFailure(op,
886 "tosa.conv ops require static shapes");
887
888 auto weightShape = weightTy.getShape();
889 auto resultShape = resultTy.getShape();
890
891 // Apply padding as necessary.
892 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
893 llvm::SmallVector<int64_t> pad;
894 pad.resize(2, 0);
895 getValuesFromIntArrayAttribute(padAttr, pad);
896 pad.resize(pad.size() + 2, 0);
897
898 input = applyPad(loc, input, pad, zeroAttr, rewriter);
899
900 // Broadcast the initial value to the output tensor before convolving.
901 SmallVector<AffineMap, 4> indexingMaps;
902 indexingMaps.push_back(AffineMap::get(
903 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
904 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
905 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
906
907 Value initTensor = rewriter.create<linalg::InitTensorOp>(
908 loc, resultTy.getShape(), resultTy.getElementType());
909
910 Value biasBroadcast =
911 rewriter
912 .create<linalg::GenericOp>(
913 loc, resultTy, bias, initTensor, indexingMaps,
914 getNParallelLoopsAttrs(resultTy.getRank()),
915 [&](OpBuilder &nestedBuilder, Location nestedLoc,
916 ValueRange args) {
917 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
918 })
919 .getResult(0);
920
921 // Extract the attributes for convolution.
922 llvm::SmallVector<int64_t> stride, dilation;
923 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
924 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
925
926 // Create the convolution op.
927 auto strideAttr = DenseIntElementsAttr::get(
928 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
929 auto dilationAttr = DenseIntElementsAttr::get(
930 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
931
932 if (isa<tosa::Conv2DOp>(op) && !isQuantized) {
933 rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>(
934 op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
935 strideAttr, dilationAttr);
936 return success();
937 }
938
939 if (isa<tosa::Conv2DOp>(op) && isQuantized) {
940 auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
941 auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
942 rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyQOp>(
943 op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
944 ValueRange{biasBroadcast}, strideAttr, dilationAttr);
945 return success();
946 }
947
948 if (isa<tosa::DepthwiseConv2DOp>(op) && !isQuantized) {
949 ShapedType linalgConvTy =
950 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
951 weightShape[2], weightShape[3]},
952 resultETy);
953
954 Value biasReshape =
955 rewriter.create<tosa::ReshapeOp>(loc, linalgConvTy, biasBroadcast);
956 Value conv = rewriter
957 .create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
958 loc, linalgConvTy, ValueRange{input, weight},
959 ValueRange{biasReshape}, dilationAttr, strideAttr)
960 .getResult(0);
961
962 Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
963 rewriter.replaceOp(op, reshape);
964 return success();
965 }
966
967 return failure();
968 }
969
970 namespace {
971
972 template <typename SrcOp>
973 class PointwiseConverter : public OpRewritePattern<SrcOp> {
974 public:
975 using OpRewritePattern<SrcOp>::OpRewritePattern;
976
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const977 LogicalResult matchAndRewrite(SrcOp op,
978 PatternRewriter &rewriter) const final {
979 return elementwiseMatchAndRewriteHelper(op, rewriter);
980 }
981 };
982
983 template <typename T>
984 class ConvConverter : public OpConversionPattern<T> {
985 public:
986 using OpConversionPattern<T>::OpConversionPattern;
987 LogicalResult
matchAndRewrite(T op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const988 matchAndRewrite(T op, ArrayRef<Value> args,
989 ConversionPatternRewriter &rewriter) const final {
990 return convolutionMatchAndRewriterHelper(op, rewriter);
991 }
992 };
993
994 class TransposeConvConverter
995 : public OpConversionPattern<tosa::TransposeConv2DOp> {
996 public:
997 using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
998 LogicalResult
matchAndRewrite(tosa::TransposeConv2DOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const999 matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef<Value> args,
1000 ConversionPatternRewriter &rewriter) const final {
1001 Location loc = op->getLoc();
1002 Value input = op->getOperand(0);
1003 Value weight = op->getOperand(1);
1004 Value bias = op->getOperand(2);
1005
1006 ShapedType inputTy = input.getType().cast<ShapedType>();
1007 ShapedType weightTy = weight.getType().cast<ShapedType>();
1008 ShapedType biasTy = bias.getType().cast<ShapedType>();
1009 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
1010
1011 llvm::SmallVector<int64_t> pad;
1012 llvm::SmallVector<int64_t> stride;
1013 llvm::SmallVector<int64_t> dilation;
1014
1015 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
1016 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
1017 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
1018
1019 // We have not solved for stride / dilation yet. Dilation should be
1020 // straight forward but stride is more complicated. Linalg work is likely
1021 // required for efficient implementation.
1022 if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
1023 return failure();
1024 if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
1025 return failure();
1026
1027 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1028 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1029 return failure();
1030
1031 int64_t inputHeight = inputTy.getDimSize(1);
1032 int64_t inputWidth = inputTy.getDimSize(2);
1033 int64_t kernelHeight = weightTy.getDimSize(1);
1034 int64_t kernelWidth = weightTy.getDimSize(2);
1035 int64_t outputHeight = resultTy.getDimSize(1);
1036 int64_t outputWidth = resultTy.getDimSize(2);
1037
1038 int64_t requiredInputHeight = outputHeight + kernelHeight - 1;
1039 int64_t requiredInputWidth = outputWidth + kernelWidth - 1;
1040
1041 llvm::SmallVector<int64_t> newPad(4, 0);
1042 newPad[0] = kernelHeight - 1 - pad[0];
1043 newPad[2] = kernelWidth - 1 - pad[1];
1044
1045 newPad[1] = requiredInputHeight - newPad[0] - inputHeight;
1046 newPad[3] = requiredInputWidth - newPad[2] - inputWidth;
1047
1048 auto reverse1 = rewriter.create<tosa::ReverseOp>(
1049 loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
1050 auto reverse2 = rewriter.create<tosa::ReverseOp>(
1051 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
1052
1053 Value conv2d;
1054 if (op.quantization_info().hasValue()) {
1055 conv2d = rewriter.create<tosa::Conv2DOp>(
1056 loc, resultTy, input, reverse2, bias,
1057 rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
1058 rewriter.getI64ArrayAttr(dilation),
1059 op.quantization_info().getValue());
1060 } else {
1061 conv2d = rewriter.create<tosa::Conv2DOp>(
1062 loc, resultTy, input, reverse2, bias,
1063 rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
1064 rewriter.getI64ArrayAttr(dilation));
1065 }
1066
1067 rewriter.replaceOp(op, conv2d);
1068 return success();
1069 }
1070 };
1071
1072 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
1073 public:
1074 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
1075 LogicalResult
matchAndRewrite(tosa::MatMulOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1076 matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args,
1077 ConversionPatternRewriter &rewriter) const final {
1078 tosa::MatMulOp::Adaptor adaptor(args);
1079
1080 Location loc = op.getLoc();
1081
1082 auto outputTy = op.getType().cast<ShapedType>();
1083 auto outputElementTy = outputTy.getElementType();
1084 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
1085 Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
1086 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1087 loc, outputTy.getShape(), outputTy.getElementType());
1088 Value zeroTensor =
1089 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1090 if (!op.quantization_info()) {
1091 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
1092 op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
1093 ValueRange{zeroTensor});
1094 return success();
1095 }
1096
1097 auto quantizationInfo = op.quantization_info().getValue();
1098 auto aZp = rewriter.create<ConstantOp>(
1099 loc, rewriter.getI32IntegerAttr(
1100 quantizationInfo.a_zp().getValue().getSExtValue()));
1101 auto bZp = rewriter.create<ConstantOp>(
1102 loc, rewriter.getI32IntegerAttr(
1103 quantizationInfo.b_zp().getValue().getSExtValue()));
1104 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
1105 op, TypeRange{op.getType()},
1106 ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
1107
1108 return success();
1109 }
1110 };
1111
1112 class FullyConnectedConverter
1113 : public OpConversionPattern<tosa::FullyConnectedOp> {
1114 public:
1115 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
1116 LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1117 matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
1118 ConversionPatternRewriter &rewriter) const final {
1119 Location loc = op.getLoc();
1120 auto outputTy = op.getType().cast<ShapedType>();
1121 auto input = op.input();
1122 auto weight = op.weight();
1123 auto bias = op.bias();
1124
1125 auto weightTy = weight.getType().cast<ShapedType>();
1126 auto weightShape = weightTy.getShape();
1127
1128 // Creating maps for the output of MatMul and the bias
1129 SmallVector<AffineMap, 4> indexingMaps;
1130
1131 // Broadcast the bias.
1132 indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
1133 {rewriter.getAffineDimExpr(1)},
1134 rewriter.getContext()));
1135
1136 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
1137
1138 auto initTensor =
1139 rewriter
1140 .create<linalg::InitTensorOp>(loc, outputTy.getShape(),
1141 outputTy.getElementType())
1142 ->getResults();
1143
1144 auto linalgOp =
1145 rewriter
1146 .create<linalg::GenericOp>(
1147 loc, outputTy, bias, initTensor, indexingMaps,
1148 getNParallelLoopsAttrs(outputTy.getRank()),
1149 [&](OpBuilder &nested_builder, Location nested_loc,
1150 ValueRange args) {
1151 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
1152 })
1153 ->getResults();
1154
1155 SmallVector<int64_t> permutation{1, 0};
1156 auto permutationAttr = DenseIntElementsAttr::get(
1157 RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
1158 Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
1159
1160 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
1161 Type newWeightTy =
1162 RankedTensorType::get(newWeightShape, weightTy.getElementType());
1163
1164 Value transposedWeight = rewriter.create<tosa::TransposeOp>(
1165 loc, newWeightTy, weight, permutationValue);
1166
1167 if (!op.quantization_info()) {
1168 rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
1169 op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
1170 linalgOp);
1171 return success();
1172 }
1173
1174 auto quantizationInfo = op.quantization_info().getValue();
1175 auto inputZp = rewriter.create<ConstantOp>(
1176 loc, rewriter.getI32IntegerAttr(
1177 quantizationInfo.input_zp().getValue().getSExtValue()));
1178 auto outputZp = rewriter.create<ConstantOp>(
1179 loc, rewriter.getI32IntegerAttr(
1180 quantizationInfo.weight_zp().getValue().getSExtValue()));
1181 rewriter.replaceOpWithNewOp<linalg::QuantizedMatmulOp>(
1182 op, TypeRange{op.getType()},
1183 ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp);
1184
1185 return success();
1186 }
1187 };
1188
1189 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
1190 public:
1191 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1192
1193 LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1194 matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
1195 ConversionPatternRewriter &rewriter) const final {
1196 typename tosa::ReshapeOp::Adaptor operands(args);
1197
1198 ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
1199 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1200
1201 if (operandTy == resultTy) {
1202 rewriter.replaceOp(reshape, args[0]);
1203 return success();
1204 }
1205
1206 if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
1207 return failure();
1208
1209 // Compute the reassociation maps for the linalg operation.
1210 ArrayRef<int64_t> expandedShape =
1211 (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
1212 : resultTy.getShape());
1213 ArrayRef<int64_t> collapsedShape =
1214 (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
1215 : operandTy.getShape());
1216 unsigned currSrcDim = 0, currDstDim = 0;
1217 SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
1218
1219 // First scan all dimensions in the source shapes to see whether we have a
1220 // perfect case where consecutive dimensions in source are collapsed. For
1221 // such case we can just generate one single linalg.reshape.
1222 bool isCollapsingSource = true;
1223 while (currSrcDim < expandedShape.size() &&
1224 currDstDim < collapsedShape.size()) {
1225 int64_t dstSize = collapsedShape[currDstDim];
1226 int64_t srcSize = expandedShape[currSrcDim];
1227 while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
1228 reassociationMap[currDstDim].push_back(
1229 rewriter.getAffineDimExpr(currSrcDim++));
1230 srcSize *= expandedShape[currSrcDim];
1231 }
1232 if (srcSize == dstSize) {
1233 reassociationMap[currDstDim].push_back(
1234 rewriter.getAffineDimExpr(currSrcDim++));
1235 // If the next dim in collapsedShape is not 1, treat subsequent dims in
1236 // expandedShape which are 1 to be collapsed.
1237 if (currDstDim == collapsedShape.size() - 1 ||
1238 collapsedShape[currDstDim + 1] != 1) {
1239 while (currSrcDim < expandedShape.size() &&
1240 expandedShape[currSrcDim] == 1) {
1241 reassociationMap[currDstDim].push_back(
1242 rewriter.getAffineDimExpr(currSrcDim++));
1243 }
1244 }
1245 } else {
1246 isCollapsingSource = false;
1247 break;
1248 }
1249 currDstDim++;
1250 }
1251
1252 // Check if any remaining dimensions exist. If either is rank-0 we only
1253 // require the directly lowering.
1254 if (currSrcDim != expandedShape.size() ||
1255 currDstDim != collapsedShape.size())
1256 isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
1257
1258 // Otherwise, we need to first reduce all source dimensions into one and
1259 // then expand to the destination dimensions.
1260 if (!isCollapsingSource) {
1261 auto getIdentityExprs = [&rewriter](int n) {
1262 SmallVector<AffineExpr, 4> exprs;
1263 for (int i = 0; i < n; ++i)
1264 exprs.push_back(rewriter.getAffineDimExpr(i));
1265 return exprs;
1266 };
1267 Location loc = reshape.getLoc();
1268 int64_t totalElems =
1269 std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
1270 std::multiplies<int64_t>());
1271 auto elemTy = operandTy.getElementType();
1272 SmallVector<ReassociationExprs, 4> collapsingMap = {
1273 // Use operandTy here because we need to collapse all operands
1274 // dimensions.
1275 getIdentityExprs(operandTy.getShape().size())};
1276 SmallVector<ReassociationExprs, 4> expandingMap = {
1277 // Use resultTy here because we need to expand to all result
1278 // dimensions.
1279 getIdentityExprs(resultTy.getShape().size())};
1280
1281 auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
1282 Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
1283 loc, collapsedTy, args[0], collapsingMap);
1284 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1285 reshape, resultTy, collapsedOp, expandingMap);
1286
1287 return success();
1288 }
1289
1290 if (resultTy.getRank() < args[0].getType().cast<ShapedType>().getRank())
1291 rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
1292 reshape, resultTy, args[0], reassociationMap);
1293 else
1294 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1295 reshape, resultTy, args[0], reassociationMap);
1296
1297 return success();
1298 }
1299 };
1300
1301 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1302 public:
1303 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1304
matchAndRewrite(tosa::TransposeOp op,PatternRewriter & rewriter) const1305 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1306 PatternRewriter &rewriter) const final {
1307 DenseIntElementsAttr perms;
1308 if (!matchPattern(op.perms(), m_Constant(&perms))) {
1309 return failure();
1310 }
1311
1312 auto resultTy = op.getType().cast<ShapedType>();
1313 if (!resultTy.hasStaticShape())
1314 return failure();
1315
1316 SmallVector<AffineExpr, 2> inputExprs;
1317 inputExprs.resize(resultTy.getRank());
1318 for (auto permutation : llvm::enumerate(perms.getIntValues())) {
1319 inputExprs[permutation.value().getZExtValue()] =
1320 rewriter.getAffineDimExpr(permutation.index());
1321 }
1322
1323 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1324 op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
1325 resultTy.getElementType());
1326
1327 SmallVector<AffineMap, 2> affineMaps = {
1328 AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1329 rewriter.getContext()),
1330 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1331
1332 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1333 op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
1334 getNParallelLoopsAttrs(resultTy.getRank()),
1335 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1336 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1337 });
1338 return success();
1339 }
1340 };
1341
1342 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1343 public:
1344 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1345
matchAndRewrite(tosa::RescaleOp op,PatternRewriter & rewriter) const1346 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1347 PatternRewriter &rewriter) const final {
1348 auto loc = op.getLoc();
1349 auto input = op.input();
1350 auto inputTy = op.input().getType().cast<ShapedType>();
1351 auto outputTy = op.output().getType().cast<ShapedType>();
1352 unsigned rank = inputTy.getRank();
1353
1354 // This is an illegal configuration. terminate and log an error
1355 if (op.double_round() && !op.scale32())
1356 return rewriter.notifyMatchFailure(
1357 op, "tosa.rescale requires scale32 for double_round to be true");
1358
1359 if (!outputTy.hasStaticShape())
1360 return rewriter.notifyMatchFailure(
1361 op, "tosa to linalg conversion expects statically shaped tensors");
1362
1363 // The shift and multiplier values.
1364 SmallVector<int32_t> multiplierValues;
1365 getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
1366
1367 SmallVector<int8_t> shiftValues;
1368 getValuesFromIntArrayAttribute(op.shift(), shiftValues);
1369
1370 // Double round only occurs if shift is greater than 31, check that this
1371 // is ever true.
1372 bool doubleRound =
1373 op.double_round() &&
1374 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1375
1376 SmallVector<AffineMap> indexingMaps = {
1377 rewriter.getMultiDimIdentityMap(rank)};
1378 SmallVector<Value, 4> genericInputs = {input};
1379
1380 // If we are rescaling per-channel then we need to store the multiplier
1381 // values in a buffer.
1382 Value multiplierConstant;
1383 int64_t multiplierArg = 0;
1384 if (multiplierValues.size() == 1) {
1385 multiplierConstant = rewriter.create<ConstantOp>(
1386 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1387 } else {
1388 SmallVector<AffineExpr, 2> multiplierExprs{
1389 rewriter.getAffineDimExpr(rank - 1)};
1390 auto multiplierType =
1391 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1392 rewriter.getI32Type());
1393 genericInputs.push_back(rewriter.create<ConstantOp>(
1394 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1395
1396 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1397 /*symbolCount=*/0, multiplierExprs,
1398 rewriter.getContext()));
1399
1400 multiplierArg = indexingMaps.size() - 1;
1401 }
1402
1403 // If we are rescaling per-channel then we need to store the shift
1404 // values in a buffer.
1405 Value shiftConstant;
1406 int64_t shiftArg = 0;
1407 if (shiftValues.size() == 1) {
1408 shiftConstant = rewriter.create<ConstantOp>(
1409 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1410 } else {
1411 SmallVector<AffineExpr, 2> shiftExprs = {
1412 rewriter.getAffineDimExpr(rank - 1)};
1413 auto shiftType =
1414 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1415 rewriter.getIntegerType(8));
1416 genericInputs.push_back(rewriter.create<ConstantOp>(
1417 loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1418 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1419 /*symbolCount=*/0, shiftExprs,
1420 rewriter.getContext()));
1421 shiftArg = indexingMaps.size() - 1;
1422 }
1423
1424 // Indexing maps for output values.
1425 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1426
1427 // Construct the indexing maps needed for linalg.generic ops.
1428 Value initTensor = rewriter.create<linalg::InitTensorOp>(
1429 loc, ArrayRef<Value>({}), outputTy.getShape(),
1430 outputTy.getElementType());
1431
1432 auto linalgOp = rewriter.create<linalg::GenericOp>(
1433 loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1434 getNParallelLoopsAttrs(rank),
1435 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1436 ValueRange blockArgs) {
1437 Value value = blockArgs[0];
1438
1439 // For now we do all of our math in 64-bit. This is not optimal but
1440 // should be correct for now, consider computing correct bit depth
1441 // later.
1442 int32_t inBitwidth =
1443 value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32;
1444
1445 auto inputZp = createConstFromIntAttribute<int32_t>(
1446 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1447 nestedBuilder);
1448 auto outputZp = createConstFromIntAttribute<int32_t>(
1449 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1450
1451 Value multiplier = multiplierConstant ? multiplierConstant
1452 : blockArgs[multiplierArg];
1453 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1454
1455 if (value.getType().getIntOrFloatBitWidth() < 32) {
1456 value = nestedBuilder.create<SignExtendIOp>(
1457 nestedLoc, nestedBuilder.getI32Type(), value);
1458 }
1459
1460 value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
1461
1462 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1463 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1464 nestedBuilder.getBoolAttr(doubleRound));
1465
1466 // Move to the new zero-point.
1467 value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
1468
1469 // Saturate to the output size.
1470 IntegerType outIntType =
1471 blockArgs.back().getType().cast<IntegerType>();
1472 unsigned outBitWidth = outIntType.getWidth();
1473 auto intMin = nestedBuilder.create<ConstantOp>(
1474 loc, nestedBuilder.getIntegerAttr(
1475 nestedBuilder.getI32Type(),
1476 APInt::getSignedMinValue(outBitWidth).getSExtValue()));
1477 auto intMax = nestedBuilder.create<ConstantOp>(
1478 loc, nestedBuilder.getIntegerAttr(
1479 nestedBuilder.getI32Type(),
1480 APInt::getSignedMaxValue(outBitWidth).getSExtValue()));
1481
1482 value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax,
1483 CmpIPredicate::slt, nestedBuilder);
1484
1485 if (outIntType.getWidth() < 32) {
1486 value =
1487 nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value);
1488 }
1489
1490 nestedBuilder.create<linalg::YieldOp>(loc, value);
1491 });
1492
1493 rewriter.replaceOp(op, linalgOp->getResults());
1494 return success();
1495 }
1496 };
1497
1498 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1499 public:
1500 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1501
matchAndRewrite(tosa::ResizeOp op,PatternRewriter & rewriter) const1502 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1503 PatternRewriter &rewriter) const final {
1504 Location loc = op.getLoc();
1505 auto input = op.input();
1506 auto inputTy = input.getType().cast<ShapedType>();
1507 auto resultTy = op.getType().cast<ShapedType>();
1508 auto resultElementTy = resultTy.getElementType();
1509
1510 auto imageH = inputTy.getShape()[1];
1511 auto imageW = inputTy.getShape()[2];
1512
1513 if (!resultTy.hasStaticShape())
1514 return failure();
1515 if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
1516 return failure();
1517
1518 auto initTensor =
1519 rewriter
1520 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
1521 resultTy.getShape(), resultElementTy)
1522 .result();
1523
1524 SmallVector<AffineMap, 2> affineMaps = {
1525 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1526
1527 auto genericOp = rewriter.create<linalg::GenericOp>(
1528 loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1529 getNParallelLoopsAttrs(resultTy.getRank()));
1530 rewriter.replaceOp(op, genericOp.getResult(0));
1531
1532 {
1533 OpBuilder::InsertionGuard regionGuard(rewriter);
1534 rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1535 TypeRange({resultElementTy}));
1536 Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1537 Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1538 Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1539 Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1540
1541 auto hwMin =
1542 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1543 auto hMax = rewriter.create<ConstantOp>(
1544 loc, rewriter.getI32IntegerAttr(imageH - 1));
1545 auto wMax = rewriter.create<ConstantOp>(
1546 loc, rewriter.getI32IntegerAttr(imageW - 1));
1547
1548 Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), y);
1549 Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), x);
1550
1551 int32_t shift = op.shift();
1552 bool floatingPointMode = shift == 0;
1553
1554 Value yStride, xStride, yOffset, xOffset;
1555 if (floatingPointMode) {
1556 yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]);
1557 xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]);
1558 yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]);
1559 xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]);
1560 } else {
1561 SmallVector<int32_t> stride, offset;
1562 getValuesFromIntArrayAttribute(op.stride(), stride);
1563 getValuesFromIntArrayAttribute(op.offset(), offset);
1564
1565 yStride = rewriter.create<ConstantOp>(
1566 loc, rewriter.getI32IntegerAttr(stride[0]));
1567 xStride = rewriter.create<ConstantOp>(
1568 loc, rewriter.getI32IntegerAttr(stride[1]));
1569 yOffset = rewriter.create<ConstantOp>(
1570 loc, rewriter.getI32IntegerAttr(offset[0]));
1571 xOffset = rewriter.create<ConstantOp>(
1572 loc, rewriter.getI32IntegerAttr(offset[1]));
1573 }
1574
1575 // Compute the the integer index and partial offset.
1576 // x = x * stride + offset;
1577 // ix = floor(x)
1578 // dx = x - ix
1579 Value ix, iy, dx, dy;
1580 if (floatingPointMode) {
1581 Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY);
1582 Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX);
1583
1584 y = rewriter.create<MulFOp>(loc, y, yStride);
1585 x = rewriter.create<MulFOp>(loc, x, xStride);
1586
1587 y = rewriter.create<AddFOp>(loc, y, yOffset);
1588 x = rewriter.create<AddFOp>(loc, x, xOffset);
1589
1590 iy = rewriter.create<FloorFOp>(loc, y);
1591 ix = rewriter.create<FloorFOp>(loc, x);
1592
1593 dy = rewriter.create<SubFOp>(loc, y, iy);
1594 dx = rewriter.create<SubFOp>(loc, x, ix);
1595
1596 iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy);
1597 ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix);
1598 } else {
1599 Value shiftVal =
1600 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift));
1601
1602 Value y = rewriter.create<MulIOp>(loc, inY, yStride);
1603 Value x = rewriter.create<MulIOp>(loc, inX, xStride);
1604
1605 y = rewriter.create<AddIOp>(loc, y, yOffset);
1606 x = rewriter.create<AddIOp>(loc, x, xOffset);
1607
1608 iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal);
1609 ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal);
1610
1611 Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal);
1612 Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal);
1613
1614 dy = rewriter.create<SubIOp>(loc, y, yTrunc);
1615 dx = rewriter.create<SubIOp>(loc, x, xTrunc);
1616 }
1617
1618 if (op.mode() == "NEAREST_NEIGHBOR") {
1619 Value yPred, xPred;
1620 // Round the index position towards the closest pixel location.
1621 if (floatingPointMode) {
1622 auto halfVal =
1623 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
1624 yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy,
1625 halfVal);
1626 xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx,
1627 halfVal);
1628 } else {
1629 auto halfVal = rewriter.create<ConstantOp>(
1630 loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
1631 yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy,
1632 halfVal);
1633 xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx,
1634 halfVal);
1635 }
1636
1637 auto zeroVal =
1638 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1639 auto oneVal =
1640 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
1641
1642 auto yOffset =
1643 rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
1644 auto xOffset =
1645 rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
1646
1647 iy = rewriter.create<AddIOp>(loc, iy, yOffset);
1648 ix = rewriter.create<AddIOp>(loc, ix, xOffset);
1649
1650 // Clamp the to be within the bounds of the input image.
1651
1652 iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt,
1653 rewriter);
1654 ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt,
1655 rewriter);
1656
1657 // Read the value from the input array.
1658 iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy);
1659 ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix);
1660
1661 Value result = rewriter.create<tensor::ExtractOp>(
1662 loc, input, ValueRange{batch, iy, ix, channel});
1663
1664 rewriter.create<linalg::YieldOp>(loc, result);
1665
1666 return success();
1667 }
1668
1669 if (op.mode() == "BILINEAR") {
1670 Value y0 = iy;
1671 Value x0 = ix;
1672
1673 auto oneVal =
1674 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
1675 Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal);
1676 Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal);
1677
1678 y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt,
1679 rewriter);
1680 y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt,
1681 rewriter);
1682
1683 x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt,
1684 rewriter);
1685 x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt,
1686 rewriter);
1687
1688 y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0);
1689 y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1);
1690 x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0);
1691 x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1);
1692
1693 Value y0x0 = rewriter.create<tensor::ExtractOp>(
1694 loc, input, ValueRange{batch, y0, x0, channel});
1695 Value y0x1 = rewriter.create<tensor::ExtractOp>(
1696 loc, input, ValueRange{batch, y0, x1, channel});
1697 Value y1x0 = rewriter.create<tensor::ExtractOp>(
1698 loc, input, ValueRange{batch, y1, x0, channel});
1699 Value y1x1 = rewriter.create<tensor::ExtractOp>(
1700 loc, input, ValueRange{batch, y1, x1, channel});
1701
1702 if (floatingPointMode) {
1703 auto oneVal =
1704 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f));
1705 Value rightPart = dx;
1706 Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx);
1707
1708 y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart);
1709 y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart);
1710 Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1);
1711
1712 y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart);
1713 y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart);
1714 Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1);
1715
1716 Value bottomPart = dy;
1717 Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy);
1718 topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart);
1719 bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart);
1720 Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc);
1721
1722 rewriter.create<linalg::YieldOp>(loc, result);
1723 return success();
1724 } else {
1725 y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0);
1726 y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1);
1727 y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0);
1728 y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1);
1729
1730 if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1731 dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx);
1732 dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy);
1733 }
1734
1735 auto unitVal = rewriter.create<ConstantOp>(
1736 loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
1737 Value rightPart = dx;
1738 Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx);
1739
1740 y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart);
1741 y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart);
1742 Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1);
1743
1744 y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart);
1745 y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart);
1746 Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1);
1747
1748 Value bottomPart = dy;
1749 Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy);
1750 topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart);
1751 bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart);
1752 Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc);
1753
1754 rewriter.create<linalg::YieldOp>(loc, result);
1755 return success();
1756 }
1757 }
1758
1759 return failure();
1760 }
1761
1762 return success();
1763 }
1764 };
1765
1766 // At the codegen level any identity operations should be removed. Any cases
1767 // where identity is load-bearing (e.g. cross device computation) should be
1768 // handled before lowering to codegen.
1769 template <typename SrcOp>
1770 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1771 public:
1772 using OpRewritePattern<SrcOp>::OpRewritePattern;
1773
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const1774 LogicalResult matchAndRewrite(SrcOp op,
1775 PatternRewriter &rewriter) const final {
1776 rewriter.replaceOp(op, op.getOperation()->getOperands());
1777 return success();
1778 }
1779 };
1780
1781 template <typename SrcOp>
1782 class ReduceConverter : public OpRewritePattern<SrcOp> {
1783 public:
1784 using OpRewritePattern<SrcOp>::OpRewritePattern;
1785
matchAndRewrite(SrcOp reduceOp,PatternRewriter & rewriter) const1786 LogicalResult matchAndRewrite(SrcOp reduceOp,
1787 PatternRewriter &rewriter) const final {
1788 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
1789 }
1790 };
1791
1792 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1793 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
1794
1795 LogicalResult
matchAndRewrite__anonb63de17d0611::ConcatConverter1796 matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args,
1797 ConversionPatternRewriter &rewriter) const override {
1798 auto resultType = op.getType().dyn_cast<RankedTensorType>();
1799 if (!resultType || !resultType.hasStaticShape()) {
1800 return rewriter.notifyMatchFailure(op,
1801 "expected static shaped tensor type");
1802 }
1803
1804 Location loc = op.getLoc();
1805 int axis = op.axis();
1806 Value axisValue =
1807 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
1808 int rank = resultType.getRank();
1809 SmallVector<Value, 3> offsets, sizes, strides;
1810 sizes.reserve(rank);
1811 strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
1812 offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
1813
1814 for (int i = 0; i < rank; ++i) {
1815 sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
1816 }
1817
1818 Value resultDimSize = sizes[axis];
1819 for (auto arg : args.drop_front()) {
1820 auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
1821 resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
1822 }
1823 sizes[axis] = resultDimSize;
1824
1825 Value init = rewriter.create<linalg::InitTensorOp>(
1826 loc, resultType.getShape(), resultType.getElementType());
1827
1828 Value zeroVal = rewriter.create<ConstantOp>(
1829 loc, rewriter.getZeroAttr(resultType.getElementType()));
1830 Value result =
1831 rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
1832
1833 for (auto arg : args) {
1834 sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
1835 result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
1836 sizes, strides);
1837 offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
1838 }
1839 rewriter.replaceOp(op, result);
1840 return success();
1841 }
1842 };
1843
1844 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1845 public:
1846 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1847
matchAndRewrite(tosa::ReverseOp op,PatternRewriter & rewriter) const1848 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1849 PatternRewriter &rewriter) const final {
1850 auto loc = op.getLoc();
1851 Value input = op.input();
1852 auto inputTy = input.getType().template cast<ShapedType>();
1853 auto resultTy = op.getType().template cast<ShapedType>();
1854 auto rank = resultTy.getRank();
1855 auto axis = op.axis();
1856
1857 if (!inputTy.hasStaticShape())
1858 return rewriter.notifyMatchFailure(
1859 op, "No initial value found for reduction operation");
1860
1861 // First fill the output buffer with the init value.
1862 auto initTensor = rewriter
1863 .create<linalg::InitTensorOp>(
1864 loc, ArrayRef<Value>({}), inputTy.getShape(),
1865 inputTy.getElementType())
1866 .result();
1867
1868 SmallVector<AffineExpr, 2> inputExprs;
1869 inputExprs.resize(resultTy.getRank());
1870
1871 for (int i = 0; i < rank; i++)
1872 inputExprs[i] = rewriter.getAffineDimExpr(i);
1873
1874 inputExprs[axis] =
1875 rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) -
1876 inputExprs[axis];
1877
1878 SmallVector<AffineMap, 2> affineMaps = {
1879 AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1880 rewriter.getContext()),
1881 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1882
1883 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1884 op, resultTy, op.input(), ValueRange{initTensor}, affineMaps,
1885 getNParallelLoopsAttrs(resultTy.getRank()),
1886 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1887 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1888 });
1889 return success();
1890 }
1891 };
1892
1893 // This converter translate a tile operation to a reshape, broadcast, reshape.
1894 // The first reshape minimally expands each tiled dimension to include a
1895 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1896 // multiple.
1897 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1898 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1899
1900 LogicalResult
matchAndRewrite__anonb63de17d0611::TileConverter1901 matchAndRewrite(tosa::TileOp op, ArrayRef<Value> args,
1902 ConversionPatternRewriter &rewriter) const override {
1903 auto loc = op.getLoc();
1904 auto input = op.input1();
1905 auto inputTy = input.getType().cast<ShapedType>();
1906 auto inputShape = inputTy.getShape();
1907 auto resultTy = op.getType().cast<ShapedType>();
1908 auto elementTy = inputTy.getElementType();
1909 int64_t rank = inputTy.getRank();
1910
1911 if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
1912 return failure();
1913
1914 SmallVector<int64_t> multiples;
1915 getValuesFromIntArrayAttribute(op.multiples(), multiples);
1916
1917 // Broadcast the newly added dimensions to their appropriate multiple.
1918 SmallVector<int64_t, 2> genericShape;
1919 for (int i = 0; i < rank; i++) {
1920 genericShape.push_back(multiples[i]);
1921 genericShape.push_back(inputShape[i]);
1922 }
1923
1924 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1925 op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
1926
1927 // We needs to map the input shape to the non-broadcasted dimensions.
1928 SmallVector<AffineExpr, 4> dimExprs;
1929 dimExprs.reserve(rank);
1930 for (unsigned i = 0; i < rank; ++i)
1931 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1932
1933 auto readAffineMap =
1934 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1935 rewriter.getContext());
1936
1937 SmallVector<AffineMap, 2> affineMaps = {
1938 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1939
1940 auto genericOp = rewriter.create<linalg::GenericOp>(
1941 loc, RankedTensorType::get(genericShape, elementTy), input,
1942 ValueRange{initTensor}, affineMaps,
1943 getNParallelLoopsAttrs(genericShape.size()),
1944 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1945 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1946 });
1947
1948 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1949 op, resultTy, genericOp.getResult(0),
1950 rewriter.getI64ArrayAttr(resultTy.getShape()));
1951 return success();
1952 }
1953 };
1954
1955 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1956 public:
1957 using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
1958
matchAndRewrite(tosa::PadOp padOp,PatternRewriter & rewriter) const1959 LogicalResult matchAndRewrite(tosa::PadOp padOp,
1960 PatternRewriter &rewriter) const final {
1961 auto loc = padOp.getLoc();
1962 auto input = padOp.input1();
1963 auto padding = padOp.padding();
1964
1965 ShapedType inputTy = input.getType().cast<ShapedType>();
1966 ShapedType paddingTy = padding.getType().cast<ShapedType>();
1967 Type elementTy = inputTy.getElementType();
1968 int64_t rank = inputTy.getRank();
1969
1970 if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
1971 return rewriter.notifyMatchFailure(
1972 padOp,
1973 "Pad converter requires static shaped input / padding values.");
1974 }
1975
1976 Attribute constantAttr;
1977 if (elementTy.isa<FloatType>())
1978 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1979 else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
1980 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1981 else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
1982 auto value = padOp.quantization_info().getValue().input_zp().getValue();
1983 constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
1984 }
1985
1986 if (!constantAttr) {
1987 return rewriter.notifyMatchFailure(
1988 padOp,
1989 "tosa.pad to linalg lowering encountered an unknown element type");
1990 }
1991
1992 Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
1993 Value highIndex =
1994 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
1995
1996 SmallVector<OpFoldResult, 3> lowValues;
1997 SmallVector<OpFoldResult, 3> highValues;
1998
1999 lowValues.reserve(rank);
2000 highValues.reserve(rank);
2001
2002 for (int i = 0; i < rank; i++) {
2003 Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
2004 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
2005 loc, padding, ValueRange({inputIndex, lowIndex}));
2006 Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
2007 loc, padding, ValueRange({inputIndex, highIndex}));
2008
2009 lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2010 lowVal);
2011 highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2012 highVal);
2013
2014 lowValues.push_back(lowVal);
2015 highValues.push_back(highVal);
2016 }
2017
2018 Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
2019
2020 auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
2021 padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
2022
2023 rewriter.replaceOp(padOp, newPadOp.getResult());
2024 return success();
2025 }
2026 };
2027
2028 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2029 // op, producing two output buffers.
2030 //
2031 // The first output buffer contains the index of the found maximum value. It is
2032 // initialized to 0 and is resulting integer type.
2033 //
2034 // The second output buffer contains the maximum value found. It is initialized
2035 // to the minimum representable value of the input element type. After being
2036 // populated by indexed_generic, this buffer is disgarded as only the index is
2037 // requested.
2038 //
2039 // The indexed_generic op updates both the maximum value and index if the
2040 // current value exceeds the running max.
2041 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2042 public:
2043 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2044
matchAndRewrite(tosa::ArgMaxOp argmaxOp,PatternRewriter & rewriter) const2045 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2046 PatternRewriter &rewriter) const final {
2047 auto loc = argmaxOp.getLoc();
2048 Value input = argmaxOp.input();
2049 auto inputTy = input.getType().cast<ShapedType>();
2050 auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
2051 auto inElementTy = inputTy.getElementType();
2052 auto outElementTy = resultTy.getElementType();
2053 int axis = argmaxOp.axis();
2054 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2055
2056 if (!inputTy.hasStaticShape())
2057 return rewriter.notifyMatchFailure(
2058 argmaxOp,
2059 "tosa.arg_max to linalg.* requires statically shaped input");
2060
2061 if (!outElementTy.isa<IntegerType>())
2062 return rewriter.notifyMatchFailure(
2063 argmaxOp,
2064 "tosa.arg_max to linalg.* requires integer-like result type");
2065
2066 // First fill the output buffer for the index.
2067 auto initTensorIdx =
2068 rewriter
2069 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2070 resultTy.getShape(), outElementTy)
2071 .result();
2072 auto fillValueIdx = rewriter.create<ConstantOp>(
2073 loc, rewriter.getIntegerAttr(outElementTy, 0));
2074 auto filledTensorIdx =
2075 rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
2076 .result();
2077
2078 // Second fill the output buffer for the running max.
2079 auto initTensorMax =
2080 rewriter
2081 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2082 resultTy.getShape(), inElementTy)
2083 .result();
2084 auto fillValueMaxAttr =
2085 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2086
2087 if (!fillValueMaxAttr)
2088 return rewriter.notifyMatchFailure(
2089 argmaxOp, "unsupported tosa.argmax element type");
2090
2091 auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
2092 auto filledTensorMax =
2093 rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
2094 .result();
2095
2096 // We need to reduce along the arg-max axis, with parallel operations along
2097 // the rest.
2098 SmallVector<StringRef, 4> iteratorTypes;
2099 iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
2100 iteratorTypes[axis] = getReductionIteratorTypeName();
2101
2102 SmallVector<AffineExpr, 2> srcExprs;
2103 SmallVector<AffineExpr, 2> dstExprs;
2104 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2105 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2106 if (axis != i)
2107 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2108 }
2109
2110 bool didEncounterError = false;
2111 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2112 auto linalgOp = rewriter.create<linalg::GenericOp>(
2113 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2114 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2115 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2116 ValueRange blockArgs) {
2117 auto newValue = blockArgs[0];
2118 auto oldIndex = blockArgs[1];
2119 auto oldValue = blockArgs[2];
2120
2121 Value newIndex = rewriter.create<IndexCastOp>(
2122 nestedLoc, oldIndex.getType(),
2123 rewriter.create<linalg::IndexOp>(loc, axis));
2124
2125 Value predicate;
2126 if (inElementTy.isa<FloatType>()) {
2127 predicate = rewriter.create<mlir::CmpFOp>(
2128 nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
2129 } else if (inElementTy.isa<IntegerType>()) {
2130 predicate = rewriter.create<mlir::CmpIOp>(
2131 nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
2132 } else {
2133 didEncounterError = true;
2134 return;
2135 }
2136
2137 auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
2138 newValue, oldValue);
2139 auto resultIndex = rewriter.create<mlir::SelectOp>(
2140 nestedLoc, predicate, newIndex, oldIndex);
2141 nestedBuilder.create<linalg::YieldOp>(
2142 nestedLoc, ValueRange({resultIndex, resultMax}));
2143 });
2144
2145 if (didEncounterError)
2146 return rewriter.notifyMatchFailure(
2147 argmaxOp, "unsupported tosa.argmax element type");
2148
2149 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2150 return success();
2151 }
2152 };
2153
2154 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2155 public:
2156 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2157 LogicalResult
matchAndRewrite(tosa::GatherOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const2158 matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
2159 ConversionPatternRewriter &rewriter) const final {
2160 auto input = args[0];
2161 auto indices = args[1];
2162
2163 auto inputTy = input.getType().cast<ShapedType>();
2164 auto indicesTy = indices.getType().cast<ShapedType>();
2165 auto resultTy = op.getType().cast<ShapedType>();
2166
2167 if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
2168 return rewriter.notifyMatchFailure(
2169 op, "require input type to have static shape");
2170
2171 auto resultElementTy = resultTy.getElementType();
2172
2173 auto loc = op.getLoc();
2174
2175 auto initTensor =
2176 rewriter
2177 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2178 resultTy.getShape(), resultElementTy)
2179 .result();
2180
2181 SmallVector<AffineMap, 2> affineMaps = {
2182 AffineMap::get(
2183 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2184 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2185 rewriter.getContext()),
2186 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2187
2188 auto genericOp = rewriter.create<linalg::GenericOp>(
2189 loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2190 ValueRange{initTensor}, affineMaps,
2191 getNParallelLoopsAttrs(resultTy.getRank()),
2192 [&](OpBuilder &b, Location loc, ValueRange args) {
2193 auto indexValue = args[0];
2194 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2195 Value index1 = rewriter.create<IndexCastOp>(
2196 loc, rewriter.getIndexType(), indexValue);
2197 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2198 Value extract = rewriter.create<tensor::ExtractOp>(
2199 loc, input, ValueRange{index0, index1, index2});
2200 rewriter.create<linalg::YieldOp>(loc, extract);
2201 });
2202 rewriter.replaceOp(op, genericOp.getResult(0));
2203 return success();
2204 }
2205 };
2206
2207 // Lowerings the TableOp to a series of gathers and numerica operations. This
2208 // includes interpolation between the high/low values. For the I8 varient, this
2209 // simplifies to a single gather operation.
2210 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2211 public:
2212 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2213
matchAndRewrite(tosa::TableOp op,PatternRewriter & rewriter) const2214 LogicalResult matchAndRewrite(tosa::TableOp op,
2215 PatternRewriter &rewriter) const final {
2216 auto loc = op.getLoc();
2217 Value input = op.input();
2218 Value table = op.table();
2219 auto inputTy = input.getType().cast<ShapedType>();
2220 auto tableTy = table.getType().cast<ShapedType>();
2221 auto resultTy = op.getType().cast<ShapedType>();
2222
2223 if (!inputTy.hasStaticShape())
2224 return rewriter.notifyMatchFailure(
2225 op, "require input type to have static shape");
2226
2227 auto inputElementTy = inputTy.getElementType();
2228 auto tableElementTy = tableTy.getElementType();
2229 auto resultElementTy = resultTy.getElementType();
2230
2231 auto initTensor =
2232 rewriter
2233 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2234 resultTy.getShape(), resultElementTy)
2235 .result();
2236
2237 SmallVector<AffineMap, 2> affineMaps = {
2238 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2239 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2240
2241 auto genericOp = rewriter.create<linalg::GenericOp>(
2242 loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2243 getNParallelLoopsAttrs(resultTy.getRank()));
2244 rewriter.replaceOp(op, genericOp.getResult(0));
2245
2246 {
2247 OpBuilder::InsertionGuard regionGuard(rewriter);
2248 Block *block =
2249 rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
2250 TypeRange({inputElementTy, resultElementTy}));
2251
2252 auto inputValue = block->getArgument(0);
2253 rewriter.setInsertionPointToStart(block);
2254 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2255 resultElementTy.isInteger(8)) {
2256 Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
2257 inputValue);
2258 Value extract =
2259 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2260 rewriter.create<linalg::YieldOp>(loc, extract);
2261 return success();
2262 }
2263
2264 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2265 resultElementTy.isInteger(32)) {
2266 Value extend = rewriter.create<SignExtendIOp>(
2267 loc, rewriter.getI32Type(), inputValue);
2268
2269 auto offset =
2270 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
2271 auto seven =
2272 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
2273 auto one =
2274 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2275 auto b1111111 =
2276 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
2277
2278 // Compute the index and fractional part from the input value:
2279 // value = value + 32768
2280 // index = value >> 7;
2281 // fraction = 0x01111111 & value
2282 auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
2283 Value index =
2284 rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
2285 Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
2286
2287 // Extract the base and next values from the table.
2288 // base = (int32_t) table[index];
2289 // next = (int32_t) table[index + 1];
2290 Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
2291
2292 index =
2293 rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
2294 indexPlusOne = rewriter.create<IndexCastOp>(
2295 loc, rewriter.getIndexType(), indexPlusOne);
2296
2297 Value base =
2298 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2299 Value next = rewriter.create<tensor::ExtractOp>(
2300 loc, table, ValueRange{indexPlusOne});
2301
2302 base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
2303 next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
2304
2305 // Use the fractional part to interpolate between the input values:
2306 // result = (base << 7) + (next - base) * fraction
2307 Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
2308 Value diff = rewriter.create<SubIOp>(loc, next, base);
2309 Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction);
2310 Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled);
2311
2312 rewriter.create<linalg::YieldOp>(loc, result);
2313
2314 return success();
2315 }
2316 }
2317
2318 return rewriter.notifyMatchFailure(
2319 op, "unable to create body for tosa.table op");
2320 }
2321 };
2322
2323 template <typename SrcOp>
2324 class Pool2dConverter : public OpRewritePattern<SrcOp> {
2325 public:
2326 using OpRewritePattern<SrcOp>::OpRewritePattern;
2327
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const2328 LogicalResult matchAndRewrite(SrcOp op,
2329 PatternRewriter &rewriter) const final {
2330 Location loc = op.getLoc();
2331 Value input = op.input();
2332 ShapedType inputTy = input.getType().cast<ShapedType>();
2333 Type inElementTy = inputTy.getElementType();
2334
2335 ShapedType resultTy = op.getType().template cast<ShapedType>();
2336 Type outElementTy = inputTy.getElementType();
2337
2338 if (!inputTy.hasStaticShape())
2339 return failure();
2340
2341 // Determine what the initial value needs to be for the max pool op.
2342 Attribute initialAttr;
2343 if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
2344 initialAttr = rewriter.getFloatAttr(
2345 outElementTy,
2346 APFloat::getLargest(
2347 outElementTy.cast<FloatType>().getFloatSemantics(), true));
2348
2349 if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
2350 initialAttr = rewriter.getIntegerAttr(
2351 outElementTy,
2352 APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
2353
2354 if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
2355 initialAttr = rewriter.getZeroAttr(outElementTy);
2356
2357 if (!initialAttr)
2358 return rewriter.notifyMatchFailure(
2359 op, "Unsupported initial value for tosa.maxpool_2d op");
2360
2361 // Apply padding as necessary.
2362 llvm::SmallVector<int64_t> pad;
2363 pad.resize(2, 0);
2364 getValuesFromIntArrayAttribute(op.pad(), pad);
2365 pad.resize(pad.size() + 2, 0);
2366 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2367
2368 Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2369
2370 SmallVector<int64_t> kernel, stride;
2371 getValuesFromIntArrayAttribute(op.kernel(), kernel);
2372 getValuesFromIntArrayAttribute(op.stride(), stride);
2373
2374 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2375 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2376
2377 // Create the linalg op that performs pooling.
2378 Value initTensor = rewriter.create<linalg::InitTensorOp>(
2379 loc, resultTy.getShape(), resultTy.getElementType());
2380
2381 Value filledInitTensor =
2382 rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
2383
2384 Value fakeWindowDims =
2385 rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
2386
2387 if (isa<tosa::MaxPool2dOp>(op)) {
2388 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
2389 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2390 filledInitTensor, strideAttr, dilationAttr);
2391 return success();
2392 }
2393
2394 if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
2395 Value poolingOp = rewriter
2396 .create<linalg::PoolingNhwcSumOp>(
2397 loc, ArrayRef<Type>{resultTy},
2398 ValueRange{paddedInput, fakeWindowDims},
2399 filledInitTensor, strideAttr, dilationAttr)
2400 .getResult(0);
2401 auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
2402 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
2403 auto genericOp = rewriter.create<linalg::GenericOp>(
2404 loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
2405 ArrayRef<AffineMap>({affineMap}),
2406 getNParallelLoopsAttrs(resultTy.getRank()),
2407 [&](OpBuilder &b, Location loc, ValueRange args) {
2408 auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
2409 auto one = rewriter.create<ConstantIndexOp>(loc, 1);
2410 auto iH = rewriter.create<ConstantIndexOp>(
2411 loc, poolingOpTy.getDimSize(1) - 1);
2412 auto iW = rewriter.create<ConstantIndexOp>(
2413 loc, poolingOpTy.getDimSize(2) - 1);
2414
2415 // Compute the indices from either end.
2416 auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
2417 auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
2418 auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
2419 auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
2420
2421 // Determines what the portion of valid input is covered by the
2422 // kernel.
2423 auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2424 if (pad == 0)
2425 return v;
2426
2427 auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
2428 Value dx = rewriter.create<SubIOp>(loc, x, padVal);
2429
2430 Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
2431 dx, zero);
2432 Value offset =
2433 rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
2434 return rewriter.create<mlir::AddIOp>(loc, v, offset)
2435 ->getResult(0);
2436 };
2437
2438 // Compute the vertical component of coverage.
2439 auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
2440 auto kH1 = padFn(kH0, y0, pad[2]);
2441 auto kH2 = padFn(kH1, y1, pad[3]);
2442 auto kHCmp =
2443 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
2444 auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
2445
2446 // compute teh horizontal component of coverage.
2447 auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
2448 auto kW1 = padFn(kW0, x0, pad[4]);
2449 auto kW2 = padFn(kW1, x1, pad[5]);
2450 auto kWCmp =
2451 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
2452 auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
2453
2454 // Compute the total number of elements and normalize.
2455 Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
2456 auto countI = rewriter.create<mlir::IndexCastOp>(
2457 loc, rewriter.getI32Type(), count);
2458 auto countF =
2459 rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
2460
2461 auto div =
2462 rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
2463
2464 rewriter.create<linalg::YieldOp>(loc, div);
2465 });
2466
2467 rewriter.replaceOp(op, genericOp.getResult(0));
2468 return success();
2469 }
2470
2471 return failure();
2472 }
2473 };
2474
2475 } // namespace
2476
populateTosaToLinalgOnTensorsConversionPatterns(RewritePatternSet * patterns)2477 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
2478 RewritePatternSet *patterns) {
2479 patterns->add<
2480 // clang-format off
2481 PointwiseConverter<tosa::AddOp>,
2482 PointwiseConverter<tosa::SubOp>,
2483 PointwiseConverter<tosa::MulOp>,
2484 PointwiseConverter<tosa::DivOp>,
2485 PointwiseConverter<tosa::NegateOp>,
2486 PointwiseConverter<tosa::PowOp>,
2487 PointwiseConverter<tosa::ReciprocalOp>,
2488 PointwiseConverter<tosa::RsqrtOp>,
2489 PointwiseConverter<tosa::LogOp>,
2490 PointwiseConverter<tosa::ExpOp>,
2491 PointwiseConverter<tosa::AbsOp>,
2492 PointwiseConverter<tosa::TanhOp>,
2493 PointwiseConverter<tosa::BitwiseAndOp>,
2494 PointwiseConverter<tosa::BitwiseOrOp>,
2495 PointwiseConverter<tosa::BitwiseNotOp>,
2496 PointwiseConverter<tosa::BitwiseXorOp>,
2497 PointwiseConverter<tosa::LogicalAndOp>,
2498 PointwiseConverter<tosa::LogicalNotOp>,
2499 PointwiseConverter<tosa::LogicalOrOp>,
2500 PointwiseConverter<tosa::LogicalXorOp>,
2501 PointwiseConverter<tosa::CastOp>,
2502 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2503 PointwiseConverter<tosa::LogicalRightShiftOp>,
2504 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2505 PointwiseConverter<tosa::SelectOp>,
2506 PointwiseConverter<tosa::GreaterOp>,
2507 PointwiseConverter<tosa::GreaterEqualOp>,
2508 PointwiseConverter<tosa::EqualOp>,
2509 PointwiseConverter<tosa::MaximumOp>,
2510 PointwiseConverter<tosa::MinimumOp>,
2511 PointwiseConverter<tosa::CeilOp>,
2512 PointwiseConverter<tosa::FloorOp>,
2513 PointwiseConverter<tosa::ClampOp>,
2514 PointwiseConverter<tosa::ReluNOp>,
2515 PointwiseConverter<tosa::SigmoidOp>,
2516 IdentityNConverter<tosa::IdentityOp>,
2517 ReduceConverter<tosa::ReduceAllOp>,
2518 ReduceConverter<tosa::ReduceAnyOp>,
2519 ReduceConverter<tosa::ReduceMinOp>,
2520 ReduceConverter<tosa::ReduceMaxOp>,
2521 ReduceConverter<tosa::ReduceSumOp>,
2522 ReduceConverter<tosa::ReduceProdOp>,
2523 ArgMaxConverter,
2524 ConcatConverter,
2525 ConvConverter<tosa::Conv2DOp>,
2526 ConvConverter<tosa::DepthwiseConv2DOp>,
2527 TransposeConvConverter,
2528 GatherConverter,
2529 PadConverter,
2530 ReshapeConverter,
2531 RescaleConverter,
2532 ResizeConverter,
2533 ReverseConverter,
2534 TableConverter,
2535 TileConverter,
2536 TransposeConverter,
2537 MatMulConverter,
2538 Pool2dConverter<tosa::AvgPool2dOp>,
2539 Pool2dConverter<tosa::MaxPool2dOp>,
2540 FullyConnectedConverter>(patterns->getContext());
2541 // clang-format on
2542 }
2543