1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #include <numeric>
26 
27 using namespace mlir;
28 
getNParallelLoopsAttrs(unsigned nParallelLoops)29 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
30   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
31 }
32 
33 template <typename T>
34 static mlir::ConstantOp
createConstFromIntAttribute(Operation * op,std::string attrName,Type requiredAttrType,OpBuilder & rewriter)35 createConstFromIntAttribute(Operation *op, std::string attrName,
36                             Type requiredAttrType, OpBuilder &rewriter) {
37   auto castedN = static_cast<T>(
38       op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
39   return rewriter.create<mlir::ConstantOp>(
40       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
41 }
42 
43 template <typename T>
getValuesFromIntArrayAttribute(ArrayAttr attr,SmallVector<T> & arrayValues)44 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
45                                            SmallVector<T> &arrayValues) {
46   for (Attribute val : attr.getValue()) {
47     arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
48   }
49 }
50 
51 template <typename T, typename P>
clampHelper(Location loc,Value arg,mlir::ConstantOp min,mlir::ConstantOp max,P pred,OpBuilder & rewriter)52 static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
53                                   mlir::ConstantOp max, P pred,
54                                   OpBuilder &rewriter) {
55   auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
56   auto minOrArg =
57       rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
58   auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
59   return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
60 }
61 
applyPad(Location loc,Value input,ArrayRef<int64_t> pad,Attribute padAttr,OpBuilder & rewriter)62 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
63                             Attribute padAttr, OpBuilder &rewriter) {
64   // Input should be padded if necessary.
65   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
66     return input;
67 
68   ShapedType inputTy = input.getType().cast<ShapedType>();
69   Type inputETy = inputTy.getElementType();
70   auto inputShape = inputTy.getShape();
71 
72   assert((inputShape.size() * 2) == pad.size());
73 
74   SmallVector<int64_t, 4> paddedShape;
75   SmallVector<OpFoldResult, 8> lowIndices;
76   SmallVector<OpFoldResult, 8> highIndices;
77   for (int i = 0, s = inputShape.size(); i < s; i++) {
78     auto lowPad = pad[i * 2];
79     auto highPad = pad[i * 2 + 1];
80     paddedShape.push_back(inputShape[i] + highPad + lowPad);
81     lowIndices.push_back(rewriter.getIndexAttr(lowPad));
82     highIndices.push_back(rewriter.getIndexAttr(highPad));
83   }
84 
85   Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
86 
87   return linalg::PadTensorOp::createPadScalarOp(
88              RankedTensorType::get(paddedShape, inputETy), input, padValue,
89              lowIndices, highIndices, loc, rewriter)
90       .result();
91 }
92 
93 static Value
createLinalgBodyCalculationForElementwiseOp(Operation * op,ValueRange args,ArrayRef<Type> resultTypes,PatternRewriter & rewriter)94 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
95                                             ArrayRef<Type> resultTypes,
96                                             PatternRewriter &rewriter) {
97   Location loc = op->getLoc();
98   auto elementTy =
99       op->getOperand(0).getType().cast<ShapedType>().getElementType();
100 
101   // tosa::AbsOp
102   if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
103     return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
104 
105   if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
106     auto zero =
107         rewriter.create<mlir::ConstantOp>(loc, rewriter.getZeroAttr(elementTy));
108     auto cmp =
109         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0], zero);
110     auto neg = rewriter.create<mlir::SubIOp>(loc, zero, args[0]);
111     return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg);
112   }
113 
114   // tosa::AddOp
115   if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
116     return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
117 
118   if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
119     return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
120 
121   // tosa::SubOp
122   if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
123     return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
124 
125   if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
126     return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
127 
128   // tosa::MulOp
129   if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
130     if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
131       (void)rewriter.notifyMatchFailure(op,
132                                         "Cannot have shift value for float");
133       return nullptr;
134     }
135     return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
136   }
137 
138   // tosa::DivOp
139   if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
140     return rewriter.create<mlir::SignedDivIOp>(loc, resultTypes, args);
141 
142   // tosa::ReciprocalOp
143   if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
144     auto one =
145         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
146     return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
147   }
148 
149   if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
150     Value a = args[0];
151     Value b = args[1];
152     auto shift =
153         op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
154     if (shift > 0) {
155       auto shiftConst =
156           rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
157       if (!a.getType().isInteger(32))
158         a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
159 
160       if (!b.getType().isInteger(32))
161         b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
162 
163       auto result = rewriter.create<tosa::ApplyScaleOp>(
164           loc, rewriter.getI32Type(), a, b, shiftConst,
165           rewriter.getBoolAttr(false));
166 
167       if (elementTy.isInteger(32))
168         return result;
169 
170       return rewriter.create<TruncateIOp>(loc, elementTy, result);
171     }
172 
173     int aWidth = a.getType().getIntOrFloatBitWidth();
174     int bWidth = b.getType().getIntOrFloatBitWidth();
175     int cWidth = resultTypes[0].getIntOrFloatBitWidth();
176 
177     if (aWidth < cWidth)
178       a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
179     if (bWidth < cWidth)
180       b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
181 
182     return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
183   }
184 
185   // tosa::NegateOp
186   if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
187     return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
188 
189   if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
190       !cast<tosa::NegateOp>(op).quantization_info()) {
191     auto constant =
192         rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
193     return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
194   }
195 
196   if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
197       cast<tosa::NegateOp>(op).quantization_info()) {
198     auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
199     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
200     int64_t inZp =
201         quantizationInfo.getValue().input_zp().getValue().getSExtValue();
202     int64_t outZp =
203         quantizationInfo.getValue().output_zp().getValue().getSExtValue();
204 
205     // Compute the maximum value that can occur in the intermediate buffer.
206     int64_t zpAdd = inZp + outZp;
207     int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
208                        std::abs(zpAdd) + 1;
209 
210     // Convert that maximum value into the maximum bitwidth needed to represent
211     // it. We assume 48-bit numbers may be supported further in the pipeline.
212     int intermediateBitWidth = 64;
213     if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
214       intermediateBitWidth = 16;
215     } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
216       intermediateBitWidth = 32;
217     } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
218       intermediateBitWidth = 48;
219     }
220 
221     Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
222     Value zpAddValue = rewriter.create<ConstantOp>(
223         loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
224 
225     // The negation can be applied by doing:
226     //  outputValue = inZp + outZp - inputValue
227     auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
228     auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
229 
230     // Clamp to the negation range.
231     auto min = rewriter.create<ConstantOp>(
232         loc, rewriter.getIntegerAttr(
233                  intermediateType,
234                  APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
235     auto max = rewriter.create<ConstantOp>(
236         loc, rewriter.getIntegerAttr(
237                  intermediateType,
238                  APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
239     auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
240                                            CmpIPredicate::slt, rewriter);
241 
242     // Truncate to the final value.
243     return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
244   }
245 
246   // tosa::BitwiseAndOp
247   if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
248     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
249 
250   // tosa::BitwiseOrOp
251   if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
252     return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
253 
254   // tosa::BitwiseNotOp
255   if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
256     auto allOnesAttr = rewriter.getIntegerAttr(
257         elementTy, APInt::getAllOnesValue(elementTy.getIntOrFloatBitWidth()));
258     auto allOnes = rewriter.create<ConstantOp>(loc, allOnesAttr);
259     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], allOnes);
260   }
261 
262   // tosa::BitwiseXOrOp
263   if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
264     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
265 
266   // tosa::LogicalLeftShiftOp
267   if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
268     return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
269 
270   // tosa::LogicalRightShiftOp
271   if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
272     return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
273 
274   // tosa::ArithmeticRightShiftOp
275   if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
276     auto result =
277         rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
278     auto round = op->getAttr("round").cast<BoolAttr>().getValue();
279     if (!round) {
280       return result;
281     }
282 
283     Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
284     auto one =
285         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
286     auto zero =
287         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
288     auto i1one =
289         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
290 
291     // Checking that input2 != 0
292     auto shiftValueGreaterThanZero =
293         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
294 
295     // Checking for the last bit of input1 to be 1
296     auto subtract =
297         rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
298     auto shifted = rewriter
299                        .create<mlir::SignedShiftRightOp>(loc, resultTypes,
300                                                          args[0], subtract)
301                        ->getResults();
302     auto truncated =
303         rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
304     auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
305 
306     auto shouldRound = rewriter.create<mlir::AndOp>(
307         loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
308     auto extended =
309         rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
310     return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
311   }
312 
313   // tosa::LogicalAnd
314   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
315     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
316 
317   // tosa::LogicalNot
318   if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
319     auto one = rewriter.create<mlir::ConstantOp>(
320         loc, rewriter.getIntegerAttr(elementTy, 1));
321     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], one);
322   }
323 
324   // tosa::LogicalOr
325   if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
326     return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
327 
328   // tosa::LogicalXor
329   if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
330     return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
331 
332   // tosa::PowOp
333   if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
334     return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
335 
336   // tosa::RsqrtOp
337   if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
338     return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
339 
340   // tosa::LogOp
341   if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
342     return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
343 
344   // tosa::ExpOp
345   if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
346     return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
347 
348   // tosa::TanhOp
349   if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
350     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
351 
352   // tosa::GreaterOp
353   if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
354     return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
355                                          args[1]);
356 
357   if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
358     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
359                                          args[1]);
360 
361   // tosa::GreaterEqualOp
362   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
363     return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
364                                          args[1]);
365 
366   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
367     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
368                                          args[1]);
369 
370   // tosa::EqualOp
371   if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
372     return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
373                                          args[1]);
374 
375   if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
376     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
377                                          args[1]);
378 
379   // tosa::SelectOp
380   if (isa<tosa::SelectOp>(op)) {
381     elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
382     if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
383       return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
384   }
385 
386   // tosa::MaximumOp
387   if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
388     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
389                                                    args[0], args[1]);
390     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
391   }
392 
393   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
394     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
395                                                    args[0], args[1]);
396     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
397   }
398 
399   // tosa::MinimumOp
400   if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
401     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
402                                                    args[0], args[1]);
403     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
404   }
405 
406   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
407     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
408                                                    args[0], args[1]);
409     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
410   }
411 
412   // tosa::CeilOp
413   if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
414     return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
415 
416   // tosa::FloorOp
417   if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
418     return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
419 
420   // tosa::ClampOp
421   if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
422     auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
423                                                  op->getAttr("min_fp"));
424     auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
425                                                  op->getAttr("max_fp"));
426     return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
427                                      rewriter);
428   }
429 
430   if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
431     auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
432                                                     rewriter);
433     auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
434                                                     rewriter);
435     return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
436                                      rewriter);
437   }
438 
439   // tosa::ReluNOp
440   if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
441     auto zero =
442         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
443     auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
444                                                op->getAttr("max_fp"));
445     return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
446                                      rewriter);
447   }
448 
449   if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
450     auto zero =
451         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
452     auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
453                                                   rewriter);
454     return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
455                                      rewriter);
456   }
457 
458   // tosa::SigmoidOp
459   if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
460     auto one =
461         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
462     auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
463     auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
464     auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
465     return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
466   }
467 
468   // tosa::CastOp
469   if (isa<tosa::CastOp>(op)) {
470     Type srcTy = elementTy;
471     Type dstTy = resultTypes.front();
472     bool bitExtend =
473         srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
474 
475     if (srcTy == dstTy)
476       return args.front();
477 
478     if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
479       return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
480 
481     if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
482       return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
483                                               mlir::None);
484 
485     // 1-bit integers need to be treated as signless.
486     if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
487       return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
488                                              mlir::None);
489 
490     if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
491       return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
492                                                   mlir::None);
493 
494     // All other si-to-fp conversions should be handled by SIToFP.
495     if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
496       return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
497                                              mlir::None);
498 
499     // Casting to boolean, floats need to only be checked as not-equal to zero.
500     if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
501       Value zero =
502           rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
503       return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
504                                            args.front(), zero);
505     }
506 
507     if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
508       auto zero =
509           rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
510       auto half =
511           rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
512 
513       auto intMin = rewriter.create<ConstantOp>(
514           loc, rewriter.getF32FloatAttr(
515                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
516                        .getSExtValue()));
517 
518       auto intMax = rewriter.create<ConstantOp>(
519           loc, rewriter.getF32FloatAttr(
520                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
521                        .getSExtValue()));
522 
523       auto added = rewriter.create<AddFOp>(loc, args[0], half);
524       auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
525       auto negative =
526           rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
527       auto rounded =
528           rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
529 
530       auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
531                                                CmpFPredicate::OLT, rewriter);
532 
533       return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
534     }
535 
536     // Casting to boolean, integers need to only be checked as not-equal to
537     // zero.
538     if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
539       Value zero =
540           rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
541       return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
542                                            zero);
543     }
544 
545     if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
546       return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
547                                                   mlir::None);
548 
549     if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
550       auto intMin = rewriter.create<ConstantIntOp>(
551           loc,
552           APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
553               .getSExtValue(),
554           srcTy.getIntOrFloatBitWidth());
555 
556       auto intMax = rewriter.create<ConstantIntOp>(
557           loc,
558           APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
559               .getSExtValue(),
560           srcTy.getIntOrFloatBitWidth());
561 
562       auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
563                                                CmpIPredicate::slt, rewriter);
564       return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
565     }
566   }
567 
568   (void)rewriter.notifyMatchFailure(
569       op, "unhandled op for linalg body calculation for elementwise op");
570   return nullptr;
571 }
572 
573 static LogicalResult
elementwiseMatchAndRewriteHelper(Operation * operation,PatternRewriter & rewriter)574 elementwiseMatchAndRewriteHelper(Operation *operation,
575                                  PatternRewriter &rewriter) {
576   auto loc = operation->getLoc();
577 
578   assert(operation->getNumResults() == 1 &&
579          "All TOSA elementwise ops should only return a single result.");
580 
581   auto results = operation->getResults();
582   auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
583 
584   if (!resultTy)
585     return rewriter.notifyMatchFailure(operation,
586                                        "All results must be a shaped type");
587 
588   unsigned rank = resultTy.getRank();
589 
590   // Construct the indexing maps needed for linalg.generic ops.
591   SmallVector<Type> bodyArgTypes;
592 
593   for (Value in : operation->getOperands())
594     bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
595 
596   SmallVector<Type> opResultTypes;
597   SmallVector<Value> initTensors;
598   for (auto result : results) {
599     auto resultTy = result.getType().template cast<ShapedType>();
600     if (!resultTy.hasStaticShape())
601       return rewriter.notifyMatchFailure(
602           operation,
603           "tosa to linalg conversion expects statically shaped tensors");
604 
605     initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
606         loc, ArrayRef<Value>({}), resultTy.getShape(),
607         resultTy.getElementType()));
608     opResultTypes.push_back(result.getType());
609   }
610 
611   auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
612       initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
613 
614   SmallVector<Value, 2> operands;
615   SmallVector<AffineMap, 2> indexingMaps;
616   indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
617 
618   // Input indexing maps may be broadcasted.
619   for (Value operand : operation->getOperands()) {
620     ShapedType type = operand.getType().cast<ShapedType>();
621 
622     if (type.getShape() == resultTy.getShape()) {
623       operands.push_back(operand);
624       indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
625       continue;
626     }
627 
628     SmallVector<int64_t, 5> newShape;
629     SmallVector<AffineExpr, 4> affineExprs;
630     newShape.reserve(type.getRank());
631     for (auto it : llvm::enumerate(type.getShape())) {
632       if (it.value() == resultTy.getDimSize(it.index())) {
633         newShape.push_back(it.value());
634         affineExprs.push_back(
635             mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
636       }
637     }
638 
639     if (newShape.size() != rank) {
640       operand = rewriter.create<tosa::ReshapeOp>(
641           loc, RankedTensorType::get(newShape, type.getElementType()), operand);
642     }
643 
644     operands.push_back(operand);
645     indexingMaps.push_back(AffineMap::get(
646         /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
647         rewriter.getContext()));
648   }
649 
650   indexingMaps.append(operation->getNumResults(),
651                       rewriter.getMultiDimIdentityMap(rank));
652 
653   bool didEncounterError = false;
654   auto linalgOp = rewriter.create<linalg::GenericOp>(
655       loc, opResultTypes, operands, initTensors, indexingMaps,
656       getNParallelLoopsAttrs(rank),
657       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
658         Value opResult = createLinalgBodyCalculationForElementwiseOp(
659             operation, blockArgs.take_front(operation->getNumOperands()),
660             bodyResultTypes, rewriter);
661         if (!opResult) {
662           didEncounterError = true;
663           return;
664         }
665         nestedBuilder.create<linalg::YieldOp>(loc, opResult);
666       });
667 
668   if (didEncounterError)
669     return failure();
670 
671   rewriter.replaceOp(operation, linalgOp->getResults());
672   return success();
673 }
674 
675 // Returns the constant initial value for a given reduction operation. The
676 // attribute type varies depending on the element type required.
createInitialValueForReduceOp(Operation * op,Type elementTy,PatternRewriter & rewriter)677 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
678                                                PatternRewriter &rewriter) {
679   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
680     return rewriter.getFloatAttr(elementTy, 0.0);
681 
682   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
683     return rewriter.getIntegerAttr(elementTy, 0);
684 
685   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
686     return rewriter.getFloatAttr(elementTy, 1.0);
687 
688   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
689     return rewriter.getIntegerAttr(elementTy, 1);
690 
691   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
692     return rewriter.getFloatAttr(
693         elementTy, APFloat::getLargest(
694                        elementTy.cast<FloatType>().getFloatSemantics(), false));
695 
696   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
697     return rewriter.getIntegerAttr(
698         elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
699 
700   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
701     return rewriter.getFloatAttr(
702         elementTy, APFloat::getLargest(
703                        elementTy.cast<FloatType>().getFloatSemantics(), true));
704 
705   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
706     return rewriter.getIntegerAttr(
707         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
708 
709   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
710     return rewriter.getIntegerAttr(elementTy, APInt::getAllOnesValue(1));
711 
712   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
713     return rewriter.getIntegerAttr(elementTy, APInt::getNullValue(1));
714 
715   if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
716     return rewriter.getFloatAttr(
717         elementTy, APFloat::getLargest(
718                        elementTy.cast<FloatType>().getFloatSemantics(), true));
719 
720   if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
721     return rewriter.getIntegerAttr(
722         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
723 
724   return {};
725 }
726 
727 // Creates the body calculation for a reduction. The operations vary depending
728 // on the input type.
createLinalgBodyCalculationForReduceOp(Operation * op,ValueRange args,Type elementTy,PatternRewriter & rewriter)729 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
730                                                     ValueRange args,
731                                                     Type elementTy,
732                                                     PatternRewriter &rewriter) {
733   Location loc = op->getLoc();
734   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
735     return rewriter.create<AddFOp>(loc, args);
736   }
737 
738   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
739     return rewriter.create<AddIOp>(loc, args);
740   }
741 
742   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
743     return rewriter.create<MulFOp>(loc, args);
744   }
745 
746   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
747     return rewriter.create<MulIOp>(loc, args);
748   }
749 
750   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
751     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
752                                                    args[0], args[1]);
753     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
754   }
755 
756   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
757     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
758                                                    args[0], args[1]);
759     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
760   }
761 
762   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
763     auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
764                                                    args[0], args[1]);
765     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
766   }
767 
768   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
769     auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
770                                                    args[0], args[1]);
771     return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
772   }
773 
774   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
775     return rewriter.create<mlir::AndOp>(loc, args);
776 
777   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
778     return rewriter.create<mlir::OrOp>(loc, args);
779 
780   return {};
781 }
782 
783 // Performs the match and rewrite for reduction operations. This includes
784 // declaring a correctly sized initial value, and the linalg.generic operation
785 // that reduces across the specified axis.
reduceMatchAndRewriteHelper(Operation * op,uint64_t axis,PatternRewriter & rewriter)786 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
787                                                  PatternRewriter &rewriter) {
788   auto loc = op->getLoc();
789   auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
790   auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
791   auto elementTy = resultTy.getElementType();
792   Value input = op->getOperand(0);
793 
794   llvm::SmallVector<int64_t> reduceShape;
795   for (unsigned i = 0; i < inputTy.getRank(); i++) {
796     if (axis != i)
797       reduceShape.push_back(inputTy.getDimSize(i));
798   }
799 
800   Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
801 
802   // First fill the output buffer with the init value.
803   auto initTensor =
804       rewriter
805           .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
806                                         resultTy.getElementType())
807           .result();
808 
809   auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
810   if (!fillValueAttr)
811     return rewriter.notifyMatchFailure(
812         op, "No initial value found for reduction operation");
813 
814   auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
815   auto filledTensor =
816       rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
817 
818   SmallVector<AffineExpr, 2> srcExprs;
819   SmallVector<AffineExpr, 2> dstExprs;
820   SmallVector<StringRef, 4> iteratorTypes;
821   for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
822     srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
823 
824     iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
825                                       : getParallelIteratorTypeName());
826     if (axis != i)
827       dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
828   }
829 
830   bool didEncounterError = false;
831   auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
832   auto linalgOp = rewriter.create<linalg::GenericOp>(
833       loc, reduceTy, input, filledTensor, maps, iteratorTypes,
834       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
835         auto result = createLinalgBodyCalculationForReduceOp(
836             op, blockArgs, elementTy, rewriter);
837         if (result)
838           didEncounterError = true;
839 
840         nestedBuilder.create<linalg::YieldOp>(loc, result);
841       });
842 
843   if (!didEncounterError)
844     return failure();
845 
846   rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
847                                                linalgOp.getResults());
848   return success();
849 }
850 
851 static LogicalResult
convolutionMatchAndRewriterHelper(Operation * op,ConversionPatternRewriter & rewriter)852 convolutionMatchAndRewriterHelper(Operation *op,
853                                   ConversionPatternRewriter &rewriter) {
854   Location loc = op->getLoc();
855   Value input = op->getOperand(0);
856   Value weight = op->getOperand(1);
857   Value bias = op->getOperand(2);
858 
859   ShapedType inputTy = input.getType().cast<ShapedType>();
860   ShapedType weightTy = weight.getType().cast<ShapedType>();
861   ShapedType biasTy = bias.getType().cast<ShapedType>();
862   ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
863 
864   Type inputETy = inputTy.getElementType();
865   Type resultETy = resultTy.getElementType();
866 
867   auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
868   auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
869   auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
870 
871   bool isQuantized = op->hasAttr("quantization_info");
872   IntegerAttr iZp;
873   IntegerAttr kZp;
874   if (isQuantized) {
875     auto quantizationInfo =
876         op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
877     iZp = rewriter.getI32IntegerAttr(
878         quantizationInfo.input_zp().getValue().getSExtValue());
879     kZp = rewriter.getI32IntegerAttr(
880         quantizationInfo.weight_zp().getValue().getSExtValue());
881   }
882 
883   if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
884       !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
885     return rewriter.notifyMatchFailure(op,
886                                        "tosa.conv ops require static shapes");
887 
888   auto weightShape = weightTy.getShape();
889   auto resultShape = resultTy.getShape();
890 
891   // Apply padding as necessary.
892   Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
893   llvm::SmallVector<int64_t> pad;
894   pad.resize(2, 0);
895   getValuesFromIntArrayAttribute(padAttr, pad);
896   pad.resize(pad.size() + 2, 0);
897 
898   input = applyPad(loc, input, pad, zeroAttr, rewriter);
899 
900   // Broadcast the initial value to the output tensor before convolving.
901   SmallVector<AffineMap, 4> indexingMaps;
902   indexingMaps.push_back(AffineMap::get(
903       /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
904       {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
905   indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
906 
907   Value initTensor = rewriter.create<linalg::InitTensorOp>(
908       loc, resultTy.getShape(), resultTy.getElementType());
909 
910   Value biasBroadcast =
911       rewriter
912           .create<linalg::GenericOp>(
913               loc, resultTy, bias, initTensor, indexingMaps,
914               getNParallelLoopsAttrs(resultTy.getRank()),
915               [&](OpBuilder &nestedBuilder, Location nestedLoc,
916                   ValueRange args) {
917                 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
918               })
919           .getResult(0);
920 
921   // Extract the attributes for convolution.
922   llvm::SmallVector<int64_t> stride, dilation;
923   getValuesFromIntArrayAttribute(strideTosaAttr, stride);
924   getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
925 
926   // Create the convolution op.
927   auto strideAttr = DenseIntElementsAttr::get(
928       RankedTensorType::get({2}, rewriter.getI64Type()), stride);
929   auto dilationAttr = DenseIntElementsAttr::get(
930       RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
931 
932   if (isa<tosa::Conv2DOp>(op) && !isQuantized) {
933     rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>(
934         op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
935         strideAttr, dilationAttr);
936     return success();
937   }
938 
939   if (isa<tosa::Conv2DOp>(op) && isQuantized) {
940     auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
941     auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
942     rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyQOp>(
943         op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
944         ValueRange{biasBroadcast}, strideAttr, dilationAttr);
945     return success();
946   }
947 
948   if (isa<tosa::DepthwiseConv2DOp>(op) && !isQuantized) {
949     ShapedType linalgConvTy =
950         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
951                                weightShape[2], weightShape[3]},
952                               resultETy);
953 
954     Value biasReshape =
955         rewriter.create<tosa::ReshapeOp>(loc, linalgConvTy, biasBroadcast);
956     Value conv = rewriter
957                      .create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
958                          loc, linalgConvTy, ValueRange{input, weight},
959                          ValueRange{biasReshape}, dilationAttr, strideAttr)
960                      .getResult(0);
961 
962     Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
963     rewriter.replaceOp(op, reshape);
964     return success();
965   }
966 
967   return failure();
968 }
969 
970 namespace {
971 
972 template <typename SrcOp>
973 class PointwiseConverter : public OpRewritePattern<SrcOp> {
974 public:
975   using OpRewritePattern<SrcOp>::OpRewritePattern;
976 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const977   LogicalResult matchAndRewrite(SrcOp op,
978                                 PatternRewriter &rewriter) const final {
979     return elementwiseMatchAndRewriteHelper(op, rewriter);
980   }
981 };
982 
983 template <typename T>
984 class ConvConverter : public OpConversionPattern<T> {
985 public:
986   using OpConversionPattern<T>::OpConversionPattern;
987   LogicalResult
matchAndRewrite(T op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const988   matchAndRewrite(T op, ArrayRef<Value> args,
989                   ConversionPatternRewriter &rewriter) const final {
990     return convolutionMatchAndRewriterHelper(op, rewriter);
991   }
992 };
993 
994 class TransposeConvConverter
995     : public OpConversionPattern<tosa::TransposeConv2DOp> {
996 public:
997   using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
998   LogicalResult
matchAndRewrite(tosa::TransposeConv2DOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const999   matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef<Value> args,
1000                   ConversionPatternRewriter &rewriter) const final {
1001     Location loc = op->getLoc();
1002     Value input = op->getOperand(0);
1003     Value weight = op->getOperand(1);
1004     Value bias = op->getOperand(2);
1005 
1006     ShapedType inputTy = input.getType().cast<ShapedType>();
1007     ShapedType weightTy = weight.getType().cast<ShapedType>();
1008     ShapedType biasTy = bias.getType().cast<ShapedType>();
1009     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
1010 
1011     llvm::SmallVector<int64_t> pad;
1012     llvm::SmallVector<int64_t> stride;
1013     llvm::SmallVector<int64_t> dilation;
1014 
1015     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
1016     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
1017     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
1018 
1019     // We have not solved for stride / dilation yet. Dilation should be
1020     // straight forward but stride is more complicated. Linalg work is likely
1021     // required for efficient implementation.
1022     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
1023       return failure();
1024     if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
1025       return failure();
1026 
1027     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
1028         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
1029       return failure();
1030 
1031     int64_t inputHeight = inputTy.getDimSize(1);
1032     int64_t inputWidth = inputTy.getDimSize(2);
1033     int64_t kernelHeight = weightTy.getDimSize(1);
1034     int64_t kernelWidth = weightTy.getDimSize(2);
1035     int64_t outputHeight = resultTy.getDimSize(1);
1036     int64_t outputWidth = resultTy.getDimSize(2);
1037 
1038     int64_t requiredInputHeight = outputHeight + kernelHeight - 1;
1039     int64_t requiredInputWidth = outputWidth + kernelWidth - 1;
1040 
1041     llvm::SmallVector<int64_t> newPad(4, 0);
1042     newPad[0] = kernelHeight - 1 - pad[0];
1043     newPad[2] = kernelWidth - 1 - pad[1];
1044 
1045     newPad[1] = requiredInputHeight - newPad[0] - inputHeight;
1046     newPad[3] = requiredInputWidth - newPad[2] - inputWidth;
1047 
1048     auto reverse1 = rewriter.create<tosa::ReverseOp>(
1049         loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
1050     auto reverse2 = rewriter.create<tosa::ReverseOp>(
1051         loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
1052 
1053     Value conv2d;
1054     if (op.quantization_info().hasValue()) {
1055       conv2d = rewriter.create<tosa::Conv2DOp>(
1056           loc, resultTy, input, reverse2, bias,
1057           rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
1058           rewriter.getI64ArrayAttr(dilation),
1059           op.quantization_info().getValue());
1060     } else {
1061       conv2d = rewriter.create<tosa::Conv2DOp>(
1062           loc, resultTy, input, reverse2, bias,
1063           rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
1064           rewriter.getI64ArrayAttr(dilation));
1065     }
1066 
1067     rewriter.replaceOp(op, conv2d);
1068     return success();
1069   }
1070 };
1071 
1072 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
1073 public:
1074   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
1075   LogicalResult
matchAndRewrite(tosa::MatMulOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1076   matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args,
1077                   ConversionPatternRewriter &rewriter) const final {
1078     tosa::MatMulOp::Adaptor adaptor(args);
1079 
1080     Location loc = op.getLoc();
1081 
1082     auto outputTy = op.getType().cast<ShapedType>();
1083     auto outputElementTy = outputTy.getElementType();
1084     auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
1085     Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
1086     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1087         loc, outputTy.getShape(), outputTy.getElementType());
1088     Value zeroTensor =
1089         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
1090     if (!op.quantization_info()) {
1091       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
1092           op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
1093           ValueRange{zeroTensor});
1094       return success();
1095     }
1096 
1097     auto quantizationInfo = op.quantization_info().getValue();
1098     auto aZp = rewriter.create<ConstantOp>(
1099         loc, rewriter.getI32IntegerAttr(
1100                  quantizationInfo.a_zp().getValue().getSExtValue()));
1101     auto bZp = rewriter.create<ConstantOp>(
1102         loc, rewriter.getI32IntegerAttr(
1103                  quantizationInfo.b_zp().getValue().getSExtValue()));
1104     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
1105         op, TypeRange{op.getType()},
1106         ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
1107 
1108     return success();
1109   }
1110 };
1111 
1112 class FullyConnectedConverter
1113     : public OpConversionPattern<tosa::FullyConnectedOp> {
1114 public:
1115   using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
1116   LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1117   matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
1118                   ConversionPatternRewriter &rewriter) const final {
1119     Location loc = op.getLoc();
1120     auto outputTy = op.getType().cast<ShapedType>();
1121     auto input = op.input();
1122     auto weight = op.weight();
1123     auto bias = op.bias();
1124 
1125     auto weightTy = weight.getType().cast<ShapedType>();
1126     auto weightShape = weightTy.getShape();
1127 
1128     // Creating maps for the output of MatMul and the bias
1129     SmallVector<AffineMap, 4> indexingMaps;
1130 
1131     // Broadcast the bias.
1132     indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
1133                                           {rewriter.getAffineDimExpr(1)},
1134                                           rewriter.getContext()));
1135 
1136     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
1137 
1138     auto initTensor =
1139         rewriter
1140             .create<linalg::InitTensorOp>(loc, outputTy.getShape(),
1141                                           outputTy.getElementType())
1142             ->getResults();
1143 
1144     auto linalgOp =
1145         rewriter
1146             .create<linalg::GenericOp>(
1147                 loc, outputTy, bias, initTensor, indexingMaps,
1148                 getNParallelLoopsAttrs(outputTy.getRank()),
1149                 [&](OpBuilder &nested_builder, Location nested_loc,
1150                     ValueRange args) {
1151                   nested_builder.create<linalg::YieldOp>(loc, *args.begin());
1152                 })
1153             ->getResults();
1154 
1155     SmallVector<int64_t> permutation{1, 0};
1156     auto permutationAttr = DenseIntElementsAttr::get(
1157         RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
1158     Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
1159 
1160     SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
1161     Type newWeightTy =
1162         RankedTensorType::get(newWeightShape, weightTy.getElementType());
1163 
1164     Value transposedWeight = rewriter.create<tosa::TransposeOp>(
1165         loc, newWeightTy, weight, permutationValue);
1166 
1167     if (!op.quantization_info()) {
1168       rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
1169           op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
1170           linalgOp);
1171       return success();
1172     }
1173 
1174     auto quantizationInfo = op.quantization_info().getValue();
1175     auto inputZp = rewriter.create<ConstantOp>(
1176         loc, rewriter.getI32IntegerAttr(
1177                  quantizationInfo.input_zp().getValue().getSExtValue()));
1178     auto outputZp = rewriter.create<ConstantOp>(
1179         loc, rewriter.getI32IntegerAttr(
1180                  quantizationInfo.weight_zp().getValue().getSExtValue()));
1181     rewriter.replaceOpWithNewOp<linalg::QuantizedMatmulOp>(
1182         op, TypeRange{op.getType()},
1183         ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp);
1184 
1185     return success();
1186   }
1187 };
1188 
1189 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
1190 public:
1191   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1192 
1193   LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1194   matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
1195                   ConversionPatternRewriter &rewriter) const final {
1196     typename tosa::ReshapeOp::Adaptor operands(args);
1197 
1198     ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
1199     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1200 
1201     if (operandTy == resultTy) {
1202       rewriter.replaceOp(reshape, args[0]);
1203       return success();
1204     }
1205 
1206     if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
1207       return failure();
1208 
1209     // Compute the reassociation maps for the linalg operation.
1210     ArrayRef<int64_t> expandedShape =
1211         (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
1212                                                   : resultTy.getShape());
1213     ArrayRef<int64_t> collapsedShape =
1214         (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
1215                                                   : operandTy.getShape());
1216     unsigned currSrcDim = 0, currDstDim = 0;
1217     SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
1218 
1219     // First scan all dimensions in the source shapes to see whether we have a
1220     // perfect case where consecutive dimensions in source are collapsed. For
1221     // such case we can just generate one single linalg.reshape.
1222     bool isCollapsingSource = true;
1223     while (currSrcDim < expandedShape.size() &&
1224            currDstDim < collapsedShape.size()) {
1225       int64_t dstSize = collapsedShape[currDstDim];
1226       int64_t srcSize = expandedShape[currSrcDim];
1227       while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
1228         reassociationMap[currDstDim].push_back(
1229             rewriter.getAffineDimExpr(currSrcDim++));
1230         srcSize *= expandedShape[currSrcDim];
1231       }
1232       if (srcSize == dstSize) {
1233         reassociationMap[currDstDim].push_back(
1234             rewriter.getAffineDimExpr(currSrcDim++));
1235         // If the next dim in collapsedShape is not 1, treat subsequent dims in
1236         // expandedShape which are 1 to be collapsed.
1237         if (currDstDim == collapsedShape.size() - 1 ||
1238             collapsedShape[currDstDim + 1] != 1) {
1239           while (currSrcDim < expandedShape.size() &&
1240                  expandedShape[currSrcDim] == 1) {
1241             reassociationMap[currDstDim].push_back(
1242                 rewriter.getAffineDimExpr(currSrcDim++));
1243           }
1244         }
1245       } else {
1246         isCollapsingSource = false;
1247         break;
1248       }
1249       currDstDim++;
1250     }
1251 
1252     // Check if any remaining dimensions exist. If either is rank-0 we only
1253     // require the directly lowering.
1254     if (currSrcDim != expandedShape.size() ||
1255         currDstDim != collapsedShape.size())
1256       isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
1257 
1258     // Otherwise, we need to first reduce all source dimensions into one and
1259     // then expand to the destination dimensions.
1260     if (!isCollapsingSource) {
1261       auto getIdentityExprs = [&rewriter](int n) {
1262         SmallVector<AffineExpr, 4> exprs;
1263         for (int i = 0; i < n; ++i)
1264           exprs.push_back(rewriter.getAffineDimExpr(i));
1265         return exprs;
1266       };
1267       Location loc = reshape.getLoc();
1268       int64_t totalElems =
1269           std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
1270                           std::multiplies<int64_t>());
1271       auto elemTy = operandTy.getElementType();
1272       SmallVector<ReassociationExprs, 4> collapsingMap = {
1273           // Use operandTy here because we need to collapse all operands
1274           // dimensions.
1275           getIdentityExprs(operandTy.getShape().size())};
1276       SmallVector<ReassociationExprs, 4> expandingMap = {
1277           // Use resultTy here because we need to expand to all result
1278           // dimensions.
1279           getIdentityExprs(resultTy.getShape().size())};
1280 
1281       auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
1282       Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
1283           loc, collapsedTy, args[0], collapsingMap);
1284       rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1285           reshape, resultTy, collapsedOp, expandingMap);
1286 
1287       return success();
1288     }
1289 
1290     if (resultTy.getRank() < args[0].getType().cast<ShapedType>().getRank())
1291       rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
1292           reshape, resultTy, args[0], reassociationMap);
1293     else
1294       rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1295           reshape, resultTy, args[0], reassociationMap);
1296 
1297     return success();
1298   }
1299 };
1300 
1301 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1302 public:
1303   using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1304 
matchAndRewrite(tosa::TransposeOp op,PatternRewriter & rewriter) const1305   LogicalResult matchAndRewrite(tosa::TransposeOp op,
1306                                 PatternRewriter &rewriter) const final {
1307     DenseIntElementsAttr perms;
1308     if (!matchPattern(op.perms(), m_Constant(&perms))) {
1309       return failure();
1310     }
1311 
1312     auto resultTy = op.getType().cast<ShapedType>();
1313     if (!resultTy.hasStaticShape())
1314       return failure();
1315 
1316     SmallVector<AffineExpr, 2> inputExprs;
1317     inputExprs.resize(resultTy.getRank());
1318     for (auto permutation : llvm::enumerate(perms.getIntValues())) {
1319       inputExprs[permutation.value().getZExtValue()] =
1320           rewriter.getAffineDimExpr(permutation.index());
1321     }
1322 
1323     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1324         op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
1325         resultTy.getElementType());
1326 
1327     SmallVector<AffineMap, 2> affineMaps = {
1328         AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1329                        rewriter.getContext()),
1330         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1331 
1332     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1333         op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
1334         getNParallelLoopsAttrs(resultTy.getRank()),
1335         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1336           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1337         });
1338     return success();
1339   }
1340 };
1341 
1342 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1343 public:
1344   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1345 
matchAndRewrite(tosa::RescaleOp op,PatternRewriter & rewriter) const1346   LogicalResult matchAndRewrite(tosa::RescaleOp op,
1347                                 PatternRewriter &rewriter) const final {
1348     auto loc = op.getLoc();
1349     auto input = op.input();
1350     auto inputTy = op.input().getType().cast<ShapedType>();
1351     auto outputTy = op.output().getType().cast<ShapedType>();
1352     unsigned rank = inputTy.getRank();
1353 
1354     // This is an illegal configuration. terminate and log an error
1355     if (op.double_round() && !op.scale32())
1356       return rewriter.notifyMatchFailure(
1357           op, "tosa.rescale requires scale32 for double_round to be true");
1358 
1359     if (!outputTy.hasStaticShape())
1360       return rewriter.notifyMatchFailure(
1361           op, "tosa to linalg conversion expects statically shaped tensors");
1362 
1363     // The shift and multiplier values.
1364     SmallVector<int32_t> multiplierValues;
1365     getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
1366 
1367     SmallVector<int8_t> shiftValues;
1368     getValuesFromIntArrayAttribute(op.shift(), shiftValues);
1369 
1370     // Double round only occurs if shift is greater than 31, check that this
1371     // is ever true.
1372     bool doubleRound =
1373         op.double_round() &&
1374         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1375 
1376     SmallVector<AffineMap> indexingMaps = {
1377         rewriter.getMultiDimIdentityMap(rank)};
1378     SmallVector<Value, 4> genericInputs = {input};
1379 
1380     // If we are rescaling per-channel then we need to store the multiplier
1381     // values in a buffer.
1382     Value multiplierConstant;
1383     int64_t multiplierArg = 0;
1384     if (multiplierValues.size() == 1) {
1385       multiplierConstant = rewriter.create<ConstantOp>(
1386           loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1387     } else {
1388       SmallVector<AffineExpr, 2> multiplierExprs{
1389           rewriter.getAffineDimExpr(rank - 1)};
1390       auto multiplierType =
1391           RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1392                                 rewriter.getI32Type());
1393       genericInputs.push_back(rewriter.create<ConstantOp>(
1394           loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1395 
1396       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1397                                             /*symbolCount=*/0, multiplierExprs,
1398                                             rewriter.getContext()));
1399 
1400       multiplierArg = indexingMaps.size() - 1;
1401     }
1402 
1403     // If we are rescaling per-channel then we need to store the shift
1404     // values in a buffer.
1405     Value shiftConstant;
1406     int64_t shiftArg = 0;
1407     if (shiftValues.size() == 1) {
1408       shiftConstant = rewriter.create<ConstantOp>(
1409           loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1410     } else {
1411       SmallVector<AffineExpr, 2> shiftExprs = {
1412           rewriter.getAffineDimExpr(rank - 1)};
1413       auto shiftType =
1414           RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1415                                 rewriter.getIntegerType(8));
1416       genericInputs.push_back(rewriter.create<ConstantOp>(
1417           loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1418       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1419                                             /*symbolCount=*/0, shiftExprs,
1420                                             rewriter.getContext()));
1421       shiftArg = indexingMaps.size() - 1;
1422     }
1423 
1424     // Indexing maps for output values.
1425     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1426 
1427     // Construct the indexing maps needed for linalg.generic ops.
1428     Value initTensor = rewriter.create<linalg::InitTensorOp>(
1429         loc, ArrayRef<Value>({}), outputTy.getShape(),
1430         outputTy.getElementType());
1431 
1432     auto linalgOp = rewriter.create<linalg::GenericOp>(
1433         loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1434         getNParallelLoopsAttrs(rank),
1435         [&](OpBuilder &nestedBuilder, Location nestedLoc,
1436             ValueRange blockArgs) {
1437           Value value = blockArgs[0];
1438 
1439           // For now we do all of our math in 64-bit. This is not optimal but
1440           // should be correct for now, consider computing correct bit depth
1441           // later.
1442           int32_t inBitwidth =
1443               value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32;
1444 
1445           auto inputZp = createConstFromIntAttribute<int32_t>(
1446               op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1447               nestedBuilder);
1448           auto outputZp = createConstFromIntAttribute<int32_t>(
1449               op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1450 
1451           Value multiplier = multiplierConstant ? multiplierConstant
1452                                                 : blockArgs[multiplierArg];
1453           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1454 
1455           if (value.getType().getIntOrFloatBitWidth() < 32) {
1456             value = nestedBuilder.create<SignExtendIOp>(
1457                 nestedLoc, nestedBuilder.getI32Type(), value);
1458           }
1459 
1460           value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
1461 
1462           value = nestedBuilder.create<tosa::ApplyScaleOp>(
1463               loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1464               nestedBuilder.getBoolAttr(doubleRound));
1465 
1466           // Move to the new zero-point.
1467           value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
1468 
1469           // Saturate to the output size.
1470           IntegerType outIntType =
1471               blockArgs.back().getType().cast<IntegerType>();
1472           unsigned outBitWidth = outIntType.getWidth();
1473           auto intMin = nestedBuilder.create<ConstantOp>(
1474               loc, nestedBuilder.getIntegerAttr(
1475                        nestedBuilder.getI32Type(),
1476                        APInt::getSignedMinValue(outBitWidth).getSExtValue()));
1477           auto intMax = nestedBuilder.create<ConstantOp>(
1478               loc, nestedBuilder.getIntegerAttr(
1479                        nestedBuilder.getI32Type(),
1480                        APInt::getSignedMaxValue(outBitWidth).getSExtValue()));
1481 
1482           value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax,
1483                                             CmpIPredicate::slt, nestedBuilder);
1484 
1485           if (outIntType.getWidth() < 32) {
1486             value =
1487                 nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value);
1488           }
1489 
1490           nestedBuilder.create<linalg::YieldOp>(loc, value);
1491         });
1492 
1493     rewriter.replaceOp(op, linalgOp->getResults());
1494     return success();
1495   }
1496 };
1497 
1498 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1499 public:
1500   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1501 
matchAndRewrite(tosa::ResizeOp op,PatternRewriter & rewriter) const1502   LogicalResult matchAndRewrite(tosa::ResizeOp op,
1503                                 PatternRewriter &rewriter) const final {
1504     Location loc = op.getLoc();
1505     auto input = op.input();
1506     auto inputTy = input.getType().cast<ShapedType>();
1507     auto resultTy = op.getType().cast<ShapedType>();
1508     auto resultElementTy = resultTy.getElementType();
1509 
1510     auto imageH = inputTy.getShape()[1];
1511     auto imageW = inputTy.getShape()[2];
1512 
1513     if (!resultTy.hasStaticShape())
1514       return failure();
1515     if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
1516       return failure();
1517 
1518     auto initTensor =
1519         rewriter
1520             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
1521                                           resultTy.getShape(), resultElementTy)
1522             .result();
1523 
1524     SmallVector<AffineMap, 2> affineMaps = {
1525         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1526 
1527     auto genericOp = rewriter.create<linalg::GenericOp>(
1528         loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1529         getNParallelLoopsAttrs(resultTy.getRank()));
1530     rewriter.replaceOp(op, genericOp.getResult(0));
1531 
1532     {
1533       OpBuilder::InsertionGuard regionGuard(rewriter);
1534       rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1535                            TypeRange({resultElementTy}));
1536       Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1537       Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1538       Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1539       Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1540 
1541       auto hwMin =
1542           rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1543       auto hMax = rewriter.create<ConstantOp>(
1544           loc, rewriter.getI32IntegerAttr(imageH - 1));
1545       auto wMax = rewriter.create<ConstantOp>(
1546           loc, rewriter.getI32IntegerAttr(imageW - 1));
1547 
1548       Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), y);
1549       Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(), x);
1550 
1551       int32_t shift = op.shift();
1552       bool floatingPointMode = shift == 0;
1553 
1554       Value yStride, xStride, yOffset, xOffset;
1555       if (floatingPointMode) {
1556         yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]);
1557         xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]);
1558         yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]);
1559         xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]);
1560       } else {
1561         SmallVector<int32_t> stride, offset;
1562         getValuesFromIntArrayAttribute(op.stride(), stride);
1563         getValuesFromIntArrayAttribute(op.offset(), offset);
1564 
1565         yStride = rewriter.create<ConstantOp>(
1566             loc, rewriter.getI32IntegerAttr(stride[0]));
1567         xStride = rewriter.create<ConstantOp>(
1568             loc, rewriter.getI32IntegerAttr(stride[1]));
1569         yOffset = rewriter.create<ConstantOp>(
1570             loc, rewriter.getI32IntegerAttr(offset[0]));
1571         xOffset = rewriter.create<ConstantOp>(
1572             loc, rewriter.getI32IntegerAttr(offset[1]));
1573       }
1574 
1575       // Compute the the integer index and partial offset.
1576       // x = x * stride + offset;
1577       // ix = floor(x)
1578       // dx = x - ix
1579       Value ix, iy, dx, dy;
1580       if (floatingPointMode) {
1581         Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY);
1582         Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX);
1583 
1584         y = rewriter.create<MulFOp>(loc, y, yStride);
1585         x = rewriter.create<MulFOp>(loc, x, xStride);
1586 
1587         y = rewriter.create<AddFOp>(loc, y, yOffset);
1588         x = rewriter.create<AddFOp>(loc, x, xOffset);
1589 
1590         iy = rewriter.create<FloorFOp>(loc, y);
1591         ix = rewriter.create<FloorFOp>(loc, x);
1592 
1593         dy = rewriter.create<SubFOp>(loc, y, iy);
1594         dx = rewriter.create<SubFOp>(loc, x, ix);
1595 
1596         iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy);
1597         ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix);
1598       } else {
1599         Value shiftVal =
1600             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift));
1601 
1602         Value y = rewriter.create<MulIOp>(loc, inY, yStride);
1603         Value x = rewriter.create<MulIOp>(loc, inX, xStride);
1604 
1605         y = rewriter.create<AddIOp>(loc, y, yOffset);
1606         x = rewriter.create<AddIOp>(loc, x, xOffset);
1607 
1608         iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal);
1609         ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal);
1610 
1611         Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal);
1612         Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal);
1613 
1614         dy = rewriter.create<SubIOp>(loc, y, yTrunc);
1615         dx = rewriter.create<SubIOp>(loc, x, xTrunc);
1616       }
1617 
1618       if (op.mode() == "NEAREST_NEIGHBOR") {
1619         Value yPred, xPred;
1620         // Round the index position towards the closest pixel location.
1621         if (floatingPointMode) {
1622           auto halfVal =
1623               rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
1624           yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy,
1625                                                 halfVal);
1626           xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx,
1627                                                 halfVal);
1628         } else {
1629           auto halfVal = rewriter.create<ConstantOp>(
1630               loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
1631           yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy,
1632                                                 halfVal);
1633           xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx,
1634                                                 halfVal);
1635         }
1636 
1637         auto zeroVal =
1638             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1639         auto oneVal =
1640             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
1641 
1642         auto yOffset =
1643             rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
1644         auto xOffset =
1645             rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
1646 
1647         iy = rewriter.create<AddIOp>(loc, iy, yOffset);
1648         ix = rewriter.create<AddIOp>(loc, ix, xOffset);
1649 
1650         // Clamp the to be within the bounds of the input image.
1651 
1652         iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt,
1653                                        rewriter);
1654         ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt,
1655                                        rewriter);
1656 
1657         // Read the value from the input array.
1658         iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy);
1659         ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix);
1660 
1661         Value result = rewriter.create<tensor::ExtractOp>(
1662             loc, input, ValueRange{batch, iy, ix, channel});
1663 
1664         rewriter.create<linalg::YieldOp>(loc, result);
1665 
1666         return success();
1667       }
1668 
1669       if (op.mode() == "BILINEAR") {
1670         Value y0 = iy;
1671         Value x0 = ix;
1672 
1673         auto oneVal =
1674             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
1675         Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal);
1676         Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal);
1677 
1678         y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt,
1679                                        rewriter);
1680         y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt,
1681                                        rewriter);
1682 
1683         x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt,
1684                                        rewriter);
1685         x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt,
1686                                        rewriter);
1687 
1688         y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0);
1689         y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1);
1690         x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0);
1691         x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1);
1692 
1693         Value y0x0 = rewriter.create<tensor::ExtractOp>(
1694             loc, input, ValueRange{batch, y0, x0, channel});
1695         Value y0x1 = rewriter.create<tensor::ExtractOp>(
1696             loc, input, ValueRange{batch, y0, x1, channel});
1697         Value y1x0 = rewriter.create<tensor::ExtractOp>(
1698             loc, input, ValueRange{batch, y1, x0, channel});
1699         Value y1x1 = rewriter.create<tensor::ExtractOp>(
1700             loc, input, ValueRange{batch, y1, x1, channel});
1701 
1702         if (floatingPointMode) {
1703           auto oneVal =
1704               rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f));
1705           Value rightPart = dx;
1706           Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx);
1707 
1708           y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart);
1709           y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart);
1710           Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1);
1711 
1712           y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart);
1713           y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart);
1714           Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1);
1715 
1716           Value bottomPart = dy;
1717           Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy);
1718           topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart);
1719           bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart);
1720           Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc);
1721 
1722           rewriter.create<linalg::YieldOp>(loc, result);
1723           return success();
1724         } else {
1725           y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0);
1726           y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1);
1727           y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0);
1728           y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1);
1729 
1730           if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1731             dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx);
1732             dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy);
1733           }
1734 
1735           auto unitVal = rewriter.create<ConstantOp>(
1736               loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
1737           Value rightPart = dx;
1738           Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx);
1739 
1740           y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart);
1741           y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart);
1742           Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1);
1743 
1744           y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart);
1745           y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart);
1746           Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1);
1747 
1748           Value bottomPart = dy;
1749           Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy);
1750           topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart);
1751           bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart);
1752           Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc);
1753 
1754           rewriter.create<linalg::YieldOp>(loc, result);
1755           return success();
1756         }
1757       }
1758 
1759       return failure();
1760     }
1761 
1762     return success();
1763   }
1764 };
1765 
1766 // At the codegen level any identity operations should be removed. Any cases
1767 // where identity is load-bearing (e.g. cross device computation) should be
1768 // handled before lowering to codegen.
1769 template <typename SrcOp>
1770 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1771 public:
1772   using OpRewritePattern<SrcOp>::OpRewritePattern;
1773 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const1774   LogicalResult matchAndRewrite(SrcOp op,
1775                                 PatternRewriter &rewriter) const final {
1776     rewriter.replaceOp(op, op.getOperation()->getOperands());
1777     return success();
1778   }
1779 };
1780 
1781 template <typename SrcOp>
1782 class ReduceConverter : public OpRewritePattern<SrcOp> {
1783 public:
1784   using OpRewritePattern<SrcOp>::OpRewritePattern;
1785 
matchAndRewrite(SrcOp reduceOp,PatternRewriter & rewriter) const1786   LogicalResult matchAndRewrite(SrcOp reduceOp,
1787                                 PatternRewriter &rewriter) const final {
1788     return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
1789   }
1790 };
1791 
1792 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1793   using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
1794 
1795   LogicalResult
matchAndRewrite__anonb63de17d0611::ConcatConverter1796   matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args,
1797                   ConversionPatternRewriter &rewriter) const override {
1798     auto resultType = op.getType().dyn_cast<RankedTensorType>();
1799     if (!resultType || !resultType.hasStaticShape()) {
1800       return rewriter.notifyMatchFailure(op,
1801                                          "expected static shaped tensor type");
1802     }
1803 
1804     Location loc = op.getLoc();
1805     int axis = op.axis();
1806     Value axisValue =
1807         rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
1808     int rank = resultType.getRank();
1809     SmallVector<Value, 3> offsets, sizes, strides;
1810     sizes.reserve(rank);
1811     strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
1812     offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
1813 
1814     for (int i = 0; i < rank; ++i) {
1815       sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
1816     }
1817 
1818     Value resultDimSize = sizes[axis];
1819     for (auto arg : args.drop_front()) {
1820       auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
1821       resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
1822     }
1823     sizes[axis] = resultDimSize;
1824 
1825     Value init = rewriter.create<linalg::InitTensorOp>(
1826         loc, resultType.getShape(), resultType.getElementType());
1827 
1828     Value zeroVal = rewriter.create<ConstantOp>(
1829         loc, rewriter.getZeroAttr(resultType.getElementType()));
1830     Value result =
1831         rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
1832 
1833     for (auto arg : args) {
1834       sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
1835       result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
1836                                                       sizes, strides);
1837       offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
1838     }
1839     rewriter.replaceOp(op, result);
1840     return success();
1841   }
1842 };
1843 
1844 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1845 public:
1846   using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1847 
matchAndRewrite(tosa::ReverseOp op,PatternRewriter & rewriter) const1848   LogicalResult matchAndRewrite(tosa::ReverseOp op,
1849                                 PatternRewriter &rewriter) const final {
1850     auto loc = op.getLoc();
1851     Value input = op.input();
1852     auto inputTy = input.getType().template cast<ShapedType>();
1853     auto resultTy = op.getType().template cast<ShapedType>();
1854     auto rank = resultTy.getRank();
1855     auto axis = op.axis();
1856 
1857     if (!inputTy.hasStaticShape())
1858       return rewriter.notifyMatchFailure(
1859           op, "No initial value found for reduction operation");
1860 
1861     // First fill the output buffer with the init value.
1862     auto initTensor = rewriter
1863                           .create<linalg::InitTensorOp>(
1864                               loc, ArrayRef<Value>({}), inputTy.getShape(),
1865                               inputTy.getElementType())
1866                           .result();
1867 
1868     SmallVector<AffineExpr, 2> inputExprs;
1869     inputExprs.resize(resultTy.getRank());
1870 
1871     for (int i = 0; i < rank; i++)
1872       inputExprs[i] = rewriter.getAffineDimExpr(i);
1873 
1874     inputExprs[axis] =
1875         rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) -
1876         inputExprs[axis];
1877 
1878     SmallVector<AffineMap, 2> affineMaps = {
1879         AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1880                        rewriter.getContext()),
1881         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1882 
1883     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1884         op, resultTy, op.input(), ValueRange{initTensor}, affineMaps,
1885         getNParallelLoopsAttrs(resultTy.getRank()),
1886         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1887           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1888         });
1889     return success();
1890   }
1891 };
1892 
1893 // This converter translate a tile operation to a reshape, broadcast, reshape.
1894 // The first reshape minimally expands each tiled dimension to include a
1895 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1896 // multiple.
1897 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1898   using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1899 
1900   LogicalResult
matchAndRewrite__anonb63de17d0611::TileConverter1901   matchAndRewrite(tosa::TileOp op, ArrayRef<Value> args,
1902                   ConversionPatternRewriter &rewriter) const override {
1903     auto loc = op.getLoc();
1904     auto input = op.input1();
1905     auto inputTy = input.getType().cast<ShapedType>();
1906     auto inputShape = inputTy.getShape();
1907     auto resultTy = op.getType().cast<ShapedType>();
1908     auto elementTy = inputTy.getElementType();
1909     int64_t rank = inputTy.getRank();
1910 
1911     if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
1912       return failure();
1913 
1914     SmallVector<int64_t> multiples;
1915     getValuesFromIntArrayAttribute(op.multiples(), multiples);
1916 
1917     // Broadcast the newly added dimensions to their appropriate multiple.
1918     SmallVector<int64_t, 2> genericShape;
1919     for (int i = 0; i < rank; i++) {
1920       genericShape.push_back(multiples[i]);
1921       genericShape.push_back(inputShape[i]);
1922     }
1923 
1924     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1925         op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
1926 
1927     // We needs to map the input shape to the non-broadcasted dimensions.
1928     SmallVector<AffineExpr, 4> dimExprs;
1929     dimExprs.reserve(rank);
1930     for (unsigned i = 0; i < rank; ++i)
1931       dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1932 
1933     auto readAffineMap =
1934         AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1935                        rewriter.getContext());
1936 
1937     SmallVector<AffineMap, 2> affineMaps = {
1938         readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1939 
1940     auto genericOp = rewriter.create<linalg::GenericOp>(
1941         loc, RankedTensorType::get(genericShape, elementTy), input,
1942         ValueRange{initTensor}, affineMaps,
1943         getNParallelLoopsAttrs(genericShape.size()),
1944         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1945           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1946         });
1947 
1948     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1949         op, resultTy, genericOp.getResult(0),
1950         rewriter.getI64ArrayAttr(resultTy.getShape()));
1951     return success();
1952   }
1953 };
1954 
1955 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1956 public:
1957   using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
1958 
matchAndRewrite(tosa::PadOp padOp,PatternRewriter & rewriter) const1959   LogicalResult matchAndRewrite(tosa::PadOp padOp,
1960                                 PatternRewriter &rewriter) const final {
1961     auto loc = padOp.getLoc();
1962     auto input = padOp.input1();
1963     auto padding = padOp.padding();
1964 
1965     ShapedType inputTy = input.getType().cast<ShapedType>();
1966     ShapedType paddingTy = padding.getType().cast<ShapedType>();
1967     Type elementTy = inputTy.getElementType();
1968     int64_t rank = inputTy.getRank();
1969 
1970     if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
1971       return rewriter.notifyMatchFailure(
1972           padOp,
1973           "Pad converter requires static shaped input / padding values.");
1974     }
1975 
1976     Attribute constantAttr;
1977     if (elementTy.isa<FloatType>())
1978       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1979     else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
1980       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1981     else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
1982       auto value = padOp.quantization_info().getValue().input_zp().getValue();
1983       constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
1984     }
1985 
1986     if (!constantAttr) {
1987       return rewriter.notifyMatchFailure(
1988           padOp,
1989           "tosa.pad to linalg lowering encountered an unknown element type");
1990     }
1991 
1992     Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
1993     Value highIndex =
1994         rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
1995 
1996     SmallVector<OpFoldResult, 3> lowValues;
1997     SmallVector<OpFoldResult, 3> highValues;
1998 
1999     lowValues.reserve(rank);
2000     highValues.reserve(rank);
2001 
2002     for (int i = 0; i < rank; i++) {
2003       Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
2004       Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
2005           loc, padding, ValueRange({inputIndex, lowIndex}));
2006       Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
2007           loc, padding, ValueRange({inputIndex, highIndex}));
2008 
2009       lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2010                                                   lowVal);
2011       highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
2012                                                    highVal);
2013 
2014       lowValues.push_back(lowVal);
2015       highValues.push_back(highVal);
2016     }
2017 
2018     Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
2019 
2020     auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
2021         padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
2022 
2023     rewriter.replaceOp(padOp, newPadOp.getResult());
2024     return success();
2025   }
2026 };
2027 
2028 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2029 // op, producing two output buffers.
2030 //
2031 // The first output buffer contains the index of the found maximum value. It is
2032 // initialized to 0 and is resulting integer type.
2033 //
2034 // The second output buffer contains the maximum value found. It is initialized
2035 // to the minimum representable value of the input element type. After being
2036 // populated by indexed_generic, this buffer is disgarded as only the index is
2037 // requested.
2038 //
2039 // The indexed_generic op updates both the maximum value and index if the
2040 // current value exceeds the running max.
2041 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2042 public:
2043   using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2044 
matchAndRewrite(tosa::ArgMaxOp argmaxOp,PatternRewriter & rewriter) const2045   LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2046                                 PatternRewriter &rewriter) const final {
2047     auto loc = argmaxOp.getLoc();
2048     Value input = argmaxOp.input();
2049     auto inputTy = input.getType().cast<ShapedType>();
2050     auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
2051     auto inElementTy = inputTy.getElementType();
2052     auto outElementTy = resultTy.getElementType();
2053     int axis = argmaxOp.axis();
2054     auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2055 
2056     if (!inputTy.hasStaticShape())
2057       return rewriter.notifyMatchFailure(
2058           argmaxOp,
2059           "tosa.arg_max to linalg.* requires statically shaped input");
2060 
2061     if (!outElementTy.isa<IntegerType>())
2062       return rewriter.notifyMatchFailure(
2063           argmaxOp,
2064           "tosa.arg_max to linalg.* requires integer-like result type");
2065 
2066     // First fill the output buffer for the index.
2067     auto initTensorIdx =
2068         rewriter
2069             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2070                                           resultTy.getShape(), outElementTy)
2071             .result();
2072     auto fillValueIdx = rewriter.create<ConstantOp>(
2073         loc, rewriter.getIntegerAttr(outElementTy, 0));
2074     auto filledTensorIdx =
2075         rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
2076             .result();
2077 
2078     // Second fill the output buffer for the running max.
2079     auto initTensorMax =
2080         rewriter
2081             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
2082                                           resultTy.getShape(), inElementTy)
2083             .result();
2084     auto fillValueMaxAttr =
2085         createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2086 
2087     if (!fillValueMaxAttr)
2088       return rewriter.notifyMatchFailure(
2089           argmaxOp, "unsupported tosa.argmax element type");
2090 
2091     auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
2092     auto filledTensorMax =
2093         rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
2094             .result();
2095 
2096     // We need to reduce along the arg-max axis, with parallel operations along
2097     // the rest.
2098     SmallVector<StringRef, 4> iteratorTypes;
2099     iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
2100     iteratorTypes[axis] = getReductionIteratorTypeName();
2101 
2102     SmallVector<AffineExpr, 2> srcExprs;
2103     SmallVector<AffineExpr, 2> dstExprs;
2104     for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2105       srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2106       if (axis != i)
2107         dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2108     }
2109 
2110     bool didEncounterError = false;
2111     auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2112     auto linalgOp = rewriter.create<linalg::GenericOp>(
2113         loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2114         ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2115         [&](OpBuilder &nestedBuilder, Location nestedLoc,
2116             ValueRange blockArgs) {
2117           auto newValue = blockArgs[0];
2118           auto oldIndex = blockArgs[1];
2119           auto oldValue = blockArgs[2];
2120 
2121           Value newIndex = rewriter.create<IndexCastOp>(
2122               nestedLoc, oldIndex.getType(),
2123               rewriter.create<linalg::IndexOp>(loc, axis));
2124 
2125           Value predicate;
2126           if (inElementTy.isa<FloatType>()) {
2127             predicate = rewriter.create<mlir::CmpFOp>(
2128                 nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
2129           } else if (inElementTy.isa<IntegerType>()) {
2130             predicate = rewriter.create<mlir::CmpIOp>(
2131                 nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
2132           } else {
2133             didEncounterError = true;
2134             return;
2135           }
2136 
2137           auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
2138                                                            newValue, oldValue);
2139           auto resultIndex = rewriter.create<mlir::SelectOp>(
2140               nestedLoc, predicate, newIndex, oldIndex);
2141           nestedBuilder.create<linalg::YieldOp>(
2142               nestedLoc, ValueRange({resultIndex, resultMax}));
2143         });
2144 
2145     if (didEncounterError)
2146       return rewriter.notifyMatchFailure(
2147           argmaxOp, "unsupported tosa.argmax element type");
2148 
2149     rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2150     return success();
2151   }
2152 };
2153 
2154 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2155 public:
2156   using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2157   LogicalResult
matchAndRewrite(tosa::GatherOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const2158   matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
2159                   ConversionPatternRewriter &rewriter) const final {
2160     auto input = args[0];
2161     auto indices = args[1];
2162 
2163     auto inputTy = input.getType().cast<ShapedType>();
2164     auto indicesTy = indices.getType().cast<ShapedType>();
2165     auto resultTy = op.getType().cast<ShapedType>();
2166 
2167     if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
2168       return rewriter.notifyMatchFailure(
2169           op, "require input type to have static shape");
2170 
2171     auto resultElementTy = resultTy.getElementType();
2172 
2173     auto loc = op.getLoc();
2174 
2175     auto initTensor =
2176         rewriter
2177             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2178                                           resultTy.getShape(), resultElementTy)
2179             .result();
2180 
2181     SmallVector<AffineMap, 2> affineMaps = {
2182         AffineMap::get(
2183             /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2184             {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2185             rewriter.getContext()),
2186         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2187 
2188     auto genericOp = rewriter.create<linalg::GenericOp>(
2189         loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2190         ValueRange{initTensor}, affineMaps,
2191         getNParallelLoopsAttrs(resultTy.getRank()),
2192         [&](OpBuilder &b, Location loc, ValueRange args) {
2193           auto indexValue = args[0];
2194           auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2195           Value index1 = rewriter.create<IndexCastOp>(
2196               loc, rewriter.getIndexType(), indexValue);
2197           auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2198           Value extract = rewriter.create<tensor::ExtractOp>(
2199               loc, input, ValueRange{index0, index1, index2});
2200           rewriter.create<linalg::YieldOp>(loc, extract);
2201         });
2202     rewriter.replaceOp(op, genericOp.getResult(0));
2203     return success();
2204   }
2205 };
2206 
2207 // Lowerings the TableOp to a series of gathers and numerica operations. This
2208 // includes interpolation between the high/low values. For the I8 varient, this
2209 // simplifies to a single gather operation.
2210 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2211 public:
2212   using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2213 
matchAndRewrite(tosa::TableOp op,PatternRewriter & rewriter) const2214   LogicalResult matchAndRewrite(tosa::TableOp op,
2215                                 PatternRewriter &rewriter) const final {
2216     auto loc = op.getLoc();
2217     Value input = op.input();
2218     Value table = op.table();
2219     auto inputTy = input.getType().cast<ShapedType>();
2220     auto tableTy = table.getType().cast<ShapedType>();
2221     auto resultTy = op.getType().cast<ShapedType>();
2222 
2223     if (!inputTy.hasStaticShape())
2224       return rewriter.notifyMatchFailure(
2225           op, "require input type to have static shape");
2226 
2227     auto inputElementTy = inputTy.getElementType();
2228     auto tableElementTy = tableTy.getElementType();
2229     auto resultElementTy = resultTy.getElementType();
2230 
2231     auto initTensor =
2232         rewriter
2233             .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
2234                                           resultTy.getShape(), resultElementTy)
2235             .result();
2236 
2237     SmallVector<AffineMap, 2> affineMaps = {
2238         rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2239         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2240 
2241     auto genericOp = rewriter.create<linalg::GenericOp>(
2242         loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2243         getNParallelLoopsAttrs(resultTy.getRank()));
2244     rewriter.replaceOp(op, genericOp.getResult(0));
2245 
2246     {
2247       OpBuilder::InsertionGuard regionGuard(rewriter);
2248       Block *block =
2249           rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
2250                                TypeRange({inputElementTy, resultElementTy}));
2251 
2252       auto inputValue = block->getArgument(0);
2253       rewriter.setInsertionPointToStart(block);
2254       if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2255           resultElementTy.isInteger(8)) {
2256         Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
2257                                                    inputValue);
2258         Value extract =
2259             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2260         rewriter.create<linalg::YieldOp>(loc, extract);
2261         return success();
2262       }
2263 
2264       if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2265           resultElementTy.isInteger(32)) {
2266         Value extend = rewriter.create<SignExtendIOp>(
2267             loc, rewriter.getI32Type(), inputValue);
2268 
2269         auto offset =
2270             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
2271         auto seven =
2272             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
2273         auto one =
2274             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
2275         auto b1111111 =
2276             rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
2277 
2278         // Compute the index and fractional part from the input value:
2279         // value = value + 32768
2280         // index = value >> 7;
2281         // fraction = 0x01111111 & value
2282         auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
2283         Value index =
2284             rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
2285         Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
2286 
2287         // Extract the base and next values from the table.
2288         // base = (int32_t) table[index];
2289         // next = (int32_t) table[index + 1];
2290         Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
2291 
2292         index =
2293             rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
2294         indexPlusOne = rewriter.create<IndexCastOp>(
2295             loc, rewriter.getIndexType(), indexPlusOne);
2296 
2297         Value base =
2298             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2299         Value next = rewriter.create<tensor::ExtractOp>(
2300             loc, table, ValueRange{indexPlusOne});
2301 
2302         base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
2303         next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
2304 
2305         // Use the fractional part to interpolate between the input values:
2306         // result = (base << 7) + (next - base) * fraction
2307         Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
2308         Value diff = rewriter.create<SubIOp>(loc, next, base);
2309         Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction);
2310         Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled);
2311 
2312         rewriter.create<linalg::YieldOp>(loc, result);
2313 
2314         return success();
2315       }
2316     }
2317 
2318     return rewriter.notifyMatchFailure(
2319         op, "unable to create body for tosa.table op");
2320   }
2321 };
2322 
2323 template <typename SrcOp>
2324 class Pool2dConverter : public OpRewritePattern<SrcOp> {
2325 public:
2326   using OpRewritePattern<SrcOp>::OpRewritePattern;
2327 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const2328   LogicalResult matchAndRewrite(SrcOp op,
2329                                 PatternRewriter &rewriter) const final {
2330     Location loc = op.getLoc();
2331     Value input = op.input();
2332     ShapedType inputTy = input.getType().cast<ShapedType>();
2333     Type inElementTy = inputTy.getElementType();
2334 
2335     ShapedType resultTy = op.getType().template cast<ShapedType>();
2336     Type outElementTy = inputTy.getElementType();
2337 
2338     if (!inputTy.hasStaticShape())
2339       return failure();
2340 
2341     // Determine what the initial value needs to be for the max pool op.
2342     Attribute initialAttr;
2343     if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
2344       initialAttr = rewriter.getFloatAttr(
2345           outElementTy,
2346           APFloat::getLargest(
2347               outElementTy.cast<FloatType>().getFloatSemantics(), true));
2348 
2349     if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
2350       initialAttr = rewriter.getIntegerAttr(
2351           outElementTy,
2352           APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
2353 
2354     if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
2355       initialAttr = rewriter.getZeroAttr(outElementTy);
2356 
2357     if (!initialAttr)
2358       return rewriter.notifyMatchFailure(
2359           op, "Unsupported initial value for tosa.maxpool_2d op");
2360 
2361     // Apply padding as necessary.
2362     llvm::SmallVector<int64_t> pad;
2363     pad.resize(2, 0);
2364     getValuesFromIntArrayAttribute(op.pad(), pad);
2365     pad.resize(pad.size() + 2, 0);
2366     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
2367 
2368     Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
2369 
2370     SmallVector<int64_t> kernel, stride;
2371     getValuesFromIntArrayAttribute(op.kernel(), kernel);
2372     getValuesFromIntArrayAttribute(op.stride(), stride);
2373 
2374     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
2375     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
2376 
2377     // Create the linalg op that performs pooling.
2378     Value initTensor = rewriter.create<linalg::InitTensorOp>(
2379         loc, resultTy.getShape(), resultTy.getElementType());
2380 
2381     Value filledInitTensor =
2382         rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
2383 
2384     Value fakeWindowDims =
2385         rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
2386 
2387     if (isa<tosa::MaxPool2dOp>(op)) {
2388       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
2389           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
2390           filledInitTensor, strideAttr, dilationAttr);
2391       return success();
2392     }
2393 
2394     if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
2395       Value poolingOp = rewriter
2396                             .create<linalg::PoolingNhwcSumOp>(
2397                                 loc, ArrayRef<Type>{resultTy},
2398                                 ValueRange{paddedInput, fakeWindowDims},
2399                                 filledInitTensor, strideAttr, dilationAttr)
2400                             .getResult(0);
2401       auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
2402       auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
2403       auto genericOp = rewriter.create<linalg::GenericOp>(
2404           loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
2405           ArrayRef<AffineMap>({affineMap}),
2406           getNParallelLoopsAttrs(resultTy.getRank()),
2407           [&](OpBuilder &b, Location loc, ValueRange args) {
2408             auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
2409             auto one = rewriter.create<ConstantIndexOp>(loc, 1);
2410             auto iH = rewriter.create<ConstantIndexOp>(
2411                 loc, poolingOpTy.getDimSize(1) - 1);
2412             auto iW = rewriter.create<ConstantIndexOp>(
2413                 loc, poolingOpTy.getDimSize(2) - 1);
2414 
2415             // Compute the indices from either end.
2416             auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
2417             auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
2418             auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
2419             auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
2420 
2421             // Determines what the portion of valid input is covered by the
2422             // kernel.
2423             auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
2424               if (pad == 0)
2425                 return v;
2426 
2427               auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
2428               Value dx = rewriter.create<SubIOp>(loc, x, padVal);
2429 
2430               Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
2431                                                         dx, zero);
2432               Value offset =
2433                   rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
2434               return rewriter.create<mlir::AddIOp>(loc, v, offset)
2435                   ->getResult(0);
2436             };
2437 
2438             // Compute the vertical component of coverage.
2439             auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
2440             auto kH1 = padFn(kH0, y0, pad[2]);
2441             auto kH2 = padFn(kH1, y1, pad[3]);
2442             auto kHCmp =
2443                 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
2444             auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
2445 
2446             // compute teh horizontal component of coverage.
2447             auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
2448             auto kW1 = padFn(kW0, x0, pad[4]);
2449             auto kW2 = padFn(kW1, x1, pad[5]);
2450             auto kWCmp =
2451                 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
2452             auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
2453 
2454             // Compute the total number of elements and normalize.
2455             Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
2456             auto countI = rewriter.create<mlir::IndexCastOp>(
2457                 loc, rewriter.getI32Type(), count);
2458             auto countF =
2459                 rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
2460 
2461             auto div =
2462                 rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
2463 
2464             rewriter.create<linalg::YieldOp>(loc, div);
2465           });
2466 
2467       rewriter.replaceOp(op, genericOp.getResult(0));
2468       return success();
2469     }
2470 
2471     return failure();
2472   }
2473 };
2474 
2475 } // namespace
2476 
populateTosaToLinalgOnTensorsConversionPatterns(RewritePatternSet * patterns)2477 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
2478     RewritePatternSet *patterns) {
2479   patterns->add<
2480       // clang-format off
2481       PointwiseConverter<tosa::AddOp>,
2482       PointwiseConverter<tosa::SubOp>,
2483       PointwiseConverter<tosa::MulOp>,
2484       PointwiseConverter<tosa::DivOp>,
2485       PointwiseConverter<tosa::NegateOp>,
2486       PointwiseConverter<tosa::PowOp>,
2487       PointwiseConverter<tosa::ReciprocalOp>,
2488       PointwiseConverter<tosa::RsqrtOp>,
2489       PointwiseConverter<tosa::LogOp>,
2490       PointwiseConverter<tosa::ExpOp>,
2491       PointwiseConverter<tosa::AbsOp>,
2492       PointwiseConverter<tosa::TanhOp>,
2493       PointwiseConverter<tosa::BitwiseAndOp>,
2494       PointwiseConverter<tosa::BitwiseOrOp>,
2495       PointwiseConverter<tosa::BitwiseNotOp>,
2496       PointwiseConverter<tosa::BitwiseXorOp>,
2497       PointwiseConverter<tosa::LogicalAndOp>,
2498       PointwiseConverter<tosa::LogicalNotOp>,
2499       PointwiseConverter<tosa::LogicalOrOp>,
2500       PointwiseConverter<tosa::LogicalXorOp>,
2501       PointwiseConverter<tosa::CastOp>,
2502       PointwiseConverter<tosa::LogicalLeftShiftOp>,
2503       PointwiseConverter<tosa::LogicalRightShiftOp>,
2504       PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2505       PointwiseConverter<tosa::SelectOp>,
2506       PointwiseConverter<tosa::GreaterOp>,
2507       PointwiseConverter<tosa::GreaterEqualOp>,
2508       PointwiseConverter<tosa::EqualOp>,
2509       PointwiseConverter<tosa::MaximumOp>,
2510       PointwiseConverter<tosa::MinimumOp>,
2511       PointwiseConverter<tosa::CeilOp>,
2512       PointwiseConverter<tosa::FloorOp>,
2513       PointwiseConverter<tosa::ClampOp>,
2514       PointwiseConverter<tosa::ReluNOp>,
2515       PointwiseConverter<tosa::SigmoidOp>,
2516       IdentityNConverter<tosa::IdentityOp>,
2517       ReduceConverter<tosa::ReduceAllOp>,
2518       ReduceConverter<tosa::ReduceAnyOp>,
2519       ReduceConverter<tosa::ReduceMinOp>,
2520       ReduceConverter<tosa::ReduceMaxOp>,
2521       ReduceConverter<tosa::ReduceSumOp>,
2522       ReduceConverter<tosa::ReduceProdOp>,
2523       ArgMaxConverter,
2524       ConcatConverter,
2525       ConvConverter<tosa::Conv2DOp>,
2526       ConvConverter<tosa::DepthwiseConv2DOp>,
2527       TransposeConvConverter,
2528       GatherConverter,
2529       PadConverter,
2530       ReshapeConverter,
2531       RescaleConverter,
2532       ResizeConverter,
2533       ReverseConverter,
2534       TableConverter,
2535       TileConverter,
2536       TransposeConverter,
2537       MatMulConverter,
2538       Pool2dConverter<tosa::AvgPool2dOp>,
2539       Pool2dConverter<tosa::MaxPool2dOp>,
2540       FullyConnectedConverter>(patterns->getContext());
2541   // clang-format on
2542 }
2543