1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/Transforms/FoldUtils.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "mlir/Transforms/RegionUtils.h"
23
24 using namespace mlir;
25 using namespace mlir::tosa;
26
27 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
28
29 //===----------------------------------------------------------------------===//
30 // Tosa dialect structs and interface includes.
31 //===----------------------------------------------------------------------===//
32 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
33 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
34
35 namespace {
36 //===----------------------------------------------------------------------===//
37 // Dialect Function Inliner Interface.
38 //===----------------------------------------------------------------------===//
39 struct TosaInlinerInterface : public DialectInlinerInterface {
40 using DialectInlinerInterface::DialectInlinerInterface;
41
42 //===--------------------------------------------------------------------===//
43 // Analysis Hooks.
44 //===--------------------------------------------------------------------===//
45
46 /// All operations can be inlined by default.
isLegalToInline__anona3b7ade10111::TosaInlinerInterface47 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
48 BlockAndValueMapping &map) const final {
49 return true;
50 }
51
52 /// All regions with If and While parent operators can be inlined.
isLegalToInline__anona3b7ade10111::TosaInlinerInterface53 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
54 BlockAndValueMapping &map) const final {
55 return (isa<tosa::IfOp>(dest->getParentOp()) ||
56 isa<tosa::WhileOp>(dest->getParentOp()));
57 }
58 };
59 } // end anonymous namespace
60
61 //===----------------------------------------------------------------------===//
62 // TOSA control flow support.
63 //===----------------------------------------------------------------------===//
64
65 /// Returns the while loop body.
getLoopBody()66 Region &tosa::WhileOp::getLoopBody() { return body(); }
67
isDefinedOutsideOfLoop(Value value)68 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
69 return !body().isAncestor(value.getParentRegion());
70 }
71
moveOutOfLoop(ArrayRef<mlir::Operation * > ops)72 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
73 if (ops.empty())
74 return success();
75
76 Operation *tosaWhileOp = this->getOperation();
77 for (auto *op : ops)
78 op->moveBefore(tosaWhileOp);
79
80 return success();
81 }
82
83 //===----------------------------------------------------------------------===//
84 // Tosa dialect initialization.
85 //===----------------------------------------------------------------------===//
86
initialize()87 void TosaDialect::initialize() {
88 addOperations<
89 #define GET_OP_LIST
90 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
91 >();
92 addInterfaces<TosaInlinerInterface>();
93 }
94
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)95 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
96 Type type, Location loc) {
97 // Tosa dialect constants only support ElementsAttr unlike standard dialect
98 // constant which supports all attributes.
99 if (value.isa<ElementsAttr>())
100 return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
101 return nullptr;
102 }
103
104 //===----------------------------------------------------------------------===//
105 // Operator Folders.
106 //===----------------------------------------------------------------------===//
107
fold(ArrayRef<Attribute> operands)108 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
109 assert(operands.empty() && "constant has no operands");
110 return valueAttr();
111 }
112
113 //===----------------------------------------------------------------------===//
114 // TOSA Operator Verifiers.
115 //===----------------------------------------------------------------------===//
116
117 template <typename T>
verifyConvOp(T op)118 static LogicalResult verifyConvOp(T op) {
119 // All TOSA conv ops have an input() and weight().
120 auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
121 auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
122
123 // Must be ranked tensor types
124 if (!inputType || !weightType)
125 return failure();
126
127 auto inputEType = inputType.getElementType();
128 auto weightEType = weightType.getElementType();
129
130 bool inputIsQuant = !inputEType.template isa<FloatType>();
131 bool weightIsQuant = !weightEType.template isa<FloatType>();
132
133 // Either both must be quantized or both unquantized.
134 if (inputIsQuant != weightIsQuant)
135 return failure();
136
137 // Quantized type must have constructed the quantizationattr, and unquantized
138 // types should not have a quantizationattr.
139 if ((inputIsQuant && !op.quantization_info()) ||
140 (!inputIsQuant && op.quantization_info()))
141 return failure();
142
143 return success();
144 }
145
146 //===----------------------------------------------------------------------===//
147 // TOSA Operator Quantization Builders.
148 //===----------------------------------------------------------------------===//
149
150 /// This builder is called on all convolution operators except TransposeConv,
151 /// which has specialized output shape semantics. The builder also defines the
152 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)153 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
154 Type outputType, Value input, Value weight,
155 Value bias, ArrayAttr pad,
156 ArrayAttr stride, ArrayAttr dilation) {
157
158 result.addOperands({input, weight, bias});
159 result.addAttribute("pad", pad);
160 result.addAttribute("stride", stride);
161 result.addAttribute("dilation", dilation);
162
163 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
164 if (quantAttr) {
165 result.addAttribute("quantization_info", quantAttr);
166 result.addTypes(
167 buildConvOpResultTypeInfo(builder, outputType, input, weight));
168 } else {
169 result.addTypes(outputType);
170 }
171 }
172
173 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
174 static void
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr dilation,ArrayAttr outputShape)175 buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
176 Type outputType, Value input, Value weight,
177 Value bias, ArrayAttr outpad, ArrayAttr stride,
178 ArrayAttr dilation, ArrayAttr outputShape) {
179 result.addOperands({input, weight, bias});
180 result.addAttribute("out_pad", outpad);
181 result.addAttribute("stride", stride);
182 result.addAttribute("dilation", dilation);
183 result.addAttribute("out_shape", outputShape);
184 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
185
186 if (quantAttr) {
187 result.addAttribute("quantization_info", quantAttr);
188 result.addTypes(
189 buildConvOpResultTypeInfo(builder, outputType, input, weight));
190 } else {
191 result.addTypes(outputType);
192 }
193 }
194
195 /// The tosa.fully_connected op has its own builder as it does not have
196 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)197 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
198 Type outputType, Value input, Value weight,
199 Value bias) {
200
201 result.addOperands({input, weight, bias});
202 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
203 if (quantAttr) {
204 result.addAttribute("quantization_info", quantAttr);
205 result.addTypes(
206 buildConvOpResultTypeInfo(builder, outputType, input, weight));
207 } else {
208 result.addTypes(outputType);
209 }
210 }
211
212 /// The tosa.matmul op is also intended to be generated where a fully_connected
213 /// op must be constructed where the weight is not a constant. In this case,
214 /// the fully_connected op must be expressed using matmul.
215 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)216 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
217 OperationState &result, Type outputType,
218 Value a, Value b) {
219 result.addOperands({a, b});
220 auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
221
222 if (quantAttr) {
223 result.addAttribute("quantization_info", quantAttr);
224
225 auto inputType = a.getType().dyn_cast<RankedTensorType>();
226 assert(inputType && "Input must be a ranked tensor type!");
227
228 auto inputQType = inputType.getElementType()
229 .dyn_cast<mlir::quant::UniformQuantizedType>();
230 assert(inputQType && "Tensor must have quantized datatype!");
231
232 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
233
234 auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
235 assert(outputShapedType && "Output must be a ranked tensor type");
236
237 auto outputShape = outputShapedType.getShape();
238
239 IntegerType accElementType;
240 if (inputBits == 16)
241 accElementType = builder.getIntegerType(48);
242 else
243 accElementType = builder.getI32Type();
244 auto accType = RankedTensorType::get(outputShape, accElementType);
245 result.addTypes(accType);
246 } else {
247 result.addTypes(outputType);
248 }
249 }
250
251 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
252 /// but avg_pool operator has its own builder as it has additional parameters
253 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)254 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
255 OperationState &result,
256 Type outputType, Value input,
257 ArrayAttr kernel, ArrayAttr stride,
258 ArrayAttr pad) {
259 result.addOperands(input);
260 result.addAttribute("kernel", kernel);
261 result.addAttribute("stride", stride);
262 result.addAttribute("pad", pad);
263 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
264 if (quantAttr)
265 result.addAttribute("quantization_info", quantAttr);
266 result.types.push_back(outputType);
267 }
268
269 /// This builder is called on single-parameter unary operators that have scale
270 /// relationship between their input and output, expressed by the
271 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)272 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
273 OperationState &result, Type outputType,
274 Value input) {
275 result.addOperands(input);
276 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
277 if (quantAttr)
278 result.addAttribute("quantization_info", quantAttr);
279 result.types.push_back(outputType);
280 }
281
282 /// This builder is called on TOSA pad operator that needs to create its own
283 /// OptionalAttr quantization_attr parameter to scale the padding values
284 /// correctly.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)285 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
286 Type outputType, Value input,
287 Value paddings) {
288 result.addOperands({input, paddings});
289 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
290 if (quantAttr)
291 result.addAttribute("quantization_info", quantAttr);
292 result.types.push_back(outputType);
293 }
294
295 //===----------------------------------------------------------------------===//
296 // TOSA Operator Return Type Inference.
297 //===----------------------------------------------------------------------===//
298
getI64Values(ArrayAttr arrayAttr,SmallVector<int64_t> & values)299 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
300 for (auto it : arrayAttr) {
301 values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
302 }
303 }
304
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)305 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
306 MLIRContext *context, ::llvm::Optional<Location> location,
307 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
308 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
309 ShapedType inputTy = operands[0].getType().cast<ShapedType>();
310 IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
311 int32_t axisVal = axis.getValue().getSExtValue();
312
313 if (!inputTy.hasRank()) {
314 inferredReturnShapes.push_back(ShapedTypeComponents());
315 return success();
316 }
317
318 SmallVector<int64_t> outShape;
319 outShape.reserve(inputTy.getRank() - 1);
320 for (int i = 0, s = inputTy.getRank(); i < s; i++) {
321 if (i == axisVal)
322 continue;
323 outShape.push_back(inputTy.getDimSize(i));
324 }
325
326 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
327 return success();
328 }
329
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)330 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
331 MLIRContext *context, ::llvm::Optional<Location> location,
332 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
333 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
334 // Infer all dimension sizes by reducing based on inputs.
335 int32_t axis =
336 attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
337 llvm::SmallVector<int64_t> outputShape;
338 bool hasRankedInput = false;
339 for (auto operand : operands) {
340 ShapedType operandTy = operand.getType().cast<ShapedType>();
341 if (!operandTy.hasRank())
342 continue;
343
344 // Copy the Operand's rank.
345 if (!hasRankedInput)
346 outputShape.resize(operandTy.getRank(), -1);
347
348 // Copy shapes until the dim is non-dynamic.
349 for (int i = 0, s = operandTy.getRank(); i < s; i++) {
350 if (i == axis || operandTy.isDynamicDim(i))
351 continue;
352 if (outputShape[i] == -1)
353 outputShape[i] = operandTy.getDimSize(i);
354 if (outputShape[i] != operandTy.getDimSize(i))
355 return failure();
356 }
357
358 hasRankedInput = true;
359 }
360
361 if (!hasRankedInput) {
362 inferredReturnShapes.push_back(ShapedTypeComponents());
363 return success();
364 }
365
366 // Determine the dimension size along the concatenation axis.
367 int concatDimSize = 0;
368 for (auto operand : operands) {
369 ShapedType operandTy = operand.getType().cast<ShapedType>();
370
371 // We need to know the length of the concatenation axis of all inputs to
372 // determine the dimension size of the output shape.
373 if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) {
374 concatDimSize = -1;
375 break;
376 }
377
378 concatDimSize += operandTy.getDimSize(axis);
379 }
380
381 outputShape[axis] = concatDimSize;
382
383 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
384 return success();
385 }
386
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)387 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
388 MLIRContext *context, ::llvm::Optional<Location> location,
389 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
390 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
391 ShapedType inputTy = operands[0].getType().cast<ShapedType>();
392 ShapedType weightTy = operands[1].getType().cast<ShapedType>();
393 ShapedType biasTy = operands[2].getType().cast<ShapedType>();
394
395 // All shapes are dynamic.
396 SmallVector<int64_t> outShape;
397 outShape.resize(2, -1);
398
399 if (inputTy.hasRank()) {
400 outShape[0] = inputTy.getDimSize(0);
401 }
402
403 if (weightTy.hasRank()) {
404 outShape[1] = weightTy.getDimSize(0);
405 }
406
407 if (biasTy.hasRank()) {
408 outShape[1] = outShape[1] == -1 ? biasTy.getDimSize(0) : outShape[1];
409 }
410
411 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
412 return success();
413 }
414
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)415 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
416 MLIRContext *context, ::llvm::Optional<Location> location,
417 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
418 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
419 ShapedType lhsTy = operands[0].getType().cast<ShapedType>();
420 ShapedType rhsTy = operands[1].getType().cast<ShapedType>();
421
422 // All shapes are dynamic.
423 SmallVector<int64_t> outShape;
424 outShape.resize(3, -1);
425
426 if (lhsTy.hasRank()) {
427 outShape[0] = lhsTy.getDimSize(0);
428 outShape[1] = lhsTy.getDimSize(1);
429 }
430
431 if (rhsTy.hasRank()) {
432 outShape[0] = outShape[0] == -1 ? rhsTy.getDimSize(0) : outShape[0];
433 outShape[2] = rhsTy.getDimSize(2);
434 }
435
436 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
437 return success();
438 }
439
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)440 LogicalResult tosa::PadOp::inferReturnTypeComponents(
441 MLIRContext *context, ::llvm::Optional<Location> location,
442 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
443 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
444 ShapedType inputTy = operands[0].getType().cast<ShapedType>();
445 ShapedType paddingTy = operands[1].getType().cast<ShapedType>();
446 SmallVector<int64_t> outputShape;
447
448 // If both inputs have unknown shape, we cannot determine the shape of the
449 // output.
450 if (!inputTy.hasRank() && !paddingTy.hasRank()) {
451 inferredReturnShapes.push_back(ShapedTypeComponents());
452 return success();
453 }
454
455 // If the input rank is unknown we can info the output rank using the padding
456 // shape's first dim.
457 if (!inputTy.hasRank()) {
458 if (paddingTy.isDynamicDim(0)) {
459 inferredReturnShapes.push_back(ShapedTypeComponents());
460 return success();
461 }
462
463 outputShape.resize(paddingTy.getDimSize(0), -1);
464 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
465 return success();
466 }
467
468 DenseIntElementsAttr paddings;
469 // If the paddings value is not a constant, all dimensions must be dynamic.
470 if (!matchPattern(operands[1], m_Constant(&paddings))) {
471 outputShape.resize(inputTy.getRank(), -1);
472 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
473 return success();
474 }
475
476 SmallVector<int64_t> paddingValues;
477 for (auto val : paddings) {
478 paddingValues.push_back(val.getSExtValue());
479 }
480
481 outputShape.reserve(inputTy.getRank());
482 for (int i = 0, s = inputTy.getRank(); i < s; i++) {
483 if (inputTy.isDynamicDim(i)) {
484 outputShape.push_back(-1);
485 continue;
486 }
487
488 outputShape.push_back(inputTy.getDimSize(i) + paddingValues[i * 2] +
489 paddingValues[i * 2 + 1]);
490 }
491
492 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
493 return success();
494 }
495
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)496 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
497 MLIRContext *context, ::llvm::Optional<Location> location,
498 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
499 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
500 auto sizes = attributes.get("size").cast<ArrayAttr>().getValue();
501 SmallVector<int64_t> outputShape;
502 outputShape.reserve(sizes.size());
503 for (auto val : sizes) {
504 outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
505 }
506
507 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
508 return success();
509 }
510
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)511 LogicalResult tosa::TableOp::inferReturnTypeComponents(
512 MLIRContext *context, ::llvm::Optional<Location> location,
513 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
514 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
515 ShapedType inputTy = operands[0].getType().cast<ShapedType>();
516
517 if (!inputTy.hasRank()) {
518 inferredReturnShapes.push_back(ShapedTypeComponents());
519 return success();
520 }
521
522 inferredReturnShapes.push_back(inputTy.getShape());
523 return success();
524 }
525
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)526 LogicalResult tosa::TileOp::inferReturnTypeComponents(
527 MLIRContext *context, ::llvm::Optional<Location> location,
528 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
529 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
530 auto multiples = attributes.get("multiples").cast<ArrayAttr>().getValue();
531 ShapedType inputTy = operands[0].getType().cast<ShapedType>();
532 SmallVector<int64_t> outputShape;
533 if (!inputTy.hasRank()) {
534 outputShape.resize(multiples.size(), -1);
535 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
536 return success();
537 }
538
539 // We need the multiple values to determine the output shape.
540 SmallVector<int64_t> multipleValues;
541 multipleValues.reserve(multiples.size());
542 for (auto val : multiples) {
543 multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
544 }
545
546 // Any non dynamic dimension can be multiplied to a known size.
547 outputShape.reserve(multiples.size());
548 for (int i = 0, s = inputTy.getRank(); i < s; i++) {
549 int dim = inputTy.getDimSize(i);
550 if (dim != -1)
551 dim *= multipleValues[i];
552 outputShape.push_back(dim);
553 }
554
555 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
556 return success();
557 }
558
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)559 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
560 MLIRContext *context, ::llvm::Optional<Location> location,
561 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
562 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
563 ShapedType type = operands.front().getType().cast<ShapedType>();
564
565 auto newShape = attributes.get("new_shape").cast<ArrayAttr>();
566 llvm::SmallVector<int64_t> newShapeValue;
567 getI64Values(newShape, newShapeValue);
568
569 // We cannot infer from the total number of elements so we must take the
570 // shape attribute as exact.
571 if (!type.hasRank() || !type.hasStaticShape()) {
572 inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
573 return success();
574 }
575
576 // Determine the number of elements covered by the slice of all static
577 // dimensions. This allows us to infer the length of the remaining dynamic
578 // dimension.
579 int64_t numElements = type.getNumElements();
580 int64_t staticMul = 1;
581 for (auto val : newShapeValue) {
582 if (val != -1) {
583 staticMul *= val;
584 }
585 }
586
587 // Determine the length of the dynamic dimension.
588 for (auto &val : newShapeValue) {
589 if (val == -1)
590 val = numElements / staticMul;
591 }
592
593 inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
594 return success();
595 }
596
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)597 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
598 MLIRContext *context, ::llvm::Optional<Location> location,
599 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
600 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
601 ShapedType inputTy = operands[0].getType().cast<ShapedType>();
602 ShapedType permsTy = operands[1].getType().cast<ShapedType>();
603
604 // If input rank and permutation length is unknown, the output rank is
605 // unknown.
606 if (!inputTy.hasRank() && (!permsTy.hasRank() || permsTy.isDynamicDim(0))) {
607 inferredReturnShapes.push_back(ShapedTypeComponents());
608 return success();
609 }
610
611 // Without the input dims we cannot determine the output dim sizes but we
612 // can determine the output rank.
613 SmallVector<int64_t> outputShape;
614 if (!inputTy.hasRank()) {
615 outputShape.resize(permsTy.getDimSize(0), -1);
616 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
617 return success();
618 }
619
620 // Rank-0 means no permutations matter.
621 if (inputTy.getRank() == 0) {
622 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
623 return success();
624 }
625
626 // Check whether the input dimensions are all the same.
627 bool allTheSame = true;
628 for (int i = 1, s = inputTy.getRank(); i < s; i++) {
629 if (inputTy.getDimSize(0) != inputTy.getDimSize(i)) {
630 allTheSame = false;
631 break;
632 }
633 }
634
635 // If all of the input dimensions are the same we don't care about the
636 // permutation.
637 if (allTheSame) {
638 outputShape.resize(inputTy.getRank(), inputTy.getDimSize(0));
639 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
640 return success();
641 }
642
643 DenseIntElementsAttr perms;
644 outputShape.resize(inputTy.getRank(), -1);
645 // If the permuations are a constant we can directly determine the output
646 // shape.
647 if (matchPattern(operands[1], m_Constant(&perms))) {
648 llvm::SmallVector<int64_t> permValues;
649 for (auto val : perms) {
650 permValues.push_back(val.getSExtValue());
651 }
652
653 outputShape.reserve(inputTy.getRank());
654 for (int i = 0, s = inputTy.getRank(); i < s; i++) {
655 outputShape[i] = inputTy.getDimSize(permValues[i]);
656 }
657 }
658
659 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
660 return success();
661 }
662
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)663 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
664 MLIRContext *context, ::llvm::Optional<Location> location,
665 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
666 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
667 llvm::SmallVector<int64_t> outputShape;
668 outputShape.resize(3, -1);
669
670 if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
671 outputShape[0] = ty.getDimSize(0);
672 outputShape[2] = ty.getDimSize(2);
673 }
674
675 if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
676 if (outputShape[0] == -1)
677 outputShape[0] = ty.getDimSize(0);
678 if (outputShape[1] == -1)
679 outputShape[1] = ty.getDimSize(1);
680 }
681
682 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
683 return success();
684 }
685
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)686 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
687 MLIRContext *context, ::llvm::Optional<Location> location,
688 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
689 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
690 llvm::SmallVector<int64_t> outputShape;
691 outputShape.resize(3, -1);
692
693 if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
694 outputShape[0] = ty.getDimSize(0);
695 outputShape[1] = ty.getDimSize(1);
696 outputShape[2] = ty.getDimSize(2);
697 }
698
699 if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
700 if (outputShape[0] == -1)
701 outputShape[0] = ty.getDimSize(0);
702 }
703
704 if (auto ty = operands[2].getType().dyn_cast<RankedTensorType>()) {
705 if (outputShape[0] == -1)
706 outputShape[0] = ty.getDimSize(0);
707 if (outputShape[2] == -1)
708 outputShape[2] = ty.getDimSize(2);
709 }
710
711 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
712 return success();
713 }
714
ReduceInferReturnTypes(Value operand,IntegerAttr axis,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)715 static LogicalResult ReduceInferReturnTypes(
716 Value operand, IntegerAttr axis,
717 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
718 auto operandTy = operand.getType().cast<ShapedType>();
719 if (!operandTy.hasRank()) {
720 inferredReturnShapes.push_back(ShapedTypeComponents());
721 return success();
722 }
723
724 int64_t axisVal = axis.getValue().getSExtValue();
725 SmallVector<int64_t> outputShape;
726 outputShape.reserve(operandTy.getRank());
727 for (auto dim : operandTy.getShape()) {
728 outputShape.push_back(dim);
729 }
730
731 outputShape[axisVal] = 1;
732 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
733 return success();
734 }
735
736 #define REDUCE_SHAPE_INFER(OP) \
737 LogicalResult OP::inferReturnTypeComponents( \
738 MLIRContext *context, ::llvm::Optional<Location> location, \
739 ValueShapeRange operands, DictionaryAttr attributes, \
740 RegionRange regions, \
741 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
742 return ReduceInferReturnTypes(operands[0], \
743 attributes.get("axis").cast<IntegerAttr>(), \
744 inferredReturnShapes); \
745 }
746
747 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)748 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
749 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
750 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
751 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
752 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
753 #undef REDUCE_SHAPE_INFER
754
755 static LogicalResult resolveBroadcastShape(ValueRange operands,
756 SmallVector<int64_t> &outShape) {
757 int64_t outRank = 0;
758 for (auto operand : operands) {
759 auto type = operand.getType().cast<ShapedType>();
760 if (!type.hasRank())
761 return failure();
762 outRank = std::max<int64_t>(outRank, type.getRank());
763 }
764
765 outShape.resize(outRank, 1);
766
767 for (auto operand : operands) {
768 auto type = operand.getType().cast<ShapedType>();
769 auto shape = type.getShape();
770 auto rankDiff = outShape.size() - shape.size();
771
772 for (size_t i = 0; i < shape.size(); i++) {
773 auto dim1 = outShape[i + rankDiff];
774 auto dim2 = shape[i];
775 auto resolvedDim = dim1;
776
777 if (dim1 == 1) {
778 resolvedDim = dim2;
779 } else if (dim2 == 1) {
780 resolvedDim = dim1;
781 } else if (dim1 != dim2) {
782 return failure();
783 }
784 outShape[i + rankDiff] = resolvedDim;
785 }
786 }
787
788 return success();
789 }
790
NAryInferReturnTypes(ValueRange operands,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)791 static LogicalResult NAryInferReturnTypes(
792 ValueRange operands,
793 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
794 llvm::SmallVector<int64_t> outShape;
795 if (resolveBroadcastShape(operands, outShape).failed()) {
796 inferredReturnShapes.push_back(ShapedTypeComponents());
797 } else {
798 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
799 }
800 return success();
801 }
802
803 #define NARY_SHAPE_INFER(OP) \
804 LogicalResult OP::inferReturnTypeComponents( \
805 MLIRContext *context, ::llvm::Optional<Location> location, \
806 ValueShapeRange operands, DictionaryAttr attributes, \
807 RegionRange regions, \
808 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
809 return NAryInferReturnTypes(operands, inferredReturnShapes); \
810 }
811
812 NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)813 NARY_SHAPE_INFER(tosa::AddOp)
814 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
815 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
816 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
817 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
818 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
819 NARY_SHAPE_INFER(tosa::CeilOp)
820 NARY_SHAPE_INFER(tosa::ClampOp)
821 NARY_SHAPE_INFER(tosa::ClzOp)
822 NARY_SHAPE_INFER(tosa::DivOp)
823 NARY_SHAPE_INFER(tosa::EqualOp)
824 NARY_SHAPE_INFER(tosa::ExpOp)
825 NARY_SHAPE_INFER(tosa::FloorOp)
826 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
827 NARY_SHAPE_INFER(tosa::GreaterOp)
828 NARY_SHAPE_INFER(tosa::LogOp)
829 NARY_SHAPE_INFER(tosa::LogicalAndOp)
830 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
831 NARY_SHAPE_INFER(tosa::LogicalNotOp)
832 NARY_SHAPE_INFER(tosa::LogicalOrOp)
833 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
834 NARY_SHAPE_INFER(tosa::LogicalXorOp)
835 NARY_SHAPE_INFER(tosa::MaximumOp)
836 NARY_SHAPE_INFER(tosa::MinimumOp)
837 NARY_SHAPE_INFER(tosa::MulOp)
838 NARY_SHAPE_INFER(tosa::NegateOp)
839 NARY_SHAPE_INFER(tosa::PowOp)
840 NARY_SHAPE_INFER(tosa::ReciprocalOp)
841 NARY_SHAPE_INFER(tosa::ReluNOp)
842 NARY_SHAPE_INFER(tosa::ReverseOp)
843 NARY_SHAPE_INFER(tosa::RsqrtOp)
844 NARY_SHAPE_INFER(tosa::SelectOp)
845 NARY_SHAPE_INFER(tosa::SubOp)
846 NARY_SHAPE_INFER(tosa::TanhOp)
847 NARY_SHAPE_INFER(tosa::SigmoidOp)
848 #undef PRED_SHAPE_INFER
849
850 static LogicalResult poolingInferReturnTypes(
851 ValueRange operands, DictionaryAttr attributes,
852 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
853 RankedTensorType inputTy = operands[0].getType().dyn_cast<RankedTensorType>();
854 llvm::SmallVector<int64_t> outputShape;
855 outputShape.resize(4, -1);
856
857 // We only know the rank if the input type is unranked.
858 if (!inputTy) {
859 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
860 return success();
861 }
862
863 // Batch and number of channels are identical for pooling layer.
864 outputShape[0] = inputTy.getDimSize(0);
865 outputShape[3] = inputTy.getDimSize(3);
866
867 int32_t height = inputTy.getDimSize(1);
868 int32_t width = inputTy.getDimSize(2);
869
870 llvm::SmallVector<int64_t> kernel;
871 llvm::SmallVector<int64_t> stride;
872 llvm::SmallVector<int64_t> pad;
873
874 getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
875 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
876 getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
877
878 if (height != -1) {
879 int32_t padded = height + pad[0] + pad[1] - kernel[0];
880 outputShape[1] = padded / stride[0] + 1;
881 }
882
883 if (width != -1) {
884 int32_t padded = width + pad[2] + pad[3] - kernel[1];
885 outputShape[2] = padded / stride[1] + 1;
886 }
887
888 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
889 return success();
890 }
891
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)892 LogicalResult Conv2DOp::inferReturnTypeComponents(
893 MLIRContext *context, ::llvm::Optional<Location> location,
894 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
895 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
896 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
897 Conv2DOp::Adaptor adaptor(operands.getValues());
898
899 int32_t inputWidth = ShapedType::kDynamicSize;
900 int32_t inputHeight = ShapedType::kDynamicSize;
901 int32_t weightWidth = ShapedType::kDynamicSize;
902 int32_t weightHeight = ShapedType::kDynamicSize;
903
904 // Input shape describes input width/height and batch.
905 if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
906 outputShape[0] = inputTy.getDimSize(0);
907 inputHeight = inputTy.getDimSize(1);
908 inputWidth = inputTy.getDimSize(2);
909 }
910
911 // Weight shapes describes the filter width/height and the output channels.
912 if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
913 outputShape[3] = weightTy.getDimSize(0);
914 weightHeight = weightTy.getDimSize(1);
915 weightWidth = weightTy.getDimSize(2);
916 }
917
918 // Bias shape can describe the output channels.
919 if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
920 outputShape[3] = ShapedType::isDynamic(outputShape[3])
921 ? biasTy.getDimSize(0)
922 : outputShape[3];
923 }
924
925 llvm::SmallVector<int64_t> dilation;
926 llvm::SmallVector<int64_t> padding;
927 llvm::SmallVector<int64_t> stride;
928
929 getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
930 getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
931 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
932
933 if (!ShapedType::isDynamic(inputHeight) &&
934 !ShapedType::isDynamic(weightHeight)) {
935 int32_t inputSize = inputHeight + padding[0] + padding[1];
936 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
937 int32_t unstridedResult = inputSize - filterSize + 1;
938 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
939 }
940
941 if (!ShapedType::isDynamic(inputWidth) &&
942 !ShapedType::isDynamic(weightWidth)) {
943 int32_t inputSize = inputWidth + padding[2] + padding[3];
944 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
945 int32_t unstridedResult = inputSize - filterSize + 1;
946 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
947 }
948
949 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
950 return success();
951 }
952
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)953 LogicalResult Conv3DOp::inferReturnTypeComponents(
954 MLIRContext *context, ::llvm::Optional<Location> location,
955 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
956 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
957 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
958 Conv2DOp::Adaptor adaptor(operands.getValues());
959
960 int32_t inputWidth = ShapedType::kDynamicSize;
961 int32_t inputHeight = ShapedType::kDynamicSize;
962 int32_t inputDepth = ShapedType::kDynamicSize;
963
964 int32_t weightWidth = ShapedType::kDynamicSize;
965 int32_t weightHeight = ShapedType::kDynamicSize;
966 int32_t weightDepth = ShapedType::kDynamicSize;
967
968 // Input shape describes input width/height and batch.
969 if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
970 outputShape[0] = inputTy.getDimSize(0);
971 inputHeight = inputTy.getDimSize(1);
972 inputWidth = inputTy.getDimSize(2);
973 inputDepth = inputTy.getDimSize(3);
974 }
975
976 // Weight shapes describes the filter width/height and the output channels.
977 if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
978 outputShape[4] = weightTy.getDimSize(0);
979 weightHeight = weightTy.getDimSize(1);
980 weightWidth = weightTy.getDimSize(2);
981 weightDepth = weightTy.getDimSize(3);
982 }
983
984 // Bias shape can describe the output channels.
985 if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
986 outputShape[4] =
987 (outputShape[4] == -1) ? biasTy.getDimSize(0) : outputShape[4];
988 }
989
990 llvm::SmallVector<int64_t> dilation;
991 llvm::SmallVector<int64_t> padding;
992 llvm::SmallVector<int64_t> stride;
993
994 getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
995 getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
996 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
997
998 if (!ShapedType::isDynamic(inputHeight) &&
999 !ShapedType::isDynamic(weightHeight)) {
1000 int32_t inputSize = inputHeight + padding[0] + padding[1];
1001 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1002 int32_t unstridedResult = inputSize - filterSize + 1;
1003 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1004 }
1005
1006 if (!ShapedType::isDynamic(inputWidth) &&
1007 !ShapedType::isDynamic(weightWidth)) {
1008 int32_t inputSize = inputWidth + padding[2] + padding[3];
1009 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1010 int32_t unstridedResult = inputSize - filterSize + 1;
1011 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1012 }
1013
1014 if (!ShapedType::isDynamic(inputDepth) &&
1015 !ShapedType::isDynamic(weightDepth)) {
1016 int32_t inputSize = inputDepth + padding[4] + padding[5];
1017 int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1018 int32_t unstridedResult = inputSize - filterSize + 1;
1019 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1020 }
1021
1022 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1023 return success();
1024 }
1025
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1026 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1027 MLIRContext *context, ::llvm::Optional<Location> location,
1028 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1029 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1030 return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1031 }
1032
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1033 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1034 MLIRContext *context, ::llvm::Optional<Location> location,
1035 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1036 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1037 return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1038 }
1039
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1040 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1041 MLIRContext *context, ::llvm::Optional<Location> location,
1042 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1043 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1044 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1045 DepthwiseConv2DOp::Adaptor adaptor(operands.getValues());
1046
1047 int32_t inputWidth = ShapedType::kDynamicSize;
1048 int32_t inputHeight = ShapedType::kDynamicSize;
1049 int32_t inputChannels = ShapedType::kDynamicSize;
1050
1051 int32_t weightWidth = ShapedType::kDynamicSize;
1052 int32_t weightHeight = ShapedType::kDynamicSize;
1053 int32_t depthChannels = ShapedType::kDynamicSize;
1054
1055 // Input shape describes input width/height and batch.
1056 if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
1057 outputShape[0] = inputTy.getDimSize(0);
1058 inputHeight = inputTy.getDimSize(1);
1059 inputWidth = inputTy.getDimSize(2);
1060 inputChannels = inputTy.getDimSize(3);
1061 }
1062
1063 // Weight shapes describes the filter width/height and the output channels.
1064 if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
1065 weightHeight = weightTy.getDimSize(0);
1066 weightWidth = weightTy.getDimSize(1);
1067 inputChannels = ShapedType::isDynamic(inputChannels)
1068 ? weightTy.getDimSize(2)
1069 : inputChannels;
1070 depthChannels = weightTy.getDimSize(3);
1071 }
1072
1073 // If both inputChannels and depthChannels are available we can determine
1074 // the output channels.
1075 if (!ShapedType::isDynamic(inputChannels) &&
1076 !ShapedType::isDynamic(depthChannels)) {
1077 outputShape[3] = inputChannels * depthChannels;
1078 }
1079
1080 // Bias shape can describe the output channels.
1081 if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
1082 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1083 ? biasTy.getDimSize(0)
1084 : outputShape[3];
1085 }
1086
1087 llvm::SmallVector<int64_t> dilation;
1088 llvm::SmallVector<int64_t> padding;
1089 llvm::SmallVector<int64_t> stride;
1090
1091 getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
1092 getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
1093 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1094
1095 if (!ShapedType::isDynamic(inputHeight) &&
1096 !ShapedType::isDynamic(weightHeight)) {
1097 int32_t inputSize = inputHeight + padding[0] + padding[1];
1098 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1099 int32_t unstridedResult = inputSize - filterSize + 1;
1100 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1101 }
1102
1103 if (!ShapedType::isDynamic(inputWidth) &&
1104 !ShapedType::isDynamic(weightWidth)) {
1105 int32_t inputSize = inputWidth + padding[2] + padding[3];
1106 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1107 int32_t unstridedResult = inputSize - filterSize + 1;
1108 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1109 }
1110
1111 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1112 return success();
1113 }
1114
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1115 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1116 MLIRContext *context, ::llvm::Optional<Location> location,
1117 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1118 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1119 TransposeConv2DOp::Adaptor adaptor(operands.getValues());
1120 llvm::SmallVector<int64_t> outputShape;
1121 getI64Values(attributes.get("out_shape").cast<ArrayAttr>(), outputShape);
1122
1123 int32_t inputWidth = ShapedType::kDynamicSize;
1124 int32_t inputHeight = ShapedType::kDynamicSize;
1125 int32_t weightWidth = ShapedType::kDynamicSize;
1126 int32_t weightHeight = ShapedType::kDynamicSize;
1127
1128 // Input shape describes input width/height and batch.
1129 if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
1130 outputShape[0] = ShapedType::isDynamic(outputShape[0])
1131 ? inputTy.getDimSize(0)
1132 : outputShape[0];
1133 inputHeight = inputTy.getDimSize(1);
1134 inputWidth = inputTy.getDimSize(2);
1135 }
1136
1137 // Weight shapes describes the filter width/height and the output channels.
1138 if (auto weightTy = adaptor.filter().getType().dyn_cast<RankedTensorType>()) {
1139 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1140 ? weightTy.getDimSize(0)
1141 : outputShape[3];
1142 weightHeight = weightTy.getDimSize(1);
1143 weightWidth = weightTy.getDimSize(2);
1144 }
1145
1146 // Bias shape can describe the output channels.
1147 if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
1148 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1149 ? biasTy.getDimSize(0)
1150 : outputShape[3];
1151 }
1152
1153 llvm::SmallVector<int64_t> dilation;
1154 llvm::SmallVector<int64_t> padding;
1155 llvm::SmallVector<int64_t> stride;
1156
1157 getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
1158 getI64Values(attributes.get("out_pad").cast<ArrayAttr>(), padding);
1159 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1160
1161 if (!ShapedType::isDynamic(inputHeight) &&
1162 !ShapedType::isDynamic(weightHeight)) {
1163 int32_t dilated = (weightHeight - 1) * dilation[0] + 1;
1164 int32_t calculateSize =
1165 (inputHeight - 1) * stride[0] - padding[0] + dilated;
1166 outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1167 }
1168
1169 if (!ShapedType::isDynamic(inputWidth) &&
1170 !ShapedType::isDynamic(weightWidth)) {
1171 int32_t dilated = (weightWidth - 1) * dilation[1] + 1;
1172 int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated;
1173 outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1174 }
1175
1176 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1177 return success();
1178 }
1179
1180 //===----------------------------------------------------------------------===//
1181 // TOSA Operator Definitions.
1182 //===----------------------------------------------------------------------===//
1183
1184 #define GET_OP_CLASSES
1185 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1186