1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
20 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26 #include <numeric>
27
28 using namespace mlir;
29
getNParallelLoopsAttrs(unsigned nParallelLoops)30 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
31 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
32 }
33
34 template <typename T>
35 static mlir::ConstantOp
createConstFromIntAttribute(Operation * op,std::string attrName,Type requiredAttrType,OpBuilder & rewriter)36 createConstFromIntAttribute(Operation *op, std::string attrName,
37 Type requiredAttrType, OpBuilder &rewriter) {
38 auto castedN = static_cast<T>(
39 op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
40 return rewriter.create<mlir::ConstantOp>(
41 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
42 }
43
44 template <typename T>
getValuesFromIntArrayAttribute(ArrayAttr attr,SmallVector<T> & arrayValues)45 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
46 SmallVector<T> &arrayValues) {
47 for (Attribute val : attr.getValue()) {
48 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
49 }
50 }
51
52 template <typename T, typename P>
clampHelper(Location loc,Value arg,mlir::ConstantOp min,mlir::ConstantOp max,P pred,OpBuilder & rewriter)53 static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
54 mlir::ConstantOp max, P pred,
55 OpBuilder &rewriter) {
56 auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
57 auto minOrArg =
58 rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
59 auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
60 return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
61 }
62
applyPad(Location loc,Value input,ArrayRef<int64_t> pad,Attribute padAttr,OpBuilder & rewriter)63 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
64 Attribute padAttr, OpBuilder &rewriter) {
65 // Input should be padded if necessary.
66 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
67 return input;
68
69 ShapedType inputTy = input.getType().cast<ShapedType>();
70 Type inputETy = inputTy.getElementType();
71 auto inputShape = inputTy.getShape();
72
73 assert((inputShape.size() * 2) == pad.size());
74
75 SmallVector<int64_t, 4> paddedShape;
76 SmallVector<OpFoldResult, 8> lowIndices;
77 SmallVector<OpFoldResult, 8> highIndices;
78 for (int i = 0, s = inputShape.size(); i < s; i++) {
79 auto lowPad = pad[i * 2];
80 auto highPad = pad[i * 2 + 1];
81 paddedShape.push_back(inputShape[i] + highPad + lowPad);
82 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
83 highIndices.push_back(rewriter.getIndexAttr(highPad));
84 }
85
86 Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
87
88 return linalg::PadTensorOp::createPadScalarOp(
89 RankedTensorType::get(paddedShape, inputETy), input, padValue,
90 lowIndices, highIndices, /*nofold=*/false, loc, rewriter)
91 .result();
92 }
93
filterDynamicDims(SmallVector<Value> dynDims)94 static SmallVector<Value> filterDynamicDims(SmallVector<Value> dynDims) {
95 SmallVector<Value> filteredDims;
96 for (auto dim : dynDims)
97 if (dim)
98 filteredDims.push_back(dim);
99 return filteredDims;
100 }
101
102 static Value
createLinalgBodyCalculationForElementwiseOp(Operation * op,ValueRange args,ArrayRef<Type> resultTypes,PatternRewriter & rewriter)103 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
104 ArrayRef<Type> resultTypes,
105 PatternRewriter &rewriter) {
106 Location loc = op->getLoc();
107 auto elementTy =
108 op->getOperand(0).getType().cast<ShapedType>().getElementType();
109
110 // tosa::AbsOp
111 if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
112 return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
113
114 if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
115 auto zero =
116 rewriter.create<mlir::ConstantOp>(loc, rewriter.getZeroAttr(elementTy));
117 auto cmp =
118 rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0], zero);
119 auto neg = rewriter.create<mlir::SubIOp>(loc, zero, args[0]);
120 return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg);
121 }
122
123 // tosa::AddOp
124 if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
125 return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
126
127 if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
128 return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
129
130 // tosa::SubOp
131 if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
132 return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
133
134 if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
135 return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
136
137 // tosa::MulOp
138 if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
139 if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
140 (void)rewriter.notifyMatchFailure(op,
141 "Cannot have shift value for float");
142 return nullptr;
143 }
144 return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
145 }
146
147 // tosa::DivOp
148 if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
149 return rewriter.create<mlir::SignedDivIOp>(loc, resultTypes, args);
150
151 // tosa::ReciprocalOp
152 if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
153 auto one =
154 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
155 return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
156 }
157
158 if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
159 Value a = args[0];
160 Value b = args[1];
161 auto shift =
162 op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
163 if (shift > 0) {
164 auto shiftConst =
165 rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
166 if (!a.getType().isInteger(32))
167 a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
168
169 if (!b.getType().isInteger(32))
170 b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
171
172 auto result = rewriter.create<tosa::ApplyScaleOp>(
173 loc, rewriter.getI32Type(), a, b, shiftConst,
174 rewriter.getBoolAttr(false));
175
176 if (elementTy.isInteger(32))
177 return result;
178
179 return rewriter.create<TruncateIOp>(loc, elementTy, result);
180 }
181
182 int aWidth = a.getType().getIntOrFloatBitWidth();
183 int bWidth = b.getType().getIntOrFloatBitWidth();
184 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
185
186 if (aWidth < cWidth)
187 a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
188 if (bWidth < cWidth)
189 b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
190
191 return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
192 }
193
194 // tosa::NegateOp
195 if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
196 return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
197
198 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
199 !cast<tosa::NegateOp>(op).quantization_info()) {
200 auto constant =
201 rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
202 return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
203 }
204
205 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
206 cast<tosa::NegateOp>(op).quantization_info()) {
207 auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
208 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
209 int64_t inZp =
210 quantizationInfo.getValue().input_zp().getValue().getSExtValue();
211 int64_t outZp =
212 quantizationInfo.getValue().output_zp().getValue().getSExtValue();
213
214 // Compute the maximum value that can occur in the intermediate buffer.
215 int64_t zpAdd = inZp + outZp;
216 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
217 std::abs(zpAdd) + 1;
218
219 // Convert that maximum value into the maximum bitwidth needed to represent
220 // it. We assume 48-bit numbers may be supported further in the pipeline.
221 int intermediateBitWidth = 64;
222 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
223 intermediateBitWidth = 16;
224 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
225 intermediateBitWidth = 32;
226 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
227 intermediateBitWidth = 48;
228 }
229
230 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
231 Value zpAddValue = rewriter.create<ConstantOp>(
232 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
233
234 // The negation can be applied by doing:
235 // outputValue = inZp + outZp - inputValue
236 auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
237 auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
238
239 // Clamp to the negation range.
240 auto min = rewriter.create<ConstantOp>(
241 loc, rewriter.getIntegerAttr(
242 intermediateType,
243 APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
244 auto max = rewriter.create<ConstantOp>(
245 loc, rewriter.getIntegerAttr(
246 intermediateType,
247 APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
248 auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
249 CmpIPredicate::slt, rewriter);
250
251 // Truncate to the final value.
252 return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
253 }
254
255 // tosa::BitwiseAndOp
256 if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
257 return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
258
259 // tosa::BitwiseOrOp
260 if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
261 return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
262
263 // tosa::BitwiseNotOp
264 if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
265 auto allOnesAttr = rewriter.getIntegerAttr(
266 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
267 auto allOnes = rewriter.create<ConstantOp>(loc, allOnesAttr);
268 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], allOnes);
269 }
270
271 // tosa::BitwiseXOrOp
272 if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
273 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
274
275 // tosa::LogicalLeftShiftOp
276 if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
277 return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
278
279 // tosa::LogicalRightShiftOp
280 if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
281 return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
282
283 // tosa::ArithmeticRightShiftOp
284 if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
285 auto result =
286 rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
287 auto round = op->getAttr("round").cast<BoolAttr>().getValue();
288 if (!round) {
289 return result;
290 }
291
292 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
293 auto one =
294 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
295 auto zero =
296 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
297 auto i1one =
298 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
299
300 // Checking that input2 != 0
301 auto shiftValueGreaterThanZero =
302 rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
303
304 // Checking for the last bit of input1 to be 1
305 auto subtract =
306 rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
307 auto shifted = rewriter
308 .create<mlir::SignedShiftRightOp>(loc, resultTypes,
309 args[0], subtract)
310 ->getResults();
311 auto truncated =
312 rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
313 auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
314
315 auto shouldRound = rewriter.create<mlir::AndOp>(
316 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
317 auto extended =
318 rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
319 return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
320 }
321
322 // tosa::ClzOp
323 if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
324 int bitWidth = elementTy.getIntOrFloatBitWidth();
325 auto zero =
326 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
327 auto leadingZeros = rewriter.create<mlir::ConstantOp>(
328 loc, IntegerAttr::get(elementTy, bitWidth));
329
330 SmallVector<Value> operands = {args[0], leadingZeros, zero};
331 SmallVector<Type> types = {elementTy, elementTy, elementTy};
332
333 auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
334 Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
335 Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
336
337 // The conditional block of the while loop.
338 {
339 rewriter.setInsertionPointToStart(&whileOp.before().front());
340 Value input = before->getArgument(0);
341 Value zero = before->getArgument(2);
342
343 Value inputLargerThanZero =
344 rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, input, zero);
345 rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
346 before->getArguments());
347 }
348
349 // The body of the while loop: shift right until reaching a value of 0.
350 {
351 rewriter.setInsertionPointToStart(&whileOp.after().front());
352 Value input = after->getArgument(0);
353 Value leadingZeros = after->getArgument(1);
354
355 auto one = rewriter.create<mlir::ConstantOp>(
356 loc, IntegerAttr::get(elementTy, 1));
357 auto shifted = rewriter.create<mlir::UnsignedShiftRightOp>(
358 loc, resultTypes, input, one);
359 auto leadingZerosMinusOne =
360 rewriter.create<mlir::SubIOp>(loc, resultTypes, leadingZeros, one);
361
362 rewriter.create<scf::YieldOp>(
363 loc,
364 ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
365 }
366
367 rewriter.setInsertionPointAfter(whileOp);
368 return whileOp->getResult(1);
369 }
370
371 // tosa::LogicalAnd
372 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
373 return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
374
375 // tosa::LogicalNot
376 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
377 auto one = rewriter.create<mlir::ConstantOp>(
378 loc, rewriter.getIntegerAttr(elementTy, 1));
379 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], one);
380 }
381
382 // tosa::LogicalOr
383 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
384 return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
385
386 // tosa::LogicalXor
387 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
388 return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
389
390 // tosa::PowOp
391 if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
392 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
393
394 // tosa::RsqrtOp
395 if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
396 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
397
398 // tosa::LogOp
399 if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
400 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
401
402 // tosa::ExpOp
403 if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
404 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
405
406 // tosa::TanhOp
407 if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
408 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
409
410 // tosa::GreaterOp
411 if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
412 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
413 args[1]);
414
415 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
416 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
417 args[1]);
418
419 // tosa::GreaterEqualOp
420 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
421 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
422 args[1]);
423
424 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
425 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
426 args[1]);
427
428 // tosa::EqualOp
429 if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
430 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
431 args[1]);
432
433 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
434 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
435 args[1]);
436
437 // tosa::SelectOp
438 if (isa<tosa::SelectOp>(op)) {
439 elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
440 if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
441 return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
442 }
443
444 // tosa::MaximumOp
445 if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
446 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
447 args[0], args[1]);
448 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
449 }
450
451 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
452 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
453 args[0], args[1]);
454 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
455 }
456
457 // tosa::MinimumOp
458 if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
459 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
460 args[0], args[1]);
461 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
462 }
463
464 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
465 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
466 args[0], args[1]);
467 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
468 }
469
470 // tosa::CeilOp
471 if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
472 return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
473
474 // tosa::FloorOp
475 if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
476 return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
477
478 // tosa::ClampOp
479 if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
480 auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
481 op->getAttr("min_fp"));
482 auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
483 op->getAttr("max_fp"));
484 return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
485 rewriter);
486 }
487
488 if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
489 auto intTy = elementTy.cast<IntegerType>();
490 int32_t min = static_cast<int32_t>(
491 op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
492 int32_t max = static_cast<int32_t>(
493 op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
494
495 if (intTy.isUnsignedInteger()) {
496 min = std::max<int32_t>(min, 0);
497 max = std::min<int32_t>(
498 max,
499 APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
500 } else {
501 min = std::max<int32_t>(
502 min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
503 .getSExtValue());
504 max = std::min<int32_t>(
505 max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
506 .getSExtValue());
507 }
508
509 auto minVal =
510 rewriter.create<ConstantIntOp>(loc, min, intTy.getIntOrFloatBitWidth());
511 auto maxVal =
512 rewriter.create<ConstantIntOp>(loc, max, intTy.getIntOrFloatBitWidth());
513 return clampHelper<mlir::CmpIOp>(loc, args[0], minVal, maxVal,
514 CmpIPredicate::slt, rewriter);
515 }
516
517 // tosa::ReluNOp
518 if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
519 auto zero =
520 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
521 auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
522 op->getAttr("max_fp"));
523 return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
524 rewriter);
525 }
526
527 if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
528 auto zero =
529 rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
530 auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
531 rewriter);
532 return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
533 rewriter);
534 }
535
536 // tosa::SigmoidOp
537 if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
538 auto one =
539 rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
540 auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
541 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
542 auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
543 return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
544 }
545
546 // tosa::CastOp
547 if (isa<tosa::CastOp>(op)) {
548 Type srcTy = elementTy;
549 Type dstTy = resultTypes.front();
550 bool bitExtend =
551 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
552
553 if (srcTy == dstTy)
554 return args.front();
555
556 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
557 return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
558
559 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
560 return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
561 mlir::None);
562
563 // 1-bit integers need to be treated as signless.
564 if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
565 return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
566 mlir::None);
567
568 if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
569 return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
570 mlir::None);
571
572 // All other si-to-fp conversions should be handled by SIToFP.
573 if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
574 return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
575 mlir::None);
576
577 // Unsigned integers need an unrealized cast so that they can be passed
578 // to UIToFP.
579 if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
580 auto unrealizedCast =
581 rewriter
582 .create<UnrealizedConversionCastOp>(
583 loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
584 args[0])
585 .getResult(0);
586 return rewriter.create<mlir::UIToFPOp>(loc, resultTypes[0],
587 unrealizedCast);
588 }
589
590 // Casting to boolean, floats need to only be checked as not-equal to zero.
591 if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
592 Value zero =
593 rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
594 return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
595 args.front(), zero);
596 }
597
598 if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
599 auto zero =
600 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
601 auto half =
602 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
603
604 auto intMin = rewriter.create<ConstantOp>(
605 loc, rewriter.getF32FloatAttr(
606 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
607 .getSExtValue()));
608
609 auto intMax = rewriter.create<ConstantOp>(
610 loc, rewriter.getF32FloatAttr(
611 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
612 .getSExtValue()));
613
614 auto added = rewriter.create<AddFOp>(loc, args[0], half);
615 auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
616 auto negative =
617 rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
618 auto rounded =
619 rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
620
621 auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
622 CmpFPredicate::OLT, rewriter);
623
624 return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
625 }
626
627 // Casting to boolean, integers need to only be checked as not-equal to
628 // zero.
629 if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
630 Value zero =
631 rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
632 return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
633 zero);
634 }
635
636 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
637 return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
638 mlir::None);
639
640 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
641 auto intMin = rewriter.create<ConstantIntOp>(
642 loc,
643 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
644 .getSExtValue(),
645 srcTy.getIntOrFloatBitWidth());
646
647 auto intMax = rewriter.create<ConstantIntOp>(
648 loc,
649 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
650 .getSExtValue(),
651 srcTy.getIntOrFloatBitWidth());
652
653 auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
654 CmpIPredicate::slt, rewriter);
655 return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
656 }
657 }
658
659 (void)rewriter.notifyMatchFailure(
660 op, "unhandled op for linalg body calculation for elementwise op");
661 return nullptr;
662 }
663
664 static LogicalResult
elementwiseMatchAndRewriteHelper(Operation * operation,PatternRewriter & rewriter)665 elementwiseMatchAndRewriteHelper(Operation *operation,
666 PatternRewriter &rewriter) {
667 auto loc = operation->getLoc();
668
669 assert(operation->getNumResults() == 1 &&
670 "All TOSA elementwise ops should only return a single result.");
671
672 auto results = operation->getResults();
673 auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
674
675 if (!resultTy)
676 return rewriter.notifyMatchFailure(operation,
677 "All results must be a shaped type");
678
679 unsigned rank = resultTy.getRank();
680
681 // Construct the indexing maps needed for linalg.generic ops.
682 SmallVector<Type> bodyArgTypes;
683
684 for (Value in : operation->getOperands())
685 bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
686
687 SmallVector<Type> opResultTypes;
688 SmallVector<Value> initTensors;
689
690 SmallVector<Value> dynDims;
691 dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
692
693 for (auto arg : operation->getOperands()) {
694 auto operandTy = arg.getType().cast<ShapedType>();
695 for (int i = 0; i < operandTy.getRank(); i++) {
696 if (operandTy.isDynamicDim(i) && !dynDims[i])
697 dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
698 }
699 }
700
701 SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
702
703 for (auto result : results) {
704 auto resultTy = result.getType().template cast<ShapedType>();
705 initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
706 loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
707 opResultTypes.push_back(result.getType());
708 }
709
710 auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
711 initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
712
713 SmallVector<Value, 2> operands;
714 SmallVector<AffineMap, 2> indexingMaps;
715 indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
716
717 // Input indexing maps may be broadcasted.
718 for (Value operand : operation->getOperands()) {
719 ShapedType type = operand.getType().cast<ShapedType>();
720
721 if (type.getShape() == resultTy.getShape()) {
722 operands.push_back(operand);
723 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
724 continue;
725 }
726
727 SmallVector<int64_t, 5> newShape;
728 SmallVector<AffineExpr, 4> affineExprs;
729 newShape.reserve(type.getRank());
730 for (auto it : llvm::enumerate(type.getShape())) {
731 if (it.value() == resultTy.getDimSize(it.index())) {
732 newShape.push_back(it.value());
733 affineExprs.push_back(
734 mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
735 }
736 }
737
738 if (newShape.size() != rank) {
739 operand = rewriter.create<tosa::ReshapeOp>(
740 loc, RankedTensorType::get(newShape, type.getElementType()), operand,
741 rewriter.getI64ArrayAttr(newShape));
742 }
743
744 operands.push_back(operand);
745 indexingMaps.push_back(AffineMap::get(
746 /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
747 rewriter.getContext()));
748 }
749
750 indexingMaps.append(operation->getNumResults(),
751 rewriter.getMultiDimIdentityMap(rank));
752
753 bool didEncounterError = false;
754 auto linalgOp = rewriter.create<linalg::GenericOp>(
755 loc, opResultTypes, operands, initTensors, indexingMaps,
756 getNParallelLoopsAttrs(rank),
757 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
758 Value opResult = createLinalgBodyCalculationForElementwiseOp(
759 operation, blockArgs.take_front(operation->getNumOperands()),
760 bodyResultTypes, rewriter);
761 if (!opResult) {
762 didEncounterError = true;
763 return;
764 }
765 nestedBuilder.create<linalg::YieldOp>(loc, opResult);
766 });
767
768 if (didEncounterError)
769 return failure();
770
771 rewriter.replaceOp(operation, linalgOp->getResults());
772 return success();
773 }
774
775 // Returns the constant initial value for a given reduction operation. The
776 // attribute type varies depending on the element type required.
createInitialValueForReduceOp(Operation * op,Type elementTy,PatternRewriter & rewriter)777 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
778 PatternRewriter &rewriter) {
779 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
780 return rewriter.getFloatAttr(elementTy, 0.0);
781
782 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
783 return rewriter.getIntegerAttr(elementTy, 0);
784
785 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
786 return rewriter.getFloatAttr(elementTy, 1.0);
787
788 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
789 return rewriter.getIntegerAttr(elementTy, 1);
790
791 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
792 return rewriter.getFloatAttr(
793 elementTy, APFloat::getLargest(
794 elementTy.cast<FloatType>().getFloatSemantics(), false));
795
796 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
797 return rewriter.getIntegerAttr(
798 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
799
800 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
801 return rewriter.getFloatAttr(
802 elementTy, APFloat::getLargest(
803 elementTy.cast<FloatType>().getFloatSemantics(), true));
804
805 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
806 return rewriter.getIntegerAttr(
807 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
808
809 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
810 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
811
812 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
813 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
814
815 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
816 return rewriter.getFloatAttr(
817 elementTy, APFloat::getLargest(
818 elementTy.cast<FloatType>().getFloatSemantics(), true));
819
820 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
821 return rewriter.getIntegerAttr(
822 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
823
824 return {};
825 }
826
827 // Creates the body calculation for a reduction. The operations vary depending
828 // on the input type.
createLinalgBodyCalculationForReduceOp(Operation * op,ValueRange args,Type elementTy,PatternRewriter & rewriter)829 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
830 ValueRange args,
831 Type elementTy,
832 PatternRewriter &rewriter) {
833 Location loc = op->getLoc();
834 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
835 return rewriter.create<AddFOp>(loc, args);
836 }
837
838 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
839 return rewriter.create<AddIOp>(loc, args);
840 }
841
842 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
843 return rewriter.create<MulFOp>(loc, args);
844 }
845
846 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
847 return rewriter.create<MulIOp>(loc, args);
848 }
849
850 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
851 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
852 args[0], args[1]);
853 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
854 }
855
856 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
857 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
858 args[0], args[1]);
859 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
860 }
861
862 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
863 auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
864 args[0], args[1]);
865 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
866 }
867
868 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
869 auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
870 args[0], args[1]);
871 return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
872 }
873
874 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
875 return rewriter.create<mlir::AndOp>(loc, args);
876
877 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
878 return rewriter.create<mlir::OrOp>(loc, args);
879
880 return {};
881 }
882
883 // Performs the match and rewrite for reduction operations. This includes
884 // declaring a correctly sized initial value, and the linalg.generic operation
885 // that reduces across the specified axis.
reduceMatchAndRewriteHelper(Operation * op,uint64_t axis,PatternRewriter & rewriter)886 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
887 PatternRewriter &rewriter) {
888 auto loc = op->getLoc();
889 auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
890 auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
891 auto elementTy = resultTy.getElementType();
892 Value input = op->getOperand(0);
893
894 llvm::SmallVector<int64_t> reduceShape;
895 for (unsigned i = 0; i < inputTy.getRank(); i++) {
896 if (axis != i)
897 reduceShape.push_back(inputTy.getDimSize(i));
898 }
899
900 Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
901
902 // First fill the output buffer with the init value.
903 auto initTensor =
904 rewriter
905 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
906 resultTy.getElementType())
907 .result();
908
909 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
910 if (!fillValueAttr)
911 return rewriter.notifyMatchFailure(
912 op, "No initial value found for reduction operation");
913
914 auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
915 auto filledTensor =
916 rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
917
918 SmallVector<AffineExpr, 2> srcExprs;
919 SmallVector<AffineExpr, 2> dstExprs;
920 SmallVector<StringRef, 4> iteratorTypes;
921 for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
922 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
923
924 iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
925 : getParallelIteratorTypeName());
926 if (axis != i)
927 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
928 }
929
930 bool didEncounterError = false;
931 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
932 auto linalgOp = rewriter.create<linalg::GenericOp>(
933 loc, reduceTy, input, filledTensor, maps, iteratorTypes,
934 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
935 auto result = createLinalgBodyCalculationForReduceOp(
936 op, blockArgs, elementTy, rewriter);
937 if (result)
938 didEncounterError = true;
939
940 nestedBuilder.create<linalg::YieldOp>(loc, result);
941 });
942
943 if (!didEncounterError)
944 return failure();
945
946 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
947 linalgOp.getResults());
948 return success();
949 }
950
951 namespace {
952
953 template <typename SrcOp>
954 class PointwiseConverter : public OpRewritePattern<SrcOp> {
955 public:
956 using OpRewritePattern<SrcOp>::OpRewritePattern;
957
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const958 LogicalResult matchAndRewrite(SrcOp op,
959 PatternRewriter &rewriter) const final {
960 return elementwiseMatchAndRewriteHelper(op, rewriter);
961 }
962 };
963
964 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
965 public:
966 using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
967 LogicalResult
matchAndRewrite(tosa::Conv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const968 matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
969 ConversionPatternRewriter &rewriter) const final {
970 Location loc = op->getLoc();
971 Value input = op->getOperand(0);
972 Value weight = op->getOperand(1);
973 Value bias = op->getOperand(2);
974
975 ShapedType inputTy = input.getType().cast<ShapedType>();
976 ShapedType weightTy = weight.getType().cast<ShapedType>();
977 ShapedType biasTy = bias.getType().cast<ShapedType>();
978 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
979
980 Type inputETy = inputTy.getElementType();
981 Type resultETy = resultTy.getElementType();
982
983 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
984 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
985 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
986 bool isQuantized = op->hasAttr("quantization_info");
987
988 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
989 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
990 return rewriter.notifyMatchFailure(op,
991 "tosa.conv ops require static shapes");
992
993 if (inputETy.isUnsignedInteger())
994 return rewriter.notifyMatchFailure(
995 op, "tosa.conv ops does not support unsigned integer input");
996
997 auto weightShape = weightTy.getShape();
998
999 // Apply padding as necessary.
1000 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
1001 if (isQuantized) {
1002 auto quantizationInfo =
1003 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1004 auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
1005
1006 int64_t intMin =
1007 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
1008 .getSExtValue();
1009 int64_t intMax =
1010 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
1011 .getSExtValue();
1012
1013 if (iZp < intMin || iZp > intMax)
1014 return rewriter.notifyMatchFailure(
1015 op, "tosa.conv op quantization has zp outside of input range");
1016
1017 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
1018 }
1019
1020 llvm::SmallVector<int64_t> pad;
1021 pad.resize(2, 0);
1022 getValuesFromIntArrayAttribute(padAttr, pad);
1023 pad.resize(pad.size() + 2, 0);
1024 input = applyPad(loc, input, pad, zeroAttr, rewriter);
1025
1026 // Transpose the kernel to match dimension ordering of the linalg
1027 // convolution operation.
1028 // TODO(suderman): See if this can be efficiently folded - check whether
1029 // the input is used anywhere else, if not fold the constant.
1030 SmallVector<int64_t> weightPerm{1, 2, 3, 0};
1031 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
1032 weightShape[3], weightShape[0]};
1033 auto weightPermAttr = DenseIntElementsAttr::get(
1034 RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
1035 Value weightPermValue = rewriter.create<ConstantOp>(loc, weightPermAttr);
1036 Type newWeightTy =
1037 RankedTensorType::get(newWeightShape, weightTy.getElementType());
1038 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
1039 weightPermValue);
1040
1041 Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
1042 Value initTensor = rewriter.create<linalg::InitTensorOp>(
1043 loc, resultTy.getShape(), resultETy);
1044 Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
1045 Value zeroTensor =
1046 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1047
1048 // Extract the attributes for convolution.
1049 llvm::SmallVector<int64_t> stride, dilation;
1050 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
1051 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
1052
1053 // Create the convolution op.
1054 auto strideAttr = DenseIntElementsAttr::get(
1055 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
1056 auto dilationAttr = DenseIntElementsAttr::get(
1057 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
1058
1059 // Create maps for the bias broadcasting
1060 SmallVector<AffineMap, 4> indexingMaps;
1061 indexingMaps.push_back(AffineMap::get(
1062 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
1063 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
1064 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1065 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1066
1067 Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
1068 loc, resultTy.getShape(), resultETy);
1069
1070 if (isQuantized) {
1071 auto quantizationInfo =
1072 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1073 auto iZp = rewriter.getI32IntegerAttr(
1074 quantizationInfo.input_zp().getValue().getSExtValue());
1075 auto kZp = rewriter.getI32IntegerAttr(
1076 quantizationInfo.weight_zp().getValue().getSExtValue());
1077
1078 auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
1079 auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
1080 Value conv =
1081 rewriter
1082 .create<linalg::Conv2DNhwcHwcfQOp>(
1083 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
1084 ValueRange{zeroTensor}, strideAttr, dilationAttr)
1085 ->getResult(0);
1086
1087 Value result =
1088 rewriter
1089 .create<linalg::GenericOp>(
1090 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
1091 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
1092 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1093 ValueRange args) {
1094 Value added =
1095 nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
1096 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1097 })
1098 .getResult(0);
1099 rewriter.replaceOp(op, result);
1100 return success();
1101 }
1102
1103 Value conv = rewriter
1104 .create<linalg::Conv2DNhwcHwcfOp>(
1105 loc, resultTy, ValueRange{input, weight},
1106 ValueRange{zeroTensor}, strideAttr, dilationAttr)
1107 ->getResult(0);
1108
1109 Value result =
1110 rewriter
1111 .create<linalg::GenericOp>(
1112 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
1113 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
1114 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1115 ValueRange args) {
1116 Value added =
1117 nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
1118 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1119 })
1120 .getResult(0);
1121
1122 rewriter.replaceOp(op, result);
1123 return success();
1124 }
1125 };
1126
1127 class DepthwiseConvConverter
1128 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
1129 public:
1130 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
1131 LogicalResult
matchAndRewrite(tosa::DepthwiseConv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1132 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
1133 ConversionPatternRewriter &rewriter) const final {
1134 Location loc = op->getLoc();
1135 Value input = op->getOperand(0);
1136 Value weight = op->getOperand(1);
1137 Value bias = op->getOperand(2);
1138
1139 ShapedType inputTy = input.getType().cast<ShapedType>();
1140 ShapedType weightTy = weight.getType().cast<ShapedType>();
1141 ShapedType biasTy = bias.getType().cast<ShapedType>();
1142 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
1143
1144 Type inputETy = inputTy.getElementType();
1145 Type resultETy = resultTy.getElementType();
1146
1147 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
1148 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
1149 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
1150
1151 bool isQuantized = op->hasAttr("quantization_info");
1152 IntegerAttr iZp;
1153 IntegerAttr kZp;
1154 if (isQuantized) {
1155 auto quantizationInfo =
1156 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1157 iZp = rewriter.getI32IntegerAttr(
1158 quantizationInfo.input_zp().getValue().getSExtValue());
1159 kZp = rewriter.getI32IntegerAttr(
1160 quantizationInfo.weight_zp().getValue().getSExtValue());
1161 }
1162
1163 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1164 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1165 return rewriter.notifyMatchFailure(op,
1166 "tosa.conv ops require static shapes");
1167
1168 auto weightShape = weightTy.getShape();
1169 auto resultShape = resultTy.getShape();
1170
1171 // Apply padding as necessary.
1172 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
1173 if (isQuantized) {
1174 auto quantizationInfo =
1175 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1176 auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
1177
1178 int64_t intMin =
1179 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
1180 .getSExtValue();
1181 int64_t intMax =
1182 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
1183 .getSExtValue();
1184
1185 if (iZp < intMin || iZp > intMax)
1186 return rewriter.notifyMatchFailure(
1187 op, "tosa.depthwise_conv op quantization has zp outside of input "
1188 "range");
1189
1190 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
1191 }
1192
1193 llvm::SmallVector<int64_t> pad;
1194 pad.resize(2, 0);
1195 getValuesFromIntArrayAttribute(padAttr, pad);
1196 pad.resize(pad.size() + 2, 0);
1197
1198 input = applyPad(loc, input, pad, zeroAttr, rewriter);
1199
1200 // Extract the attributes for convolution.
1201 llvm::SmallVector<int64_t> stride, dilation;
1202 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
1203 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
1204
1205 // Create the convolution op.
1206 auto strideAttr = DenseIntElementsAttr::get(
1207 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
1208 auto dilationAttr = DenseIntElementsAttr::get(
1209 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
1210 ShapedType linalgConvTy =
1211 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
1212 weightShape[2], weightShape[3]},
1213 resultETy);
1214
1215 // Broadcast the initial value to the output tensor before convolving.
1216 SmallVector<AffineMap, 4> indexingMaps;
1217 indexingMaps.push_back(AffineMap::get(
1218 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
1219 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
1220 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1221 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1222
1223 Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
1224 Value initTensor = rewriter.create<linalg::InitTensorOp>(
1225 loc, linalgConvTy.getShape(), resultETy);
1226 Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
1227 Value zeroTensor =
1228 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1229
1230 Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
1231 loc, resultTy.getShape(), resultETy);
1232 if (!isQuantized) {
1233 Value conv = rewriter
1234 .create<linalg::DepthwiseConv2DNhwcOp>(
1235 loc, linalgConvTy, ValueRange{input, weight},
1236 ValueRange{zeroTensor}, strideAttr, dilationAttr)
1237 .getResult(0);
1238 Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
1239 Value result =
1240 rewriter
1241 .create<linalg::GenericOp>(
1242 loc, resultTy, ValueRange({bias, convReshape}),
1243 biasInitTensor, indexingMaps,
1244 getNParallelLoopsAttrs(resultTy.getRank()),
1245 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1246 ValueRange args) {
1247 Value added =
1248 nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
1249 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1250 })
1251 .getResult(0);
1252 rewriter.replaceOp(op, result);
1253 } else {
1254 auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
1255 auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
1256 Value conv =
1257 rewriter
1258 .create<linalg::DepthwiseConv2DNhwcQOp>(
1259 loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
1260 ValueRange{zeroTensor}, strideAttr, dilationAttr)
1261 .getResult(0);
1262 Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
1263 Value result =
1264 rewriter
1265 .create<linalg::GenericOp>(
1266 loc, resultTy, ValueRange({bias, convReshape}),
1267 biasInitTensor, indexingMaps,
1268 getNParallelLoopsAttrs(resultTy.getRank()),
1269 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1270 ValueRange args) {
1271 Value added =
1272 nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
1273 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1274 })
1275 .getResult(0);
1276 rewriter.replaceOp(op, result);
1277 }
1278 return success();
1279 }
1280 };
1281
1282 class TransposeConvConverter
1283 : public OpConversionPattern<tosa::TransposeConv2DOp> {
1284 public:
1285 using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
1286 LogicalResult
matchAndRewrite(tosa::TransposeConv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1287 matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
1288 ConversionPatternRewriter &rewriter) const final {
1289 Location loc = op->getLoc();
1290 Value input = op->getOperand(0);
1291 Value weight = op->getOperand(1);
1292 Value bias = op->getOperand(2);
1293
1294 ShapedType inputTy = input.getType().cast<ShapedType>();
1295 ShapedType weightTy = weight.getType().cast<ShapedType>();
1296 ShapedType biasTy = bias.getType().cast<ShapedType>();
1297 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
1298
1299 llvm::SmallVector<int64_t> pad;
1300 llvm::SmallVector<int64_t> stride;
1301 llvm::SmallVector<int64_t> dilation;
1302
1303 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
1304 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
1305 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
1306
1307 // If striding is all 1 we can modify padding and reverse the kernel along
1308 // the x/y direction to make it a regular convolution. This is much simpler
1309 // then handling striding....
1310 if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
1311 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1312 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1313 return failure();
1314
1315 int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
1316 int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
1317 int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
1318 int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
1319
1320 llvm::SmallVector<int64_t> convPad(4, 0);
1321 convPad[0] = kernelHeight - 1 - pad[0];
1322 convPad[2] = kernelWidth - 1 - pad[1];
1323 convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
1324 convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
1325
1326 auto reverse1 = rewriter.create<tosa::ReverseOp>(
1327 loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
1328 auto reverse2 = rewriter.create<tosa::ReverseOp>(
1329 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
1330
1331 Value conv2d;
1332 if (op.quantization_info().hasValue()) {
1333 conv2d = rewriter.create<tosa::Conv2DOp>(
1334 loc, resultTy, input, reverse2, bias,
1335 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
1336 rewriter.getI64ArrayAttr(dilation),
1337 op.quantization_info().getValue());
1338 } else {
1339 conv2d = rewriter.create<tosa::Conv2DOp>(
1340 loc, resultTy, input, reverse2, bias,
1341 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
1342 rewriter.getI64ArrayAttr(dilation));
1343 }
1344
1345 rewriter.replaceOp(op, conv2d);
1346 return success();
1347 }
1348
1349 return failure();
1350 }
1351 };
1352
1353 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
1354 public:
1355 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
1356 LogicalResult
matchAndRewrite(tosa::MatMulOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1357 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
1358 ConversionPatternRewriter &rewriter) const final {
1359 Location loc = op.getLoc();
1360
1361 auto outputTy = op.getType().cast<ShapedType>();
1362 auto outputElementTy = outputTy.getElementType();
1363
1364 auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
1365 auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
1366
1367 SmallVector<Value> dynDims;
1368 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1369
1370 if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
1371 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
1372 }
1373
1374 if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
1375 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
1376 }
1377
1378 if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
1379 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
1380 }
1381
1382 SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
1383
1384 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
1385 Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
1386 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1387 loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
1388 Value zeroTensor =
1389 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1390 if (!op.quantization_info()) {
1391 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
1392 op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
1393 ValueRange{zeroTensor});
1394 return success();
1395 }
1396
1397 auto quantizationInfo = op.quantization_info().getValue();
1398 auto aZp = rewriter.create<ConstantOp>(
1399 loc, rewriter.getI32IntegerAttr(
1400 quantizationInfo.a_zp().getValue().getSExtValue()));
1401 auto bZp = rewriter.create<ConstantOp>(
1402 loc, rewriter.getI32IntegerAttr(
1403 quantizationInfo.b_zp().getValue().getSExtValue()));
1404 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
1405 op, TypeRange{op.getType()},
1406 ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
1407
1408 return success();
1409 }
1410 };
1411
1412 class FullyConnectedConverter
1413 : public OpConversionPattern<tosa::FullyConnectedOp> {
1414 public:
1415 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
1416 LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1417 matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
1418 ConversionPatternRewriter &rewriter) const final {
1419 Location loc = op.getLoc();
1420 auto outputTy = op.getType().cast<ShapedType>();
1421 auto input = op.input();
1422 auto inputTy = input.getType().cast<ShapedType>();
1423
1424 auto bias = op.bias();
1425
1426 auto weight = op.weight();
1427 auto weightTy = weight.getType().cast<ShapedType>();
1428 auto weightShape = weightTy.getShape();
1429
1430 auto outputETy = outputTy.getElementType();
1431
1432 SmallVector<Value> dynDims;
1433 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1434
1435 if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
1436 dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
1437 }
1438
1439 if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
1440 dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
1441 }
1442
1443 SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
1444
1445 // Creating maps for the output of MatMul and the bias
1446 SmallVector<AffineMap, 4> indexingMaps;
1447
1448 // Broadcast the bias.
1449 indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
1450 {rewriter.getAffineDimExpr(1)},
1451 rewriter.getContext()));
1452
1453 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
1454 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
1455
1456 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1457 loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
1458
1459 // When quantized, the input elemeny type is not the same as the output
1460 Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
1461 Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
1462 Value zeroTensor =
1463 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1464
1465 SmallVector<int64_t> permutation{1, 0};
1466 auto permutationAttr = DenseIntElementsAttr::get(
1467 RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
1468 Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
1469
1470 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
1471 Type newWeightTy =
1472 RankedTensorType::get(newWeightShape, weightTy.getElementType());
1473
1474 Value transposedWeight = rewriter.create<tosa::TransposeOp>(
1475 loc, newWeightTy, weight, permutationValue);
1476
1477 auto biasInitTensor =
1478 rewriter
1479 .create<linalg::InitTensorOp>(loc, filteredDims,
1480 outputTy.getShape(), outputETy)
1481 ->getResults();
1482
1483 if (!op.quantization_info()) {
1484 Value matmul = rewriter
1485 .create<linalg::MatmulOp>(
1486 loc, TypeRange{op.getType()},
1487 ValueRange{input, transposedWeight}, zeroTensor)
1488 ->getResult(0);
1489
1490 Value result =
1491 rewriter
1492 .create<linalg::GenericOp>(
1493 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
1494 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
1495 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1496 ValueRange args) {
1497 Value added =
1498 nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
1499 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1500 })
1501 .getResult(0);
1502 rewriter.replaceOp(op, result);
1503 return success();
1504 }
1505
1506 auto quantizationInfo = op.quantization_info().getValue();
1507 auto inputZp = rewriter.create<ConstantOp>(
1508 loc, rewriter.getI32IntegerAttr(
1509 quantizationInfo.input_zp().getValue().getSExtValue()));
1510 auto outputZp = rewriter.create<ConstantOp>(
1511 loc, rewriter.getI32IntegerAttr(
1512 quantizationInfo.weight_zp().getValue().getSExtValue()));
1513 Value matmul =
1514 rewriter
1515 .create<linalg::QuantizedMatmulOp>(
1516 loc, TypeRange{op.getType()},
1517 ValueRange{input, transposedWeight, inputZp, outputZp},
1518 zeroTensor)
1519 ->getResult(0);
1520 Value result =
1521 rewriter
1522 .create<linalg::GenericOp>(
1523 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
1524 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
1525 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1526 ValueRange args) {
1527 Value added =
1528 nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
1529 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1530 })
1531 .getResult(0);
1532 rewriter.replaceOp(op, result);
1533 return success();
1534 }
1535 };
1536
1537 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
1538 public:
1539 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1540
1541 LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1542 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1543 ConversionPatternRewriter &rewriter) const final {
1544 ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1545 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1546
1547 if (operandTy == resultTy) {
1548 rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1549 return success();
1550 }
1551
1552 if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
1553 return failure();
1554
1555 // Compute the reassociation maps for the linalg operation.
1556 ArrayRef<int64_t> expandedShape =
1557 (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
1558 : resultTy.getShape());
1559 ArrayRef<int64_t> collapsedShape =
1560 (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
1561 : operandTy.getShape());
1562 unsigned currSrcDim = 0, currDstDim = 0;
1563 SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
1564
1565 // First scan all dimensions in the source shapes to see whether we have a
1566 // perfect case where consecutive dimensions in source are collapsed. For
1567 // such case we can just generate one single linalg.reshape.
1568 bool isCollapsingSource = true;
1569 while (currSrcDim < expandedShape.size() &&
1570 currDstDim < collapsedShape.size()) {
1571 int64_t dstSize = collapsedShape[currDstDim];
1572 int64_t srcSize = expandedShape[currSrcDim];
1573 while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
1574 reassociationMap[currDstDim].push_back(
1575 rewriter.getAffineDimExpr(currSrcDim++));
1576 srcSize *= expandedShape[currSrcDim];
1577 }
1578 if (srcSize == dstSize) {
1579 reassociationMap[currDstDim].push_back(
1580 rewriter.getAffineDimExpr(currSrcDim++));
1581 // If the next dim in collapsedShape is not 1, treat subsequent dims in
1582 // expandedShape which are 1 to be collapsed.
1583 if (currDstDim == collapsedShape.size() - 1 ||
1584 collapsedShape[currDstDim + 1] != 1) {
1585 while (currSrcDim < expandedShape.size() &&
1586 expandedShape[currSrcDim] == 1) {
1587 reassociationMap[currDstDim].push_back(
1588 rewriter.getAffineDimExpr(currSrcDim++));
1589 }
1590 }
1591 } else {
1592 isCollapsingSource = false;
1593 break;
1594 }
1595 currDstDim++;
1596 }
1597
1598 // Check if any remaining dimensions exist. If either is rank-0 we only
1599 // require the directly lowering.
1600 if (currSrcDim != expandedShape.size() ||
1601 currDstDim != collapsedShape.size())
1602 isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
1603
1604 // Otherwise, we need to first reduce all source dimensions into one and
1605 // then expand to the destination dimensions.
1606 if (!isCollapsingSource) {
1607 auto getIdentityExprs = [&rewriter](int n) {
1608 SmallVector<AffineExpr, 4> exprs;
1609 for (int i = 0; i < n; ++i)
1610 exprs.push_back(rewriter.getAffineDimExpr(i));
1611 return exprs;
1612 };
1613 Location loc = reshape.getLoc();
1614 int64_t totalElems =
1615 std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
1616 std::multiplies<int64_t>());
1617 auto elemTy = operandTy.getElementType();
1618 SmallVector<ReassociationExprs, 4> collapsingMap = {
1619 // Use operandTy here because we need to collapse all operands
1620 // dimensions.
1621 getIdentityExprs(operandTy.getShape().size())};
1622 SmallVector<ReassociationExprs, 4> expandingMap = {
1623 // Use resultTy here because we need to expand to all result
1624 // dimensions.
1625 getIdentityExprs(resultTy.getShape().size())};
1626
1627 auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
1628 Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
1629 loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
1630 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1631 reshape, resultTy, collapsedOp, expandingMap);
1632
1633 return success();
1634 }
1635
1636 if (resultTy.getRank() <
1637 adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
1638 rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
1639 reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1640 else
1641 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1642 reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1643
1644 return success();
1645 }
1646 };
1647
1648 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1649 public:
1650 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1651
matchAndRewrite(tosa::TransposeOp op,PatternRewriter & rewriter) const1652 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1653 PatternRewriter &rewriter) const final {
1654 DenseIntElementsAttr perms;
1655 if (!matchPattern(op.perms(), m_Constant(&perms))) {
1656 return failure();
1657 }
1658
1659 auto loc = op.getLoc();
1660 auto input = op->getOperand(0);
1661 auto resultTy = op.getType().cast<ShapedType>();
1662
1663 SmallVector<Value> dynDims;
1664 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1665
1666 SmallVector<AffineExpr, 2> inputExprs;
1667 inputExprs.resize(resultTy.getRank());
1668 auto operandTy = input.getType().cast<ShapedType>();
1669 for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
1670 auto index = permutation.index();
1671 auto value = permutation.value().getZExtValue();
1672 if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1673 dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1674 }
1675 inputExprs[value] = rewriter.getAffineDimExpr(index);
1676 }
1677
1678 SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
1679
1680 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1681 loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1682
1683 SmallVector<AffineMap, 2> affineMaps = {
1684 AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1685 rewriter.getContext()),
1686 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1687
1688 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1689 op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
1690 getNParallelLoopsAttrs(resultTy.getRank()),
1691 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1692 nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1693 });
1694 return success();
1695 }
1696 };
1697
1698 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1699 public:
1700 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1701
matchAndRewrite(tosa::RescaleOp op,PatternRewriter & rewriter) const1702 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1703 PatternRewriter &rewriter) const final {
1704 auto loc = op.getLoc();
1705 auto input = op.input();
1706 auto inputTy = op.input().getType().cast<ShapedType>();
1707 auto outputTy = op.output().getType().cast<ShapedType>();
1708 unsigned rank = inputTy.getRank();
1709
1710 // This is an illegal configuration. terminate and log an error
1711 if (op.double_round() && !op.scale32())
1712 return rewriter.notifyMatchFailure(
1713 op, "tosa.rescale requires scale32 for double_round to be true");
1714
1715 if (!outputTy.hasStaticShape())
1716 return rewriter.notifyMatchFailure(
1717 op, "tosa to linalg conversion expects statically shaped tensors");
1718
1719 // The shift and multiplier values.
1720 SmallVector<int32_t> multiplierValues;
1721 getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
1722
1723 SmallVector<int8_t> shiftValues;
1724 getValuesFromIntArrayAttribute(op.shift(), shiftValues);
1725
1726 // Double round only occurs if shift is greater than 31, check that this
1727 // is ever true.
1728 bool doubleRound =
1729 op.double_round() &&
1730 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1731
1732 SmallVector<AffineMap> indexingMaps = {
1733 rewriter.getMultiDimIdentityMap(rank)};
1734 SmallVector<Value, 4> genericInputs = {input};
1735
1736 // If we are rescaling per-channel then we need to store the multiplier
1737 // values in a buffer.
1738 Value multiplierConstant;
1739 int64_t multiplierArg = 0;
1740 if (multiplierValues.size() == 1) {
1741 multiplierConstant = rewriter.create<ConstantOp>(
1742 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1743 } else {
1744 SmallVector<AffineExpr, 2> multiplierExprs{
1745 rewriter.getAffineDimExpr(rank - 1)};
1746 auto multiplierType =
1747 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1748 rewriter.getI32Type());
1749 genericInputs.push_back(rewriter.create<ConstantOp>(
1750 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1751
1752 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1753 /*symbolCount=*/0, multiplierExprs,
1754 rewriter.getContext()));
1755
1756 multiplierArg = indexingMaps.size() - 1;
1757 }
1758
1759 // If we are rescaling per-channel then we need to store the shift
1760 // values in a buffer.
1761 Value shiftConstant;
1762 int64_t shiftArg = 0;
1763 if (shiftValues.size() == 1) {
1764 shiftConstant = rewriter.create<ConstantOp>(
1765 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1766 } else {
1767 SmallVector<AffineExpr, 2> shiftExprs = {
1768 rewriter.getAffineDimExpr(rank - 1)};
1769 auto shiftType =
1770 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1771 rewriter.getIntegerType(8));
1772 genericInputs.push_back(rewriter.create<ConstantOp>(
1773 loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1774 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1775 /*symbolCount=*/0, shiftExprs,
1776 rewriter.getContext()));
1777 shiftArg = indexingMaps.size() - 1;
1778 }
1779
1780 // Indexing maps for output values.
1781 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1782
1783 // Construct the indexing maps needed for linalg.generic ops.
1784 Value initTensor = rewriter.create<linalg::InitTensorOp>(
1785 loc, ArrayRef<Value>({}), outputTy.getShape(),
1786 outputTy.getElementType());
1787
1788 auto linalgOp = rewriter.create<linalg::GenericOp>(
1789 loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1790 getNParallelLoopsAttrs(rank),
1791 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1792 ValueRange blockArgs) {
1793 Value value = blockArgs[0];
1794 Type valueTy = value.getType();
1795
1796 // For now we do all of our math in 64-bit. This is not optimal but
1797 // should be correct for now, consider computing correct bit depth
1798 // later.
1799 int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1800
1801 auto inputZp = createConstFromIntAttribute<int32_t>(
1802 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1803 nestedBuilder);
1804 auto outputZp = createConstFromIntAttribute<int32_t>(
1805 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1806
1807 Value multiplier = multiplierConstant ? multiplierConstant
1808 : blockArgs[multiplierArg];
1809 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1810
1811 if (valueTy.getIntOrFloatBitWidth() < 32) {
1812 if (valueTy.isUnsignedInteger()) {
1813 value = nestedBuilder
1814 .create<UnrealizedConversionCastOp>(
1815 nestedLoc,
1816 nestedBuilder.getIntegerType(
1817 valueTy.getIntOrFloatBitWidth()),
1818 value)
1819 .getResult(0);
1820 value = nestedBuilder.create<ZeroExtendIOp>(
1821 nestedLoc, nestedBuilder.getI32Type(), value);
1822 } else {
1823 value = nestedBuilder.create<SignExtendIOp>(
1824 nestedLoc, nestedBuilder.getI32Type(), value);
1825 }
1826 }
1827
1828 value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
1829
1830 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1831 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1832 nestedBuilder.getBoolAttr(doubleRound));
1833
1834 // Move to the new zero-point.
1835 value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
1836
1837 // Saturate to the output size.
1838 IntegerType outIntType =
1839 blockArgs.back().getType().cast<IntegerType>();
1840 unsigned outBitWidth = outIntType.getWidth();
1841
1842 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1843 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1844
1845 // Unsigned integers have a difference output value.
1846 if (outIntType.isUnsignedInteger()) {
1847 intMin = 0;
1848 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1849 }
1850
1851 auto intMinVal = nestedBuilder.create<ConstantOp>(
1852 loc,
1853 nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin));
1854 auto intMaxVal = nestedBuilder.create<ConstantOp>(
1855 loc,
1856 nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax));
1857
1858 value =
1859 clampHelper<mlir::CmpIOp>(nestedLoc, value, intMinVal, intMaxVal,
1860 CmpIPredicate::slt, nestedBuilder);
1861
1862 if (outIntType.getWidth() < 32) {
1863 value = nestedBuilder.create<TruncateIOp>(
1864 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1865 value);
1866
1867 if (outIntType.isUnsignedInteger()) {
1868 value = nestedBuilder
1869 .create<UnrealizedConversionCastOp>(nestedLoc,
1870 outIntType, value)
1871 .getResult(0);
1872 }
1873 }
1874
1875 nestedBuilder.create<linalg::YieldOp>(loc, value);
1876 });
1877
1878 rewriter.replaceOp(op, linalgOp->getResults());
1879 return success();
1880 }
1881 };
1882
1883 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1884 public:
1885 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1886
matchAndRewrite(tosa::ResizeOp op,PatternRewriter & rewriter) const1887 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1888 PatternRewriter &rewriter) const final {
1889 Location loc = op.getLoc();
1890 auto input = op.input();
1891 auto inputTy = input.getType().cast<ShapedType>();
1892 auto resultTy = op.getType().cast<ShapedType>();
1893 auto resultElementTy = resultTy.getElementType();
1894
1895 auto imageH = inputTy.getShape()[1];
1896 auto imageW = inputTy.getShape()[2];
1897
1898 if (!resultTy.hasStaticShape())
1899 return failure();
1900 if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
1901 return failure();
1902
1903 auto initTensor =
1904 rewriter
1905 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
1906 resultTy.getShape(), resultElementTy)
1907 .result();
1908
1909 SmallVector<AffineMap, 2> affineMaps = {
1910 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1911
1912 auto genericOp = rewriter.create<linalg::GenericOp>(
1913 loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1914 getNParallelLoopsAttrs(resultTy.getRank()));
1915 rewriter.replaceOp(op, genericOp.getResult(0));
1916
1917 {
1918 OpBuilder::InsertionGuard regionGuard(rewriter);
1919 rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1920 TypeRange({resultElementTy}));
1921 Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1922 Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1923 Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1924 Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1925
1926 auto hwMin =
1927 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1928 auto hMax = rewriter.create<ConstantOp>(
1929 loc, rewriter.getI32IntegerAttr(imageH - 1));
1930 auto wMax = rewriter.create<ConstantOp>(
1931 loc, rewriter.getI32IntegerAttr(imageW - 1));
1932
1933 Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), y);
1934 Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), x);
1935
1936 int32_t shift = op.shift();
1937 bool floatingPointMode = shift == 0;
1938
1939 Value yStride, xStride, yOffset, xOffset;
1940 if (floatingPointMode) {
1941 yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]);
1942 xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]);
1943 yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]);
1944 xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]);
1945 } else {
1946 SmallVector<int32_t> stride, offset;
1947 getValuesFromIntArrayAttribute(op.stride(), stride);
1948 getValuesFromIntArrayAttribute(op.offset(), offset);
1949
1950 yStride = rewriter.create<ConstantOp>(
1951 loc, rewriter.getI32IntegerAttr(stride[0]));
1952 xStride = rewriter.create<ConstantOp>(
1953 loc, rewriter.getI32IntegerAttr(stride[1]));
1954 yOffset = rewriter.create<ConstantOp>(
1955 loc, rewriter.getI32IntegerAttr(offset[0]));
1956 xOffset = rewriter.create<ConstantOp>(
1957 loc, rewriter.getI32IntegerAttr(offset[1]));
1958 }
1959
1960 // Compute the the integer index and partial offset.
1961 // x = x * stride + offset;
1962 // ix = floor(x)
1963 // dx = x - ix
1964 Value ix, iy, dx, dy;
1965 if (floatingPointMode) {
1966 Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY);
1967 Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX);
1968
1969 y = rewriter.create<MulFOp>(loc, y, yStride);
1970 x = rewriter.create<MulFOp>(loc, x, xStride);
1971
1972 y = rewriter.create<AddFOp>(loc, y, yOffset);
1973 x = rewriter.create<AddFOp>(loc, x, xOffset);
1974
1975 iy = rewriter.create<FloorFOp>(loc, y);
1976 ix = rewriter.create<FloorFOp>(loc, x);
1977
1978 dy = rewriter.create<SubFOp>(loc, y, iy);
1979 dx = rewriter.create<SubFOp>(loc, x, ix);
1980
1981 iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy);
1982 ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix);
1983 } else {
1984 Value shiftVal =
1985 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift));
1986
1987 Value y = rewriter.create<MulIOp>(loc, inY, yStride);
1988 Value x = rewriter.create<MulIOp>(loc, inX, xStride);
1989
1990 y = rewriter.create<AddIOp>(loc, y, yOffset);
1991 x = rewriter.create<AddIOp>(loc, x, xOffset);
1992
1993 iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal);
1994 ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal);
1995
1996 Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal);
1997 Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal);
1998
1999 dy = rewriter.create<SubIOp>(loc, y, yTrunc);
2000 dx = rewriter.create<SubIOp>(loc, x, xTrunc);
2001 }
2002
2003 if (op.mode() == "NEAREST_NEIGHBOR") {
2004 Value yPred, xPred;
2005 // Round the index position towards the closest pixel location.
2006 if (floatingPointMode) {
2007 auto halfVal =
2008 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
2009 yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy,
2010 halfVal);
2011 xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx,
2012 halfVal);
2013 } else {
2014 auto halfVal = rewriter.create<ConstantOp>(
2015 loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
2016 yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy,
2017 halfVal);
2018 xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx,
2019 halfVal);
2020 }
2021
2022 auto zeroVal =
2023 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
2024 auto oneVal =
2025 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2026
2027 auto yOffset =
2028 rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
2029 auto xOffset =
2030 rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
2031
2032 iy = rewriter.create<AddIOp>(loc, iy, yOffset);
2033 ix = rewriter.create<AddIOp>(loc, ix, xOffset);
2034
2035 // Clamp the to be within the bounds of the input image.
2036
2037 iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt,
2038 rewriter);
2039 ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt,
2040 rewriter);
2041
2042 // Read the value from the input array.
2043 iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy);
2044 ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix);
2045
2046 Value result = rewriter.create<tensor::ExtractOp>(
2047 loc, input, ValueRange{batch, iy, ix, channel});
2048
2049 rewriter.create<linalg::YieldOp>(loc, result);
2050
2051 return success();
2052 }
2053
2054 if (op.mode() == "BILINEAR") {
2055 Value y0 = iy;
2056 Value x0 = ix;
2057
2058 auto oneVal =
2059 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2060 Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal);
2061 Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal);
2062
2063 y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt,
2064 rewriter);
2065 y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt,
2066 rewriter);
2067
2068 x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt,
2069 rewriter);
2070 x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt,
2071 rewriter);
2072
2073 y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0);
2074 y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1);
2075 x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0);
2076 x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1);
2077
2078 Value y0x0 = rewriter.create<tensor::ExtractOp>(
2079 loc, input, ValueRange{batch, y0, x0, channel});
2080 Value y0x1 = rewriter.create<tensor::ExtractOp>(
2081 loc, input, ValueRange{batch, y0, x1, channel});
2082 Value y1x0 = rewriter.create<tensor::ExtractOp>(
2083 loc, input, ValueRange{batch, y1, x0, channel});
2084 Value y1x1 = rewriter.create<tensor::ExtractOp>(
2085 loc, input, ValueRange{batch, y1, x1, channel});
2086
2087 if (floatingPointMode) {
2088 auto oneVal =
2089 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f));
2090 Value rightPart = dx;
2091 Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx);
2092
2093 y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart);
2094 y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart);
2095 Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1);
2096
2097 y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart);
2098 y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart);
2099 Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1);
2100
2101 Value bottomPart = dy;
2102 Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy);
2103 topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart);
2104 bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart);
2105 Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc);
2106
2107 rewriter.create<linalg::YieldOp>(loc, result);
2108 return success();
2109 } else {
2110 y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0);
2111 y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1);
2112 y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0);
2113 y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1);
2114
2115 if (resultElementTy.getIntOrFloatBitWidth() > 32) {
2116 dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx);
2117 dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy);
2118 }
2119
2120 auto unitVal = rewriter.create<ConstantOp>(
2121 loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
2122 Value rightPart = dx;
2123 Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx);
2124
2125 y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart);
2126 y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart);
2127 Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1);
2128
2129 y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart);
2130 y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart);
2131 Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1);
2132
2133 Value bottomPart = dy;
2134 Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy);
2135 topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart);
2136 bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart);
2137 Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc);
2138
2139 rewriter.create<linalg::YieldOp>(loc, result);
2140 return success();
2141 }
2142 }
2143
2144 return failure();
2145 }
2146
2147 return success();
2148 }
2149 };
2150
2151 // At the codegen level any identity operations should be removed. Any cases
2152 // where identity is load-bearing (e.g. cross device computation) should be
2153 // handled before lowering to codegen.
2154 template <typename SrcOp>
2155 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2156 public:
2157 using OpRewritePattern<SrcOp>::OpRewritePattern;
2158
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const2159 LogicalResult matchAndRewrite(SrcOp op,
2160 PatternRewriter &rewriter) const final {
2161 rewriter.replaceOp(op, op.getOperation()->getOperands());
2162 return success();
2163 }
2164 };
2165
2166 template <typename SrcOp>
2167 class ReduceConverter : public OpRewritePattern<SrcOp> {
2168 public:
2169 using OpRewritePattern<SrcOp>::OpRewritePattern;
2170
matchAndRewrite(SrcOp reduceOp,PatternRewriter & rewriter) const2171 LogicalResult matchAndRewrite(SrcOp reduceOp,
2172 PatternRewriter &rewriter) const final {
2173 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
2174 }
2175 };
2176
2177 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
2178 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
2179
2180 LogicalResult
matchAndRewrite__anonfa7bfd680511::ConcatConverter2181 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
2182 ConversionPatternRewriter &rewriter) const override {
2183 auto resultType = op.getType().dyn_cast<RankedTensorType>();
2184 if (!resultType || !resultType.hasStaticShape()) {
2185 return rewriter.notifyMatchFailure(op,
2186 "expected static shaped tensor type");
2187 }
2188
2189 Location loc = op.getLoc();
2190 int axis = op.axis();
2191 Value axisValue =
2192 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
2193 int rank = resultType.getRank();
2194 SmallVector<Value, 3> offsets, sizes, strides;
2195 sizes.reserve(rank);
2196 strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
2197 offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
2198
2199 for (int i = 0; i < rank; ++i) {
2200 sizes.push_back(
2201 rewriter.create<tensor::DimOp>(loc, adaptor.getOperands()[0], i));
2202 }
2203
2204 Value resultDimSize = sizes[axis];
2205 for (auto arg : adaptor.getOperands().drop_front()) {
2206 auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
2207 resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
2208 }
2209 sizes[axis] = resultDimSize;
2210
2211 Value init = rewriter.create<linalg::InitTensorOp>(
2212 loc, resultType.getShape(), resultType.getElementType());
2213
2214 Value zeroVal = rewriter.create<ConstantOp>(
2215 loc, rewriter.getZeroAttr(resultType.getElementType()));
2216 Value result =
2217 rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
2218
2219 for (auto arg : adaptor.getOperands()) {
2220 sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
2221 result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
2222 sizes, strides);
2223 offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
2224 }
2225 rewriter.replaceOp(op, result);
2226 return success();
2227 }
2228 };
2229
2230 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2231 public:
2232 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2233
matchAndRewrite(tosa::ReverseOp op,PatternRewriter & rewriter) const2234 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2235 PatternRewriter &rewriter) const final {
2236 auto loc = op.getLoc();
2237 Value input = op.input();
2238 auto inputTy = input.getType().template cast<ShapedType>();
2239 auto resultTy = op.getType().template cast<ShapedType>();
2240 auto axis = op.axis();
2241
2242 SmallVector<Value> dynDims;
2243 for (int i = 0; i < inputTy.getRank(); i++) {
2244 if (inputTy.isDynamicDim(i)) {
2245 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2246 }
2247 }
2248
2249 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2250
2251 // First fill the output buffer with the init value.
2252 auto initTensor = rewriter
2253 .create<linalg::InitTensorOp>(
2254 loc, ArrayRef<Value>({dynDims}),
2255 inputTy.getShape(), inputTy.getElementType())
2256 .result();
2257 SmallVector<AffineMap, 2> affineMaps = {
2258 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2259
2260 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2261 op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
2262 getNParallelLoopsAttrs(resultTy.getRank()),
2263 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2264 llvm::SmallVector<Value> indices;
2265 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2266 auto index =
2267 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2268 if (i == axis) {
2269 auto one = rewriter.create<ConstantIndexOp>(nestedLoc, 1);
2270 auto sizeMinusOne =
2271 rewriter.create<SubIOp>(nestedLoc, axisDimSize, one);
2272 index = rewriter.create<SubIOp>(nestedLoc, sizeMinusOne, index);
2273 }
2274
2275 indices.push_back(index);
2276 }
2277
2278 auto extract = nestedBuilder.create<tensor::ExtractOp>(
2279 nestedLoc, input, indices);
2280 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2281 extract.getResult());
2282 });
2283 return success();
2284 }
2285 };
2286
2287 // This converter translate a tile operation to a reshape, broadcast, reshape.
2288 // The first reshape minimally expands each tiled dimension to include a
2289 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2290 // multiple.
2291 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2292 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2293
2294 LogicalResult
matchAndRewrite__anonfa7bfd680511::TileConverter2295 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2296 ConversionPatternRewriter &rewriter) const override {
2297 auto loc = op.getLoc();
2298 auto input = op.input1();
2299 auto inputTy = input.getType().cast<ShapedType>();
2300 auto inputShape = inputTy.getShape();
2301 auto resultTy = op.getType().cast<ShapedType>();
2302 auto elementTy = inputTy.getElementType();
2303 int64_t rank = inputTy.getRank();
2304
2305 if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
2306 return failure();
2307
2308 SmallVector<int64_t> multiples;
2309 getValuesFromIntArrayAttribute(op.multiples(), multiples);
2310
2311 // Broadcast the newly added dimensions to their appropriate multiple.
2312 SmallVector<int64_t, 2> genericShape;
2313 for (int i = 0; i < rank; i++) {
2314 genericShape.push_back(multiples[i]);
2315 genericShape.push_back(inputShape[i]);
2316 }
2317
2318 auto initTensor = rewriter.create<linalg::InitTensorOp>(
2319 op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
2320
2321 // We needs to map the input shape to the non-broadcasted dimensions.
2322 SmallVector<AffineExpr, 4> dimExprs;
2323 dimExprs.reserve(rank);
2324 for (unsigned i = 0; i < rank; ++i)
2325 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2326
2327 auto readAffineMap =
2328 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2329 rewriter.getContext());
2330
2331 SmallVector<AffineMap, 2> affineMaps = {
2332 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2333
2334 auto genericOp = rewriter.create<linalg::GenericOp>(
2335 loc, RankedTensorType::get(genericShape, elementTy), input,
2336 ValueRange{initTensor}, affineMaps,
2337 getNParallelLoopsAttrs(genericShape.size()),
2338 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2339 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2340 });
2341
2342 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2343 op, resultTy, genericOp.getResult(0),
2344 rewriter.getI64ArrayAttr(resultTy.getShape()));
2345 return success();
2346 }
2347 };
2348
2349 class PadConverter : public OpRewritePattern<tosa::PadOp> {
2350 public:
2351 using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
2352
matchAndRewrite(tosa::PadOp padOp,PatternRewriter & rewriter) const2353 LogicalResult matchAndRewrite(tosa::PadOp padOp,
2354 PatternRewriter &rewriter) const final {
2355 auto loc = padOp.getLoc();
2356 auto input = padOp.input1();
2357 auto padding = padOp.padding();
2358
2359 ShapedType inputTy = input.getType().cast<ShapedType>();
2360 ShapedType paddingTy = padding.getType().cast<ShapedType>();
2361 Type elementTy = inputTy.getElementType();
2362 int64_t rank = inputTy.getRank();
2363
2364 if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
2365 return rewriter.notifyMatchFailure(
2366 padOp,
2367 "Pad converter requires static shaped input / padding values.");
2368 }
2369
2370 Attribute constantAttr;
2371 if (elementTy.isa<FloatType>())
2372 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
2373 else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
2374 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
2375 else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
2376 auto value = padOp.quantization_info().getValue().input_zp().getValue();
2377 constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
2378 }
2379
2380 if (!constantAttr) {
2381 return rewriter.notifyMatchFailure(
2382 padOp,
2383 "tosa.pad to linalg lowering encountered an unknown element type");
2384 }
2385
2386 Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
2387 Value highIndex =
2388 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
2389
2390 SmallVector<OpFoldResult, 3> lowValues;
2391 SmallVector<OpFoldResult, 3> highValues;
2392
2393 lowValues.reserve(rank);
2394 highValues.reserve(rank);
2395
2396 for (int i = 0; i < rank; i++) {
2397 Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
2398 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
2399 loc, padding, ValueRange({inputIndex, lowIndex}));
2400 Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
2401 loc, padding, ValueRange({inputIndex, highIndex}));
2402
2403 lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2404 lowVal);
2405 highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2406 highVal);
2407
2408 lowValues.push_back(lowVal);
2409 highValues.push_back(highVal);
2410 }
2411
2412 Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
2413
2414 auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
2415 padOp.getType(), input, constant, lowValues, highValues,
2416 /*nofold=*/false, loc, rewriter);
2417
2418 rewriter.replaceOp(padOp, newPadOp.getResult());
2419 return success();
2420 }
2421 };
2422
2423 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2424 // op, producing two output buffers.
2425 //
2426 // The first output buffer contains the index of the found maximum value. It is
2427 // initialized to 0 and is resulting integer type.
2428 //
2429 // The second output buffer contains the maximum value found. It is initialized
2430 // to the minimum representable value of the input element type. After being
2431 // populated by indexed_generic, this buffer is disgarded as only the index is
2432 // requested.
2433 //
2434 // The indexed_generic op updates both the maximum value and index if the
2435 // current value exceeds the running max.
2436 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2437 public:
2438 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2439
matchAndRewrite(tosa::ArgMaxOp argmaxOp,PatternRewriter & rewriter) const2440 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2441 PatternRewriter &rewriter) const final {
2442 auto loc = argmaxOp.getLoc();
2443 Value input = argmaxOp.input();
2444 auto inputTy = input.getType().cast<ShapedType>();
2445 auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
2446 auto inElementTy = inputTy.getElementType();
2447 auto outElementTy = resultTy.getElementType();
2448 int axis = argmaxOp.axis();
2449 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2450
2451 if (!inputTy.hasStaticShape())
2452 return rewriter.notifyMatchFailure(
2453 argmaxOp,
2454 "tosa.arg_max to linalg.* requires statically shaped input");
2455
2456 if (!outElementTy.isa<IntegerType>())
2457 return rewriter.notifyMatchFailure(
2458 argmaxOp,
2459 "tosa.arg_max to linalg.* requires integer-like result type");
2460
2461 // First fill the output buffer for the index.
2462 auto initTensorIdx =
2463 rewriter
2464 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2465 resultTy.getShape(), outElementTy)
2466 .result();
2467 auto fillValueIdx = rewriter.create<ConstantOp>(
2468 loc, rewriter.getIntegerAttr(outElementTy, 0));
2469 auto filledTensorIdx =
2470 rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
2471 .result();
2472
2473 // Second fill the output buffer for the running max.
2474 auto initTensorMax =
2475 rewriter
2476 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2477 resultTy.getShape(), inElementTy)
2478 .result();
2479 auto fillValueMaxAttr =
2480 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2481
2482 if (!fillValueMaxAttr)
2483 return rewriter.notifyMatchFailure(
2484 argmaxOp, "unsupported tosa.argmax element type");
2485
2486 auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
2487 auto filledTensorMax =
2488 rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
2489 .result();
2490
2491 // We need to reduce along the arg-max axis, with parallel operations along
2492 // the rest.
2493 SmallVector<StringRef, 4> iteratorTypes;
2494 iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
2495 iteratorTypes[axis] = getReductionIteratorTypeName();
2496
2497 SmallVector<AffineExpr, 2> srcExprs;
2498 SmallVector<AffineExpr, 2> dstExprs;
2499 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2500 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2501 if (axis != i)
2502 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2503 }
2504
2505 bool didEncounterError = false;
2506 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2507 auto linalgOp = rewriter.create<linalg::GenericOp>(
2508 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2509 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2510 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2511 ValueRange blockArgs) {
2512 auto newValue = blockArgs[0];
2513 auto oldIndex = blockArgs[1];
2514 auto oldValue = blockArgs[2];
2515
2516 Value newIndex = rewriter.create<IndexCastOp>(
2517 nestedLoc, oldIndex.getType(),
2518 rewriter.create<linalg::IndexOp>(loc, axis));
2519
2520 Value predicate;
2521 if (inElementTy.isa<FloatType>()) {
2522 predicate = rewriter.create<mlir::CmpFOp>(
2523 nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
2524 } else if (inElementTy.isa<IntegerType>()) {
2525 predicate = rewriter.create<mlir::CmpIOp>(
2526 nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
2527 } else {
2528 didEncounterError = true;
2529 return;
2530 }
2531
2532 auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
2533 newValue, oldValue);
2534 auto resultIndex = rewriter.create<mlir::SelectOp>(
2535 nestedLoc, predicate, newIndex, oldIndex);
2536 nestedBuilder.create<linalg::YieldOp>(
2537 nestedLoc, ValueRange({resultIndex, resultMax}));
2538 });
2539
2540 if (didEncounterError)
2541 return rewriter.notifyMatchFailure(
2542 argmaxOp, "unsupported tosa.argmax element type");
2543
2544 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2545 return success();
2546 }
2547 };
2548
2549 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2550 public:
2551 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2552 LogicalResult
matchAndRewrite(tosa::GatherOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2553 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2554 ConversionPatternRewriter &rewriter) const final {
2555 auto input = adaptor.getOperands()[0];
2556 auto indices = adaptor.getOperands()[1];
2557
2558 auto inputTy = input.getType().cast<ShapedType>();
2559 auto indicesTy = indices.getType().cast<ShapedType>();
2560 auto resultTy = op.getType().cast<ShapedType>();
2561
2562 if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
2563 return rewriter.notifyMatchFailure(
2564 op, "require input type to have static shape");
2565
2566 auto resultElementTy = resultTy.getElementType();
2567
2568 auto loc = op.getLoc();
2569
2570 auto initTensor =
2571 rewriter
2572 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2573 resultTy.getShape(), resultElementTy)
2574 .result();
2575
2576 SmallVector<AffineMap, 2> affineMaps = {
2577 AffineMap::get(
2578 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2579 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2580 rewriter.getContext()),
2581 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2582
2583 auto genericOp = rewriter.create<linalg::GenericOp>(
2584 loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2585 ValueRange{initTensor}, affineMaps,
2586 getNParallelLoopsAttrs(resultTy.getRank()),
2587 [&](OpBuilder &b, Location loc, ValueRange args) {
2588 auto indexValue = args[0];
2589 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2590 Value index1 = rewriter.create<IndexCastOp>(
2591 loc, rewriter.getIndexType(), indexValue);
2592 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2593 Value extract = rewriter.create<tensor::ExtractOp>(
2594 loc, input, ValueRange{index0, index1, index2});
2595 rewriter.create<linalg::YieldOp>(loc, extract);
2596 });
2597 rewriter.replaceOp(op, genericOp.getResult(0));
2598 return success();
2599 }
2600 };
2601
2602 // Lowerings the TableOp to a series of gathers and numerica operations. This
2603 // includes interpolation between the high/low values. For the I8 varient, this
2604 // simplifies to a single gather operation.
2605 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2606 public:
2607 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2608
matchAndRewrite(tosa::TableOp op,PatternRewriter & rewriter) const2609 LogicalResult matchAndRewrite(tosa::TableOp op,
2610 PatternRewriter &rewriter) const final {
2611 auto loc = op.getLoc();
2612 Value input = op.input();
2613 Value table = op.table();
2614 auto inputTy = input.getType().cast<ShapedType>();
2615 auto tableTy = table.getType().cast<ShapedType>();
2616 auto resultTy = op.getType().cast<ShapedType>();
2617
2618 if (!inputTy.hasStaticShape())
2619 return rewriter.notifyMatchFailure(
2620 op, "require input type to have static shape");
2621
2622 auto inputElementTy = inputTy.getElementType();
2623 auto tableElementTy = tableTy.getElementType();
2624 auto resultElementTy = resultTy.getElementType();
2625
2626 auto initTensor =
2627 rewriter
2628 .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2629 resultTy.getShape(), resultElementTy)
2630 .result();
2631
2632 SmallVector<AffineMap, 2> affineMaps = {
2633 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2634 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2635
2636 auto genericOp = rewriter.create<linalg::GenericOp>(
2637 loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2638 getNParallelLoopsAttrs(resultTy.getRank()));
2639 rewriter.replaceOp(op, genericOp.getResult(0));
2640
2641 {
2642 OpBuilder::InsertionGuard regionGuard(rewriter);
2643 Block *block =
2644 rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
2645 TypeRange({inputElementTy, resultElementTy}));
2646
2647 auto inputValue = block->getArgument(0);
2648 rewriter.setInsertionPointToStart(block);
2649 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2650 resultElementTy.isInteger(8)) {
2651 Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
2652 inputValue);
2653 Value offset = rewriter.create<ConstantIndexOp>(loc, 128);
2654 index = rewriter.create<AddIOp>(loc, rewriter.getIndexType(), index,
2655 offset);
2656 Value extract =
2657 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2658 rewriter.create<linalg::YieldOp>(loc, extract);
2659 return success();
2660 }
2661
2662 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2663 resultElementTy.isInteger(32)) {
2664 Value extend = rewriter.create<SignExtendIOp>(
2665 loc, rewriter.getI32Type(), inputValue);
2666
2667 auto offset =
2668 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
2669 auto seven =
2670 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
2671 auto one =
2672 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2673 auto b1111111 =
2674 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
2675
2676 // Compute the index and fractional part from the input value:
2677 // value = value + 32768
2678 // index = value >> 7;
2679 // fraction = 0x01111111 & value
2680 auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
2681 Value index =
2682 rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
2683 Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
2684
2685 // Extract the base and next values from the table.
2686 // base = (int32_t) table[index];
2687 // next = (int32_t) table[index + 1];
2688 Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
2689
2690 index =
2691 rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
2692 indexPlusOne = rewriter.create<IndexCastOp>(
2693 loc, rewriter.getIndexType(), indexPlusOne);
2694
2695 Value base =
2696 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2697 Value next = rewriter.create<tensor::ExtractOp>(
2698 loc, table, ValueRange{indexPlusOne});
2699
2700 base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
2701 next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
2702
2703 // Use the fractional part to interpolate between the input values:
2704 // result = (base << 7) + (next - base) * fraction
2705 Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
2706 Value diff = rewriter.create<SubIOp>(loc, next, base);
2707 Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction);
2708 Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled);
2709
2710 rewriter.create<linalg::YieldOp>(loc, result);
2711
2712 return success();
2713 }
2714 }
2715
2716 return rewriter.notifyMatchFailure(
2717 op, "unable to create body for tosa.table op");
2718 }
2719 };
2720
2721 class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
2722 public:
2723 using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
2724
matchAndRewrite(tosa::MaxPool2dOp op,PatternRewriter & rewriter) const2725 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
2726 PatternRewriter &rewriter) const final {
2727 Location loc = op.getLoc();
2728 Value input = op.input();
2729 ShapedType inputTy = input.getType().cast<ShapedType>();
2730
2731 ShapedType resultTy = op.getType().template cast<ShapedType>();
2732 Type resultETy = inputTy.getElementType();
2733
2734 if (!inputTy.hasStaticShape())
2735 return failure();
2736
2737 // Determine what the initial value needs to be for the max pool op.
2738 Attribute initialAttr;
2739 if (resultETy.isF32())
2740 initialAttr = rewriter.getFloatAttr(
2741 resultETy,
2742 APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
2743 true));
2744
2745 if (resultETy.isa<IntegerType>())
2746 initialAttr = rewriter.getIntegerAttr(
2747 resultETy,
2748 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
2749
2750 if (!initialAttr)
2751 return rewriter.notifyMatchFailure(
2752 op, "Unsupported initial value for tosa.maxpool_2d op");
2753
2754 // Apply padding as necessary.
2755 llvm::SmallVector<int64_t> pad;
2756 pad.resize(2, 0);
2757 getValuesFromIntArrayAttribute(op.pad(), pad);
2758 pad.resize(pad.size() + 2, 0);
2759 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2760
2761 Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2762
2763 SmallVector<int64_t> kernel, stride;
2764 getValuesFromIntArrayAttribute(op.kernel(), kernel);
2765 getValuesFromIntArrayAttribute(op.stride(), stride);
2766
2767 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2768 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2769
2770 // Create the linalg op that performs pooling.
2771 Value initTensor = rewriter.create<linalg::InitTensorOp>(
2772 loc, resultTy.getShape(), resultTy.getElementType());
2773
2774 Value filledInitTensor =
2775 rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
2776
2777 Value fakeWindowDims =
2778 rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
2779
2780 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
2781 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2782 filledInitTensor, strideAttr, dilationAttr);
2783 return success();
2784 }
2785 };
2786
2787 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
2788 public:
2789 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
2790
matchAndRewrite(tosa::AvgPool2dOp op,PatternRewriter & rewriter) const2791 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
2792 PatternRewriter &rewriter) const final {
2793 Location loc = op.getLoc();
2794 Value input = op.input();
2795 ShapedType inputTy = input.getType().cast<ShapedType>();
2796 Type inElementTy = inputTy.getElementType();
2797
2798 ShapedType resultTy = op.getType().template cast<ShapedType>();
2799 Type resultETy = inputTy.getElementType();
2800
2801 Type accETy =
2802 inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
2803 ShapedType accTy = resultTy.clone(accETy);
2804
2805 if (!inputTy.hasStaticShape())
2806 return failure();
2807
2808 // Apply padding as necessary.
2809 llvm::SmallVector<int64_t> pad;
2810 pad.resize(2, 0);
2811 getValuesFromIntArrayAttribute(op.pad(), pad);
2812 pad.resize(pad.size() + 2, 0);
2813 Attribute initialAttr = rewriter.getZeroAttr(accETy);
2814 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2815
2816 Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2817
2818 SmallVector<int64_t> kernel, stride;
2819 getValuesFromIntArrayAttribute(op.kernel(), kernel);
2820 getValuesFromIntArrayAttribute(op.stride(), stride);
2821
2822 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2823 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2824
2825 // Create the linalg op that performs pooling.
2826 Value poolInitTensor =
2827 rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
2828
2829 Value filledInitTensor =
2830 rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
2831 .result();
2832
2833 Value fakeWindowDims =
2834 rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
2835
2836 // Sum across the pooled region.
2837 Value poolingOp = rewriter
2838 .create<linalg::PoolingNhwcSumOp>(
2839 loc, ArrayRef<Type>{accTy},
2840 ValueRange{paddedInput, fakeWindowDims},
2841 filledInitTensor, strideAttr, dilationAttr)
2842 .getResult(0);
2843
2844 // Normalize the summed value by the number of elements grouped in each
2845 // pool.
2846 auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
2847 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
2848
2849 Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
2850 loc, resultTy.getShape(), resultETy);
2851
2852 auto genericOp = rewriter.create<linalg::GenericOp>(
2853 loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
2854 ValueRange{genericInitTensor},
2855 ArrayRef<AffineMap>({affineMap, affineMap}),
2856 getNParallelLoopsAttrs(resultTy.getRank()),
2857 [&](OpBuilder &b, Location loc, ValueRange args) {
2858 auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
2859 auto one = rewriter.create<ConstantIndexOp>(loc, 1);
2860 auto iH = rewriter.create<ConstantIndexOp>(
2861 loc, poolingOpTy.getDimSize(1) - 1);
2862 auto iW = rewriter.create<ConstantIndexOp>(
2863 loc, poolingOpTy.getDimSize(2) - 1);
2864
2865 // Compute the indices from either end.
2866 auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
2867 auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
2868 auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
2869 auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
2870
2871 // Determines what the portion of valid input is covered by the
2872 // kernel.
2873 auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2874 if (pad == 0)
2875 return v;
2876
2877 auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
2878 Value dx = rewriter.create<SubIOp>(loc, x, padVal);
2879
2880 Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
2881 dx, zero);
2882 Value offset = rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
2883 return rewriter.create<mlir::AddIOp>(loc, v, offset)->getResult(0);
2884 };
2885
2886 // Compute the vertical component of coverage.
2887 auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
2888 auto kH1 = padFn(kH0, y0, pad[2]);
2889 auto kH2 = padFn(kH1, y1, pad[3]);
2890 auto kHCmp =
2891 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
2892 auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
2893
2894 // compute the horizontal component of coverage.
2895 auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
2896 auto kW1 = padFn(kW0, x0, pad[4]);
2897 auto kW2 = padFn(kW1, x1, pad[5]);
2898 auto kWCmp =
2899 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
2900 auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
2901
2902 // Compute the total number of elements and normalize.
2903 Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
2904 auto countI = rewriter.create<mlir::IndexCastOp>(
2905 loc, rewriter.getI32Type(), count);
2906
2907 // Divide by the number of summed values. For floats this is just
2908 // a div however for quantized values input normalization had
2909 // to be applied.
2910 Value poolVal = args[0];
2911 if (accETy.isa<FloatType>()) {
2912 auto countF =
2913 rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
2914 poolVal =
2915 rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
2916 } else {
2917
2918 // If we have quantization information we need to apply an offset
2919 // for the input zp value.
2920 if (op.quantization_info()) {
2921 auto quantizationInfo = op.quantization_info().getValue();
2922 auto inputZp = rewriter.create<mlir::ConstantOp>(
2923 loc, quantizationInfo.input_zp());
2924 Value offset =
2925 rewriter.create<mlir::MulIOp>(loc, accETy, countI, inputZp);
2926 poolVal = rewriter.create<SubIOp>(loc, accETy, poolVal, offset);
2927 }
2928
2929 // Compute the multiplier and shift values for the quantization
2930 // normalization. Preferably we would want to compute more bits
2931 // however 32-bits should be enough for compute. Honestly we
2932 // should probably straight divide.
2933 int64_t numerator = ((1 << 30) + 1);
2934 int64_t shift = 30;
2935
2936 Value numeratorVal = rewriter.create<ConstantOp>(
2937 loc, rewriter.getI32IntegerAttr(numerator));
2938 Value multiplierVal =
2939 rewriter
2940 .create<UnsignedDivIOp>(loc, rewriter.getI32Type(),
2941 numeratorVal, countI)
2942 .getResult();
2943 Value shiftVal = rewriter.create<ConstantOp>(
2944 loc, rewriter.getI8IntegerAttr(shift));
2945
2946 auto scaled =
2947 rewriter
2948 .create<tosa::ApplyScaleOp>(
2949 loc, rewriter.getI32Type(), poolVal, multiplierVal,
2950 shiftVal, rewriter.getBoolAttr(false))
2951 .getResult();
2952
2953 // If we have quantization information we need to apply output
2954 // zeropoint.
2955 if (op.quantization_info()) {
2956 auto quantizationInfo = op.quantization_info().getValue();
2957 auto outputZp = rewriter.create<mlir::ConstantOp>(
2958 loc, quantizationInfo.output_zp());
2959 scaled =
2960 rewriter.create<AddIOp>(loc, scaled, outputZp).getResult();
2961 }
2962
2963 // Apply Clip.
2964 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
2965
2966 auto min = rewriter.create<ConstantOp>(
2967 loc, rewriter.getIntegerAttr(
2968 accETy,
2969 APInt::getSignedMinValue(outBitwidth).getSExtValue()));
2970 auto max = rewriter.create<ConstantOp>(
2971 loc, rewriter.getIntegerAttr(
2972 accETy,
2973 APInt::getSignedMaxValue(outBitwidth).getSExtValue()));
2974 auto clamp = clampHelper<mlir::CmpIOp>(
2975 loc, scaled, min, max, CmpIPredicate::slt, rewriter);
2976
2977 // Convert type.
2978 poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
2979 }
2980
2981 // Cast to output type.
2982
2983 rewriter.create<linalg::YieldOp>(loc, poolVal);
2984 });
2985
2986 rewriter.replaceOp(op, genericOp.getResult(0));
2987 return success();
2988 }
2989 };
2990
2991 } // namespace
2992
populateTosaToLinalgOnTensorsConversionPatterns(RewritePatternSet * patterns)2993 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
2994 RewritePatternSet *patterns) {
2995 patterns->add<
2996 // clang-format off
2997 PointwiseConverter<tosa::AddOp>,
2998 PointwiseConverter<tosa::SubOp>,
2999 PointwiseConverter<tosa::MulOp>,
3000 PointwiseConverter<tosa::DivOp>,
3001 PointwiseConverter<tosa::NegateOp>,
3002 PointwiseConverter<tosa::PowOp>,
3003 PointwiseConverter<tosa::ReciprocalOp>,
3004 PointwiseConverter<tosa::RsqrtOp>,
3005 PointwiseConverter<tosa::LogOp>,
3006 PointwiseConverter<tosa::ExpOp>,
3007 PointwiseConverter<tosa::AbsOp>,
3008 PointwiseConverter<tosa::TanhOp>,
3009 PointwiseConverter<tosa::BitwiseAndOp>,
3010 PointwiseConverter<tosa::BitwiseOrOp>,
3011 PointwiseConverter<tosa::BitwiseNotOp>,
3012 PointwiseConverter<tosa::BitwiseXorOp>,
3013 PointwiseConverter<tosa::LogicalAndOp>,
3014 PointwiseConverter<tosa::LogicalNotOp>,
3015 PointwiseConverter<tosa::LogicalOrOp>,
3016 PointwiseConverter<tosa::LogicalXorOp>,
3017 PointwiseConverter<tosa::CastOp>,
3018 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3019 PointwiseConverter<tosa::LogicalRightShiftOp>,
3020 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3021 PointwiseConverter<tosa::ClzOp>,
3022 PointwiseConverter<tosa::SelectOp>,
3023 PointwiseConverter<tosa::GreaterOp>,
3024 PointwiseConverter<tosa::GreaterEqualOp>,
3025 PointwiseConverter<tosa::EqualOp>,
3026 PointwiseConverter<tosa::MaximumOp>,
3027 PointwiseConverter<tosa::MinimumOp>,
3028 PointwiseConverter<tosa::CeilOp>,
3029 PointwiseConverter<tosa::FloorOp>,
3030 PointwiseConverter<tosa::ClampOp>,
3031 PointwiseConverter<tosa::ReluNOp>,
3032 PointwiseConverter<tosa::SigmoidOp>,
3033 IdentityNConverter<tosa::IdentityOp>,
3034 ReduceConverter<tosa::ReduceAllOp>,
3035 ReduceConverter<tosa::ReduceAnyOp>,
3036 ReduceConverter<tosa::ReduceMinOp>,
3037 ReduceConverter<tosa::ReduceMaxOp>,
3038 ReduceConverter<tosa::ReduceSumOp>,
3039 ReduceConverter<tosa::ReduceProdOp>,
3040 ArgMaxConverter,
3041 ConcatConverter,
3042 ConvConverter,
3043 DepthwiseConvConverter,
3044 TransposeConvConverter,
3045 GatherConverter,
3046 PadConverter,
3047 ReshapeConverter,
3048 RescaleConverter,
3049 ResizeConverter,
3050 ReverseConverter,
3051 TableConverter,
3052 TileConverter,
3053 TransposeConverter,
3054 MatMulConverter,
3055 MaxPool2dConverter,
3056 AvgPool2dConverter,
3057 FullyConnectedConverter>(patterns->getContext());
3058 // clang-format on
3059 }
3060