1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/Transforms/FoldUtils.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 
24 using namespace mlir;
25 using namespace mlir::tosa;
26 
27 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // Tosa dialect structs and interface includes.
31 //===----------------------------------------------------------------------===//
32 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
33 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
34 
35 namespace {
36 //===----------------------------------------------------------------------===//
37 // Dialect Function Inliner Interface.
38 //===----------------------------------------------------------------------===//
39 struct TosaInlinerInterface : public DialectInlinerInterface {
40   using DialectInlinerInterface::DialectInlinerInterface;
41 
42   //===--------------------------------------------------------------------===//
43   // Analysis Hooks.
44   //===--------------------------------------------------------------------===//
45 
46   /// All operations can be inlined by default.
isLegalToInline__anon8a7822f50111::TosaInlinerInterface47   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
48                        BlockAndValueMapping &map) const final {
49     return true;
50   }
51 
52   /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon8a7822f50111::TosaInlinerInterface53   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
54                        BlockAndValueMapping &map) const final {
55     return (isa<tosa::IfOp>(dest->getParentOp()) ||
56             isa<tosa::WhileOp>(dest->getParentOp()));
57   }
58 };
59 } // end anonymous namespace
60 
61 //===----------------------------------------------------------------------===//
62 // TOSA control flow support.
63 //===----------------------------------------------------------------------===//
64 
65 /// Returns the while loop body.
getLoopBody()66 Region &tosa::WhileOp::getLoopBody() { return body(); }
67 
isDefinedOutsideOfLoop(Value value)68 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
69   return !body().isAncestor(value.getParentRegion());
70 }
71 
moveOutOfLoop(ArrayRef<mlir::Operation * > ops)72 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
73   if (ops.empty())
74     return success();
75 
76   Operation *tosaWhileOp = this->getOperation();
77   for (auto *op : ops)
78     op->moveBefore(tosaWhileOp);
79 
80   return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // Tosa dialect initialization.
85 //===----------------------------------------------------------------------===//
86 
initialize()87 void TosaDialect::initialize() {
88   addOperations<
89 #define GET_OP_LIST
90 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
91       >();
92   addInterfaces<TosaInlinerInterface>();
93 }
94 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)95 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
96                                             Type type, Location loc) {
97   // Tosa dialect constants only support ElementsAttr unlike standard dialect
98   // constant which supports all attributes.
99   if (value.isa<ElementsAttr>())
100     return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
101   return nullptr;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Operator Folders.
106 //===----------------------------------------------------------------------===//
107 
fold(ArrayRef<Attribute> operands)108 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
109   assert(operands.empty() && "constant has no operands");
110   return valueAttr();
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // TOSA Operator Verifiers.
115 //===----------------------------------------------------------------------===//
116 
117 template <typename T>
verifyConvOp(T op)118 static LogicalResult verifyConvOp(T op) {
119   // All TOSA conv ops have an input() and weight().
120   auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
121   auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
122 
123   // Must be ranked tensor types
124   if (!inputType || !weightType)
125     return failure();
126 
127   auto inputEType = inputType.getElementType();
128   auto weightEType = weightType.getElementType();
129 
130   bool inputIsQuant = !inputEType.template isa<FloatType>();
131   bool weightIsQuant = !weightEType.template isa<FloatType>();
132 
133   // Either both must be quantized or both unquantized.
134   if (inputIsQuant != weightIsQuant)
135     return failure();
136 
137   // Quantized type must have constructed the quantizationattr, and unquantized
138   // types should not have a quantizationattr.
139   if ((inputIsQuant && !op.quantization_info()) ||
140       (!inputIsQuant && op.quantization_info()))
141     return failure();
142 
143   return success();
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // TOSA Operator Quantization Builders.
148 //===----------------------------------------------------------------------===//
149 
150 /// This builder is called on all convolution operators except TransposeConv,
151 /// which has specialized output shape semantics. The builder also defines the
152 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)153 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
154                                      Type outputType, Value input, Value weight,
155                                      Value bias, ArrayAttr pad,
156                                      ArrayAttr stride, ArrayAttr dilation) {
157 
158   result.addOperands({input, weight, bias});
159   result.addAttribute("pad", pad);
160   result.addAttribute("stride", stride);
161   result.addAttribute("dilation", dilation);
162 
163   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
164   if (quantAttr) {
165     result.addAttribute("quantization_info", quantAttr);
166     result.addTypes(
167         buildConvOpResultTypeInfo(builder, outputType, input, weight));
168   } else {
169     result.addTypes(outputType);
170   }
171 }
172 
173 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
174 static void
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr dilation,ArrayAttr outputShape)175 buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
176                               Type outputType, Value input, Value weight,
177                               Value bias, ArrayAttr outpad, ArrayAttr stride,
178                               ArrayAttr dilation, ArrayAttr outputShape) {
179   result.addOperands({input, weight, bias});
180   result.addAttribute("out_pad", outpad);
181   result.addAttribute("stride", stride);
182   result.addAttribute("dilation", dilation);
183   result.addAttribute("out_shape", outputShape);
184   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
185 
186   if (quantAttr) {
187     result.addAttribute("quantization_info", quantAttr);
188     result.addTypes(
189         buildConvOpResultTypeInfo(builder, outputType, input, weight));
190   } else {
191     result.addTypes(outputType);
192   }
193 }
194 
195 /// The tosa.fully_connected op has its own builder as it does not have
196 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)197 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
198                                    Type outputType, Value input, Value weight,
199                                    Value bias) {
200 
201   result.addOperands({input, weight, bias});
202   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
203   if (quantAttr) {
204     result.addAttribute("quantization_info", quantAttr);
205     result.addTypes(
206         buildConvOpResultTypeInfo(builder, outputType, input, weight));
207   } else {
208     result.addTypes(outputType);
209   }
210 }
211 
212 /// The tosa.matmul op is also intended to be generated where a fully_connected
213 /// op must be constructed where the weight is not a constant. In this case,
214 /// the fully_connected op must be expressed using matmul.
215 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)216 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
217                                        OperationState &result, Type outputType,
218                                        Value a, Value b) {
219   result.addOperands({a, b});
220   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
221 
222   if (quantAttr) {
223     result.addAttribute("quantization_info", quantAttr);
224 
225     auto inputType = a.getType().dyn_cast<RankedTensorType>();
226     assert(inputType && "Input must be a ranked tensor type!");
227 
228     auto inputQType = inputType.getElementType()
229                           .dyn_cast<mlir::quant::UniformQuantizedType>();
230     assert(inputQType && "Tensor must have quantized datatype!");
231 
232     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
233 
234     auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
235     assert(outputShapedType && "Output must be a ranked tensor type");
236 
237     auto outputShape = outputShapedType.getShape();
238 
239     IntegerType accElementType;
240     if (inputBits == 16)
241       accElementType = builder.getIntegerType(48);
242     else
243       accElementType = builder.getI32Type();
244     auto accType = RankedTensorType::get(outputShape, accElementType);
245     result.addTypes(accType);
246   } else {
247     result.addTypes(outputType);
248   }
249 }
250 
251 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
252 /// but avg_pool operator has its own builder as it has additional parameters
253 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)254 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
255                                           OperationState &result,
256                                           Type outputType, Value input,
257                                           ArrayAttr kernel, ArrayAttr stride,
258                                           ArrayAttr pad) {
259   result.addOperands(input);
260   result.addAttribute("kernel", kernel);
261   result.addAttribute("stride", stride);
262   result.addAttribute("pad", pad);
263   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
264   if (quantAttr)
265     result.addAttribute("quantization_info", quantAttr);
266   result.types.push_back(outputType);
267 }
268 
269 /// This builder is called on single-parameter unary operators that have scale
270 /// relationship between their input and output, expressed by the
271 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)272 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
273                                       OperationState &result, Type outputType,
274                                       Value input) {
275   result.addOperands(input);
276   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
277   if (quantAttr)
278     result.addAttribute("quantization_info", quantAttr);
279   result.types.push_back(outputType);
280 }
281 
282 /// This builder is called on TOSA pad operator that needs to create its own
283 /// OptionalAttr quantization_attr parameter to scale the padding values
284 /// correctly.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)285 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
286                                     Type outputType, Value input,
287                                     Value paddings) {
288   result.addOperands({input, paddings});
289   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
290   if (quantAttr)
291     result.addAttribute("quantization_info", quantAttr);
292   result.types.push_back(outputType);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // TOSA Operator Return Type Inference.
297 //===----------------------------------------------------------------------===//
298 
getI64Values(ArrayAttr arrayAttr,SmallVector<int64_t> & values)299 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
300   for (auto it : arrayAttr) {
301     values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
302   }
303 }
304 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)305 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
306     MLIRContext *context, ::llvm::Optional<Location> location,
307     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
308     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
309   ShapedType inputTy = operands[0].getType().cast<ShapedType>();
310   IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
311   int32_t axisVal = axis.getValue().getSExtValue();
312 
313   if (!inputTy.hasRank()) {
314     inferredReturnShapes.push_back(ShapedTypeComponents());
315     return success();
316   }
317 
318   SmallVector<int64_t> outShape;
319   outShape.reserve(inputTy.getRank() - 1);
320   for (int i = 0, s = inputTy.getRank(); i < s; i++) {
321     if (i == axisVal)
322       continue;
323     outShape.push_back(inputTy.getDimSize(i));
324   }
325 
326   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
327   return success();
328 }
329 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)330 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
331     MLIRContext *context, ::llvm::Optional<Location> location,
332     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
333     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
334   // Infer all dimension sizes by reducing based on inputs.
335   int32_t axis =
336       attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
337   llvm::SmallVector<int64_t> outputShape;
338   bool hasRankedInput = false;
339   for (auto operand : operands) {
340     ShapedType operandTy = operand.getType().cast<ShapedType>();
341     if (!operandTy.hasRank())
342       continue;
343 
344     // Copy the Operand's rank.
345     if (!hasRankedInput)
346       outputShape.resize(operandTy.getRank(), -1);
347 
348     // Copy shapes until the dim is non-dynamic.
349     for (int i = 0, s = operandTy.getRank(); i < s; i++) {
350       if (i == axis || operandTy.isDynamicDim(i))
351         continue;
352       if (outputShape[i] == -1)
353         outputShape[i] = operandTy.getDimSize(i);
354       if (outputShape[i] != operandTy.getDimSize(i))
355         return failure();
356     }
357 
358     hasRankedInput = true;
359   }
360 
361   if (!hasRankedInput) {
362     inferredReturnShapes.push_back(ShapedTypeComponents());
363     return success();
364   }
365 
366   // Determine the dimension size along the concatenation axis.
367   int concatDimSize = 0;
368   for (auto operand : operands) {
369     ShapedType operandTy = operand.getType().cast<ShapedType>();
370 
371     // We need to know the length of the concatenation axis of all inputs to
372     // determine the dimension size of the output shape.
373     if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) {
374       concatDimSize = -1;
375       break;
376     }
377 
378     concatDimSize += operandTy.getDimSize(axis);
379   }
380 
381   outputShape[axis] = concatDimSize;
382 
383   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
384   return success();
385 }
386 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)387 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
388     MLIRContext *context, ::llvm::Optional<Location> location,
389     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
390     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
391   ShapedType inputTy = operands[0].getType().cast<ShapedType>();
392   ShapedType weightTy = operands[1].getType().cast<ShapedType>();
393   ShapedType biasTy = operands[2].getType().cast<ShapedType>();
394 
395   // All shapes are dynamic.
396   SmallVector<int64_t> outShape;
397   outShape.resize(2, -1);
398 
399   if (inputTy.hasRank()) {
400     outShape[0] = inputTy.getDimSize(0);
401   }
402 
403   if (weightTy.hasRank()) {
404     outShape[1] = weightTy.getDimSize(0);
405   }
406 
407   if (biasTy.hasRank()) {
408     outShape[1] = outShape[1] == -1 ? biasTy.getDimSize(0) : outShape[1];
409   }
410 
411   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
412   return success();
413 }
414 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)415 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
416     MLIRContext *context, ::llvm::Optional<Location> location,
417     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
418     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
419   ShapedType lhsTy = operands[0].getType().cast<ShapedType>();
420   ShapedType rhsTy = operands[1].getType().cast<ShapedType>();
421 
422   // All shapes are dynamic.
423   SmallVector<int64_t> outShape;
424   outShape.resize(3, -1);
425 
426   if (lhsTy.hasRank()) {
427     outShape[0] = lhsTy.getDimSize(0);
428     outShape[1] = lhsTy.getDimSize(1);
429   }
430 
431   if (rhsTy.hasRank()) {
432     outShape[0] = outShape[0] == -1 ? rhsTy.getDimSize(0) : outShape[0];
433     outShape[2] = rhsTy.getDimSize(2);
434   }
435 
436   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
437   return success();
438 }
439 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)440 LogicalResult tosa::PadOp::inferReturnTypeComponents(
441     MLIRContext *context, ::llvm::Optional<Location> location,
442     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
443     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
444   ShapedType inputTy = operands[0].getType().cast<ShapedType>();
445   ShapedType paddingTy = operands[1].getType().cast<ShapedType>();
446   SmallVector<int64_t> outputShape;
447 
448   // If both inputs have unknown shape, we cannot determine the shape of the
449   // output.
450   if (!inputTy.hasRank() && !paddingTy.hasRank()) {
451     inferredReturnShapes.push_back(ShapedTypeComponents());
452     return success();
453   }
454 
455   // If the input rank is unknown we can info the output rank using the padding
456   // shape's first dim.
457   if (!inputTy.hasRank()) {
458     if (paddingTy.isDynamicDim(0)) {
459       inferredReturnShapes.push_back(ShapedTypeComponents());
460       return success();
461     }
462 
463     outputShape.resize(paddingTy.getDimSize(0), -1);
464     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
465     return success();
466   }
467 
468   DenseIntElementsAttr paddings;
469   // If the paddings value is not a constant, all dimensions must be dynamic.
470   if (!matchPattern(operands[1], m_Constant(&paddings))) {
471     outputShape.resize(inputTy.getRank(), -1);
472     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
473     return success();
474   }
475 
476   SmallVector<int64_t> paddingValues;
477   for (auto val : paddings) {
478     paddingValues.push_back(val.getSExtValue());
479   }
480 
481   outputShape.reserve(inputTy.getRank());
482   for (int i = 0, s = inputTy.getRank(); i < s; i++) {
483     if (inputTy.isDynamicDim(i)) {
484       outputShape.push_back(-1);
485       continue;
486     }
487 
488     outputShape.push_back(inputTy.getDimSize(i) + paddingValues[i * 2] +
489                           paddingValues[i * 2 + 1]);
490   }
491 
492   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
493   return success();
494 }
495 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)496 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
497     MLIRContext *context, ::llvm::Optional<Location> location,
498     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
499     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
500   auto sizes = attributes.get("size").cast<ArrayAttr>().getValue();
501   SmallVector<int64_t> outputShape;
502   outputShape.reserve(sizes.size());
503   for (auto val : sizes) {
504     outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
505   }
506 
507   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
508   return success();
509 }
510 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)511 LogicalResult tosa::TableOp::inferReturnTypeComponents(
512     MLIRContext *context, ::llvm::Optional<Location> location,
513     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
514     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
515   ShapedType inputTy = operands[0].getType().cast<ShapedType>();
516 
517   if (!inputTy.hasRank()) {
518     inferredReturnShapes.push_back(ShapedTypeComponents());
519     return success();
520   }
521 
522   inferredReturnShapes.push_back(inputTy.getShape());
523   return success();
524 }
525 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)526 LogicalResult tosa::TileOp::inferReturnTypeComponents(
527     MLIRContext *context, ::llvm::Optional<Location> location,
528     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
529     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
530   auto multiples = attributes.get("multiples").cast<ArrayAttr>().getValue();
531   ShapedType inputTy = operands[0].getType().cast<ShapedType>();
532   SmallVector<int64_t> outputShape;
533   if (!inputTy.hasRank()) {
534     outputShape.resize(multiples.size(), -1);
535     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
536     return success();
537   }
538 
539   // We need the multiple values to determine the output shape.
540   SmallVector<int64_t> multipleValues;
541   multipleValues.reserve(multiples.size());
542   for (auto val : multiples) {
543     multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
544   }
545 
546   // Any non dynamic dimension can be multiplied to a known size.
547   outputShape.reserve(multiples.size());
548   for (int i = 0, s = inputTy.getRank(); i < s; i++) {
549     int dim = inputTy.getDimSize(i);
550     if (dim != -1)
551       dim *= multipleValues[i];
552     outputShape.push_back(dim);
553   }
554 
555   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
556   return success();
557 }
558 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)559 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
560     MLIRContext *context, ::llvm::Optional<Location> location,
561     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
562     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
563   ShapedType type = operands.front().getType().cast<ShapedType>();
564 
565   auto newShape = attributes.get("new_shape").cast<ArrayAttr>();
566   llvm::SmallVector<int64_t> newShapeValue;
567   getI64Values(newShape, newShapeValue);
568 
569   // We cannot infer from the total number of elements so we must take the
570   // shape attribute as exact.
571   if (!type.hasRank() || !type.hasStaticShape()) {
572     inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
573     return success();
574   }
575 
576   // Determine the number of elements covered by the slice of all static
577   // dimensions. This allows us to infer the length of the remaining dynamic
578   // dimension.
579   int64_t numElements = type.getNumElements();
580   int64_t staticMul = 1;
581   for (auto val : newShapeValue) {
582     if (val != -1) {
583       staticMul *= val;
584     }
585   }
586 
587   // Determine the length of the dynamic dimension.
588   for (auto &val : newShapeValue) {
589     if (val == -1)
590       val = numElements / staticMul;
591   }
592 
593   inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
594   return success();
595 }
596 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)597 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
598     MLIRContext *context, ::llvm::Optional<Location> location,
599     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
600     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
601   ShapedType inputTy = operands[0].getType().cast<ShapedType>();
602   ShapedType permsTy = operands[1].getType().cast<ShapedType>();
603 
604   // If input rank and permutation length is unknown, the output rank is
605   // unknown.
606   if (!inputTy.hasRank() && (!permsTy.hasRank() || permsTy.isDynamicDim(0))) {
607     inferredReturnShapes.push_back(ShapedTypeComponents());
608     return success();
609   }
610 
611   // Without the input dims we cannot determine the output dim sizes but we
612   // can determine the output rank.
613   SmallVector<int64_t> outputShape;
614   if (!inputTy.hasRank()) {
615     outputShape.resize(permsTy.getDimSize(0), -1);
616     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
617     return success();
618   }
619 
620   // Rank-0 means no permutations matter.
621   if (inputTy.getRank() == 0) {
622     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
623     return success();
624   }
625 
626   // Check whether the input dimensions are all the same.
627   bool allTheSame = true;
628   for (int i = 1, s = inputTy.getRank(); i < s; i++) {
629     if (inputTy.getDimSize(0) != inputTy.getDimSize(i)) {
630       allTheSame = false;
631       break;
632     }
633   }
634 
635   // If all of the input dimensions are the same we don't care about the
636   // permutation.
637   if (allTheSame) {
638     outputShape.resize(inputTy.getRank(), inputTy.getDimSize(0));
639     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
640     return success();
641   }
642 
643   DenseIntElementsAttr perms;
644   outputShape.resize(inputTy.getRank(), -1);
645   // If the permuations are a constant we can directly determine the output
646   // shape.
647   if (matchPattern(operands[1], m_Constant(&perms))) {
648     llvm::SmallVector<int64_t> permValues;
649     for (auto val : perms) {
650       permValues.push_back(val.getSExtValue());
651     }
652 
653     outputShape.reserve(inputTy.getRank());
654     for (int i = 0, s = inputTy.getRank(); i < s; i++) {
655       outputShape[i] = inputTy.getDimSize(permValues[i]);
656     }
657   }
658 
659   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
660   return success();
661 }
662 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)663 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
664     MLIRContext *context, ::llvm::Optional<Location> location,
665     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
666     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
667   llvm::SmallVector<int64_t> outputShape;
668   outputShape.resize(3, -1);
669 
670   if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
671     outputShape[0] = ty.getDimSize(0);
672     outputShape[2] = ty.getDimSize(2);
673   }
674 
675   if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
676     if (outputShape[0] == -1)
677       outputShape[0] = ty.getDimSize(0);
678     if (outputShape[1] == -1)
679       outputShape[1] = ty.getDimSize(1);
680   }
681 
682   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
683   return success();
684 }
685 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)686 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
687     MLIRContext *context, ::llvm::Optional<Location> location,
688     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
689     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
690   llvm::SmallVector<int64_t> outputShape;
691   outputShape.resize(3, -1);
692 
693   if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
694     outputShape[0] = ty.getDimSize(0);
695     outputShape[1] = ty.getDimSize(1);
696     outputShape[2] = ty.getDimSize(2);
697   }
698 
699   if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
700     if (outputShape[0] == -1)
701       outputShape[0] = ty.getDimSize(0);
702   }
703 
704   if (auto ty = operands[2].getType().dyn_cast<RankedTensorType>()) {
705     if (outputShape[0] == -1)
706       outputShape[0] = ty.getDimSize(0);
707     if (outputShape[2] == -1)
708       outputShape[2] = ty.getDimSize(2);
709   }
710 
711   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
712   return success();
713 }
714 
ReduceInferReturnTypes(Value operand,IntegerAttr axis,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)715 static LogicalResult ReduceInferReturnTypes(
716     Value operand, IntegerAttr axis,
717     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
718   auto operandTy = operand.getType().cast<ShapedType>();
719   if (!operandTy.hasRank()) {
720     inferredReturnShapes.push_back(ShapedTypeComponents());
721     return success();
722   }
723 
724   int64_t axisVal = axis.getValue().getSExtValue();
725   SmallVector<int64_t> outputShape;
726   outputShape.reserve(operandTy.getRank());
727   for (auto dim : operandTy.getShape()) {
728     outputShape.push_back(dim);
729   }
730 
731   outputShape[axisVal] = 1;
732   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
733   return success();
734 }
735 
736 #define REDUCE_SHAPE_INFER(OP)                                                 \
737   LogicalResult OP::inferReturnTypeComponents(                                 \
738       MLIRContext *context, ::llvm::Optional<Location> location,               \
739       ValueShapeRange operands, DictionaryAttr attributes,                     \
740       RegionRange regions,                                                     \
741       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
742     return ReduceInferReturnTypes(operands[0],                                 \
743                                   attributes.get("axis").cast<IntegerAttr>(),  \
744                                   inferredReturnShapes);                       \
745   }
746 
747 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)748 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
749 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
750 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
751 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
752 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
753 #undef REDUCE_SHAPE_INFER
754 
755 static LogicalResult resolveBroadcastShape(ValueRange operands,
756                                            SmallVector<int64_t> &outShape) {
757   int64_t outRank = 0;
758   for (auto operand : operands) {
759     auto type = operand.getType().cast<ShapedType>();
760     if (!type.hasRank())
761       return failure();
762     outRank = std::max<int64_t>(outRank, type.getRank());
763   }
764 
765   outShape.resize(outRank, 1);
766 
767   for (auto operand : operands) {
768     auto type = operand.getType().cast<ShapedType>();
769     auto shape = type.getShape();
770     auto rankDiff = outShape.size() - shape.size();
771 
772     for (size_t i = 0; i < shape.size(); i++) {
773       auto dim1 = outShape[i + rankDiff];
774       auto dim2 = shape[i];
775       auto resolvedDim = dim1;
776 
777       if (dim1 == 1) {
778         resolvedDim = dim2;
779       } else if (dim2 == 1) {
780         resolvedDim = dim1;
781       } else if (dim1 != dim2) {
782         return failure();
783       }
784       outShape[i + rankDiff] = resolvedDim;
785     }
786   }
787 
788   return success();
789 }
790 
NAryInferReturnTypes(ValueRange operands,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)791 static LogicalResult NAryInferReturnTypes(
792     ValueRange operands,
793     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
794   llvm::SmallVector<int64_t> outShape;
795   if (resolveBroadcastShape(operands, outShape).failed()) {
796     inferredReturnShapes.push_back(ShapedTypeComponents());
797   } else {
798     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
799   }
800   return success();
801 }
802 
803 #define NARY_SHAPE_INFER(OP)                                                   \
804   LogicalResult OP::inferReturnTypeComponents(                                 \
805       MLIRContext *context, ::llvm::Optional<Location> location,               \
806       ValueShapeRange operands, DictionaryAttr attributes,                     \
807       RegionRange regions,                                                     \
808       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
809     return NAryInferReturnTypes(operands, inferredReturnShapes);               \
810   }
811 
812 NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)813 NARY_SHAPE_INFER(tosa::AddOp)
814 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
815 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
816 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
817 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
818 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
819 NARY_SHAPE_INFER(tosa::CeilOp)
820 NARY_SHAPE_INFER(tosa::ClampOp)
821 NARY_SHAPE_INFER(tosa::ClzOp)
822 NARY_SHAPE_INFER(tosa::DivOp)
823 NARY_SHAPE_INFER(tosa::EqualOp)
824 NARY_SHAPE_INFER(tosa::ExpOp)
825 NARY_SHAPE_INFER(tosa::FloorOp)
826 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
827 NARY_SHAPE_INFER(tosa::GreaterOp)
828 NARY_SHAPE_INFER(tosa::LogOp)
829 NARY_SHAPE_INFER(tosa::LogicalAndOp)
830 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
831 NARY_SHAPE_INFER(tosa::LogicalNotOp)
832 NARY_SHAPE_INFER(tosa::LogicalOrOp)
833 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
834 NARY_SHAPE_INFER(tosa::LogicalXorOp)
835 NARY_SHAPE_INFER(tosa::MaximumOp)
836 NARY_SHAPE_INFER(tosa::MinimumOp)
837 NARY_SHAPE_INFER(tosa::MulOp)
838 NARY_SHAPE_INFER(tosa::NegateOp)
839 NARY_SHAPE_INFER(tosa::PowOp)
840 NARY_SHAPE_INFER(tosa::ReciprocalOp)
841 NARY_SHAPE_INFER(tosa::ReluNOp)
842 NARY_SHAPE_INFER(tosa::ReverseOp)
843 NARY_SHAPE_INFER(tosa::RsqrtOp)
844 NARY_SHAPE_INFER(tosa::SelectOp)
845 NARY_SHAPE_INFER(tosa::SubOp)
846 NARY_SHAPE_INFER(tosa::TanhOp)
847 NARY_SHAPE_INFER(tosa::SigmoidOp)
848 #undef PRED_SHAPE_INFER
849 
850 static LogicalResult poolingInferReturnTypes(
851     ValueRange operands, DictionaryAttr attributes,
852     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
853   RankedTensorType inputTy = operands[0].getType().dyn_cast<RankedTensorType>();
854   llvm::SmallVector<int64_t> outputShape;
855   outputShape.resize(4, -1);
856 
857   // We only know the rank if the input type is unranked.
858   if (!inputTy) {
859     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
860     return success();
861   }
862 
863   // Batch and number of channels are identical for pooling layer.
864   outputShape[0] = inputTy.getDimSize(0);
865   outputShape[3] = inputTy.getDimSize(3);
866 
867   int32_t height = inputTy.getDimSize(1);
868   int32_t width = inputTy.getDimSize(2);
869 
870   llvm::SmallVector<int64_t> kernel;
871   llvm::SmallVector<int64_t> stride;
872   llvm::SmallVector<int64_t> pad;
873 
874   getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
875   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
876   getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
877 
878   if (height != -1) {
879     int32_t padded = height + pad[0] + pad[1] - kernel[0];
880     outputShape[1] = padded / stride[0] + 1;
881   }
882 
883   if (width != -1) {
884     int32_t padded = width + pad[2] + pad[3] - kernel[1];
885     outputShape[2] = padded / stride[1] + 1;
886   }
887 
888   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
889   return success();
890 }
891 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)892 LogicalResult Conv2DOp::inferReturnTypeComponents(
893     MLIRContext *context, ::llvm::Optional<Location> location,
894     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
895     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
896   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
897   Conv2DOp::Adaptor adaptor(operands.getValues());
898 
899   int32_t inputWidth = ShapedType::kDynamicSize;
900   int32_t inputHeight = ShapedType::kDynamicSize;
901   int32_t weightWidth = ShapedType::kDynamicSize;
902   int32_t weightHeight = ShapedType::kDynamicSize;
903 
904   // Input shape describes input width/height and batch.
905   if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
906     outputShape[0] = inputTy.getDimSize(0);
907     inputHeight = inputTy.getDimSize(1);
908     inputWidth = inputTy.getDimSize(2);
909   }
910 
911   // Weight shapes describes the filter width/height and the output channels.
912   if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
913     outputShape[3] = weightTy.getDimSize(0);
914     weightHeight = weightTy.getDimSize(1);
915     weightWidth = weightTy.getDimSize(2);
916   }
917 
918   // Bias shape can describe the output channels.
919   if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
920     outputShape[3] = ShapedType::isDynamic(outputShape[3])
921                          ? biasTy.getDimSize(0)
922                          : outputShape[3];
923   }
924 
925   llvm::SmallVector<int64_t> dilation;
926   llvm::SmallVector<int64_t> padding;
927   llvm::SmallVector<int64_t> stride;
928 
929   getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
930   getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
931   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
932 
933   if (!ShapedType::isDynamic(inputHeight) &&
934       !ShapedType::isDynamic(weightHeight)) {
935     int32_t inputSize = inputHeight + padding[0] + padding[1];
936     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
937     int32_t unstridedResult = inputSize - filterSize + 1;
938     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
939   }
940 
941   if (!ShapedType::isDynamic(inputWidth) &&
942       !ShapedType::isDynamic(weightWidth)) {
943     int32_t inputSize = inputWidth + padding[2] + padding[3];
944     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
945     int32_t unstridedResult = inputSize - filterSize + 1;
946     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
947   }
948 
949   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
950   return success();
951 }
952 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)953 LogicalResult Conv3DOp::inferReturnTypeComponents(
954     MLIRContext *context, ::llvm::Optional<Location> location,
955     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
956     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
957   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
958   Conv2DOp::Adaptor adaptor(operands.getValues());
959 
960   int32_t inputWidth = ShapedType::kDynamicSize;
961   int32_t inputHeight = ShapedType::kDynamicSize;
962   int32_t inputDepth = ShapedType::kDynamicSize;
963 
964   int32_t weightWidth = ShapedType::kDynamicSize;
965   int32_t weightHeight = ShapedType::kDynamicSize;
966   int32_t weightDepth = ShapedType::kDynamicSize;
967 
968   // Input shape describes input width/height and batch.
969   if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
970     outputShape[0] = inputTy.getDimSize(0);
971     inputHeight = inputTy.getDimSize(1);
972     inputWidth = inputTy.getDimSize(2);
973     inputDepth = inputTy.getDimSize(3);
974   }
975 
976   // Weight shapes describes the filter width/height and the output channels.
977   if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
978     outputShape[4] = weightTy.getDimSize(0);
979     weightHeight = weightTy.getDimSize(1);
980     weightWidth = weightTy.getDimSize(2);
981     weightDepth = weightTy.getDimSize(3);
982   }
983 
984   // Bias shape can describe the output channels.
985   if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
986     outputShape[4] =
987         (outputShape[4] == -1) ? biasTy.getDimSize(0) : outputShape[4];
988   }
989 
990   llvm::SmallVector<int64_t> dilation;
991   llvm::SmallVector<int64_t> padding;
992   llvm::SmallVector<int64_t> stride;
993 
994   getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
995   getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
996   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
997 
998   if (!ShapedType::isDynamic(inputHeight) &&
999       !ShapedType::isDynamic(weightHeight)) {
1000     int32_t inputSize = inputHeight + padding[0] + padding[1];
1001     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1002     int32_t unstridedResult = inputSize - filterSize + 1;
1003     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1004   }
1005 
1006   if (!ShapedType::isDynamic(inputWidth) &&
1007       !ShapedType::isDynamic(weightWidth)) {
1008     int32_t inputSize = inputWidth + padding[2] + padding[3];
1009     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1010     int32_t unstridedResult = inputSize - filterSize + 1;
1011     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1012   }
1013 
1014   if (!ShapedType::isDynamic(inputDepth) &&
1015       !ShapedType::isDynamic(weightDepth)) {
1016     int32_t inputSize = inputDepth + padding[4] + padding[5];
1017     int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1018     int32_t unstridedResult = inputSize - filterSize + 1;
1019     outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1020   }
1021 
1022   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1023   return success();
1024 }
1025 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1026 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1027     MLIRContext *context, ::llvm::Optional<Location> location,
1028     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1029     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1030   return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1031 }
1032 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1033 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1034     MLIRContext *context, ::llvm::Optional<Location> location,
1035     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1036     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1037   return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1038 }
1039 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1040 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1041     MLIRContext *context, ::llvm::Optional<Location> location,
1042     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1043     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1044   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1045   DepthwiseConv2DOp::Adaptor adaptor(operands.getValues());
1046 
1047   int32_t inputWidth = ShapedType::kDynamicSize;
1048   int32_t inputHeight = ShapedType::kDynamicSize;
1049   int32_t inputChannels = ShapedType::kDynamicSize;
1050 
1051   int32_t weightWidth = ShapedType::kDynamicSize;
1052   int32_t weightHeight = ShapedType::kDynamicSize;
1053   int32_t depthChannels = ShapedType::kDynamicSize;
1054 
1055   // Input shape describes input width/height and batch.
1056   if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
1057     outputShape[0] = inputTy.getDimSize(0);
1058     inputHeight = inputTy.getDimSize(1);
1059     inputWidth = inputTy.getDimSize(2);
1060     inputChannels = inputTy.getDimSize(3);
1061   }
1062 
1063   // Weight shapes describes the filter width/height and the output channels.
1064   if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
1065     weightHeight = weightTy.getDimSize(0);
1066     weightWidth = weightTy.getDimSize(1);
1067     inputChannels = ShapedType::isDynamic(inputChannels)
1068                         ? weightTy.getDimSize(2)
1069                         : inputChannels;
1070     depthChannels = weightTy.getDimSize(3);
1071   }
1072 
1073   // If both inputChannels and depthChannels are available we can determine
1074   // the output channels.
1075   if (!ShapedType::isDynamic(inputChannels) &&
1076       !ShapedType::isDynamic(depthChannels)) {
1077     outputShape[3] = inputChannels * depthChannels;
1078   }
1079 
1080   // Bias shape can describe the output channels.
1081   if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
1082     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1083                          ? biasTy.getDimSize(0)
1084                          : outputShape[3];
1085   }
1086 
1087   llvm::SmallVector<int64_t> dilation;
1088   llvm::SmallVector<int64_t> padding;
1089   llvm::SmallVector<int64_t> stride;
1090 
1091   getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
1092   getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
1093   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1094 
1095   if (!ShapedType::isDynamic(inputHeight) &&
1096       !ShapedType::isDynamic(weightHeight)) {
1097     int32_t inputSize = inputHeight + padding[0] + padding[1];
1098     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1099     int32_t unstridedResult = inputSize - filterSize + 1;
1100     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1101   }
1102 
1103   if (!ShapedType::isDynamic(inputWidth) &&
1104       !ShapedType::isDynamic(weightWidth)) {
1105     int32_t inputSize = inputWidth + padding[2] + padding[3];
1106     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1107     int32_t unstridedResult = inputSize - filterSize + 1;
1108     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1109   }
1110 
1111   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1112   return success();
1113 }
1114 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1115 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1116     MLIRContext *context, ::llvm::Optional<Location> location,
1117     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1118     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1119   TransposeConv2DOp::Adaptor adaptor(operands.getValues());
1120   llvm::SmallVector<int64_t> outputShape;
1121   getI64Values(attributes.get("out_shape").cast<ArrayAttr>(), outputShape);
1122 
1123   int32_t inputWidth = ShapedType::kDynamicSize;
1124   int32_t inputHeight = ShapedType::kDynamicSize;
1125   int32_t weightWidth = ShapedType::kDynamicSize;
1126   int32_t weightHeight = ShapedType::kDynamicSize;
1127 
1128   // Input shape describes input width/height and batch.
1129   if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
1130     outputShape[0] = ShapedType::isDynamic(outputShape[0])
1131                          ? inputTy.getDimSize(0)
1132                          : outputShape[0];
1133     inputHeight = inputTy.getDimSize(1);
1134     inputWidth = inputTy.getDimSize(2);
1135   }
1136 
1137   // Weight shapes describes the filter width/height and the output channels.
1138   if (auto weightTy = adaptor.filter().getType().dyn_cast<RankedTensorType>()) {
1139     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1140                          ? weightTy.getDimSize(0)
1141                          : outputShape[3];
1142     weightHeight = weightTy.getDimSize(1);
1143     weightWidth = weightTy.getDimSize(2);
1144   }
1145 
1146   // Bias shape can describe the output channels.
1147   if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
1148     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1149                          ? biasTy.getDimSize(0)
1150                          : outputShape[3];
1151   }
1152 
1153   llvm::SmallVector<int64_t> dilation;
1154   llvm::SmallVector<int64_t> padding;
1155   llvm::SmallVector<int64_t> stride;
1156 
1157   getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
1158   getI64Values(attributes.get("out_pad").cast<ArrayAttr>(), padding);
1159   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1160 
1161   if (!ShapedType::isDynamic(inputHeight) &&
1162       !ShapedType::isDynamic(weightHeight)) {
1163     int32_t dilated = (weightHeight - 1) * dilation[0] + 1;
1164     int32_t calculateSize =
1165         (inputHeight - 1) * stride[0] - padding[0] + dilated;
1166     outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1167   }
1168 
1169   if (!ShapedType::isDynamic(inputWidth) &&
1170       !ShapedType::isDynamic(weightWidth)) {
1171     int32_t dilated = (weightWidth - 1) * dilation[1] + 1;
1172     int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated;
1173     outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1174   }
1175 
1176   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1177   return success();
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // TOSA Operator Definitions.
1182 //===----------------------------------------------------------------------===//
1183 
1184 #define GET_OP_CLASSES
1185 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1186