1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/DenseMap.h"
27
28 using namespace mlir;
29 using namespace mlir::tosa;
30
31 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
32
33 //===----------------------------------------------------------------------===//
34 // Tosa dialect structs and interface includes.
35 //===----------------------------------------------------------------------===//
36 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
37 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
38
39 namespace {
40 //===----------------------------------------------------------------------===//
41 // Dialect Function Inliner Interface.
42 //===----------------------------------------------------------------------===//
43 struct TosaInlinerInterface : public DialectInlinerInterface {
44 using DialectInlinerInterface::DialectInlinerInterface;
45
46 //===--------------------------------------------------------------------===//
47 // Analysis Hooks.
48 //===--------------------------------------------------------------------===//
49
50 /// All operations can be inlined by default.
isLegalToInline__anon4d6485000111::TosaInlinerInterface51 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
52 BlockAndValueMapping &map) const final {
53 return true;
54 }
55
56 /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon4d6485000111::TosaInlinerInterface57 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
58 BlockAndValueMapping &map) const final {
59 return (isa<tosa::IfOp>(dest->getParentOp()) ||
60 isa<tosa::WhileOp>(dest->getParentOp()));
61 }
62 };
63 } // end anonymous namespace
64
65 //===----------------------------------------------------------------------===//
66 // TOSA control flow support.
67 //===----------------------------------------------------------------------===//
68
69 /// Returns the while loop body.
getLoopBody()70 Region &tosa::WhileOp::getLoopBody() { return body(); }
71
isDefinedOutsideOfLoop(Value value)72 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
73 return !body().isAncestor(value.getParentRegion());
74 }
75
moveOutOfLoop(ArrayRef<mlir::Operation * > ops)76 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
77 if (ops.empty())
78 return success();
79
80 Operation *tosaWhileOp = this->getOperation();
81 for (auto *op : ops)
82 op->moveBefore(tosaWhileOp);
83
84 return success();
85 }
86
87 //===----------------------------------------------------------------------===//
88 // Tosa dialect initialization.
89 //===----------------------------------------------------------------------===//
90
initialize()91 void TosaDialect::initialize() {
92 addOperations<
93 #define GET_OP_LIST
94 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
95 >();
96 addInterfaces<TosaInlinerInterface>();
97 }
98
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)99 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
100 Type type, Location loc) {
101 // Tosa dialect constants only support ElementsAttr unlike standard dialect
102 // constant which supports all attributes.
103 if (value.isa<ElementsAttr>())
104 return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
105 return nullptr;
106 }
107
108 //===----------------------------------------------------------------------===//
109 // Operator Canonicalizers.
110 //===----------------------------------------------------------------------===//
111
112 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
113 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
114
matchAndRewriteConcatOptimization115 LogicalResult matchAndRewrite(tosa::ConcatOp op,
116 PatternRewriter &rewriter) const override {
117 if (op.input1().size() != 1)
118 return failure();
119 if (op.input1().front().getType() != op.getType()) {
120 rewriter
121 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
122 op.input1().front())
123 .getResult();
124 return success();
125 }
126
127 rewriter.replaceOp(op, op.input1().front());
128 return success();
129 }
130 };
131
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)132 void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
133 MLIRContext *context) {
134 results.insert<ConcatOptimization>(context);
135 }
136
137 struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
138 using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
139
matchAndRewriteReshapeReshapeOptimization140 LogicalResult matchAndRewrite(tosa::ReshapeOp op,
141 PatternRewriter &rewriter) const override {
142 Value input = op.input1();
143 Operation *definingOp = input.getDefiningOp();
144 if (!definingOp)
145 return failure();
146
147 if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
148 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
149 op, op.getType(), reshapeOp.input1(), op.new_shape());
150 return success();
151 }
152
153 return failure();
154 }
155 };
156
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)157 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
158 MLIRContext *context) {
159 results.insert<ReshapeReshapeOptimization>(context);
160 }
161
162 struct ConstantTransposeOptimization
163 : public OpRewritePattern<tosa::TransposeOp> {
164 using OpRewritePattern::OpRewritePattern;
165
matchAndRewriteConstantTransposeOptimization166 LogicalResult matchAndRewrite(tosa::TransposeOp op,
167 PatternRewriter &rewriter) const override {
168 auto outputType = op.getType().cast<ShapedType>();
169 ArrayRef<int64_t> outputShape = outputType.getShape();
170 // TOSA supports quantized types.
171 if (!outputType.getElementType().isIntOrIndexOrFloat())
172 return failure();
173
174 DenseElementsAttr inputValues;
175 if (!matchPattern(op.input1(), m_Constant(&inputValues)))
176 return failure();
177 // Make sure the input is a constant that has a single user.
178 if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
179 return failure();
180
181 DenseIntElementsAttr permAttr;
182 if (!matchPattern(op.perms(), m_Constant(&permAttr)))
183 return failure();
184 auto permValues = llvm::to_vector<6>(llvm::map_range(
185 // TOSA allows both 32- and 64-bit integer tensors here.
186 permAttr.getValues<APInt>(),
187 [](const APInt &val) { return val.getZExtValue(); }));
188
189 auto inputType = op.input1().getType().cast<ShapedType>();
190 ArrayRef<int64_t> inputShape = inputType.getShape();
191 int64_t numElements = inputType.getNumElements();
192
193 SmallVector<Attribute, 4> outputValues;
194 outputValues.resize(numElements);
195
196 // Transpose the input constant. Because we don't know its rank in advance,
197 // we need to loop over the range [0, element count) and delinearize the
198 // index.
199 for (int srcLinearIndex = 0; srcLinearIndex < numElements;
200 ++srcLinearIndex) {
201 SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
202 int totalCount = srcLinearIndex;
203 for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
204 srcIndices[dim] = totalCount % inputShape[dim];
205 totalCount /= inputShape[dim];
206 }
207
208 SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
209 for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
210 dstIndices[dim] = srcIndices[permValues[dim]];
211
212 uint64_t dstLinearIndex = dstIndices.front();
213 for (int dim = 1; dim < outputType.getRank(); ++dim)
214 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
215
216 outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
217 }
218
219 rewriter.replaceOpWithNewOp<tosa::ConstOp>(
220 op, outputType, DenseElementsAttr::get(outputType, outputValues));
221 return success();
222 }
223 };
224
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)225 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
226 MLIRContext *context) {
227 results.insert<ConstantTransposeOptimization>(context);
228 }
229
230 //===----------------------------------------------------------------------===//
231 // Operator Folders.
232 //===----------------------------------------------------------------------===//
233
fold(ArrayRef<Attribute> operands)234 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
235 if (input().getType() == getType())
236 return input();
237 return {};
238 }
239
fold(ArrayRef<Attribute> operands)240 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
241 assert(operands.empty() && "constant has no operands");
242 return valueAttr();
243 }
244
245 #define ReduceFolder(OP) \
246 OpFoldResult OP::fold(ArrayRef<Attribute> operands) { \
247 ShapedType inputTy = input().getType().cast<ShapedType>(); \
248 if (!inputTy.hasRank()) \
249 return {}; \
250 if (inputTy.getDimSize(axis()) == 1) \
251 return input(); \
252 return {}; \
253 }
254
ReduceFolder(ReduceAnyOp)255 ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp)
256 ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp)
257 ReduceFolder(ReduceSumOp)
258 #undef ReduceFolder
259
260 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
261 auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
262 auto outputTy = getType().dyn_cast<RankedTensorType>();
263
264 if (!inputTy || !outputTy || inputTy != outputTy)
265 return {};
266 return input1();
267 }
268
fold(ArrayRef<Attribute> operands)269 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
270 auto inputTy = input().getType().dyn_cast<RankedTensorType>();
271 auto outputTy = getType().dyn_cast<RankedTensorType>();
272
273 if (!inputTy || !outputTy || inputTy != outputTy)
274 return {};
275 if (inputTy.hasStaticShape())
276 return input();
277
278 return {};
279 }
280
fold(ArrayRef<Attribute> operands)281 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
282 bool allOnes = true;
283 for (Attribute val : multiples().getValue()) {
284 allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
285 }
286
287 if (allOnes && input1().getType() == getType())
288 return input1();
289 return {};
290 }
291
fold(ArrayRef<Attribute> operands)292 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
293 if (!operands[1])
294 return {};
295
296 // Transposing splat values just means reshaping.
297 if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
298 if (input.isSplat())
299 return input.reshape(getType().cast<ShapedType>());
300 }
301
302 auto perms = llvm::to_vector<6>(llvm::map_range(
303 operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
304 [](const APInt &val) { return val.getSExtValue(); }));
305
306 if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
307 input1().getType() == getType())
308 return input1();
309 return {};
310 }
311
312 //===----------------------------------------------------------------------===//
313 // TOSA Operator Verifiers.
314 //===----------------------------------------------------------------------===//
315
316 template <typename T>
verifyConvOp(T op)317 static LogicalResult verifyConvOp(T op) {
318 // All TOSA conv ops have an input() and weight().
319 auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
320 auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
321
322 // Must be ranked tensor types
323 if (!inputType || !weightType)
324 return failure();
325
326 auto inputEType = inputType.getElementType();
327 auto weightEType = weightType.getElementType();
328
329 bool inputIsQuant = !inputEType.template isa<FloatType>();
330 bool weightIsQuant = !weightEType.template isa<FloatType>();
331
332 // Either both must be quantized or both unquantized.
333 if (inputIsQuant != weightIsQuant)
334 return failure();
335
336 // Quantized type must have constructed the quantizationattr, and unquantized
337 // types should not have a quantizationattr.
338 if ((inputIsQuant && !op.quantization_info()) ||
339 (!inputIsQuant && op.quantization_info()))
340 return failure();
341
342 return success();
343 }
344
345 //===----------------------------------------------------------------------===//
346 // TOSA Operator Quantization Builders.
347 //===----------------------------------------------------------------------===//
348
349 /// This builder is called on all convolution operators except TransposeConv,
350 /// which has specialized output shape semantics. The builder also defines the
351 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)352 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
353 Type outputType, Value input, Value weight,
354 Value bias, ArrayAttr pad,
355 ArrayAttr stride, ArrayAttr dilation) {
356
357 result.addOperands({input, weight, bias});
358 result.addAttribute("pad", pad);
359 result.addAttribute("stride", stride);
360 result.addAttribute("dilation", dilation);
361
362 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
363 if (quantAttr) {
364 result.addAttribute("quantization_info", quantAttr);
365 result.addTypes(
366 buildConvOpResultTypeInfo(builder, outputType, input, weight));
367 } else {
368 result.addTypes(outputType);
369 }
370 }
371
372 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
373 static void
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr dilation,ArrayAttr outputShape)374 buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
375 Type outputType, Value input, Value weight,
376 Value bias, ArrayAttr outpad, ArrayAttr stride,
377 ArrayAttr dilation, ArrayAttr outputShape) {
378 result.addOperands({input, weight, bias});
379 result.addAttribute("out_pad", outpad);
380 result.addAttribute("stride", stride);
381 result.addAttribute("dilation", dilation);
382 result.addAttribute("out_shape", outputShape);
383 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
384
385 if (quantAttr) {
386 result.addAttribute("quantization_info", quantAttr);
387 result.addTypes(
388 buildConvOpResultTypeInfo(builder, outputType, input, weight));
389 } else {
390 result.addTypes(outputType);
391 }
392 }
393
394 /// The tosa.fully_connected op has its own builder as it does not have
395 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)396 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
397 Type outputType, Value input, Value weight,
398 Value bias) {
399
400 result.addOperands({input, weight, bias});
401 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
402 if (quantAttr) {
403 result.addAttribute("quantization_info", quantAttr);
404 result.addTypes(
405 buildConvOpResultTypeInfo(builder, outputType, input, weight));
406 } else {
407 result.addTypes(outputType);
408 }
409 }
410
411 /// The tosa.matmul op is also intended to be generated where a fully_connected
412 /// op must be constructed where the weight is not a constant. In this case,
413 /// the fully_connected op must be expressed using matmul.
414 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)415 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
416 OperationState &result, Type outputType,
417 Value a, Value b) {
418 result.addOperands({a, b});
419 auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
420
421 if (quantAttr) {
422 result.addAttribute("quantization_info", quantAttr);
423
424 auto inputType = a.getType().dyn_cast<ShapedType>();
425 assert(inputType && "Input must be a shaped tensor type!");
426
427 auto inputQType = inputType.getElementType()
428 .dyn_cast<mlir::quant::UniformQuantizedType>();
429 assert(inputQType && "Tensor must have quantized datatype!");
430
431 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
432
433 auto outputShapedType = outputType.dyn_cast<ShapedType>();
434 assert(outputShapedType && "Output must be a shaped type");
435
436 IntegerType accElementType;
437 if (inputBits == 16)
438 accElementType = builder.getIntegerType(48);
439 else
440 accElementType = builder.getI32Type();
441 auto accType = outputShapedType.clone(accElementType);
442 result.addTypes(accType);
443 } else {
444 result.addTypes(outputType);
445 }
446 }
447
448 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
449 /// but avg_pool operator has its own builder as it has additional parameters
450 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)451 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
452 OperationState &result,
453 Type outputType, Value input,
454 ArrayAttr kernel, ArrayAttr stride,
455 ArrayAttr pad) {
456 result.addOperands(input);
457 result.addAttribute("kernel", kernel);
458 result.addAttribute("stride", stride);
459 result.addAttribute("pad", pad);
460 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
461 if (quantAttr)
462 result.addAttribute("quantization_info", quantAttr);
463 result.types.push_back(outputType);
464 }
465
466 /// This builder is called on single-parameter unary operators that have scale
467 /// relationship between their input and output, expressed by the
468 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)469 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
470 OperationState &result, Type outputType,
471 Value input) {
472 result.addOperands(input);
473 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
474 if (quantAttr)
475 result.addAttribute("quantization_info", quantAttr);
476 result.types.push_back(outputType);
477 }
478
479 /// This builder is called on TOSA pad operator that needs to create its own
480 /// OptionalAttr quantization_attr parameter to scale the padding values
481 /// correctly.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)482 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
483 Type outputType, Value input,
484 Value paddings) {
485 result.addOperands({input, paddings});
486 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
487 if (quantAttr)
488 result.addAttribute("quantization_info", quantAttr);
489 result.types.push_back(outputType);
490 }
491
492 //===----------------------------------------------------------------------===//
493 // TOSA Operator Return Type Inference.
494 //===----------------------------------------------------------------------===//
495
getI64Values(ArrayAttr arrayAttr,SmallVector<int64_t> & values)496 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
497 for (auto it : arrayAttr) {
498 values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
499 }
500 }
501
getF64Values(ArrayAttr arrayAttr,SmallVector<double> & values)502 static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
503 for (auto it : arrayAttr) {
504 values.push_back(it.cast<FloatAttr>().getValueAsDouble());
505 }
506 }
507
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)508 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
509 MLIRContext *context, ::llvm::Optional<Location> location,
510 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
511 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
512 ShapeAdaptor inputShape = operands.getShape(0);
513 IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
514 int32_t axisVal = axis.getValue().getSExtValue();
515
516 if (!inputShape.hasRank()) {
517 inferredReturnShapes.push_back(ShapedTypeComponents());
518 return success();
519 }
520
521 SmallVector<int64_t> outShape;
522 outShape.reserve(inputShape.getRank() - 1);
523 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
524 if (i == axisVal)
525 continue;
526 outShape.push_back(inputShape.getDimSize(i));
527 }
528
529 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
530 return success();
531 }
532
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)533 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
534 MLIRContext *context, ::llvm::Optional<Location> location,
535 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
536 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
537 // Infer all dimension sizes by reducing based on inputs.
538 int32_t axis =
539 attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
540 llvm::SmallVector<int64_t> outputShape;
541 bool hasRankedInput = false;
542 for (auto operand : operands) {
543 ShapeAdaptor operandShape = operands.getShape(operand);
544 if (!operandShape.hasRank())
545 continue;
546
547 // Copy the Operand's rank.
548 if (!hasRankedInput)
549 outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
550
551 // Copy shapes until the dim is non-dynamic.
552 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
553 if (i == axis || operandShape.isDynamicDim(i))
554 continue;
555 if (outputShape[i] == ShapedType::kDynamicSize)
556 outputShape[i] = operandShape.getDimSize(i);
557 if (outputShape[i] != operandShape.getDimSize(i))
558 return failure();
559 }
560
561 hasRankedInput = true;
562 }
563
564 if (!hasRankedInput) {
565 inferredReturnShapes.push_back(ShapedTypeComponents());
566 return success();
567 }
568
569 // Determine the dimension size along the concatenation axis.
570 int concatDimSize = 0;
571 for (auto operand : operands) {
572 ShapeAdaptor operandShape = operands.getShape(operand);
573
574 // We need to know the length of the concatenation axis of all inputs to
575 // determine the dimension size of the output shape.
576 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
577 concatDimSize = ShapedType::kDynamicSize;
578 break;
579 }
580
581 concatDimSize += operandShape.getDimSize(axis);
582 }
583
584 outputShape[axis] = concatDimSize;
585
586 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
587 return success();
588 }
589
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)590 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
591 MLIRContext *context, ::llvm::Optional<Location> location,
592 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
593 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
594 ShapeAdaptor inputShape = operands.getShape(0);
595 ShapeAdaptor weightShape = operands.getShape(1);
596 ShapeAdaptor biasShape = operands.getShape(2);
597
598 // All shapes are dynamic.
599 SmallVector<int64_t> outShape;
600 outShape.resize(2, ShapedType::kDynamicSize);
601
602 if (inputShape.hasRank()) {
603 outShape[0] = inputShape.getDimSize(0);
604 }
605
606 if (weightShape.hasRank()) {
607 outShape[1] = weightShape.getDimSize(0);
608 }
609
610 if (biasShape.hasRank()) {
611 outShape[1] = outShape[1] == ShapedType::kDynamicSize
612 ? biasShape.getDimSize(0)
613 : outShape[1];
614 }
615
616 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
617 return success();
618 }
619
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)620 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
621 MLIRContext *context, ::llvm::Optional<Location> location,
622 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
623 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
624 ShapeAdaptor lhsShape = operands.getShape(0);
625 ShapeAdaptor rhsShape = operands.getShape(1);
626
627 // All shapes are dynamic.
628 SmallVector<int64_t> outShape;
629 outShape.resize(3, ShapedType::kDynamicSize);
630
631 if (lhsShape.hasRank()) {
632 outShape[0] = lhsShape.getDimSize(0);
633 outShape[1] = lhsShape.getDimSize(1);
634 }
635
636 if (rhsShape.hasRank()) {
637 outShape[0] = outShape[0] == ShapedType::kDynamicSize
638 ? rhsShape.getDimSize(0)
639 : outShape[0];
640 outShape[2] = rhsShape.getDimSize(2);
641 }
642
643 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
644 return success();
645 }
646
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)647 LogicalResult tosa::PadOp::inferReturnTypeComponents(
648 MLIRContext *context, ::llvm::Optional<Location> location,
649 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
650 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
651 ShapeAdaptor inputShape = operands.getShape(0);
652 ShapeAdaptor paddingShape = operands.getShape(1);
653 SmallVector<int64_t> outputShape;
654
655 // If both inputs have unknown shape, we cannot determine the shape of the
656 // output.
657 if (!inputShape.hasRank() && !paddingShape.hasRank()) {
658 inferredReturnShapes.push_back(ShapedTypeComponents());
659 return success();
660 }
661
662 // If the input rank is unknown we can info the output rank using the padding
663 // shape's first dim.
664 if (!inputShape.hasRank()) {
665 if (paddingShape.isDynamicDim(0)) {
666 inferredReturnShapes.push_back(ShapedTypeComponents());
667 return success();
668 }
669
670 outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
671 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
672 return success();
673 }
674
675 DenseIntElementsAttr paddings;
676 // If the paddings value is not a constant, all dimensions must be dynamic.
677 if (!matchPattern(operands[1], m_Constant(&paddings))) {
678 outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
679 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
680 return success();
681 }
682
683 SmallVector<int64_t> paddingValues;
684 for (auto val : paddings) {
685 paddingValues.push_back(val.getSExtValue());
686 }
687
688 outputShape.reserve(inputShape.getRank());
689 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
690 if (inputShape.isDynamicDim(i)) {
691 outputShape.push_back(ShapedType::kDynamicSize);
692 continue;
693 }
694
695 outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
696 paddingValues[i * 2 + 1]);
697 }
698
699 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
700 return success();
701 }
702
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)703 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
704 MLIRContext *context, ::llvm::Optional<Location> location,
705 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
706 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
707 ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size();
708 SmallVector<int64_t> outputShape;
709 outputShape.reserve(sizes.size());
710 for (auto val : sizes) {
711 outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
712 }
713
714 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
715 return success();
716 }
717
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)718 LogicalResult tosa::TableOp::inferReturnTypeComponents(
719 MLIRContext *context, ::llvm::Optional<Location> location,
720 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
721 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
722 ShapeAdaptor inputShape = operands.getShape(0);
723
724 if (!inputShape.hasRank()) {
725 inferredReturnShapes.push_back(ShapedTypeComponents());
726 return success();
727 }
728
729 inferredReturnShapes.resize(1);
730 inputShape.getDims(inferredReturnShapes[0]);
731 return success();
732 }
733
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)734 LogicalResult tosa::TileOp::inferReturnTypeComponents(
735 MLIRContext *context, ::llvm::Optional<Location> location,
736 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
737 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
738 TileOpAdaptor adaptor(operands, attributes);
739 ArrayAttr multiples = adaptor.multiples();
740 ShapeAdaptor inputShape = operands.getShape(0);
741 SmallVector<int64_t> outputShape;
742 if (!inputShape.hasRank()) {
743 outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
744 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
745 return success();
746 }
747
748 // We need the multiple values to determine the output shape.
749 SmallVector<int64_t> multipleValues;
750 multipleValues.reserve(multiples.size());
751 for (auto val : multiples) {
752 multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
753 }
754
755 // Any non dynamic dimension can be multiplied to a known size.
756 outputShape.reserve(multiples.size());
757 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
758 int dim = inputShape.getDimSize(i);
759 if (dim != ShapedType::kDynamicSize)
760 dim *= multipleValues[i];
761 outputShape.push_back(dim);
762 }
763
764 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
765 return success();
766 }
767
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)768 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
769 MLIRContext *context, ::llvm::Optional<Location> location,
770 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
771 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
772 ReshapeOpAdaptor adaptor(operands, attributes);
773 ShapeAdaptor inputShape = operands.getShape(0);
774
775 ArrayAttr newShape = adaptor.new_shape();
776 llvm::SmallVector<int64_t> newShapeValue;
777 getI64Values(newShape, newShapeValue);
778
779 // We cannot infer from the total number of elements so we must take the
780 // shape attribute as exact.
781 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
782 inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
783 return success();
784 }
785
786 // Determine the number of elements covered by the slice of all static
787 // dimensions. This allows us to infer the length of the remaining dynamic
788 // dimension.
789 int64_t numElements = inputShape.getNumElements();
790 int64_t staticMul = 1;
791 for (auto val : newShapeValue) {
792 if (val != ShapedType::kDynamicSize) {
793 staticMul *= val;
794 }
795 }
796
797 // Determine the length of the dynamic dimension.
798 for (auto &val : newShapeValue) {
799 if (val == ShapedType::kDynamicSize)
800 val = numElements / staticMul;
801 }
802
803 inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
804 return success();
805 }
806
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)807 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
808 MLIRContext *context, ::llvm::Optional<Location> location,
809 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
810 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
811 ShapeAdaptor inputShape = operands.getShape(0);
812 ShapeAdaptor permsShape = operands.getShape(1);
813
814 // If input rank and permutation length is unknown, the output rank is
815 // unknown.
816 if (!inputShape.hasRank() || !permsShape.hasRank() ||
817 permsShape.isDynamicDim(0)) {
818 inferredReturnShapes.push_back(ShapedTypeComponents());
819 return success();
820 }
821
822 // This would imply the number of permutations does not match the rank of the
823 // input which is illegal.
824 if (permsShape.getDimSize(0) != inputShape.getRank()) {
825 return failure();
826 }
827
828 // Without the input dims we cannot determine the output dim sizes but we
829 // can determine the output rank.
830 SmallVector<int64_t> outputShape;
831 if (!inputShape.hasRank()) {
832 outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
833 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
834 return success();
835 }
836
837 // Rank-0 means no permutations matter.
838 if (inputShape.getRank() == 0) {
839 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
840 return success();
841 }
842
843 // Check whether the input dimensions are all the same.
844 bool allTheSame = true;
845 for (int i = 1, s = inputShape.getRank(); i < s; i++) {
846 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
847 allTheSame = false;
848 break;
849 }
850 }
851
852 // If all of the input dimensions are the same we don't care about the
853 // permutation.
854 if (allTheSame) {
855 outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
856 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
857 return success();
858 }
859
860 outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
861 // If the permuations are a constant we can directly determine the output
862 // shape.
863 if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
864 outputShape.reserve(inputShape.getRank());
865 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
866 outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
867 }
868 }
869
870 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
871 return success();
872 }
873
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)874 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
875 MLIRContext *context, ::llvm::Optional<Location> location,
876 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
877 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
878 llvm::SmallVector<int64_t> outputShape;
879 outputShape.resize(3, ShapedType::kDynamicSize);
880
881 ShapeAdaptor valuesShape = operands.getShape(0);
882 if (valuesShape.hasRank()) {
883 outputShape[0] = valuesShape.getDimSize(0);
884 outputShape[2] = valuesShape.getDimSize(2);
885 }
886
887 ShapeAdaptor indicesShape = operands.getShape(1);
888 if (indicesShape.hasRank()) {
889 if (outputShape[0] == ShapedType::kDynamicSize)
890 outputShape[0] = indicesShape.getDimSize(0);
891 if (outputShape[1] == ShapedType::kDynamicSize)
892 outputShape[1] = indicesShape.getDimSize(1);
893 }
894
895 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
896 return success();
897 }
898
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)899 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
900 MLIRContext *context, ::llvm::Optional<Location> location,
901 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
902 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
903 ResizeOpAdaptor adaptor(operands, attributes);
904 llvm::SmallVector<int64_t, 4> outputShape;
905 outputShape.resize(4, ShapedType::kDynamicSize);
906
907 int32_t inHeight = ShapedType::kDynamicSize;
908 int32_t inWidth = ShapedType::kDynamicSize;
909
910 ShapeAdaptor inputShape = operands.getShape(adaptor.input());
911 if (inputShape.hasRank()) {
912 outputShape[0] = inputShape.getDimSize(0);
913 outputShape[3] = inputShape.getDimSize(3);
914
915 inHeight = inputShape.getDimSize(1);
916 inWidth = inputShape.getDimSize(2);
917 }
918
919 int32_t shift = adaptor.shift().getValue().getSExtValue();
920 llvm::SmallVector<int64_t> newShape;
921 getI64Values(adaptor.output_size(), newShape);
922 outputShape[1] = newShape[0];
923 outputShape[2] = newShape[1];
924
925 llvm::SmallVector<int64_t> strideInt;
926 llvm::SmallVector<int64_t> offsetInt;
927 llvm::SmallVector<double> strideFp;
928 llvm::SmallVector<double> offsetFp;
929 getI64Values(adaptor.offset(), offsetInt);
930 getF64Values(adaptor.offset_fp(), offsetFp);
931 getI64Values(adaptor.stride(), strideInt);
932 getF64Values(adaptor.stride_fp(), strideFp);
933
934 // If we have a 0 zero in integers we know that the resize indexing needs to
935 // be performed in floating point. Use the floating point varient to compute
936 // the resize shape.
937 bool fpMode = strideInt[0] == 0;
938
939 // We can compute the output shape if attribute specifies unknown dimensions
940 // based on the offset and stride. If we perfectly line up to the last index
941 // we need to round up the size to include it.
942 if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
943 float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
944 float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
945 outputShape[1] = std::ceil(sizeFp) + round;
946 }
947
948 if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
949 float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
950 float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
951 outputShape[2] = std::ceil(sizeFp) + round;
952 }
953
954 if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
955 int64_t size = (inHeight - 1);
956 size = ((size << shift) - offsetInt[0]) / strideInt[0];
957 outputShape[1] = size + 1;
958 }
959
960 if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
961 int64_t size = (inWidth - 1);
962 size = ((size << shift) - offsetInt[1]) / strideInt[1];
963 outputShape[2] = size + 1;
964 }
965
966 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
967 return success();
968 }
969
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)970 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
971 MLIRContext *context, ::llvm::Optional<Location> location,
972 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
973 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
974 llvm::SmallVector<int64_t> outputShape;
975 outputShape.resize(3, ShapedType::kDynamicSize);
976
977 ShapeAdaptor valuesInShape = operands.getShape(0);
978 if (valuesInShape.hasRank()) {
979 outputShape[0] = valuesInShape.getDimSize(0);
980 outputShape[1] = valuesInShape.getDimSize(1);
981 outputShape[2] = valuesInShape.getDimSize(2);
982 }
983
984 ShapeAdaptor indicesShape = operands.getShape(1);
985 if (indicesShape.hasRank()) {
986 if (outputShape[0] == ShapedType::kDynamicSize)
987 outputShape[0] = indicesShape.getDimSize(0);
988 }
989
990 ShapeAdaptor inputShape = operands.getShape(2);
991 if (inputShape.hasRank()) {
992 if (outputShape[0] == ShapedType::kDynamicSize)
993 outputShape[0] = inputShape.getDimSize(0);
994 if (outputShape[2] == ShapedType::kDynamicSize)
995 outputShape[2] = inputShape.getDimSize(2);
996 }
997
998 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
999 return success();
1000 }
1001
ReduceInferReturnTypes(ShapeAdaptor operandShape,IntegerAttr axis,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1002 static LogicalResult ReduceInferReturnTypes(
1003 ShapeAdaptor operandShape, IntegerAttr axis,
1004 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1005 if (!operandShape.hasRank()) {
1006 inferredReturnShapes.push_back(ShapedTypeComponents());
1007 return success();
1008 }
1009
1010 SmallVector<int64_t> outputShape;
1011 operandShape.getDims(outputShape);
1012 int64_t axisVal = axis.getValue().getSExtValue();
1013 outputShape[axisVal] = 1;
1014 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1015 return success();
1016 }
1017
1018 #define REDUCE_SHAPE_INFER(OP) \
1019 LogicalResult OP::inferReturnTypeComponents( \
1020 MLIRContext *context, ::llvm::Optional<Location> location, \
1021 ValueShapeRange operands, DictionaryAttr attributes, \
1022 RegionRange regions, \
1023 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1024 return ReduceInferReturnTypes(operands.getShape(0), \
1025 attributes.get("axis").cast<IntegerAttr>(), \
1026 inferredReturnShapes); \
1027 }
1028
1029 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)1030 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1031 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1032 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1033 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1034 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1035 #undef REDUCE_SHAPE_INFER
1036
1037 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1038 SmallVector<int64_t> &outShape) {
1039 int64_t outRank = 0;
1040 for (int i = 0, e = operands.size(); i != e; ++i) {
1041 auto shape = operands.getShape(i);
1042 if (!shape.hasRank()) {
1043 return failure();
1044 }
1045 outRank = std::max<int64_t>(outRank, shape.getRank());
1046 }
1047
1048 outShape.resize(outRank, 1);
1049
1050 for (int i = 0, e = operands.size(); i != e; ++i) {
1051 auto shape = operands.getShape(i);
1052 auto rankDiff = outShape.size() - shape.getRank();
1053
1054 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1055 auto dim1 = outShape[i + rankDiff];
1056 auto dim2 = shape.getDimSize(i);
1057 auto resolvedDim = dim1;
1058
1059 if (dim1 == 1) {
1060 resolvedDim = dim2;
1061 } else if (dim2 == 1) {
1062 resolvedDim = dim1;
1063 } else if (dim1 != dim2) {
1064 return failure();
1065 }
1066 outShape[i + rankDiff] = resolvedDim;
1067 }
1068 }
1069
1070 return success();
1071 }
1072
NAryInferReturnTypes(const ValueShapeRange & operands,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1073 static LogicalResult NAryInferReturnTypes(
1074 const ValueShapeRange &operands,
1075 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1076 llvm::SmallVector<int64_t> outShape;
1077 if (resolveBroadcastShape(operands, outShape).failed()) {
1078 inferredReturnShapes.push_back(ShapedTypeComponents());
1079 } else {
1080 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1081 }
1082 return success();
1083 }
1084
1085 #define NARY_SHAPE_INFER(OP) \
1086 LogicalResult OP::inferReturnTypeComponents( \
1087 MLIRContext *context, ::llvm::Optional<Location> location, \
1088 ValueShapeRange operands, DictionaryAttr attributes, \
1089 RegionRange regions, \
1090 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1091 return NAryInferReturnTypes(operands, inferredReturnShapes); \
1092 }
1093
1094 NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)1095 NARY_SHAPE_INFER(tosa::AddOp)
1096 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1097 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1098 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1099 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1100 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1101 NARY_SHAPE_INFER(tosa::CastOp)
1102 NARY_SHAPE_INFER(tosa::CeilOp)
1103 NARY_SHAPE_INFER(tosa::ClampOp)
1104 NARY_SHAPE_INFER(tosa::ClzOp)
1105 NARY_SHAPE_INFER(tosa::DivOp)
1106 NARY_SHAPE_INFER(tosa::EqualOp)
1107 NARY_SHAPE_INFER(tosa::ExpOp)
1108 NARY_SHAPE_INFER(tosa::FloorOp)
1109 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1110 NARY_SHAPE_INFER(tosa::GreaterOp)
1111 NARY_SHAPE_INFER(tosa::IdentityOp)
1112 NARY_SHAPE_INFER(tosa::LogOp)
1113 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1114 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1115 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1116 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1117 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1118 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1119 NARY_SHAPE_INFER(tosa::MaximumOp)
1120 NARY_SHAPE_INFER(tosa::MinimumOp)
1121 NARY_SHAPE_INFER(tosa::MulOp)
1122 NARY_SHAPE_INFER(tosa::NegateOp)
1123 NARY_SHAPE_INFER(tosa::PowOp)
1124 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1125 NARY_SHAPE_INFER(tosa::ReluNOp)
1126 NARY_SHAPE_INFER(tosa::RescaleOp)
1127 NARY_SHAPE_INFER(tosa::ReverseOp)
1128 NARY_SHAPE_INFER(tosa::RsqrtOp)
1129 NARY_SHAPE_INFER(tosa::SelectOp)
1130 NARY_SHAPE_INFER(tosa::SubOp)
1131 NARY_SHAPE_INFER(tosa::TanhOp)
1132 NARY_SHAPE_INFER(tosa::SigmoidOp)
1133 #undef PRED_SHAPE_INFER
1134
1135 static LogicalResult poolingInferReturnTypes(
1136 const ValueShapeRange &operands, DictionaryAttr attributes,
1137 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1138 ShapeAdaptor inputShape = operands.getShape(0);
1139 llvm::SmallVector<int64_t> outputShape;
1140 outputShape.resize(4, -1);
1141
1142 // We only know the rank if the input type is unranked.
1143 if (!inputShape) {
1144 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1145 return success();
1146 }
1147
1148 // Batch and number of channels are identical for pooling layer.
1149 outputShape[0] = inputShape.getDimSize(0);
1150 outputShape[3] = inputShape.getDimSize(3);
1151
1152 int32_t height = inputShape.getDimSize(1);
1153 int32_t width = inputShape.getDimSize(2);
1154
1155 llvm::SmallVector<int64_t> kernel;
1156 llvm::SmallVector<int64_t> stride;
1157 llvm::SmallVector<int64_t> pad;
1158
1159 getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
1160 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1161 getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
1162
1163 if (height != -1) {
1164 int32_t padded = height + pad[0] + pad[1] - kernel[0];
1165 outputShape[1] = padded / stride[0] + 1;
1166 }
1167
1168 if (width != -1) {
1169 int32_t padded = width + pad[2] + pad[3] - kernel[1];
1170 outputShape[2] = padded / stride[1] + 1;
1171 }
1172
1173 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1174 return success();
1175 }
1176
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1177 LogicalResult Conv2DOp::inferReturnTypeComponents(
1178 MLIRContext *context, ::llvm::Optional<Location> location,
1179 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1180 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1181 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1182 Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1183
1184 int32_t inputWidth = ShapedType::kDynamicSize;
1185 int32_t inputHeight = ShapedType::kDynamicSize;
1186 int32_t weightWidth = ShapedType::kDynamicSize;
1187 int32_t weightHeight = ShapedType::kDynamicSize;
1188
1189 // Input shape describes input width/height and batch.
1190
1191 ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1192 if (inputShape.hasRank()) {
1193 outputShape[0] = inputShape.getDimSize(0);
1194 inputHeight = inputShape.getDimSize(1);
1195 inputWidth = inputShape.getDimSize(2);
1196 }
1197
1198 // Weight shapes describes the filter width/height and the output channels.
1199 ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1200 if (weightShape.hasRank()) {
1201 outputShape[3] = weightShape.getDimSize(0);
1202 weightHeight = weightShape.getDimSize(1);
1203 weightWidth = weightShape.getDimSize(2);
1204 }
1205
1206 // Bias shape can describe the output channels.
1207 ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1208 if (biasShape.hasRank()) {
1209 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1210 ? biasShape.getDimSize(0)
1211 : outputShape[3];
1212 }
1213
1214 llvm::SmallVector<int64_t> dilation;
1215 llvm::SmallVector<int64_t> padding;
1216 llvm::SmallVector<int64_t> stride;
1217
1218 getI64Values(adaptor.dilation(), dilation);
1219 getI64Values(adaptor.pad(), padding);
1220 getI64Values(adaptor.stride(), stride);
1221
1222 if (!ShapedType::isDynamic(inputHeight) &&
1223 !ShapedType::isDynamic(weightHeight)) {
1224 int32_t inputSize = inputHeight + padding[0] + padding[1];
1225 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1226 int32_t unstridedResult = inputSize - filterSize + 1;
1227 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1228 }
1229
1230 if (!ShapedType::isDynamic(inputWidth) &&
1231 !ShapedType::isDynamic(weightWidth)) {
1232 int32_t inputSize = inputWidth + padding[2] + padding[3];
1233 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1234 int32_t unstridedResult = inputSize - filterSize + 1;
1235 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1236 }
1237
1238 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1239 return success();
1240 }
1241
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1242 LogicalResult Conv3DOp::inferReturnTypeComponents(
1243 MLIRContext *context, ::llvm::Optional<Location> location,
1244 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1245 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1246 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1247 Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1248
1249 int32_t inputWidth = ShapedType::kDynamicSize;
1250 int32_t inputHeight = ShapedType::kDynamicSize;
1251 int32_t inputDepth = ShapedType::kDynamicSize;
1252
1253 int32_t weightWidth = ShapedType::kDynamicSize;
1254 int32_t weightHeight = ShapedType::kDynamicSize;
1255 int32_t weightDepth = ShapedType::kDynamicSize;
1256
1257 // Input shape describes input width/height and batch.
1258 ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1259 if (inputShape.hasRank()) {
1260 outputShape[0] = inputShape.getDimSize(0);
1261 inputHeight = inputShape.getDimSize(1);
1262 inputWidth = inputShape.getDimSize(2);
1263 inputDepth = inputShape.getDimSize(3);
1264 }
1265
1266 // Weight shapes describes the filter width/height and the output channels.
1267 ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1268 if (weightShape.hasRank()) {
1269 outputShape[4] = weightShape.getDimSize(0);
1270 weightHeight = weightShape.getDimSize(1);
1271 weightWidth = weightShape.getDimSize(2);
1272 weightDepth = weightShape.getDimSize(3);
1273 }
1274
1275 // Bias shape can describe the output channels.
1276 ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1277 if (biasShape.hasRank()) {
1278 outputShape[4] =
1279 (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1280 }
1281
1282 llvm::SmallVector<int64_t> dilation;
1283 llvm::SmallVector<int64_t> padding;
1284 llvm::SmallVector<int64_t> stride;
1285
1286 getI64Values(adaptor.dilation(), dilation);
1287 getI64Values(adaptor.pad(), padding);
1288 getI64Values(adaptor.stride(), stride);
1289
1290 if (!ShapedType::isDynamic(inputHeight) &&
1291 !ShapedType::isDynamic(weightHeight)) {
1292 int32_t inputSize = inputHeight + padding[0] + padding[1];
1293 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1294 int32_t unstridedResult = inputSize - filterSize + 1;
1295 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1296 }
1297
1298 if (!ShapedType::isDynamic(inputWidth) &&
1299 !ShapedType::isDynamic(weightWidth)) {
1300 int32_t inputSize = inputWidth + padding[2] + padding[3];
1301 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1302 int32_t unstridedResult = inputSize - filterSize + 1;
1303 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1304 }
1305
1306 if (!ShapedType::isDynamic(inputDepth) &&
1307 !ShapedType::isDynamic(weightDepth)) {
1308 int32_t inputSize = inputDepth + padding[4] + padding[5];
1309 int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1310 int32_t unstridedResult = inputSize - filterSize + 1;
1311 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1312 }
1313
1314 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1315 return success();
1316 }
1317
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1318 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1319 MLIRContext *context, ::llvm::Optional<Location> location,
1320 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1321 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1322 return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1323 }
1324
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1325 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1326 MLIRContext *context, ::llvm::Optional<Location> location,
1327 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1328 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1329 return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1330 }
1331
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1332 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1333 MLIRContext *context, ::llvm::Optional<Location> location,
1334 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1335 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1336 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1337 DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1338
1339 int32_t inputWidth = ShapedType::kDynamicSize;
1340 int32_t inputHeight = ShapedType::kDynamicSize;
1341 int32_t inputChannels = ShapedType::kDynamicSize;
1342
1343 int32_t weightWidth = ShapedType::kDynamicSize;
1344 int32_t weightHeight = ShapedType::kDynamicSize;
1345 int32_t depthChannels = ShapedType::kDynamicSize;
1346
1347 // Input shape describes input width/height and batch.
1348 ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1349 if (inputShape.hasRank()) {
1350 outputShape[0] = inputShape.getDimSize(0);
1351 inputHeight = inputShape.getDimSize(1);
1352 inputWidth = inputShape.getDimSize(2);
1353 inputChannels = inputShape.getDimSize(3);
1354 }
1355
1356 // Weight shapes describes the filter width/height and the output channels.
1357 ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1358 if (weightShape.hasRank()) {
1359 weightHeight = weightShape.getDimSize(0);
1360 weightWidth = weightShape.getDimSize(1);
1361 inputChannels = ShapedType::isDynamic(inputChannels)
1362 ? weightShape.getDimSize(2)
1363 : inputChannels;
1364 depthChannels = weightShape.getDimSize(3);
1365 }
1366
1367 // If both inputChannels and depthChannels are available we can determine
1368 // the output channels.
1369 if (!ShapedType::isDynamic(inputChannels) &&
1370 !ShapedType::isDynamic(depthChannels)) {
1371 outputShape[3] = inputChannels * depthChannels;
1372 }
1373
1374 // Bias shape can describe the output channels.
1375 ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1376 if (biasShape.hasRank()) {
1377 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1378 ? biasShape.getDimSize(0)
1379 : outputShape[3];
1380 }
1381
1382 llvm::SmallVector<int64_t> dilation;
1383 llvm::SmallVector<int64_t> padding;
1384 llvm::SmallVector<int64_t> stride;
1385
1386 getI64Values(adaptor.dilation(), dilation);
1387 getI64Values(adaptor.pad(), padding);
1388 getI64Values(adaptor.stride(), stride);
1389
1390 if (!ShapedType::isDynamic(inputHeight) &&
1391 !ShapedType::isDynamic(weightHeight)) {
1392 int32_t inputSize = inputHeight + padding[0] + padding[1];
1393 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1394 int32_t unstridedResult = inputSize - filterSize + 1;
1395 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1396 }
1397
1398 if (!ShapedType::isDynamic(inputWidth) &&
1399 !ShapedType::isDynamic(weightWidth)) {
1400 int32_t inputSize = inputWidth + padding[2] + padding[3];
1401 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1402 int32_t unstridedResult = inputSize - filterSize + 1;
1403 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1404 }
1405
1406 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1407 return success();
1408 }
1409
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1410 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1411 MLIRContext *context, ::llvm::Optional<Location> location,
1412 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1413 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1414 TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1415 llvm::SmallVector<int64_t> outputShape;
1416 getI64Values(adaptor.out_shape(), outputShape);
1417
1418 int32_t inputWidth = ShapedType::kDynamicSize;
1419 int32_t inputHeight = ShapedType::kDynamicSize;
1420 int32_t weightWidth = ShapedType::kDynamicSize;
1421 int32_t weightHeight = ShapedType::kDynamicSize;
1422
1423 // Input shape describes input width/height and batch.
1424 ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1425 if (inputShape.hasRank()) {
1426 outputShape[0] = ShapedType::isDynamic(outputShape[0])
1427 ? inputShape.getDimSize(0)
1428 : outputShape[0];
1429 inputHeight = inputShape.getDimSize(1);
1430 inputWidth = inputShape.getDimSize(2);
1431 }
1432
1433 // Weight shapes describes the filter width/height and the output channels.
1434 ShapeAdaptor weightShape = operands.getShape(adaptor.input());
1435 if (weightShape.hasRank()) {
1436 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1437 ? weightShape.getDimSize(0)
1438 : outputShape[3];
1439 weightHeight = weightShape.getDimSize(1);
1440 weightWidth = weightShape.getDimSize(2);
1441 }
1442
1443 // Bias shape can describe the output channels.
1444 ShapeAdaptor biasShape = operands.getShape(adaptor.input());
1445 if (biasShape.hasRank()) {
1446 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1447 ? biasShape.getDimSize(0)
1448 : outputShape[3];
1449 }
1450
1451 llvm::SmallVector<int64_t> dilation;
1452 llvm::SmallVector<int64_t> padding;
1453 llvm::SmallVector<int64_t> stride;
1454
1455 getI64Values(adaptor.dilation(), dilation);
1456 getI64Values(adaptor.out_pad(), padding);
1457 getI64Values(adaptor.stride(), stride);
1458
1459 if (!ShapedType::isDynamic(inputHeight) &&
1460 !ShapedType::isDynamic(weightHeight)) {
1461 int32_t dilated = (weightHeight - 1) * dilation[0] + 1;
1462 int32_t calculateSize =
1463 (inputHeight - 1) * stride[0] - padding[0] + dilated;
1464 outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1465 }
1466
1467 if (!ShapedType::isDynamic(inputWidth) &&
1468 !ShapedType::isDynamic(weightWidth)) {
1469 int32_t dilated = (weightWidth - 1) * dilation[1] + 1;
1470 int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated;
1471 outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1472 }
1473
1474 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1475 return success();
1476 }
1477
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1478 LogicalResult IfOp::inferReturnTypeComponents(
1479 MLIRContext *context, ::llvm::Optional<Location> location,
1480 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1481 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1482 llvm::SmallVector<tosa::YieldOp> yieldOps;
1483 for (Region *region : regions) {
1484 for (auto &block : *region)
1485 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1486 yieldOps.push_back(returnOp);
1487 }
1488
1489 if (yieldOps.empty())
1490 return failure();
1491
1492 // Get the initial type information for the yield op.
1493 llvm::SmallVector<ValueKnowledge> resultKnowledge;
1494 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1495 for (auto operand : yieldOps.front().getOperands()) {
1496 resultKnowledge.push_back(
1497 ValueKnowledge::getKnowledgeFromType(operand.getType()));
1498 }
1499
1500 for (auto yieldOp : yieldOps) {
1501 if (resultKnowledge.size() != yieldOp.getNumOperands())
1502 return failure();
1503
1504 for (auto it : llvm::enumerate(yieldOp.getOperands())) {
1505 int32_t index = it.index();
1506 auto meet = ValueKnowledge::meet(
1507 resultKnowledge[index],
1508 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1509 if (!meet)
1510 continue;
1511 resultKnowledge[index] = meet;
1512 }
1513 }
1514
1515 for (const ValueKnowledge &result : resultKnowledge) {
1516 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1517 }
1518
1519 return success();
1520 }
1521
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1522 LogicalResult WhileOp::inferReturnTypeComponents(
1523 MLIRContext *context, ::llvm::Optional<Location> location,
1524 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1525 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1526 llvm::SmallVector<tosa::YieldOp> yieldOps;
1527 for (auto &block : *regions[1])
1528 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1529 yieldOps.push_back(returnOp);
1530
1531 // TOSA's while must have a tosa.yield as its terminator. If not found this
1532 // tosa.while is invalid.
1533 if (yieldOps.empty())
1534 return failure();
1535
1536 // Get the initial type information from the operand types.
1537 llvm::SmallVector<ValueKnowledge> resultKnowledge;
1538 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1539 for (auto operand : yieldOps.front().getOperands()) {
1540 resultKnowledge.push_back(
1541 ValueKnowledge::getKnowledgeFromType(operand.getType()));
1542 }
1543
1544 for (auto yieldOp : yieldOps) {
1545 if (resultKnowledge.size() != yieldOp.getNumOperands())
1546 return failure();
1547
1548 for (auto it : llvm::enumerate(yieldOp.getOperands())) {
1549 int32_t index = it.index();
1550 if (auto meet = ValueKnowledge::meet(
1551 resultKnowledge[index],
1552 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1553 resultKnowledge[index] = meet;
1554 };
1555 }
1556 }
1557
1558 for (const ValueKnowledge &result : resultKnowledge) {
1559 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1560 }
1561
1562 return success();
1563 }
1564
1565 //===----------------------------------------------------------------------===//
1566 // TOSA Operator Definitions.
1567 //===----------------------------------------------------------------------===//
1568
1569 #define GET_OP_CLASSES
1570 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1571