1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
20 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 #include <numeric>
27 
28 using namespace mlir;
29 
getNParallelLoopsAttrs(unsigned nParallelLoops)30 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
31   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
32 }
33 
34 template <typename T>
35 static mlir::ConstantOp
createConstFromIntAttribute(Operation * op,std::string attrName,Type requiredAttrType,OpBuilder & rewriter)36 createConstFromIntAttribute(Operation *op, std::string attrName,
37                             Type requiredAttrType, OpBuilder &rewriter) {
38   auto castedN = static_cast<T>(
39       op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
40   return rewriter.create<mlir::ConstantOp>(
41       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
42 }
43 
44 template <typename T>
getValuesFromIntArrayAttribute(ArrayAttr attr,SmallVector<T> & arrayValues)45 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
46                                            SmallVector<T> &arrayValues) {
47   for (Attribute val : attr.getValue()) {
48     arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
49   }
50 }
51 
52 template <typename T, typename P>
clampHelper(Location loc,Value arg,mlir::ConstantOp min,mlir::ConstantOp max,P pred,OpBuilder & rewriter)53 static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
54                                   mlir::ConstantOp max, P pred,
55                                   OpBuilder &rewriter) {
56   auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
57   auto minOrArg =
58       rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
59   auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
60   return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
61 }
62 
applyPad(Location loc,Value input,ArrayRef<int64_t> pad,Attribute padAttr,OpBuilder & rewriter)63 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
64                             Attribute padAttr, OpBuilder &rewriter) {
65   // Input should be padded if necessary.
66   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
67     return input;
68 
69   ShapedType inputTy = input.getType().cast<ShapedType>();
70   Type inputETy = inputTy.getElementType();
71   auto inputShape = inputTy.getShape();
72 
73   assert((inputShape.size() * 2) == pad.size());
74 
75   SmallVector<int64_t, 4> paddedShape;
76   SmallVector<OpFoldResult, 8> lowIndices;
77   SmallVector<OpFoldResult, 8> highIndices;
78   for (int i = 0, s = inputShape.size(); i < s; i++) {
79     auto lowPad = pad[i * 2];
80     auto highPad = pad[i * 2 + 1];
81     paddedShape.push_back(inputShape[i] + highPad + lowPad);
82     lowIndices.push_back(rewriter.getIndexAttr(lowPad));
83     highIndices.push_back(rewriter.getIndexAttr(highPad));
84   }
85 
86   Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
87 
88   return linalg::PadTensorOp::createPadScalarOp(
89              RankedTensorType::get(paddedShape, inputETy), input, padValue,
90              lowIndices, highIndices, /*nofold=*/false, loc, rewriter)
91       .result();
92 }
93 
filterDynamicDims(SmallVector<Value> dynDims)94 static SmallVector<Value> filterDynamicDims(SmallVector<Value> dynDims) {
95   SmallVector<Value> filteredDims;
96   for (auto dim : dynDims)
97     if (dim)
98       filteredDims.push_back(dim);
99   return filteredDims;
100 }
101 
102 static Value
createLinalgBodyCalculationForElementwiseOp(Operation * op,ValueRange args,ArrayRef<Type> resultTypes,PatternRewriter & rewriter)103 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
104                                             ArrayRef<Type> resultTypes,
105                                             PatternRewriter &rewriter) {
106   Location loc = op->getLoc();
107   auto elementTy =
108       op->getOperand(0).getType().cast<ShapedType>().getElementType();
109 
110   // tosa::AbsOp
111   if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
112     return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
113 
114   if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
115     auto zero =
116         rewriter.create<mlir::ConstantOp>(loc, rewriter.getZeroAttr(elementTy));
117     auto cmp =
118         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0], zero);
119     auto neg = rewriter.create<mlir::SubIOp>(loc, zero, args[0]);
120     return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg);
121   }
122 
123   // tosa::AddOp
124   if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
125     return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
126 
127   if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
128     return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
129 
130   // tosa::SubOp
131   if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
132     return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
133 
134   if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
135     return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
136 
137   // tosa::MulOp
138   if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
139     if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
140       (void)rewriter.notifyMatchFailure(op,
141                                         "Cannot have shift value for float");
142       return nullptr;
143     }
144     return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
145   }
146 
147   // tosa::DivOp
148   if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
149     return rewriter.create<mlir::SignedDivIOp>(loc, resultTypes, args);
150 
151   // tosa::ReciprocalOp
152   if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
153     auto one =
154         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
155     return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
156   }
157 
158   if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
159     Value a = args[0];
160     Value b = args[1];
161     auto shift =
162         op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
163     if (shift > 0) {
164       auto shiftConst =
165           rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
166       if (!a.getType().isInteger(32))
167         a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
168 
169       if (!b.getType().isInteger(32))
170         b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
171 
172       auto result = rewriter.create<tosa::ApplyScaleOp>(
173           loc, rewriter.getI32Type(), a, b, shiftConst,
174           rewriter.getBoolAttr(false));
175 
176       if (elementTy.isInteger(32))
177         return result;
178 
179       return rewriter.create<TruncateIOp>(loc, elementTy, result);
180     }
181 
182     int aWidth = a.getType().getIntOrFloatBitWidth();
183     int bWidth = b.getType().getIntOrFloatBitWidth();
184     int cWidth = resultTypes[0].getIntOrFloatBitWidth();
185 
186     if (aWidth < cWidth)
187       a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
188     if (bWidth < cWidth)
189       b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
190 
191     return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
192   }
193 
194   // tosa::NegateOp
195   if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
196     return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
197 
198   if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
199       !cast<tosa::NegateOp>(op).quantization_info()) {
200     auto constant =
201         rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
202     return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
203   }
204 
205   if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
206       cast<tosa::NegateOp>(op).quantization_info()) {
207     auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
208     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
209     int64_t inZp =
210         quantizationInfo.getValue().input_zp().getValue().getSExtValue();
211     int64_t outZp =
212         quantizationInfo.getValue().output_zp().getValue().getSExtValue();
213 
214     // Compute the maximum value that can occur in the intermediate buffer.
215     int64_t zpAdd = inZp + outZp;
216     int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
217                        std::abs(zpAdd) + 1;
218 
219     // Convert that maximum value into the maximum bitwidth needed to represent
220     // it. We assume 48-bit numbers may be supported further in the pipeline.
221     int intermediateBitWidth = 64;
222     if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
223       intermediateBitWidth = 16;
224     } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
225       intermediateBitWidth = 32;
226     } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
227       intermediateBitWidth = 48;
228     }
229 
230     Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
231     Value zpAddValue = rewriter.create<ConstantOp>(
232         loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
233 
234     // The negation can be applied by doing:
235     //  outputValue = inZp + outZp - inputValue
236     auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
237     auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
238 
239     // Clamp to the negation range.
240     auto min = rewriter.create<ConstantOp>(
241         loc, rewriter.getIntegerAttr(
242                  intermediateType,
243                  APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
244     auto max = rewriter.create<ConstantOp>(
245         loc, rewriter.getIntegerAttr(
246                  intermediateType,
247                  APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
248     auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
249                                            CmpIPredicate::slt, rewriter);
250 
251     // Truncate to the final value.
252     return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
253   }
254 
255   // tosa::BitwiseAndOp
256   if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
257     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
258 
259   // tosa::BitwiseOrOp
260   if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
261     return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
262 
263   // tosa::BitwiseNotOp
264   if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
265     auto allOnesAttr = rewriter.getIntegerAttr(
266         elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
267     auto allOnes = rewriter.create<ConstantOp>(loc, allOnesAttr);
268     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], allOnes);
269   }
270 
271   // tosa::BitwiseXOrOp
272   if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
273     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
274 
275   // tosa::LogicalLeftShiftOp
276   if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
277     return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
278 
279   // tosa::LogicalRightShiftOp
280   if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
281     return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
282 
283   // tosa::ArithmeticRightShiftOp
284   if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
285     auto result =
286         rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
287     auto round = op->getAttr("round").cast<BoolAttr>().getValue();
288     if (!round) {
289       return result;
290     }
291 
292     Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
293     auto one =
294         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
295     auto zero =
296         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
297     auto i1one =
298         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
299 
300     // Checking that input2 != 0
301     auto shiftValueGreaterThanZero =
302         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
303 
304     // Checking for the last bit of input1 to be 1
305     auto subtract =
306         rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
307     auto shifted = rewriter
308                        .create<mlir::SignedShiftRightOp>(loc, resultTypes,
309                                                          args[0], subtract)
310                        ->getResults();
311     auto truncated =
312         rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
313     auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
314 
315     auto shouldRound = rewriter.create<mlir::AndOp>(
316         loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
317     auto extended =
318         rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
319     return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
320   }
321 
322   // tosa::ClzOp
323   if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
324     int bitWidth = elementTy.getIntOrFloatBitWidth();
325     auto zero =
326         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
327     auto leadingZeros = rewriter.create<mlir::ConstantOp>(
328         loc, IntegerAttr::get(elementTy, bitWidth));
329 
330     SmallVector<Value> operands = {args[0], leadingZeros, zero};
331     SmallVector<Type> types = {elementTy, elementTy, elementTy};
332 
333     auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
334     Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
335     Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
336 
337     // The conditional block of the while loop.
338     {
339       rewriter.setInsertionPointToStart(&whileOp.before().front());
340       Value input = before->getArgument(0);
341       Value zero = before->getArgument(2);
342 
343       Value inputLargerThanZero =
344           rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, input, zero);
345       rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
346                                         before->getArguments());
347     }
348 
349     // The body of the while loop: shift right until reaching a value of 0.
350     {
351       rewriter.setInsertionPointToStart(&whileOp.after().front());
352       Value input = after->getArgument(0);
353       Value leadingZeros = after->getArgument(1);
354 
355       auto one = rewriter.create<mlir::ConstantOp>(
356           loc, IntegerAttr::get(elementTy, 1));
357       auto shifted = rewriter.create<mlir::UnsignedShiftRightOp>(
358           loc, resultTypes, input, one);
359       auto leadingZerosMinusOne =
360           rewriter.create<mlir::SubIOp>(loc, resultTypes, leadingZeros, one);
361 
362       rewriter.create<scf::YieldOp>(
363           loc,
364           ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
365     }
366 
367     rewriter.setInsertionPointAfter(whileOp);
368     return whileOp->getResult(1);
369   }
370 
371   // tosa::LogicalAnd
372   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
373     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
374 
375   // tosa::LogicalNot
376   if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
377     auto one = rewriter.create<mlir::ConstantOp>(
378         loc, rewriter.getIntegerAttr(elementTy, 1));
379     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], one);
380   }
381 
382   // tosa::LogicalOr
383   if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
384     return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
385 
386   // tosa::LogicalXor
387   if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
388     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
389 
390   // tosa::PowOp
391   if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
392     return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
393 
394   // tosa::RsqrtOp
395   if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
396     return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
397 
398   // tosa::LogOp
399   if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
400     return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
401 
402   // tosa::ExpOp
403   if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
404     return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
405 
406   // tosa::TanhOp
407   if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
408     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
409 
410   // tosa::GreaterOp
411   if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
412     return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
413                                          args[1]);
414 
415   if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
416     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
417                                          args[1]);
418 
419   // tosa::GreaterEqualOp
420   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
421     return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
422                                          args[1]);
423 
424   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
425     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
426                                          args[1]);
427 
428   // tosa::EqualOp
429   if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
430     return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
431                                          args[1]);
432 
433   if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
434     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
435                                          args[1]);
436 
437   // tosa::SelectOp
438   if (isa<tosa::SelectOp>(op)) {
439     elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
440     if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
441       return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
442   }
443 
444   // tosa::MaximumOp
445   if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
446     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
447                                                    args[0], args[1]);
448     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
449   }
450 
451   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
452     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
453                                                    args[0], args[1]);
454     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
455   }
456 
457   // tosa::MinimumOp
458   if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
459     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
460                                                    args[0], args[1]);
461     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
462   }
463 
464   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
465     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
466                                                    args[0], args[1]);
467     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
468   }
469 
470   // tosa::CeilOp
471   if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
472     return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
473 
474   // tosa::FloorOp
475   if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
476     return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
477 
478   // tosa::ClampOp
479   if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
480     auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
481                                                  op->getAttr("min_fp"));
482     auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
483                                                  op->getAttr("max_fp"));
484     return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
485                                      rewriter);
486   }
487 
488   if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
489     auto intTy = elementTy.cast<IntegerType>();
490     int32_t min = static_cast<int32_t>(
491         op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
492     int32_t max = static_cast<int32_t>(
493         op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
494 
495     if (intTy.isUnsignedInteger()) {
496       min = std::max<int32_t>(min, 0);
497       max = std::min<int32_t>(
498           max,
499           APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
500     } else {
501       min = std::max<int32_t>(
502           min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
503                    .getSExtValue());
504       max = std::min<int32_t>(
505           max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
506                    .getSExtValue());
507     }
508 
509     auto minVal =
510         rewriter.create<ConstantIntOp>(loc, min, intTy.getIntOrFloatBitWidth());
511     auto maxVal =
512         rewriter.create<ConstantIntOp>(loc, max, intTy.getIntOrFloatBitWidth());
513     return clampHelper<mlir::CmpIOp>(loc, args[0], minVal, maxVal,
514                                      CmpIPredicate::slt, rewriter);
515   }
516 
517   // tosa::ReluNOp
518   if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
519     auto zero =
520         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
521     auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
522                                                op->getAttr("max_fp"));
523     return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
524                                      rewriter);
525   }
526 
527   if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
528     auto zero =
529         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
530     auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
531                                                   rewriter);
532     return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
533                                      rewriter);
534   }
535 
536   // tosa::SigmoidOp
537   if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
538     auto one =
539         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
540     auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
541     auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
542     auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
543     return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
544   }
545 
546   // tosa::CastOp
547   if (isa<tosa::CastOp>(op)) {
548     Type srcTy = elementTy;
549     Type dstTy = resultTypes.front();
550     bool bitExtend =
551         srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
552 
553     if (srcTy == dstTy)
554       return args.front();
555 
556     if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
557       return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
558 
559     if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
560       return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
561                                               mlir::None);
562 
563     // 1-bit integers need to be treated as signless.
564     if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
565       return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
566                                              mlir::None);
567 
568     if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
569       return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
570                                                   mlir::None);
571 
572     // All other si-to-fp conversions should be handled by SIToFP.
573     if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
574       return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
575                                              mlir::None);
576 
577     // Unsigned integers need an unrealized cast so that they can be passed
578     // to UIToFP.
579     if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
580       auto unrealizedCast =
581           rewriter
582               .create<UnrealizedConversionCastOp>(
583                   loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
584                   args[0])
585               .getResult(0);
586       return rewriter.create<mlir::UIToFPOp>(loc, resultTypes[0],
587                                              unrealizedCast);
588     }
589 
590     // Casting to boolean, floats need to only be checked as not-equal to zero.
591     if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
592       Value zero =
593           rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
594       return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
595                                            args.front(), zero);
596     }
597 
598     if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
599       auto zero =
600           rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
601       auto half =
602           rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
603 
604       auto intMin = rewriter.create<ConstantOp>(
605           loc, rewriter.getF32FloatAttr(
606                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
607                        .getSExtValue()));
608 
609       auto intMax = rewriter.create<ConstantOp>(
610           loc, rewriter.getF32FloatAttr(
611                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
612                        .getSExtValue()));
613 
614       auto added = rewriter.create<AddFOp>(loc, args[0], half);
615       auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
616       auto negative =
617           rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
618       auto rounded =
619           rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
620 
621       auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
622                                                CmpFPredicate::OLT, rewriter);
623 
624       return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
625     }
626 
627     // Casting to boolean, integers need to only be checked as not-equal to
628     // zero.
629     if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
630       Value zero =
631           rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
632       return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
633                                            zero);
634     }
635 
636     if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
637       return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
638                                                   mlir::None);
639 
640     if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
641       auto intMin = rewriter.create<ConstantIntOp>(
642           loc,
643           APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
644               .getSExtValue(),
645           srcTy.getIntOrFloatBitWidth());
646 
647       auto intMax = rewriter.create<ConstantIntOp>(
648           loc,
649           APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
650               .getSExtValue(),
651           srcTy.getIntOrFloatBitWidth());
652 
653       auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
654                                                CmpIPredicate::slt, rewriter);
655       return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
656     }
657   }
658 
659   (void)rewriter.notifyMatchFailure(
660       op, "unhandled op for linalg body calculation for elementwise op");
661   return nullptr;
662 }
663 
664 static LogicalResult
elementwiseMatchAndRewriteHelper(Operation * operation,PatternRewriter & rewriter)665 elementwiseMatchAndRewriteHelper(Operation *operation,
666                                  PatternRewriter &rewriter) {
667   auto loc = operation->getLoc();
668 
669   assert(operation->getNumResults() == 1 &&
670          "All TOSA elementwise ops should only return a single result.");
671 
672   auto results = operation->getResults();
673   auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
674 
675   if (!resultTy)
676     return rewriter.notifyMatchFailure(operation,
677                                        "All results must be a shaped type");
678 
679   unsigned rank = resultTy.getRank();
680 
681   // Construct the indexing maps needed for linalg.generic ops.
682   SmallVector<Type> bodyArgTypes;
683 
684   for (Value in : operation->getOperands())
685     bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
686 
687   SmallVector<Type> opResultTypes;
688   SmallVector<Value> initTensors;
689 
690   SmallVector<Value> dynDims;
691   dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
692 
693   for (auto arg : operation->getOperands()) {
694     auto operandTy = arg.getType().cast<ShapedType>();
695     for (int i = 0; i < operandTy.getRank(); i++) {
696       if (operandTy.isDynamicDim(i) && !dynDims[i])
697         dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
698     }
699   }
700 
701   SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
702 
703   for (auto result : results) {
704     auto resultTy = result.getType().template cast<ShapedType>();
705     initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
706         loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
707     opResultTypes.push_back(result.getType());
708   }
709 
710   auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
711       initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
712 
713   SmallVector<Value, 2> operands;
714   SmallVector<AffineMap, 2> indexingMaps;
715   indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
716 
717   // Input indexing maps may be broadcasted.
718   for (Value operand : operation->getOperands()) {
719     ShapedType type = operand.getType().cast<ShapedType>();
720 
721     if (type.getShape() == resultTy.getShape()) {
722       operands.push_back(operand);
723       indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
724       continue;
725     }
726 
727     SmallVector<int64_t, 5> newShape;
728     SmallVector<AffineExpr, 4> affineExprs;
729     newShape.reserve(type.getRank());
730     for (auto it : llvm::enumerate(type.getShape())) {
731       if (it.value() == resultTy.getDimSize(it.index())) {
732         newShape.push_back(it.value());
733         affineExprs.push_back(
734             mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
735       }
736     }
737 
738     if (newShape.size() != rank) {
739       operand = rewriter.create<tosa::ReshapeOp>(
740           loc, RankedTensorType::get(newShape, type.getElementType()), operand,
741           rewriter.getI64ArrayAttr(newShape));
742     }
743 
744     operands.push_back(operand);
745     indexingMaps.push_back(AffineMap::get(
746         /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
747         rewriter.getContext()));
748   }
749 
750   indexingMaps.append(operation->getNumResults(),
751                       rewriter.getMultiDimIdentityMap(rank));
752 
753   bool didEncounterError = false;
754   auto linalgOp = rewriter.create<linalg::GenericOp>(
755       loc, opResultTypes, operands, initTensors, indexingMaps,
756       getNParallelLoopsAttrs(rank),
757       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
758         Value opResult = createLinalgBodyCalculationForElementwiseOp(
759             operation, blockArgs.take_front(operation->getNumOperands()),
760             bodyResultTypes, rewriter);
761         if (!opResult) {
762           didEncounterError = true;
763           return;
764         }
765         nestedBuilder.create<linalg::YieldOp>(loc, opResult);
766       });
767 
768   if (didEncounterError)
769     return failure();
770 
771   rewriter.replaceOp(operation, linalgOp->getResults());
772   return success();
773 }
774 
775 // Returns the constant initial value for a given reduction operation. The
776 // attribute type varies depending on the element type required.
createInitialValueForReduceOp(Operation * op,Type elementTy,PatternRewriter & rewriter)777 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
778                                                PatternRewriter &rewriter) {
779   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
780     return rewriter.getFloatAttr(elementTy, 0.0);
781 
782   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
783     return rewriter.getIntegerAttr(elementTy, 0);
784 
785   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
786     return rewriter.getFloatAttr(elementTy, 1.0);
787 
788   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
789     return rewriter.getIntegerAttr(elementTy, 1);
790 
791   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
792     return rewriter.getFloatAttr(
793         elementTy, APFloat::getLargest(
794                        elementTy.cast<FloatType>().getFloatSemantics(), false));
795 
796   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
797     return rewriter.getIntegerAttr(
798         elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
799 
800   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
801     return rewriter.getFloatAttr(
802         elementTy, APFloat::getLargest(
803                        elementTy.cast<FloatType>().getFloatSemantics(), true));
804 
805   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
806     return rewriter.getIntegerAttr(
807         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
808 
809   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
810     return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
811 
812   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
813     return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
814 
815   if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
816     return rewriter.getFloatAttr(
817         elementTy, APFloat::getLargest(
818                        elementTy.cast<FloatType>().getFloatSemantics(), true));
819 
820   if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
821     return rewriter.getIntegerAttr(
822         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
823 
824   return {};
825 }
826 
827 // Creates the body calculation for a reduction. The operations vary depending
828 // on the input type.
createLinalgBodyCalculationForReduceOp(Operation * op,ValueRange args,Type elementTy,PatternRewriter & rewriter)829 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
830                                                     ValueRange args,
831                                                     Type elementTy,
832                                                     PatternRewriter &rewriter) {
833   Location loc = op->getLoc();
834   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
835     return rewriter.create<AddFOp>(loc, args);
836   }
837 
838   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
839     return rewriter.create<AddIOp>(loc, args);
840   }
841 
842   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
843     return rewriter.create<MulFOp>(loc, args);
844   }
845 
846   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
847     return rewriter.create<MulIOp>(loc, args);
848   }
849 
850   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
851     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
852                                                    args[0], args[1]);
853     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
854   }
855 
856   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
857     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
858                                                    args[0], args[1]);
859     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
860   }
861 
862   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
863     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
864                                                    args[0], args[1]);
865     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
866   }
867 
868   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
869     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
870                                                    args[0], args[1]);
871     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
872   }
873 
874   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
875     return rewriter.create<mlir::AndOp>(loc, args);
876 
877   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
878     return rewriter.create<mlir::OrOp>(loc, args);
879 
880   return {};
881 }
882 
883 // Performs the match and rewrite for reduction operations. This includes
884 // declaring a correctly sized initial value, and the linalg.generic operation
885 // that reduces across the specified axis.
reduceMatchAndRewriteHelper(Operation * op,uint64_t axis,PatternRewriter & rewriter)886 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
887                                                  PatternRewriter &rewriter) {
888   auto loc = op->getLoc();
889   auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
890   auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
891   auto elementTy = resultTy.getElementType();
892   Value input = op->getOperand(0);
893 
894   llvm::SmallVector<int64_t> reduceShape;
895   for (unsigned i = 0; i < inputTy.getRank(); i++) {
896     if (axis != i)
897       reduceShape.push_back(inputTy.getDimSize(i));
898   }
899 
900   Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
901 
902   // First fill the output buffer with the init value.
903   auto initTensor =
904       rewriter
905           .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
906                                         resultTy.getElementType())
907           .result();
908 
909   auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
910   if (!fillValueAttr)
911     return rewriter.notifyMatchFailure(
912         op, "No initial value found for reduction operation");
913 
914   auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
915   auto filledTensor =
916       rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
917 
918   SmallVector<AffineExpr, 2> srcExprs;
919   SmallVector<AffineExpr, 2> dstExprs;
920   SmallVector<StringRef, 4> iteratorTypes;
921   for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
922     srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
923 
924     iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
925                                       : getParallelIteratorTypeName());
926     if (axis != i)
927       dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
928   }
929 
930   bool didEncounterError = false;
931   auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
932   auto linalgOp = rewriter.create<linalg::GenericOp>(
933       loc, reduceTy, input, filledTensor, maps, iteratorTypes,
934       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
935         auto result = createLinalgBodyCalculationForReduceOp(
936             op, blockArgs, elementTy, rewriter);
937         if (result)
938           didEncounterError = true;
939 
940         nestedBuilder.create<linalg::YieldOp>(loc, result);
941       });
942 
943   if (!didEncounterError)
944     return failure();
945 
946   rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
947                                                linalgOp.getResults());
948   return success();
949 }
950 
951 namespace {
952 
953 template <typename SrcOp>
954 class PointwiseConverter : public OpRewritePattern<SrcOp> {
955 public:
956   using OpRewritePattern<SrcOp>::OpRewritePattern;
957 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const958   LogicalResult matchAndRewrite(SrcOp op,
959                                 PatternRewriter &rewriter) const final {
960     return elementwiseMatchAndRewriteHelper(op, rewriter);
961   }
962 };
963 
964 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
965 public:
966   using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
967   LogicalResult
matchAndRewrite(tosa::Conv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const968   matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
969                   ConversionPatternRewriter &rewriter) const final {
970     Location loc = op->getLoc();
971     Value input = op->getOperand(0);
972     Value weight = op->getOperand(1);
973     Value bias = op->getOperand(2);
974 
975     ShapedType inputTy = input.getType().cast<ShapedType>();
976     ShapedType weightTy = weight.getType().cast<ShapedType>();
977     ShapedType biasTy = bias.getType().cast<ShapedType>();
978     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
979 
980     Type inputETy = inputTy.getElementType();
981     Type resultETy = resultTy.getElementType();
982 
983     auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
984     auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
985     auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
986     bool isQuantized = op->hasAttr("quantization_info");
987 
988     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
989         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
990       return rewriter.notifyMatchFailure(op,
991                                          "tosa.conv ops require static shapes");
992 
993     if (inputETy.isUnsignedInteger())
994       return rewriter.notifyMatchFailure(
995           op, "tosa.conv ops does not support unsigned integer input");
996 
997     auto weightShape = weightTy.getShape();
998 
999     // Apply padding as necessary.
1000     Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
1001     if (isQuantized) {
1002       auto quantizationInfo =
1003           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1004       auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
1005 
1006       int64_t intMin =
1007           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
1008               .getSExtValue();
1009       int64_t intMax =
1010           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
1011               .getSExtValue();
1012 
1013       if (iZp < intMin || iZp > intMax)
1014         return rewriter.notifyMatchFailure(
1015             op, "tosa.conv op quantization has zp outside of input range");
1016 
1017       zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
1018     }
1019 
1020     llvm::SmallVector<int64_t> pad;
1021     pad.resize(2, 0);
1022     getValuesFromIntArrayAttribute(padAttr, pad);
1023     pad.resize(pad.size() + 2, 0);
1024     input = applyPad(loc, input, pad, zeroAttr, rewriter);
1025 
1026     // Transpose the kernel to match dimension ordering of the linalg
1027     // convolution operation.
1028     // TODO(suderman): See if this can be efficiently folded - check whether
1029     // the input is used anywhere else, if not fold the constant.
1030     SmallVector<int64_t> weightPerm{1, 2, 3, 0};
1031     SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
1032                                         weightShape[3], weightShape[0]};
1033     auto weightPermAttr = DenseIntElementsAttr::get(
1034         RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
1035     Value weightPermValue = rewriter.create<ConstantOp>(loc, weightPermAttr);
1036     Type newWeightTy =
1037         RankedTensorType::get(newWeightShape, weightTy.getElementType());
1038     weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
1039                                                 weightPermValue);
1040 
1041     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
1042     Value initTensor = rewriter.create<linalg::InitTensorOp>(
1043         loc, resultTy.getShape(), resultETy);
1044     Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
1045     Value zeroTensor =
1046         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1047 
1048     // Extract the attributes for convolution.
1049     llvm::SmallVector<int64_t> stride, dilation;
1050     getValuesFromIntArrayAttribute(strideTosaAttr, stride);
1051     getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
1052 
1053     // Create the convolution op.
1054     auto strideAttr = DenseIntElementsAttr::get(
1055         RankedTensorType::get({2}, rewriter.getI64Type()), stride);
1056     auto dilationAttr = DenseIntElementsAttr::get(
1057         RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
1058 
1059     // Create maps for the bias broadcasting
1060     SmallVector<AffineMap, 4> indexingMaps;
1061     indexingMaps.push_back(AffineMap::get(
1062         /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
1063         {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
1064     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1065     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1066 
1067     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
1068         loc, resultTy.getShape(), resultETy);
1069 
1070     if (isQuantized) {
1071       auto quantizationInfo =
1072           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1073       auto iZp = rewriter.getI32IntegerAttr(
1074           quantizationInfo.input_zp().getValue().getSExtValue());
1075       auto kZp = rewriter.getI32IntegerAttr(
1076           quantizationInfo.weight_zp().getValue().getSExtValue());
1077 
1078       auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
1079       auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
1080       Value conv =
1081           rewriter
1082               .create<linalg::Conv2DNhwcHwcfQOp>(
1083                   loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
1084                   ValueRange{zeroTensor}, strideAttr, dilationAttr)
1085               ->getResult(0);
1086 
1087       Value result =
1088           rewriter
1089               .create<linalg::GenericOp>(
1090                   loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
1091                   indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
1092                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
1093                       ValueRange args) {
1094                     Value added =
1095                         nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
1096                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1097                   })
1098               .getResult(0);
1099       rewriter.replaceOp(op, result);
1100       return success();
1101     }
1102 
1103     Value conv = rewriter
1104                      .create<linalg::Conv2DNhwcHwcfOp>(
1105                          loc, resultTy, ValueRange{input, weight},
1106                          ValueRange{zeroTensor}, strideAttr, dilationAttr)
1107                      ->getResult(0);
1108 
1109     Value result =
1110         rewriter
1111             .create<linalg::GenericOp>(
1112                 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
1113                 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
1114                 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1115                     ValueRange args) {
1116                   Value added =
1117                       nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
1118                   nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1119                 })
1120             .getResult(0);
1121 
1122     rewriter.replaceOp(op, result);
1123     return success();
1124   }
1125 };
1126 
1127 class DepthwiseConvConverter
1128     : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
1129 public:
1130   using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
1131   LogicalResult
matchAndRewrite(tosa::DepthwiseConv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1132   matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
1133                   ConversionPatternRewriter &rewriter) const final {
1134     Location loc = op->getLoc();
1135     Value input = op->getOperand(0);
1136     Value weight = op->getOperand(1);
1137     Value bias = op->getOperand(2);
1138 
1139     ShapedType inputTy = input.getType().cast<ShapedType>();
1140     ShapedType weightTy = weight.getType().cast<ShapedType>();
1141     ShapedType biasTy = bias.getType().cast<ShapedType>();
1142     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
1143 
1144     Type inputETy = inputTy.getElementType();
1145     Type resultETy = resultTy.getElementType();
1146 
1147     auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
1148     auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
1149     auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
1150 
1151     bool isQuantized = op->hasAttr("quantization_info");
1152     IntegerAttr iZp;
1153     IntegerAttr kZp;
1154     if (isQuantized) {
1155       auto quantizationInfo =
1156           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1157       iZp = rewriter.getI32IntegerAttr(
1158           quantizationInfo.input_zp().getValue().getSExtValue());
1159       kZp = rewriter.getI32IntegerAttr(
1160           quantizationInfo.weight_zp().getValue().getSExtValue());
1161     }
1162 
1163     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1164         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1165       return rewriter.notifyMatchFailure(op,
1166                                          "tosa.conv ops require static shapes");
1167 
1168     auto weightShape = weightTy.getShape();
1169     auto resultShape = resultTy.getShape();
1170 
1171     // Apply padding as necessary.
1172     Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
1173     if (isQuantized) {
1174       auto quantizationInfo =
1175           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
1176       auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
1177 
1178       int64_t intMin =
1179           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
1180               .getSExtValue();
1181       int64_t intMax =
1182           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
1183               .getSExtValue();
1184 
1185       if (iZp < intMin || iZp > intMax)
1186         return rewriter.notifyMatchFailure(
1187             op, "tosa.depthwise_conv op quantization has zp outside of input "
1188                 "range");
1189 
1190       zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
1191     }
1192 
1193     llvm::SmallVector<int64_t> pad;
1194     pad.resize(2, 0);
1195     getValuesFromIntArrayAttribute(padAttr, pad);
1196     pad.resize(pad.size() + 2, 0);
1197 
1198     input = applyPad(loc, input, pad, zeroAttr, rewriter);
1199 
1200     // Extract the attributes for convolution.
1201     llvm::SmallVector<int64_t> stride, dilation;
1202     getValuesFromIntArrayAttribute(strideTosaAttr, stride);
1203     getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
1204 
1205     // Create the convolution op.
1206     auto strideAttr = DenseIntElementsAttr::get(
1207         RankedTensorType::get({2}, rewriter.getI64Type()), stride);
1208     auto dilationAttr = DenseIntElementsAttr::get(
1209         RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
1210     ShapedType linalgConvTy =
1211         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
1212                                weightShape[2], weightShape[3]},
1213                               resultETy);
1214 
1215     // Broadcast the initial value to the output tensor before convolving.
1216     SmallVector<AffineMap, 4> indexingMaps;
1217     indexingMaps.push_back(AffineMap::get(
1218         /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
1219         {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
1220     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1221     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
1222 
1223     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
1224     Value initTensor = rewriter.create<linalg::InitTensorOp>(
1225         loc, linalgConvTy.getShape(), resultETy);
1226     Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
1227     Value zeroTensor =
1228         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1229 
1230     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
1231         loc, resultTy.getShape(), resultETy);
1232     if (!isQuantized) {
1233       Value conv = rewriter
1234                        .create<linalg::DepthwiseConv2DNhwcOp>(
1235                            loc, linalgConvTy, ValueRange{input, weight},
1236                            ValueRange{zeroTensor}, strideAttr, dilationAttr)
1237                        .getResult(0);
1238       Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
1239       Value result =
1240           rewriter
1241               .create<linalg::GenericOp>(
1242                   loc, resultTy, ValueRange({bias, convReshape}),
1243                   biasInitTensor, indexingMaps,
1244                   getNParallelLoopsAttrs(resultTy.getRank()),
1245                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
1246                       ValueRange args) {
1247                     Value added =
1248                         nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
1249                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1250                   })
1251               .getResult(0);
1252       rewriter.replaceOp(op, result);
1253     } else {
1254       auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
1255       auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
1256       Value conv =
1257           rewriter
1258               .create<linalg::DepthwiseConv2DNhwcQOp>(
1259                   loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
1260                   ValueRange{zeroTensor}, strideAttr, dilationAttr)
1261               .getResult(0);
1262       Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
1263       Value result =
1264           rewriter
1265               .create<linalg::GenericOp>(
1266                   loc, resultTy, ValueRange({bias, convReshape}),
1267                   biasInitTensor, indexingMaps,
1268                   getNParallelLoopsAttrs(resultTy.getRank()),
1269                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
1270                       ValueRange args) {
1271                     Value added =
1272                         nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
1273                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1274                   })
1275               .getResult(0);
1276       rewriter.replaceOp(op, result);
1277     }
1278     return success();
1279   }
1280 };
1281 
1282 class TransposeConvConverter
1283     : public OpConversionPattern<tosa::TransposeConv2DOp> {
1284 public:
1285   using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
1286   LogicalResult
matchAndRewrite(tosa::TransposeConv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1287   matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
1288                   ConversionPatternRewriter &rewriter) const final {
1289     Location loc = op->getLoc();
1290     Value input = op->getOperand(0);
1291     Value weight = op->getOperand(1);
1292     Value bias = op->getOperand(2);
1293 
1294     ShapedType inputTy = input.getType().cast<ShapedType>();
1295     ShapedType weightTy = weight.getType().cast<ShapedType>();
1296     ShapedType biasTy = bias.getType().cast<ShapedType>();
1297     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
1298 
1299     llvm::SmallVector<int64_t> pad;
1300     llvm::SmallVector<int64_t> stride;
1301     llvm::SmallVector<int64_t> dilation;
1302 
1303     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
1304     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
1305     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
1306 
1307     // If striding is all 1 we can modify padding and reverse the kernel along
1308     // the x/y direction to make it a regular convolution. This is much simpler
1309     // then handling striding....
1310     if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
1311       if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1312           !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1313         return failure();
1314 
1315       int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
1316       int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
1317       int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
1318       int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
1319 
1320       llvm::SmallVector<int64_t> convPad(4, 0);
1321       convPad[0] = kernelHeight - 1 - pad[0];
1322       convPad[2] = kernelWidth - 1 - pad[1];
1323       convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
1324       convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
1325 
1326       auto reverse1 = rewriter.create<tosa::ReverseOp>(
1327           loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
1328       auto reverse2 = rewriter.create<tosa::ReverseOp>(
1329           loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
1330 
1331       Value conv2d;
1332       if (op.quantization_info().hasValue()) {
1333         conv2d = rewriter.create<tosa::Conv2DOp>(
1334             loc, resultTy, input, reverse2, bias,
1335             rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
1336             rewriter.getI64ArrayAttr(dilation),
1337             op.quantization_info().getValue());
1338       } else {
1339         conv2d = rewriter.create<tosa::Conv2DOp>(
1340             loc, resultTy, input, reverse2, bias,
1341             rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
1342             rewriter.getI64ArrayAttr(dilation));
1343       }
1344 
1345       rewriter.replaceOp(op, conv2d);
1346       return success();
1347     }
1348 
1349     return failure();
1350   }
1351 };
1352 
1353 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
1354 public:
1355   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
1356   LogicalResult
matchAndRewrite(tosa::MatMulOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1357   matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
1358                   ConversionPatternRewriter &rewriter) const final {
1359     Location loc = op.getLoc();
1360 
1361     auto outputTy = op.getType().cast<ShapedType>();
1362     auto outputElementTy = outputTy.getElementType();
1363 
1364     auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
1365     auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
1366 
1367     SmallVector<Value> dynDims;
1368     dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1369 
1370     if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
1371       dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
1372     }
1373 
1374     if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
1375       dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
1376     }
1377 
1378     if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
1379       dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
1380     }
1381 
1382     SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
1383 
1384     auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
1385     Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
1386     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1387         loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
1388     Value zeroTensor =
1389         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1390     if (!op.quantization_info()) {
1391       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
1392           op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
1393           ValueRange{zeroTensor});
1394       return success();
1395     }
1396 
1397     auto quantizationInfo = op.quantization_info().getValue();
1398     auto aZp = rewriter.create<ConstantOp>(
1399         loc, rewriter.getI32IntegerAttr(
1400                  quantizationInfo.a_zp().getValue().getSExtValue()));
1401     auto bZp = rewriter.create<ConstantOp>(
1402         loc, rewriter.getI32IntegerAttr(
1403                  quantizationInfo.b_zp().getValue().getSExtValue()));
1404     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
1405         op, TypeRange{op.getType()},
1406         ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
1407 
1408     return success();
1409   }
1410 };
1411 
1412 class FullyConnectedConverter
1413     : public OpConversionPattern<tosa::FullyConnectedOp> {
1414 public:
1415   using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
1416   LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1417   matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
1418                   ConversionPatternRewriter &rewriter) const final {
1419     Location loc = op.getLoc();
1420     auto outputTy = op.getType().cast<ShapedType>();
1421     auto input = op.input();
1422     auto inputTy = input.getType().cast<ShapedType>();
1423 
1424     auto bias = op.bias();
1425 
1426     auto weight = op.weight();
1427     auto weightTy = weight.getType().cast<ShapedType>();
1428     auto weightShape = weightTy.getShape();
1429 
1430     auto outputETy = outputTy.getElementType();
1431 
1432     SmallVector<Value> dynDims;
1433     dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1434 
1435     if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
1436       dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
1437     }
1438 
1439     if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
1440       dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
1441     }
1442 
1443     SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
1444 
1445     // Creating maps for the output of MatMul and the bias
1446     SmallVector<AffineMap, 4> indexingMaps;
1447 
1448     // Broadcast the bias.
1449     indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
1450                                           {rewriter.getAffineDimExpr(1)},
1451                                           rewriter.getContext()));
1452 
1453     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
1454     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
1455 
1456     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1457         loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
1458 
1459     // When quantized, the input elemeny type is not the same as the output
1460     Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
1461     Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
1462     Value zeroTensor =
1463         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1464 
1465     SmallVector<int64_t> permutation{1, 0};
1466     auto permutationAttr = DenseIntElementsAttr::get(
1467         RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
1468     Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
1469 
1470     SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
1471     Type newWeightTy =
1472         RankedTensorType::get(newWeightShape, weightTy.getElementType());
1473 
1474     Value transposedWeight = rewriter.create<tosa::TransposeOp>(
1475         loc, newWeightTy, weight, permutationValue);
1476 
1477     auto biasInitTensor =
1478         rewriter
1479             .create<linalg::InitTensorOp>(loc, filteredDims,
1480                                           outputTy.getShape(), outputETy)
1481             ->getResults();
1482 
1483     if (!op.quantization_info()) {
1484       Value matmul = rewriter
1485                          .create<linalg::MatmulOp>(
1486                              loc, TypeRange{op.getType()},
1487                              ValueRange{input, transposedWeight}, zeroTensor)
1488                          ->getResult(0);
1489 
1490       Value result =
1491           rewriter
1492               .create<linalg::GenericOp>(
1493                   loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
1494                   indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
1495                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
1496                       ValueRange args) {
1497                     Value added =
1498                         nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
1499                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1500                   })
1501               .getResult(0);
1502       rewriter.replaceOp(op, result);
1503       return success();
1504     }
1505 
1506     auto quantizationInfo = op.quantization_info().getValue();
1507     auto inputZp = rewriter.create<ConstantOp>(
1508         loc, rewriter.getI32IntegerAttr(
1509                  quantizationInfo.input_zp().getValue().getSExtValue()));
1510     auto outputZp = rewriter.create<ConstantOp>(
1511         loc, rewriter.getI32IntegerAttr(
1512                  quantizationInfo.weight_zp().getValue().getSExtValue()));
1513     Value matmul =
1514         rewriter
1515             .create<linalg::QuantizedMatmulOp>(
1516                 loc, TypeRange{op.getType()},
1517                 ValueRange{input, transposedWeight, inputZp, outputZp},
1518                 zeroTensor)
1519             ->getResult(0);
1520     Value result =
1521         rewriter
1522             .create<linalg::GenericOp>(
1523                 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
1524                 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
1525                 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1526                     ValueRange args) {
1527                   Value added =
1528                       nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
1529                   nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
1530                 })
1531             .getResult(0);
1532     rewriter.replaceOp(op, result);
1533     return success();
1534   }
1535 };
1536 
1537 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
1538 public:
1539   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1540 
1541   LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1542   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1543                   ConversionPatternRewriter &rewriter) const final {
1544     ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1545     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1546 
1547     if (operandTy == resultTy) {
1548       rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1549       return success();
1550     }
1551 
1552     if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
1553       return failure();
1554 
1555     // Compute the reassociation maps for the linalg operation.
1556     ArrayRef<int64_t> expandedShape =
1557         (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
1558                                                   : resultTy.getShape());
1559     ArrayRef<int64_t> collapsedShape =
1560         (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
1561                                                   : operandTy.getShape());
1562     unsigned currSrcDim = 0, currDstDim = 0;
1563     SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
1564 
1565     // First scan all dimensions in the source shapes to see whether we have a
1566     // perfect case where consecutive dimensions in source are collapsed. For
1567     // such case we can just generate one single linalg.reshape.
1568     bool isCollapsingSource = true;
1569     while (currSrcDim < expandedShape.size() &&
1570            currDstDim < collapsedShape.size()) {
1571       int64_t dstSize = collapsedShape[currDstDim];
1572       int64_t srcSize = expandedShape[currSrcDim];
1573       while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
1574         reassociationMap[currDstDim].push_back(
1575             rewriter.getAffineDimExpr(currSrcDim++));
1576         srcSize *= expandedShape[currSrcDim];
1577       }
1578       if (srcSize == dstSize) {
1579         reassociationMap[currDstDim].push_back(
1580             rewriter.getAffineDimExpr(currSrcDim++));
1581         // If the next dim in collapsedShape is not 1, treat subsequent dims in
1582         // expandedShape which are 1 to be collapsed.
1583         if (currDstDim == collapsedShape.size() - 1 ||
1584             collapsedShape[currDstDim + 1] != 1) {
1585           while (currSrcDim < expandedShape.size() &&
1586                  expandedShape[currSrcDim] == 1) {
1587             reassociationMap[currDstDim].push_back(
1588                 rewriter.getAffineDimExpr(currSrcDim++));
1589           }
1590         }
1591       } else {
1592         isCollapsingSource = false;
1593         break;
1594       }
1595       currDstDim++;
1596     }
1597 
1598     // Check if any remaining dimensions exist. If either is rank-0 we only
1599     // require the directly lowering.
1600     if (currSrcDim != expandedShape.size() ||
1601         currDstDim != collapsedShape.size())
1602       isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
1603 
1604     // Otherwise, we need to first reduce all source dimensions into one and
1605     // then expand to the destination dimensions.
1606     if (!isCollapsingSource) {
1607       auto getIdentityExprs = [&rewriter](int n) {
1608         SmallVector<AffineExpr, 4> exprs;
1609         for (int i = 0; i < n; ++i)
1610           exprs.push_back(rewriter.getAffineDimExpr(i));
1611         return exprs;
1612       };
1613       Location loc = reshape.getLoc();
1614       int64_t totalElems =
1615           std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
1616                           std::multiplies<int64_t>());
1617       auto elemTy = operandTy.getElementType();
1618       SmallVector<ReassociationExprs, 4> collapsingMap = {
1619           // Use operandTy here because we need to collapse all operands
1620           // dimensions.
1621           getIdentityExprs(operandTy.getShape().size())};
1622       SmallVector<ReassociationExprs, 4> expandingMap = {
1623           // Use resultTy here because we need to expand to all result
1624           // dimensions.
1625           getIdentityExprs(resultTy.getShape().size())};
1626 
1627       auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
1628       Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
1629           loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
1630       rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1631           reshape, resultTy, collapsedOp, expandingMap);
1632 
1633       return success();
1634     }
1635 
1636     if (resultTy.getRank() <
1637         adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
1638       rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
1639           reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1640     else
1641       rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1642           reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1643 
1644     return success();
1645   }
1646 };
1647 
1648 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1649 public:
1650   using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1651 
matchAndRewrite(tosa::TransposeOp op,PatternRewriter & rewriter) const1652   LogicalResult matchAndRewrite(tosa::TransposeOp op,
1653                                 PatternRewriter &rewriter) const final {
1654     DenseIntElementsAttr perms;
1655     if (!matchPattern(op.perms(), m_Constant(&perms))) {
1656       return failure();
1657     }
1658 
1659     auto loc = op.getLoc();
1660     auto input = op->getOperand(0);
1661     auto resultTy = op.getType().cast<ShapedType>();
1662 
1663     SmallVector<Value> dynDims;
1664     dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1665 
1666     SmallVector<AffineExpr, 2> inputExprs;
1667     inputExprs.resize(resultTy.getRank());
1668     auto operandTy = input.getType().cast<ShapedType>();
1669     for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
1670       auto index = permutation.index();
1671       auto value = permutation.value().getZExtValue();
1672       if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1673         dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1674       }
1675       inputExprs[value] = rewriter.getAffineDimExpr(index);
1676     }
1677 
1678     SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
1679 
1680     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1681         loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1682 
1683     SmallVector<AffineMap, 2> affineMaps = {
1684         AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1685                        rewriter.getContext()),
1686         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1687 
1688     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1689         op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
1690         getNParallelLoopsAttrs(resultTy.getRank()),
1691         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1692           nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1693         });
1694     return success();
1695   }
1696 };
1697 
1698 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1699 public:
1700   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1701 
matchAndRewrite(tosa::RescaleOp op,PatternRewriter & rewriter) const1702   LogicalResult matchAndRewrite(tosa::RescaleOp op,
1703                                 PatternRewriter &rewriter) const final {
1704     auto loc = op.getLoc();
1705     auto input = op.input();
1706     auto inputTy = op.input().getType().cast<ShapedType>();
1707     auto outputTy = op.output().getType().cast<ShapedType>();
1708     unsigned rank = inputTy.getRank();
1709 
1710     // This is an illegal configuration. terminate and log an error
1711     if (op.double_round() && !op.scale32())
1712       return rewriter.notifyMatchFailure(
1713           op, "tosa.rescale requires scale32 for double_round to be true");
1714 
1715     if (!outputTy.hasStaticShape())
1716       return rewriter.notifyMatchFailure(
1717           op, "tosa to linalg conversion expects statically shaped tensors");
1718 
1719     // The shift and multiplier values.
1720     SmallVector<int32_t> multiplierValues;
1721     getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
1722 
1723     SmallVector<int8_t> shiftValues;
1724     getValuesFromIntArrayAttribute(op.shift(), shiftValues);
1725 
1726     // Double round only occurs if shift is greater than 31, check that this
1727     // is ever true.
1728     bool doubleRound =
1729         op.double_round() &&
1730         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1731 
1732     SmallVector<AffineMap> indexingMaps = {
1733         rewriter.getMultiDimIdentityMap(rank)};
1734     SmallVector<Value, 4> genericInputs = {input};
1735 
1736     // If we are rescaling per-channel then we need to store the multiplier
1737     // values in a buffer.
1738     Value multiplierConstant;
1739     int64_t multiplierArg = 0;
1740     if (multiplierValues.size() == 1) {
1741       multiplierConstant = rewriter.create<ConstantOp>(
1742           loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1743     } else {
1744       SmallVector<AffineExpr, 2> multiplierExprs{
1745           rewriter.getAffineDimExpr(rank - 1)};
1746       auto multiplierType =
1747           RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1748                                 rewriter.getI32Type());
1749       genericInputs.push_back(rewriter.create<ConstantOp>(
1750           loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1751 
1752       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1753                                             /*symbolCount=*/0, multiplierExprs,
1754                                             rewriter.getContext()));
1755 
1756       multiplierArg = indexingMaps.size() - 1;
1757     }
1758 
1759     // If we are rescaling per-channel then we need to store the shift
1760     // values in a buffer.
1761     Value shiftConstant;
1762     int64_t shiftArg = 0;
1763     if (shiftValues.size() == 1) {
1764       shiftConstant = rewriter.create<ConstantOp>(
1765           loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1766     } else {
1767       SmallVector<AffineExpr, 2> shiftExprs = {
1768           rewriter.getAffineDimExpr(rank - 1)};
1769       auto shiftType =
1770           RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1771                                 rewriter.getIntegerType(8));
1772       genericInputs.push_back(rewriter.create<ConstantOp>(
1773           loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1774       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1775                                             /*symbolCount=*/0, shiftExprs,
1776                                             rewriter.getContext()));
1777       shiftArg = indexingMaps.size() - 1;
1778     }
1779 
1780     // Indexing maps for output values.
1781     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1782 
1783     // Construct the indexing maps needed for linalg.generic ops.
1784     Value initTensor = rewriter.create<linalg::InitTensorOp>(
1785         loc, ArrayRef<Value>({}), outputTy.getShape(),
1786         outputTy.getElementType());
1787 
1788     auto linalgOp = rewriter.create<linalg::GenericOp>(
1789         loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1790         getNParallelLoopsAttrs(rank),
1791         [&](OpBuilder &nestedBuilder, Location nestedLoc,
1792             ValueRange blockArgs) {
1793           Value value = blockArgs[0];
1794           Type valueTy = value.getType();
1795 
1796           // For now we do all of our math in 64-bit. This is not optimal but
1797           // should be correct for now, consider computing correct bit depth
1798           // later.
1799           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1800 
1801           auto inputZp = createConstFromIntAttribute<int32_t>(
1802               op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1803               nestedBuilder);
1804           auto outputZp = createConstFromIntAttribute<int32_t>(
1805               op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1806 
1807           Value multiplier = multiplierConstant ? multiplierConstant
1808                                                 : blockArgs[multiplierArg];
1809           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1810 
1811           if (valueTy.getIntOrFloatBitWidth() < 32) {
1812             if (valueTy.isUnsignedInteger()) {
1813               value = nestedBuilder
1814                           .create<UnrealizedConversionCastOp>(
1815                               nestedLoc,
1816                               nestedBuilder.getIntegerType(
1817                                   valueTy.getIntOrFloatBitWidth()),
1818                               value)
1819                           .getResult(0);
1820               value = nestedBuilder.create<ZeroExtendIOp>(
1821                   nestedLoc, nestedBuilder.getI32Type(), value);
1822             } else {
1823               value = nestedBuilder.create<SignExtendIOp>(
1824                   nestedLoc, nestedBuilder.getI32Type(), value);
1825             }
1826           }
1827 
1828           value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
1829 
1830           value = nestedBuilder.create<tosa::ApplyScaleOp>(
1831               loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1832               nestedBuilder.getBoolAttr(doubleRound));
1833 
1834           // Move to the new zero-point.
1835           value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
1836 
1837           // Saturate to the output size.
1838           IntegerType outIntType =
1839               blockArgs.back().getType().cast<IntegerType>();
1840           unsigned outBitWidth = outIntType.getWidth();
1841 
1842           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1843           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1844 
1845           // Unsigned integers have a difference output value.
1846           if (outIntType.isUnsignedInteger()) {
1847             intMin = 0;
1848             intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1849           }
1850 
1851           auto intMinVal = nestedBuilder.create<ConstantOp>(
1852               loc,
1853               nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin));
1854           auto intMaxVal = nestedBuilder.create<ConstantOp>(
1855               loc,
1856               nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax));
1857 
1858           value =
1859               clampHelper<mlir::CmpIOp>(nestedLoc, value, intMinVal, intMaxVal,
1860                                         CmpIPredicate::slt, nestedBuilder);
1861 
1862           if (outIntType.getWidth() < 32) {
1863             value = nestedBuilder.create<TruncateIOp>(
1864                 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1865                 value);
1866 
1867             if (outIntType.isUnsignedInteger()) {
1868               value = nestedBuilder
1869                           .create<UnrealizedConversionCastOp>(nestedLoc,
1870                                                               outIntType, value)
1871                           .getResult(0);
1872             }
1873           }
1874 
1875           nestedBuilder.create<linalg::YieldOp>(loc, value);
1876         });
1877 
1878     rewriter.replaceOp(op, linalgOp->getResults());
1879     return success();
1880   }
1881 };
1882 
1883 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1884 public:
1885   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1886 
matchAndRewrite(tosa::ResizeOp op,PatternRewriter & rewriter) const1887   LogicalResult matchAndRewrite(tosa::ResizeOp op,
1888                                 PatternRewriter &rewriter) const final {
1889     Location loc = op.getLoc();
1890     auto input = op.input();
1891     auto inputTy = input.getType().cast<ShapedType>();
1892     auto resultTy = op.getType().cast<ShapedType>();
1893     auto resultElementTy = resultTy.getElementType();
1894 
1895     auto imageH = inputTy.getShape()[1];
1896     auto imageW = inputTy.getShape()[2];
1897 
1898     if (!resultTy.hasStaticShape())
1899       return failure();
1900     if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
1901       return failure();
1902 
1903     auto initTensor =
1904         rewriter
1905             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
1906                                           resultTy.getShape(), resultElementTy)
1907             .result();
1908 
1909     SmallVector<AffineMap, 2> affineMaps = {
1910         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1911 
1912     auto genericOp = rewriter.create<linalg::GenericOp>(
1913         loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1914         getNParallelLoopsAttrs(resultTy.getRank()));
1915     rewriter.replaceOp(op, genericOp.getResult(0));
1916 
1917     {
1918       OpBuilder::InsertionGuard regionGuard(rewriter);
1919       rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1920                            TypeRange({resultElementTy}));
1921       Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1922       Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1923       Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1924       Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1925 
1926       auto hwMin =
1927           rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1928       auto hMax = rewriter.create<ConstantOp>(
1929           loc, rewriter.getI32IntegerAttr(imageH - 1));
1930       auto wMax = rewriter.create<ConstantOp>(
1931           loc, rewriter.getI32IntegerAttr(imageW - 1));
1932 
1933       Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), y);
1934       Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), x);
1935 
1936       int32_t shift = op.shift();
1937       bool floatingPointMode = shift == 0;
1938 
1939       Value yStride, xStride, yOffset, xOffset;
1940       if (floatingPointMode) {
1941         yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]);
1942         xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]);
1943         yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]);
1944         xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]);
1945       } else {
1946         SmallVector<int32_t> stride, offset;
1947         getValuesFromIntArrayAttribute(op.stride(), stride);
1948         getValuesFromIntArrayAttribute(op.offset(), offset);
1949 
1950         yStride = rewriter.create<ConstantOp>(
1951             loc, rewriter.getI32IntegerAttr(stride[0]));
1952         xStride = rewriter.create<ConstantOp>(
1953             loc, rewriter.getI32IntegerAttr(stride[1]));
1954         yOffset = rewriter.create<ConstantOp>(
1955             loc, rewriter.getI32IntegerAttr(offset[0]));
1956         xOffset = rewriter.create<ConstantOp>(
1957             loc, rewriter.getI32IntegerAttr(offset[1]));
1958       }
1959 
1960       // Compute the the integer index and partial offset.
1961       // x = x * stride + offset;
1962       // ix = floor(x)
1963       // dx = x - ix
1964       Value ix, iy, dx, dy;
1965       if (floatingPointMode) {
1966         Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY);
1967         Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX);
1968 
1969         y = rewriter.create<MulFOp>(loc, y, yStride);
1970         x = rewriter.create<MulFOp>(loc, x, xStride);
1971 
1972         y = rewriter.create<AddFOp>(loc, y, yOffset);
1973         x = rewriter.create<AddFOp>(loc, x, xOffset);
1974 
1975         iy = rewriter.create<FloorFOp>(loc, y);
1976         ix = rewriter.create<FloorFOp>(loc, x);
1977 
1978         dy = rewriter.create<SubFOp>(loc, y, iy);
1979         dx = rewriter.create<SubFOp>(loc, x, ix);
1980 
1981         iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy);
1982         ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix);
1983       } else {
1984         Value shiftVal =
1985             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift));
1986 
1987         Value y = rewriter.create<MulIOp>(loc, inY, yStride);
1988         Value x = rewriter.create<MulIOp>(loc, inX, xStride);
1989 
1990         y = rewriter.create<AddIOp>(loc, y, yOffset);
1991         x = rewriter.create<AddIOp>(loc, x, xOffset);
1992 
1993         iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal);
1994         ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal);
1995 
1996         Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal);
1997         Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal);
1998 
1999         dy = rewriter.create<SubIOp>(loc, y, yTrunc);
2000         dx = rewriter.create<SubIOp>(loc, x, xTrunc);
2001       }
2002 
2003       if (op.mode() == "NEAREST_NEIGHBOR") {
2004         Value yPred, xPred;
2005         // Round the index position towards the closest pixel location.
2006         if (floatingPointMode) {
2007           auto halfVal =
2008               rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
2009           yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy,
2010                                                 halfVal);
2011           xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx,
2012                                                 halfVal);
2013         } else {
2014           auto halfVal = rewriter.create<ConstantOp>(
2015               loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
2016           yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy,
2017                                                 halfVal);
2018           xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx,
2019                                                 halfVal);
2020         }
2021 
2022         auto zeroVal =
2023             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
2024         auto oneVal =
2025             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2026 
2027         auto yOffset =
2028             rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
2029         auto xOffset =
2030             rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
2031 
2032         iy = rewriter.create<AddIOp>(loc, iy, yOffset);
2033         ix = rewriter.create<AddIOp>(loc, ix, xOffset);
2034 
2035         // Clamp the to be within the bounds of the input image.
2036 
2037         iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt,
2038                                        rewriter);
2039         ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt,
2040                                        rewriter);
2041 
2042         // Read the value from the input array.
2043         iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy);
2044         ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix);
2045 
2046         Value result = rewriter.create<tensor::ExtractOp>(
2047             loc, input, ValueRange{batch, iy, ix, channel});
2048 
2049         rewriter.create<linalg::YieldOp>(loc, result);
2050 
2051         return success();
2052       }
2053 
2054       if (op.mode() == "BILINEAR") {
2055         Value y0 = iy;
2056         Value x0 = ix;
2057 
2058         auto oneVal =
2059             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2060         Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal);
2061         Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal);
2062 
2063         y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt,
2064                                        rewriter);
2065         y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt,
2066                                        rewriter);
2067 
2068         x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt,
2069                                        rewriter);
2070         x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt,
2071                                        rewriter);
2072 
2073         y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0);
2074         y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1);
2075         x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0);
2076         x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1);
2077 
2078         Value y0x0 = rewriter.create<tensor::ExtractOp>(
2079             loc, input, ValueRange{batch, y0, x0, channel});
2080         Value y0x1 = rewriter.create<tensor::ExtractOp>(
2081             loc, input, ValueRange{batch, y0, x1, channel});
2082         Value y1x0 = rewriter.create<tensor::ExtractOp>(
2083             loc, input, ValueRange{batch, y1, x0, channel});
2084         Value y1x1 = rewriter.create<tensor::ExtractOp>(
2085             loc, input, ValueRange{batch, y1, x1, channel});
2086 
2087         if (floatingPointMode) {
2088           auto oneVal =
2089               rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f));
2090           Value rightPart = dx;
2091           Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx);
2092 
2093           y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart);
2094           y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart);
2095           Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1);
2096 
2097           y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart);
2098           y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart);
2099           Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1);
2100 
2101           Value bottomPart = dy;
2102           Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy);
2103           topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart);
2104           bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart);
2105           Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc);
2106 
2107           rewriter.create<linalg::YieldOp>(loc, result);
2108           return success();
2109         } else {
2110           y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0);
2111           y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1);
2112           y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0);
2113           y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1);
2114 
2115           if (resultElementTy.getIntOrFloatBitWidth() > 32) {
2116             dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx);
2117             dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy);
2118           }
2119 
2120           auto unitVal = rewriter.create<ConstantOp>(
2121               loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
2122           Value rightPart = dx;
2123           Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx);
2124 
2125           y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart);
2126           y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart);
2127           Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1);
2128 
2129           y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart);
2130           y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart);
2131           Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1);
2132 
2133           Value bottomPart = dy;
2134           Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy);
2135           topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart);
2136           bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart);
2137           Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc);
2138 
2139           rewriter.create<linalg::YieldOp>(loc, result);
2140           return success();
2141         }
2142       }
2143 
2144       return failure();
2145     }
2146 
2147     return success();
2148   }
2149 };
2150 
2151 // At the codegen level any identity operations should be removed. Any cases
2152 // where identity is load-bearing (e.g. cross device computation) should be
2153 // handled before lowering to codegen.
2154 template <typename SrcOp>
2155 class IdentityNConverter : public OpRewritePattern<SrcOp> {
2156 public:
2157   using OpRewritePattern<SrcOp>::OpRewritePattern;
2158 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const2159   LogicalResult matchAndRewrite(SrcOp op,
2160                                 PatternRewriter &rewriter) const final {
2161     rewriter.replaceOp(op, op.getOperation()->getOperands());
2162     return success();
2163   }
2164 };
2165 
2166 template <typename SrcOp>
2167 class ReduceConverter : public OpRewritePattern<SrcOp> {
2168 public:
2169   using OpRewritePattern<SrcOp>::OpRewritePattern;
2170 
matchAndRewrite(SrcOp reduceOp,PatternRewriter & rewriter) const2171   LogicalResult matchAndRewrite(SrcOp reduceOp,
2172                                 PatternRewriter &rewriter) const final {
2173     return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
2174   }
2175 };
2176 
2177 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
2178   using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
2179 
2180   LogicalResult
matchAndRewrite__anonfa7bfd680511::ConcatConverter2181   matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
2182                   ConversionPatternRewriter &rewriter) const override {
2183     auto resultType = op.getType().dyn_cast<RankedTensorType>();
2184     if (!resultType || !resultType.hasStaticShape()) {
2185       return rewriter.notifyMatchFailure(op,
2186                                          "expected static shaped tensor type");
2187     }
2188 
2189     Location loc = op.getLoc();
2190     int axis = op.axis();
2191     Value axisValue =
2192         rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
2193     int rank = resultType.getRank();
2194     SmallVector<Value, 3> offsets, sizes, strides;
2195     sizes.reserve(rank);
2196     strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
2197     offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
2198 
2199     for (int i = 0; i < rank; ++i) {
2200       sizes.push_back(
2201           rewriter.create<tensor::DimOp>(loc, adaptor.getOperands()[0], i));
2202     }
2203 
2204     Value resultDimSize = sizes[axis];
2205     for (auto arg : adaptor.getOperands().drop_front()) {
2206       auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
2207       resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
2208     }
2209     sizes[axis] = resultDimSize;
2210 
2211     Value init = rewriter.create<linalg::InitTensorOp>(
2212         loc, resultType.getShape(), resultType.getElementType());
2213 
2214     Value zeroVal = rewriter.create<ConstantOp>(
2215         loc, rewriter.getZeroAttr(resultType.getElementType()));
2216     Value result =
2217         rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
2218 
2219     for (auto arg : adaptor.getOperands()) {
2220       sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
2221       result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
2222                                                       sizes, strides);
2223       offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
2224     }
2225     rewriter.replaceOp(op, result);
2226     return success();
2227   }
2228 };
2229 
2230 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2231 public:
2232   using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2233 
matchAndRewrite(tosa::ReverseOp op,PatternRewriter & rewriter) const2234   LogicalResult matchAndRewrite(tosa::ReverseOp op,
2235                                 PatternRewriter &rewriter) const final {
2236     auto loc = op.getLoc();
2237     Value input = op.input();
2238     auto inputTy = input.getType().template cast<ShapedType>();
2239     auto resultTy = op.getType().template cast<ShapedType>();
2240     auto axis = op.axis();
2241 
2242     SmallVector<Value> dynDims;
2243     for (int i = 0; i < inputTy.getRank(); i++) {
2244       if (inputTy.isDynamicDim(i)) {
2245         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2246       }
2247     }
2248 
2249     Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2250 
2251     // First fill the output buffer with the init value.
2252     auto initTensor = rewriter
2253                           .create<linalg::InitTensorOp>(
2254                               loc, ArrayRef<Value>({dynDims}),
2255                               inputTy.getShape(), inputTy.getElementType())
2256                           .result();
2257     SmallVector<AffineMap, 2> affineMaps = {
2258         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2259 
2260     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2261         op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
2262         getNParallelLoopsAttrs(resultTy.getRank()),
2263         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2264           llvm::SmallVector<Value> indices;
2265           for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2266             auto index =
2267                 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2268             if (i == axis) {
2269               auto one = rewriter.create<ConstantIndexOp>(nestedLoc, 1);
2270               auto sizeMinusOne =
2271                   rewriter.create<SubIOp>(nestedLoc, axisDimSize, one);
2272               index = rewriter.create<SubIOp>(nestedLoc, sizeMinusOne, index);
2273             }
2274 
2275             indices.push_back(index);
2276           }
2277 
2278           auto extract = nestedBuilder.create<tensor::ExtractOp>(
2279               nestedLoc, input, indices);
2280           nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2281                                                 extract.getResult());
2282         });
2283     return success();
2284   }
2285 };
2286 
2287 // This converter translate a tile operation to a reshape, broadcast, reshape.
2288 // The first reshape minimally expands each tiled dimension to include a
2289 // proceding size-1 dim. This dim is then broadcasted to the appropriate
2290 // multiple.
2291 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2292   using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2293 
2294   LogicalResult
matchAndRewrite__anonfa7bfd680511::TileConverter2295   matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2296                   ConversionPatternRewriter &rewriter) const override {
2297     auto loc = op.getLoc();
2298     auto input = op.input1();
2299     auto inputTy = input.getType().cast<ShapedType>();
2300     auto inputShape = inputTy.getShape();
2301     auto resultTy = op.getType().cast<ShapedType>();
2302     auto elementTy = inputTy.getElementType();
2303     int64_t rank = inputTy.getRank();
2304 
2305     if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
2306       return failure();
2307 
2308     SmallVector<int64_t> multiples;
2309     getValuesFromIntArrayAttribute(op.multiples(), multiples);
2310 
2311     // Broadcast the newly added dimensions to their appropriate multiple.
2312     SmallVector<int64_t, 2> genericShape;
2313     for (int i = 0; i < rank; i++) {
2314       genericShape.push_back(multiples[i]);
2315       genericShape.push_back(inputShape[i]);
2316     }
2317 
2318     auto initTensor = rewriter.create<linalg::InitTensorOp>(
2319         op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
2320 
2321     // We needs to map the input shape to the non-broadcasted dimensions.
2322     SmallVector<AffineExpr, 4> dimExprs;
2323     dimExprs.reserve(rank);
2324     for (unsigned i = 0; i < rank; ++i)
2325       dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2326 
2327     auto readAffineMap =
2328         AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2329                        rewriter.getContext());
2330 
2331     SmallVector<AffineMap, 2> affineMaps = {
2332         readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2333 
2334     auto genericOp = rewriter.create<linalg::GenericOp>(
2335         loc, RankedTensorType::get(genericShape, elementTy), input,
2336         ValueRange{initTensor}, affineMaps,
2337         getNParallelLoopsAttrs(genericShape.size()),
2338         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2339           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2340         });
2341 
2342     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2343         op, resultTy, genericOp.getResult(0),
2344         rewriter.getI64ArrayAttr(resultTy.getShape()));
2345     return success();
2346   }
2347 };
2348 
2349 class PadConverter : public OpRewritePattern<tosa::PadOp> {
2350 public:
2351   using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
2352 
matchAndRewrite(tosa::PadOp padOp,PatternRewriter & rewriter) const2353   LogicalResult matchAndRewrite(tosa::PadOp padOp,
2354                                 PatternRewriter &rewriter) const final {
2355     auto loc = padOp.getLoc();
2356     auto input = padOp.input1();
2357     auto padding = padOp.padding();
2358 
2359     ShapedType inputTy = input.getType().cast<ShapedType>();
2360     ShapedType paddingTy = padding.getType().cast<ShapedType>();
2361     Type elementTy = inputTy.getElementType();
2362     int64_t rank = inputTy.getRank();
2363 
2364     if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
2365       return rewriter.notifyMatchFailure(
2366           padOp,
2367           "Pad converter requires static shaped input / padding values.");
2368     }
2369 
2370     Attribute constantAttr;
2371     if (elementTy.isa<FloatType>())
2372       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
2373     else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
2374       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
2375     else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
2376       auto value = padOp.quantization_info().getValue().input_zp().getValue();
2377       constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
2378     }
2379 
2380     if (!constantAttr) {
2381       return rewriter.notifyMatchFailure(
2382           padOp,
2383           "tosa.pad to linalg lowering encountered an unknown element type");
2384     }
2385 
2386     Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
2387     Value highIndex =
2388         rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
2389 
2390     SmallVector<OpFoldResult, 3> lowValues;
2391     SmallVector<OpFoldResult, 3> highValues;
2392 
2393     lowValues.reserve(rank);
2394     highValues.reserve(rank);
2395 
2396     for (int i = 0; i < rank; i++) {
2397       Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
2398       Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
2399           loc, padding, ValueRange({inputIndex, lowIndex}));
2400       Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
2401           loc, padding, ValueRange({inputIndex, highIndex}));
2402 
2403       lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2404                                                   lowVal);
2405       highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2406                                                    highVal);
2407 
2408       lowValues.push_back(lowVal);
2409       highValues.push_back(highVal);
2410     }
2411 
2412     Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
2413 
2414     auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
2415         padOp.getType(), input, constant, lowValues, highValues,
2416         /*nofold=*/false, loc, rewriter);
2417 
2418     rewriter.replaceOp(padOp, newPadOp.getResult());
2419     return success();
2420   }
2421 };
2422 
2423 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2424 // op, producing two output buffers.
2425 //
2426 // The first output buffer contains the index of the found maximum value. It is
2427 // initialized to 0 and is resulting integer type.
2428 //
2429 // The second output buffer contains the maximum value found. It is initialized
2430 // to the minimum representable value of the input element type. After being
2431 // populated by indexed_generic, this buffer is disgarded as only the index is
2432 // requested.
2433 //
2434 // The indexed_generic op updates both the maximum value and index if the
2435 // current value exceeds the running max.
2436 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2437 public:
2438   using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2439 
matchAndRewrite(tosa::ArgMaxOp argmaxOp,PatternRewriter & rewriter) const2440   LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2441                                 PatternRewriter &rewriter) const final {
2442     auto loc = argmaxOp.getLoc();
2443     Value input = argmaxOp.input();
2444     auto inputTy = input.getType().cast<ShapedType>();
2445     auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
2446     auto inElementTy = inputTy.getElementType();
2447     auto outElementTy = resultTy.getElementType();
2448     int axis = argmaxOp.axis();
2449     auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2450 
2451     if (!inputTy.hasStaticShape())
2452       return rewriter.notifyMatchFailure(
2453           argmaxOp,
2454           "tosa.arg_max to linalg.* requires statically shaped input");
2455 
2456     if (!outElementTy.isa<IntegerType>())
2457       return rewriter.notifyMatchFailure(
2458           argmaxOp,
2459           "tosa.arg_max to linalg.* requires integer-like result type");
2460 
2461     // First fill the output buffer for the index.
2462     auto initTensorIdx =
2463         rewriter
2464             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2465                                           resultTy.getShape(), outElementTy)
2466             .result();
2467     auto fillValueIdx = rewriter.create<ConstantOp>(
2468         loc, rewriter.getIntegerAttr(outElementTy, 0));
2469     auto filledTensorIdx =
2470         rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
2471             .result();
2472 
2473     // Second fill the output buffer for the running max.
2474     auto initTensorMax =
2475         rewriter
2476             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2477                                           resultTy.getShape(), inElementTy)
2478             .result();
2479     auto fillValueMaxAttr =
2480         createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2481 
2482     if (!fillValueMaxAttr)
2483       return rewriter.notifyMatchFailure(
2484           argmaxOp, "unsupported tosa.argmax element type");
2485 
2486     auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
2487     auto filledTensorMax =
2488         rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
2489             .result();
2490 
2491     // We need to reduce along the arg-max axis, with parallel operations along
2492     // the rest.
2493     SmallVector<StringRef, 4> iteratorTypes;
2494     iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
2495     iteratorTypes[axis] = getReductionIteratorTypeName();
2496 
2497     SmallVector<AffineExpr, 2> srcExprs;
2498     SmallVector<AffineExpr, 2> dstExprs;
2499     for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2500       srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2501       if (axis != i)
2502         dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2503     }
2504 
2505     bool didEncounterError = false;
2506     auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2507     auto linalgOp = rewriter.create<linalg::GenericOp>(
2508         loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2509         ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2510         [&](OpBuilder &nestedBuilder, Location nestedLoc,
2511             ValueRange blockArgs) {
2512           auto newValue = blockArgs[0];
2513           auto oldIndex = blockArgs[1];
2514           auto oldValue = blockArgs[2];
2515 
2516           Value newIndex = rewriter.create<IndexCastOp>(
2517               nestedLoc, oldIndex.getType(),
2518               rewriter.create<linalg::IndexOp>(loc, axis));
2519 
2520           Value predicate;
2521           if (inElementTy.isa<FloatType>()) {
2522             predicate = rewriter.create<mlir::CmpFOp>(
2523                 nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
2524           } else if (inElementTy.isa<IntegerType>()) {
2525             predicate = rewriter.create<mlir::CmpIOp>(
2526                 nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
2527           } else {
2528             didEncounterError = true;
2529             return;
2530           }
2531 
2532           auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
2533                                                            newValue, oldValue);
2534           auto resultIndex = rewriter.create<mlir::SelectOp>(
2535               nestedLoc, predicate, newIndex, oldIndex);
2536           nestedBuilder.create<linalg::YieldOp>(
2537               nestedLoc, ValueRange({resultIndex, resultMax}));
2538         });
2539 
2540     if (didEncounterError)
2541       return rewriter.notifyMatchFailure(
2542           argmaxOp, "unsupported tosa.argmax element type");
2543 
2544     rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2545     return success();
2546   }
2547 };
2548 
2549 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2550 public:
2551   using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2552   LogicalResult
matchAndRewrite(tosa::GatherOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2553   matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2554                   ConversionPatternRewriter &rewriter) const final {
2555     auto input = adaptor.getOperands()[0];
2556     auto indices = adaptor.getOperands()[1];
2557 
2558     auto inputTy = input.getType().cast<ShapedType>();
2559     auto indicesTy = indices.getType().cast<ShapedType>();
2560     auto resultTy = op.getType().cast<ShapedType>();
2561 
2562     if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
2563       return rewriter.notifyMatchFailure(
2564           op, "require input type to have static shape");
2565 
2566     auto resultElementTy = resultTy.getElementType();
2567 
2568     auto loc = op.getLoc();
2569 
2570     auto initTensor =
2571         rewriter
2572             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2573                                           resultTy.getShape(), resultElementTy)
2574             .result();
2575 
2576     SmallVector<AffineMap, 2> affineMaps = {
2577         AffineMap::get(
2578             /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2579             {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2580             rewriter.getContext()),
2581         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2582 
2583     auto genericOp = rewriter.create<linalg::GenericOp>(
2584         loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2585         ValueRange{initTensor}, affineMaps,
2586         getNParallelLoopsAttrs(resultTy.getRank()),
2587         [&](OpBuilder &b, Location loc, ValueRange args) {
2588           auto indexValue = args[0];
2589           auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2590           Value index1 = rewriter.create<IndexCastOp>(
2591               loc, rewriter.getIndexType(), indexValue);
2592           auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2593           Value extract = rewriter.create<tensor::ExtractOp>(
2594               loc, input, ValueRange{index0, index1, index2});
2595           rewriter.create<linalg::YieldOp>(loc, extract);
2596         });
2597     rewriter.replaceOp(op, genericOp.getResult(0));
2598     return success();
2599   }
2600 };
2601 
2602 // Lowerings the TableOp to a series of gathers and numerica operations. This
2603 // includes interpolation between the high/low values. For the I8 varient, this
2604 // simplifies to a single gather operation.
2605 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2606 public:
2607   using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2608 
matchAndRewrite(tosa::TableOp op,PatternRewriter & rewriter) const2609   LogicalResult matchAndRewrite(tosa::TableOp op,
2610                                 PatternRewriter &rewriter) const final {
2611     auto loc = op.getLoc();
2612     Value input = op.input();
2613     Value table = op.table();
2614     auto inputTy = input.getType().cast<ShapedType>();
2615     auto tableTy = table.getType().cast<ShapedType>();
2616     auto resultTy = op.getType().cast<ShapedType>();
2617 
2618     if (!inputTy.hasStaticShape())
2619       return rewriter.notifyMatchFailure(
2620           op, "require input type to have static shape");
2621 
2622     auto inputElementTy = inputTy.getElementType();
2623     auto tableElementTy = tableTy.getElementType();
2624     auto resultElementTy = resultTy.getElementType();
2625 
2626     auto initTensor =
2627         rewriter
2628             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2629                                           resultTy.getShape(), resultElementTy)
2630             .result();
2631 
2632     SmallVector<AffineMap, 2> affineMaps = {
2633         rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2634         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2635 
2636     auto genericOp = rewriter.create<linalg::GenericOp>(
2637         loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2638         getNParallelLoopsAttrs(resultTy.getRank()));
2639     rewriter.replaceOp(op, genericOp.getResult(0));
2640 
2641     {
2642       OpBuilder::InsertionGuard regionGuard(rewriter);
2643       Block *block =
2644           rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
2645                                TypeRange({inputElementTy, resultElementTy}));
2646 
2647       auto inputValue = block->getArgument(0);
2648       rewriter.setInsertionPointToStart(block);
2649       if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2650           resultElementTy.isInteger(8)) {
2651         Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
2652                                                    inputValue);
2653         Value offset = rewriter.create<ConstantIndexOp>(loc, 128);
2654         index = rewriter.create<AddIOp>(loc, rewriter.getIndexType(), index,
2655                                         offset);
2656         Value extract =
2657             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2658         rewriter.create<linalg::YieldOp>(loc, extract);
2659         return success();
2660       }
2661 
2662       if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2663           resultElementTy.isInteger(32)) {
2664         Value extend = rewriter.create<SignExtendIOp>(
2665             loc, rewriter.getI32Type(), inputValue);
2666 
2667         auto offset =
2668             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
2669         auto seven =
2670             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
2671         auto one =
2672             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2673         auto b1111111 =
2674             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
2675 
2676         // Compute the index and fractional part from the input value:
2677         // value = value + 32768
2678         // index = value >> 7;
2679         // fraction = 0x01111111 & value
2680         auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
2681         Value index =
2682             rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
2683         Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
2684 
2685         // Extract the base and next values from the table.
2686         // base = (int32_t) table[index];
2687         // next = (int32_t) table[index + 1];
2688         Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
2689 
2690         index =
2691             rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
2692         indexPlusOne = rewriter.create<IndexCastOp>(
2693             loc, rewriter.getIndexType(), indexPlusOne);
2694 
2695         Value base =
2696             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2697         Value next = rewriter.create<tensor::ExtractOp>(
2698             loc, table, ValueRange{indexPlusOne});
2699 
2700         base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
2701         next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
2702 
2703         // Use the fractional part to interpolate between the input values:
2704         // result = (base << 7) + (next - base) * fraction
2705         Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
2706         Value diff = rewriter.create<SubIOp>(loc, next, base);
2707         Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction);
2708         Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled);
2709 
2710         rewriter.create<linalg::YieldOp>(loc, result);
2711 
2712         return success();
2713       }
2714     }
2715 
2716     return rewriter.notifyMatchFailure(
2717         op, "unable to create body for tosa.table op");
2718   }
2719 };
2720 
2721 class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
2722 public:
2723   using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
2724 
matchAndRewrite(tosa::MaxPool2dOp op,PatternRewriter & rewriter) const2725   LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
2726                                 PatternRewriter &rewriter) const final {
2727     Location loc = op.getLoc();
2728     Value input = op.input();
2729     ShapedType inputTy = input.getType().cast<ShapedType>();
2730 
2731     ShapedType resultTy = op.getType().template cast<ShapedType>();
2732     Type resultETy = inputTy.getElementType();
2733 
2734     if (!inputTy.hasStaticShape())
2735       return failure();
2736 
2737     // Determine what the initial value needs to be for the max pool op.
2738     Attribute initialAttr;
2739     if (resultETy.isF32())
2740       initialAttr = rewriter.getFloatAttr(
2741           resultETy,
2742           APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
2743                               true));
2744 
2745     if (resultETy.isa<IntegerType>())
2746       initialAttr = rewriter.getIntegerAttr(
2747           resultETy,
2748           APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
2749 
2750     if (!initialAttr)
2751       return rewriter.notifyMatchFailure(
2752           op, "Unsupported initial value for tosa.maxpool_2d op");
2753 
2754     // Apply padding as necessary.
2755     llvm::SmallVector<int64_t> pad;
2756     pad.resize(2, 0);
2757     getValuesFromIntArrayAttribute(op.pad(), pad);
2758     pad.resize(pad.size() + 2, 0);
2759     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2760 
2761     Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2762 
2763     SmallVector<int64_t> kernel, stride;
2764     getValuesFromIntArrayAttribute(op.kernel(), kernel);
2765     getValuesFromIntArrayAttribute(op.stride(), stride);
2766 
2767     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2768     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2769 
2770     // Create the linalg op that performs pooling.
2771     Value initTensor = rewriter.create<linalg::InitTensorOp>(
2772         loc, resultTy.getShape(), resultTy.getElementType());
2773 
2774     Value filledInitTensor =
2775         rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
2776 
2777     Value fakeWindowDims =
2778         rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
2779 
2780     rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
2781         op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2782         filledInitTensor, strideAttr, dilationAttr);
2783     return success();
2784   }
2785 };
2786 
2787 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
2788 public:
2789   using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
2790 
matchAndRewrite(tosa::AvgPool2dOp op,PatternRewriter & rewriter) const2791   LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
2792                                 PatternRewriter &rewriter) const final {
2793     Location loc = op.getLoc();
2794     Value input = op.input();
2795     ShapedType inputTy = input.getType().cast<ShapedType>();
2796     Type inElementTy = inputTy.getElementType();
2797 
2798     ShapedType resultTy = op.getType().template cast<ShapedType>();
2799     Type resultETy = inputTy.getElementType();
2800 
2801     Type accETy =
2802         inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
2803     ShapedType accTy = resultTy.clone(accETy);
2804 
2805     if (!inputTy.hasStaticShape())
2806       return failure();
2807 
2808     // Apply padding as necessary.
2809     llvm::SmallVector<int64_t> pad;
2810     pad.resize(2, 0);
2811     getValuesFromIntArrayAttribute(op.pad(), pad);
2812     pad.resize(pad.size() + 2, 0);
2813     Attribute initialAttr = rewriter.getZeroAttr(accETy);
2814     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2815 
2816     Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2817 
2818     SmallVector<int64_t> kernel, stride;
2819     getValuesFromIntArrayAttribute(op.kernel(), kernel);
2820     getValuesFromIntArrayAttribute(op.stride(), stride);
2821 
2822     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2823     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2824 
2825     // Create the linalg op that performs pooling.
2826     Value poolInitTensor =
2827         rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
2828 
2829     Value filledInitTensor =
2830         rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
2831             .result();
2832 
2833     Value fakeWindowDims =
2834         rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
2835 
2836     // Sum across the pooled region.
2837     Value poolingOp = rewriter
2838                           .create<linalg::PoolingNhwcSumOp>(
2839                               loc, ArrayRef<Type>{accTy},
2840                               ValueRange{paddedInput, fakeWindowDims},
2841                               filledInitTensor, strideAttr, dilationAttr)
2842                           .getResult(0);
2843 
2844     // Normalize the summed value by the number of elements grouped in each
2845     // pool.
2846     auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
2847     auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
2848 
2849     Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
2850         loc, resultTy.getShape(), resultETy);
2851 
2852     auto genericOp = rewriter.create<linalg::GenericOp>(
2853         loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
2854         ValueRange{genericInitTensor},
2855         ArrayRef<AffineMap>({affineMap, affineMap}),
2856         getNParallelLoopsAttrs(resultTy.getRank()),
2857         [&](OpBuilder &b, Location loc, ValueRange args) {
2858           auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
2859           auto one = rewriter.create<ConstantIndexOp>(loc, 1);
2860           auto iH = rewriter.create<ConstantIndexOp>(
2861               loc, poolingOpTy.getDimSize(1) - 1);
2862           auto iW = rewriter.create<ConstantIndexOp>(
2863               loc, poolingOpTy.getDimSize(2) - 1);
2864 
2865           // Compute the indices from either end.
2866           auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
2867           auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
2868           auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
2869           auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
2870 
2871           // Determines what the portion of valid input is covered by the
2872           // kernel.
2873           auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2874             if (pad == 0)
2875               return v;
2876 
2877             auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
2878             Value dx = rewriter.create<SubIOp>(loc, x, padVal);
2879 
2880             Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
2881                                                       dx, zero);
2882             Value offset = rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
2883             return rewriter.create<mlir::AddIOp>(loc, v, offset)->getResult(0);
2884           };
2885 
2886           // Compute the vertical component of coverage.
2887           auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
2888           auto kH1 = padFn(kH0, y0, pad[2]);
2889           auto kH2 = padFn(kH1, y1, pad[3]);
2890           auto kHCmp =
2891               rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
2892           auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
2893 
2894           // compute the horizontal component of coverage.
2895           auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
2896           auto kW1 = padFn(kW0, x0, pad[4]);
2897           auto kW2 = padFn(kW1, x1, pad[5]);
2898           auto kWCmp =
2899               rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
2900           auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
2901 
2902           // Compute the total number of elements and normalize.
2903           Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
2904           auto countI = rewriter.create<mlir::IndexCastOp>(
2905               loc, rewriter.getI32Type(), count);
2906 
2907           // Divide by the number of summed values. For floats this is just
2908           // a div however for quantized values input normalization had
2909           // to be applied.
2910           Value poolVal = args[0];
2911           if (accETy.isa<FloatType>()) {
2912             auto countF =
2913                 rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
2914             poolVal =
2915                 rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
2916           } else {
2917 
2918             // If we have quantization information we need to apply an offset
2919             // for the input zp value.
2920             if (op.quantization_info()) {
2921               auto quantizationInfo = op.quantization_info().getValue();
2922               auto inputZp = rewriter.create<mlir::ConstantOp>(
2923                   loc, quantizationInfo.input_zp());
2924               Value offset =
2925                   rewriter.create<mlir::MulIOp>(loc, accETy, countI, inputZp);
2926               poolVal = rewriter.create<SubIOp>(loc, accETy, poolVal, offset);
2927             }
2928 
2929             // Compute the multiplier and shift values for the quantization
2930             // normalization. Preferably we would want to compute more bits
2931             // however 32-bits should be enough for compute. Honestly we
2932             // should probably straight divide.
2933             int64_t numerator = ((1 << 30) + 1);
2934             int64_t shift = 30;
2935 
2936             Value numeratorVal = rewriter.create<ConstantOp>(
2937                 loc, rewriter.getI32IntegerAttr(numerator));
2938             Value multiplierVal =
2939                 rewriter
2940                     .create<UnsignedDivIOp>(loc, rewriter.getI32Type(),
2941                                             numeratorVal, countI)
2942                     .getResult();
2943             Value shiftVal = rewriter.create<ConstantOp>(
2944                 loc, rewriter.getI8IntegerAttr(shift));
2945 
2946             auto scaled =
2947                 rewriter
2948                     .create<tosa::ApplyScaleOp>(
2949                         loc, rewriter.getI32Type(), poolVal, multiplierVal,
2950                         shiftVal, rewriter.getBoolAttr(false))
2951                     .getResult();
2952 
2953             // If we have quantization information we need to apply output
2954             // zeropoint.
2955             if (op.quantization_info()) {
2956               auto quantizationInfo = op.quantization_info().getValue();
2957               auto outputZp = rewriter.create<mlir::ConstantOp>(
2958                   loc, quantizationInfo.output_zp());
2959               scaled =
2960                   rewriter.create<AddIOp>(loc, scaled, outputZp).getResult();
2961             }
2962 
2963             // Apply Clip.
2964             int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
2965 
2966             auto min = rewriter.create<ConstantOp>(
2967                 loc, rewriter.getIntegerAttr(
2968                          accETy,
2969                          APInt::getSignedMinValue(outBitwidth).getSExtValue()));
2970             auto max = rewriter.create<ConstantOp>(
2971                 loc, rewriter.getIntegerAttr(
2972                          accETy,
2973                          APInt::getSignedMaxValue(outBitwidth).getSExtValue()));
2974             auto clamp = clampHelper<mlir::CmpIOp>(
2975                 loc, scaled, min, max, CmpIPredicate::slt, rewriter);
2976 
2977             // Convert type.
2978             poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
2979           }
2980 
2981           // Cast to output type.
2982 
2983           rewriter.create<linalg::YieldOp>(loc, poolVal);
2984         });
2985 
2986     rewriter.replaceOp(op, genericOp.getResult(0));
2987     return success();
2988   }
2989 };
2990 
2991 } // namespace
2992 
populateTosaToLinalgOnTensorsConversionPatterns(RewritePatternSet * patterns)2993 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
2994     RewritePatternSet *patterns) {
2995   patterns->add<
2996       // clang-format off
2997       PointwiseConverter<tosa::AddOp>,
2998       PointwiseConverter<tosa::SubOp>,
2999       PointwiseConverter<tosa::MulOp>,
3000       PointwiseConverter<tosa::DivOp>,
3001       PointwiseConverter<tosa::NegateOp>,
3002       PointwiseConverter<tosa::PowOp>,
3003       PointwiseConverter<tosa::ReciprocalOp>,
3004       PointwiseConverter<tosa::RsqrtOp>,
3005       PointwiseConverter<tosa::LogOp>,
3006       PointwiseConverter<tosa::ExpOp>,
3007       PointwiseConverter<tosa::AbsOp>,
3008       PointwiseConverter<tosa::TanhOp>,
3009       PointwiseConverter<tosa::BitwiseAndOp>,
3010       PointwiseConverter<tosa::BitwiseOrOp>,
3011       PointwiseConverter<tosa::BitwiseNotOp>,
3012       PointwiseConverter<tosa::BitwiseXorOp>,
3013       PointwiseConverter<tosa::LogicalAndOp>,
3014       PointwiseConverter<tosa::LogicalNotOp>,
3015       PointwiseConverter<tosa::LogicalOrOp>,
3016       PointwiseConverter<tosa::LogicalXorOp>,
3017       PointwiseConverter<tosa::CastOp>,
3018       PointwiseConverter<tosa::LogicalLeftShiftOp>,
3019       PointwiseConverter<tosa::LogicalRightShiftOp>,
3020       PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3021       PointwiseConverter<tosa::ClzOp>,
3022       PointwiseConverter<tosa::SelectOp>,
3023       PointwiseConverter<tosa::GreaterOp>,
3024       PointwiseConverter<tosa::GreaterEqualOp>,
3025       PointwiseConverter<tosa::EqualOp>,
3026       PointwiseConverter<tosa::MaximumOp>,
3027       PointwiseConverter<tosa::MinimumOp>,
3028       PointwiseConverter<tosa::CeilOp>,
3029       PointwiseConverter<tosa::FloorOp>,
3030       PointwiseConverter<tosa::ClampOp>,
3031       PointwiseConverter<tosa::ReluNOp>,
3032       PointwiseConverter<tosa::SigmoidOp>,
3033       IdentityNConverter<tosa::IdentityOp>,
3034       ReduceConverter<tosa::ReduceAllOp>,
3035       ReduceConverter<tosa::ReduceAnyOp>,
3036       ReduceConverter<tosa::ReduceMinOp>,
3037       ReduceConverter<tosa::ReduceMaxOp>,
3038       ReduceConverter<tosa::ReduceSumOp>,
3039       ReduceConverter<tosa::ReduceProdOp>,
3040       ArgMaxConverter,
3041       ConcatConverter,
3042       ConvConverter,
3043       DepthwiseConvConverter,
3044       TransposeConvConverter,
3045       GatherConverter,
3046       PadConverter,
3047       ReshapeConverter,
3048       RescaleConverter,
3049       ResizeConverter,
3050       ReverseConverter,
3051       TableConverter,
3052       TileConverter,
3053       TransposeConverter,
3054       MatMulConverter,
3055       MaxPool2dConverter,
3056       AvgPool2dConverter,
3057       FullyConnectedConverter>(patterns->getContext());
3058   // clang-format on
3059 }
3060