1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/DenseMap.h"
27 
28 using namespace mlir;
29 using namespace mlir::tosa;
30 
31 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // Tosa dialect structs and interface includes.
35 //===----------------------------------------------------------------------===//
36 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
37 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
38 
39 namespace {
40 //===----------------------------------------------------------------------===//
41 // Dialect Function Inliner Interface.
42 //===----------------------------------------------------------------------===//
43 struct TosaInlinerInterface : public DialectInlinerInterface {
44   using DialectInlinerInterface::DialectInlinerInterface;
45 
46   //===--------------------------------------------------------------------===//
47   // Analysis Hooks.
48   //===--------------------------------------------------------------------===//
49 
50   /// All operations can be inlined by default.
isLegalToInline__anon4d6485000111::TosaInlinerInterface51   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
52                        BlockAndValueMapping &map) const final {
53     return true;
54   }
55 
56   /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon4d6485000111::TosaInlinerInterface57   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
58                        BlockAndValueMapping &map) const final {
59     return (isa<tosa::IfOp>(dest->getParentOp()) ||
60             isa<tosa::WhileOp>(dest->getParentOp()));
61   }
62 };
63 } // end anonymous namespace
64 
65 //===----------------------------------------------------------------------===//
66 // TOSA control flow support.
67 //===----------------------------------------------------------------------===//
68 
69 /// Returns the while loop body.
getLoopBody()70 Region &tosa::WhileOp::getLoopBody() { return body(); }
71 
isDefinedOutsideOfLoop(Value value)72 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
73   return !body().isAncestor(value.getParentRegion());
74 }
75 
moveOutOfLoop(ArrayRef<mlir::Operation * > ops)76 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
77   if (ops.empty())
78     return success();
79 
80   Operation *tosaWhileOp = this->getOperation();
81   for (auto *op : ops)
82     op->moveBefore(tosaWhileOp);
83 
84   return success();
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Tosa dialect initialization.
89 //===----------------------------------------------------------------------===//
90 
initialize()91 void TosaDialect::initialize() {
92   addOperations<
93 #define GET_OP_LIST
94 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
95       >();
96   addInterfaces<TosaInlinerInterface>();
97 }
98 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)99 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
100                                             Type type, Location loc) {
101   // Tosa dialect constants only support ElementsAttr unlike standard dialect
102   // constant which supports all attributes.
103   if (value.isa<ElementsAttr>())
104     return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
105   return nullptr;
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Operator Canonicalizers.
110 //===----------------------------------------------------------------------===//
111 
112 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
113   using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
114 
matchAndRewriteConcatOptimization115   LogicalResult matchAndRewrite(tosa::ConcatOp op,
116                                 PatternRewriter &rewriter) const override {
117     if (op.input1().size() != 1)
118       return failure();
119     if (op.input1().front().getType() != op.getType()) {
120       rewriter
121           .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
122                                               op.input1().front())
123           .getResult();
124       return success();
125     }
126 
127     rewriter.replaceOp(op, op.input1().front());
128     return success();
129   }
130 };
131 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)132 void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
133                                            MLIRContext *context) {
134   results.insert<ConcatOptimization>(context);
135 }
136 
137 struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
138   using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
139 
matchAndRewriteReshapeReshapeOptimization140   LogicalResult matchAndRewrite(tosa::ReshapeOp op,
141                                 PatternRewriter &rewriter) const override {
142     Value input = op.input1();
143     Operation *definingOp = input.getDefiningOp();
144     if (!definingOp)
145       return failure();
146 
147     if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
148       rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
149           op, op.getType(), reshapeOp.input1(), op.new_shape());
150       return success();
151     }
152 
153     return failure();
154   }
155 };
156 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)157 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
158                                             MLIRContext *context) {
159   results.insert<ReshapeReshapeOptimization>(context);
160 }
161 
162 struct ConstantTransposeOptimization
163     : public OpRewritePattern<tosa::TransposeOp> {
164   using OpRewritePattern::OpRewritePattern;
165 
matchAndRewriteConstantTransposeOptimization166   LogicalResult matchAndRewrite(tosa::TransposeOp op,
167                                 PatternRewriter &rewriter) const override {
168     auto outputType = op.getType().cast<ShapedType>();
169     ArrayRef<int64_t> outputShape = outputType.getShape();
170     // TOSA supports quantized types.
171     if (!outputType.getElementType().isIntOrIndexOrFloat())
172       return failure();
173 
174     DenseElementsAttr inputValues;
175     if (!matchPattern(op.input1(), m_Constant(&inputValues)))
176       return failure();
177     // Make sure the input is a constant that has a single user.
178     if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
179       return failure();
180 
181     DenseIntElementsAttr permAttr;
182     if (!matchPattern(op.perms(), m_Constant(&permAttr)))
183       return failure();
184     auto permValues = llvm::to_vector<6>(llvm::map_range(
185         // TOSA allows both 32- and 64-bit integer tensors here.
186         permAttr.getValues<APInt>(),
187         [](const APInt &val) { return val.getZExtValue(); }));
188 
189     auto inputType = op.input1().getType().cast<ShapedType>();
190     ArrayRef<int64_t> inputShape = inputType.getShape();
191     int64_t numElements = inputType.getNumElements();
192 
193     SmallVector<Attribute, 4> outputValues;
194     outputValues.resize(numElements);
195 
196     // Transpose the input constant. Because we don't know its rank in advance,
197     // we need to loop over the range [0, element count) and delinearize the
198     // index.
199     for (int srcLinearIndex = 0; srcLinearIndex < numElements;
200          ++srcLinearIndex) {
201       SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
202       int totalCount = srcLinearIndex;
203       for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
204         srcIndices[dim] = totalCount % inputShape[dim];
205         totalCount /= inputShape[dim];
206       }
207 
208       SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
209       for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
210         dstIndices[dim] = srcIndices[permValues[dim]];
211 
212       uint64_t dstLinearIndex = dstIndices.front();
213       for (int dim = 1; dim < outputType.getRank(); ++dim)
214         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
215 
216       outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
217     }
218 
219     rewriter.replaceOpWithNewOp<tosa::ConstOp>(
220         op, outputType, DenseElementsAttr::get(outputType, outputValues));
221     return success();
222   }
223 };
224 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)225 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
226                                               MLIRContext *context) {
227   results.insert<ConstantTransposeOptimization>(context);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Operator Folders.
232 //===----------------------------------------------------------------------===//
233 
fold(ArrayRef<Attribute> operands)234 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
235   if (input().getType() == getType())
236     return input();
237   return {};
238 }
239 
fold(ArrayRef<Attribute> operands)240 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
241   assert(operands.empty() && "constant has no operands");
242   return valueAttr();
243 }
244 
245 #define ReduceFolder(OP)                                                       \
246   OpFoldResult OP::fold(ArrayRef<Attribute> operands) {                        \
247     ShapedType inputTy = input().getType().cast<ShapedType>();                 \
248     if (!inputTy.hasRank())                                                    \
249       return {};                                                               \
250     if (inputTy.getDimSize(axis()) == 1)                                       \
251       return input();                                                          \
252     return {};                                                                 \
253   }
254 
ReduceFolder(ReduceAnyOp)255 ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp)
256     ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp)
257         ReduceFolder(ReduceSumOp)
258 #undef ReduceFolder
259 
260             OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
261   auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
262   auto outputTy = getType().dyn_cast<RankedTensorType>();
263 
264   if (!inputTy || !outputTy || inputTy != outputTy)
265     return {};
266   return input1();
267 }
268 
fold(ArrayRef<Attribute> operands)269 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
270   auto inputTy = input().getType().dyn_cast<RankedTensorType>();
271   auto outputTy = getType().dyn_cast<RankedTensorType>();
272 
273   if (!inputTy || !outputTy || inputTy != outputTy)
274     return {};
275   if (inputTy.hasStaticShape())
276     return input();
277 
278   return {};
279 }
280 
fold(ArrayRef<Attribute> operands)281 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
282   bool allOnes = true;
283   for (Attribute val : multiples().getValue()) {
284     allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
285   }
286 
287   if (allOnes && input1().getType() == getType())
288     return input1();
289   return {};
290 }
291 
fold(ArrayRef<Attribute> operands)292 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
293   if (!operands[1])
294     return {};
295 
296   // Transposing splat values just means reshaping.
297   if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
298     if (input.isSplat())
299       return input.reshape(getType().cast<ShapedType>());
300   }
301 
302   auto perms = llvm::to_vector<6>(llvm::map_range(
303       operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
304       [](const APInt &val) { return val.getSExtValue(); }));
305 
306   if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
307       input1().getType() == getType())
308     return input1();
309   return {};
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // TOSA Operator Verifiers.
314 //===----------------------------------------------------------------------===//
315 
316 template <typename T>
verifyConvOp(T op)317 static LogicalResult verifyConvOp(T op) {
318   // All TOSA conv ops have an input() and weight().
319   auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
320   auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
321 
322   // Must be ranked tensor types
323   if (!inputType || !weightType)
324     return failure();
325 
326   auto inputEType = inputType.getElementType();
327   auto weightEType = weightType.getElementType();
328 
329   bool inputIsQuant = !inputEType.template isa<FloatType>();
330   bool weightIsQuant = !weightEType.template isa<FloatType>();
331 
332   // Either both must be quantized or both unquantized.
333   if (inputIsQuant != weightIsQuant)
334     return failure();
335 
336   // Quantized type must have constructed the quantizationattr, and unquantized
337   // types should not have a quantizationattr.
338   if ((inputIsQuant && !op.quantization_info()) ||
339       (!inputIsQuant && op.quantization_info()))
340     return failure();
341 
342   return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // TOSA Operator Quantization Builders.
347 //===----------------------------------------------------------------------===//
348 
349 /// This builder is called on all convolution operators except TransposeConv,
350 /// which has specialized output shape semantics. The builder also defines the
351 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)352 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
353                                      Type outputType, Value input, Value weight,
354                                      Value bias, ArrayAttr pad,
355                                      ArrayAttr stride, ArrayAttr dilation) {
356 
357   result.addOperands({input, weight, bias});
358   result.addAttribute("pad", pad);
359   result.addAttribute("stride", stride);
360   result.addAttribute("dilation", dilation);
361 
362   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
363   if (quantAttr) {
364     result.addAttribute("quantization_info", quantAttr);
365     result.addTypes(
366         buildConvOpResultTypeInfo(builder, outputType, input, weight));
367   } else {
368     result.addTypes(outputType);
369   }
370 }
371 
372 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
373 static void
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr dilation,ArrayAttr outputShape)374 buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
375                               Type outputType, Value input, Value weight,
376                               Value bias, ArrayAttr outpad, ArrayAttr stride,
377                               ArrayAttr dilation, ArrayAttr outputShape) {
378   result.addOperands({input, weight, bias});
379   result.addAttribute("out_pad", outpad);
380   result.addAttribute("stride", stride);
381   result.addAttribute("dilation", dilation);
382   result.addAttribute("out_shape", outputShape);
383   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
384 
385   if (quantAttr) {
386     result.addAttribute("quantization_info", quantAttr);
387     result.addTypes(
388         buildConvOpResultTypeInfo(builder, outputType, input, weight));
389   } else {
390     result.addTypes(outputType);
391   }
392 }
393 
394 /// The tosa.fully_connected op has its own builder as it does not have
395 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)396 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
397                                    Type outputType, Value input, Value weight,
398                                    Value bias) {
399 
400   result.addOperands({input, weight, bias});
401   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
402   if (quantAttr) {
403     result.addAttribute("quantization_info", quantAttr);
404     result.addTypes(
405         buildConvOpResultTypeInfo(builder, outputType, input, weight));
406   } else {
407     result.addTypes(outputType);
408   }
409 }
410 
411 /// The tosa.matmul op is also intended to be generated where a fully_connected
412 /// op must be constructed where the weight is not a constant. In this case,
413 /// the fully_connected op must be expressed using matmul.
414 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)415 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
416                                        OperationState &result, Type outputType,
417                                        Value a, Value b) {
418   result.addOperands({a, b});
419   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
420 
421   if (quantAttr) {
422     result.addAttribute("quantization_info", quantAttr);
423 
424     auto inputType = a.getType().dyn_cast<ShapedType>();
425     assert(inputType && "Input must be a shaped tensor type!");
426 
427     auto inputQType = inputType.getElementType()
428                           .dyn_cast<mlir::quant::UniformQuantizedType>();
429     assert(inputQType && "Tensor must have quantized datatype!");
430 
431     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
432 
433     auto outputShapedType = outputType.dyn_cast<ShapedType>();
434     assert(outputShapedType && "Output must be a shaped type");
435 
436     IntegerType accElementType;
437     if (inputBits == 16)
438       accElementType = builder.getIntegerType(48);
439     else
440       accElementType = builder.getI32Type();
441     auto accType = outputShapedType.clone(accElementType);
442     result.addTypes(accType);
443   } else {
444     result.addTypes(outputType);
445   }
446 }
447 
448 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
449 /// but avg_pool operator has its own builder as it has additional parameters
450 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)451 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
452                                           OperationState &result,
453                                           Type outputType, Value input,
454                                           ArrayAttr kernel, ArrayAttr stride,
455                                           ArrayAttr pad) {
456   result.addOperands(input);
457   result.addAttribute("kernel", kernel);
458   result.addAttribute("stride", stride);
459   result.addAttribute("pad", pad);
460   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
461   if (quantAttr)
462     result.addAttribute("quantization_info", quantAttr);
463   result.types.push_back(outputType);
464 }
465 
466 /// This builder is called on single-parameter unary operators that have scale
467 /// relationship between their input and output, expressed by the
468 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)469 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
470                                       OperationState &result, Type outputType,
471                                       Value input) {
472   result.addOperands(input);
473   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
474   if (quantAttr)
475     result.addAttribute("quantization_info", quantAttr);
476   result.types.push_back(outputType);
477 }
478 
479 /// This builder is called on TOSA pad operator that needs to create its own
480 /// OptionalAttr quantization_attr parameter to scale the padding values
481 /// correctly.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)482 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
483                                     Type outputType, Value input,
484                                     Value paddings) {
485   result.addOperands({input, paddings});
486   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
487   if (quantAttr)
488     result.addAttribute("quantization_info", quantAttr);
489   result.types.push_back(outputType);
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // TOSA Operator Return Type Inference.
494 //===----------------------------------------------------------------------===//
495 
getI64Values(ArrayAttr arrayAttr,SmallVector<int64_t> & values)496 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
497   for (auto it : arrayAttr) {
498     values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
499   }
500 }
501 
getF64Values(ArrayAttr arrayAttr,SmallVector<double> & values)502 static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
503   for (auto it : arrayAttr) {
504     values.push_back(it.cast<FloatAttr>().getValueAsDouble());
505   }
506 }
507 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)508 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
509     MLIRContext *context, ::llvm::Optional<Location> location,
510     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
511     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
512   ShapeAdaptor inputShape = operands.getShape(0);
513   IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
514   int32_t axisVal = axis.getValue().getSExtValue();
515 
516   if (!inputShape.hasRank()) {
517     inferredReturnShapes.push_back(ShapedTypeComponents());
518     return success();
519   }
520 
521   SmallVector<int64_t> outShape;
522   outShape.reserve(inputShape.getRank() - 1);
523   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
524     if (i == axisVal)
525       continue;
526     outShape.push_back(inputShape.getDimSize(i));
527   }
528 
529   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
530   return success();
531 }
532 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)533 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
534     MLIRContext *context, ::llvm::Optional<Location> location,
535     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
536     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
537   // Infer all dimension sizes by reducing based on inputs.
538   int32_t axis =
539       attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
540   llvm::SmallVector<int64_t> outputShape;
541   bool hasRankedInput = false;
542   for (auto operand : operands) {
543     ShapeAdaptor operandShape = operands.getShape(operand);
544     if (!operandShape.hasRank())
545       continue;
546 
547     // Copy the Operand's rank.
548     if (!hasRankedInput)
549       outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
550 
551     // Copy shapes until the dim is non-dynamic.
552     for (int i = 0, s = operandShape.getRank(); i < s; i++) {
553       if (i == axis || operandShape.isDynamicDim(i))
554         continue;
555       if (outputShape[i] == ShapedType::kDynamicSize)
556         outputShape[i] = operandShape.getDimSize(i);
557       if (outputShape[i] != operandShape.getDimSize(i))
558         return failure();
559     }
560 
561     hasRankedInput = true;
562   }
563 
564   if (!hasRankedInput) {
565     inferredReturnShapes.push_back(ShapedTypeComponents());
566     return success();
567   }
568 
569   // Determine the dimension size along the concatenation axis.
570   int concatDimSize = 0;
571   for (auto operand : operands) {
572     ShapeAdaptor operandShape = operands.getShape(operand);
573 
574     // We need to know the length of the concatenation axis of all inputs to
575     // determine the dimension size of the output shape.
576     if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
577       concatDimSize = ShapedType::kDynamicSize;
578       break;
579     }
580 
581     concatDimSize += operandShape.getDimSize(axis);
582   }
583 
584   outputShape[axis] = concatDimSize;
585 
586   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
587   return success();
588 }
589 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)590 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
591     MLIRContext *context, ::llvm::Optional<Location> location,
592     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
593     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
594   ShapeAdaptor inputShape = operands.getShape(0);
595   ShapeAdaptor weightShape = operands.getShape(1);
596   ShapeAdaptor biasShape = operands.getShape(2);
597 
598   // All shapes are dynamic.
599   SmallVector<int64_t> outShape;
600   outShape.resize(2, ShapedType::kDynamicSize);
601 
602   if (inputShape.hasRank()) {
603     outShape[0] = inputShape.getDimSize(0);
604   }
605 
606   if (weightShape.hasRank()) {
607     outShape[1] = weightShape.getDimSize(0);
608   }
609 
610   if (biasShape.hasRank()) {
611     outShape[1] = outShape[1] == ShapedType::kDynamicSize
612                       ? biasShape.getDimSize(0)
613                       : outShape[1];
614   }
615 
616   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
617   return success();
618 }
619 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)620 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
621     MLIRContext *context, ::llvm::Optional<Location> location,
622     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
623     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
624   ShapeAdaptor lhsShape = operands.getShape(0);
625   ShapeAdaptor rhsShape = operands.getShape(1);
626 
627   // All shapes are dynamic.
628   SmallVector<int64_t> outShape;
629   outShape.resize(3, ShapedType::kDynamicSize);
630 
631   if (lhsShape.hasRank()) {
632     outShape[0] = lhsShape.getDimSize(0);
633     outShape[1] = lhsShape.getDimSize(1);
634   }
635 
636   if (rhsShape.hasRank()) {
637     outShape[0] = outShape[0] == ShapedType::kDynamicSize
638                       ? rhsShape.getDimSize(0)
639                       : outShape[0];
640     outShape[2] = rhsShape.getDimSize(2);
641   }
642 
643   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
644   return success();
645 }
646 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)647 LogicalResult tosa::PadOp::inferReturnTypeComponents(
648     MLIRContext *context, ::llvm::Optional<Location> location,
649     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
650     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
651   ShapeAdaptor inputShape = operands.getShape(0);
652   ShapeAdaptor paddingShape = operands.getShape(1);
653   SmallVector<int64_t> outputShape;
654 
655   // If both inputs have unknown shape, we cannot determine the shape of the
656   // output.
657   if (!inputShape.hasRank() && !paddingShape.hasRank()) {
658     inferredReturnShapes.push_back(ShapedTypeComponents());
659     return success();
660   }
661 
662   // If the input rank is unknown we can info the output rank using the padding
663   // shape's first dim.
664   if (!inputShape.hasRank()) {
665     if (paddingShape.isDynamicDim(0)) {
666       inferredReturnShapes.push_back(ShapedTypeComponents());
667       return success();
668     }
669 
670     outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
671     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
672     return success();
673   }
674 
675   DenseIntElementsAttr paddings;
676   // If the paddings value is not a constant, all dimensions must be dynamic.
677   if (!matchPattern(operands[1], m_Constant(&paddings))) {
678     outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
679     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
680     return success();
681   }
682 
683   SmallVector<int64_t> paddingValues;
684   for (auto val : paddings) {
685     paddingValues.push_back(val.getSExtValue());
686   }
687 
688   outputShape.reserve(inputShape.getRank());
689   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
690     if (inputShape.isDynamicDim(i)) {
691       outputShape.push_back(ShapedType::kDynamicSize);
692       continue;
693     }
694 
695     outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
696                           paddingValues[i * 2 + 1]);
697   }
698 
699   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
700   return success();
701 }
702 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)703 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
704     MLIRContext *context, ::llvm::Optional<Location> location,
705     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
706     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
707   ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size();
708   SmallVector<int64_t> outputShape;
709   outputShape.reserve(sizes.size());
710   for (auto val : sizes) {
711     outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
712   }
713 
714   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
715   return success();
716 }
717 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)718 LogicalResult tosa::TableOp::inferReturnTypeComponents(
719     MLIRContext *context, ::llvm::Optional<Location> location,
720     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
721     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
722   ShapeAdaptor inputShape = operands.getShape(0);
723 
724   if (!inputShape.hasRank()) {
725     inferredReturnShapes.push_back(ShapedTypeComponents());
726     return success();
727   }
728 
729   inferredReturnShapes.resize(1);
730   inputShape.getDims(inferredReturnShapes[0]);
731   return success();
732 }
733 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)734 LogicalResult tosa::TileOp::inferReturnTypeComponents(
735     MLIRContext *context, ::llvm::Optional<Location> location,
736     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
737     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
738   TileOpAdaptor adaptor(operands, attributes);
739   ArrayAttr multiples = adaptor.multiples();
740   ShapeAdaptor inputShape = operands.getShape(0);
741   SmallVector<int64_t> outputShape;
742   if (!inputShape.hasRank()) {
743     outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
744     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
745     return success();
746   }
747 
748   // We need the multiple values to determine the output shape.
749   SmallVector<int64_t> multipleValues;
750   multipleValues.reserve(multiples.size());
751   for (auto val : multiples) {
752     multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
753   }
754 
755   // Any non dynamic dimension can be multiplied to a known size.
756   outputShape.reserve(multiples.size());
757   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
758     int dim = inputShape.getDimSize(i);
759     if (dim != ShapedType::kDynamicSize)
760       dim *= multipleValues[i];
761     outputShape.push_back(dim);
762   }
763 
764   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
765   return success();
766 }
767 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)768 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
769     MLIRContext *context, ::llvm::Optional<Location> location,
770     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
771     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
772   ReshapeOpAdaptor adaptor(operands, attributes);
773   ShapeAdaptor inputShape = operands.getShape(0);
774 
775   ArrayAttr newShape = adaptor.new_shape();
776   llvm::SmallVector<int64_t> newShapeValue;
777   getI64Values(newShape, newShapeValue);
778 
779   // We cannot infer from the total number of elements so we must take the
780   // shape attribute as exact.
781   if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
782     inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
783     return success();
784   }
785 
786   // Determine the number of elements covered by the slice of all static
787   // dimensions. This allows us to infer the length of the remaining dynamic
788   // dimension.
789   int64_t numElements = inputShape.getNumElements();
790   int64_t staticMul = 1;
791   for (auto val : newShapeValue) {
792     if (val != ShapedType::kDynamicSize) {
793       staticMul *= val;
794     }
795   }
796 
797   // Determine the length of the dynamic dimension.
798   for (auto &val : newShapeValue) {
799     if (val == ShapedType::kDynamicSize)
800       val = numElements / staticMul;
801   }
802 
803   inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
804   return success();
805 }
806 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)807 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
808     MLIRContext *context, ::llvm::Optional<Location> location,
809     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
810     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
811   ShapeAdaptor inputShape = operands.getShape(0);
812   ShapeAdaptor permsShape = operands.getShape(1);
813 
814   // If input rank and permutation length is unknown, the output rank is
815   // unknown.
816   if (!inputShape.hasRank() || !permsShape.hasRank() ||
817       permsShape.isDynamicDim(0)) {
818     inferredReturnShapes.push_back(ShapedTypeComponents());
819     return success();
820   }
821 
822   // This would imply the number of permutations does not match the rank of the
823   // input which is illegal.
824   if (permsShape.getDimSize(0) != inputShape.getRank()) {
825     return failure();
826   }
827 
828   // Without the input dims we cannot determine the output dim sizes but we
829   // can determine the output rank.
830   SmallVector<int64_t> outputShape;
831   if (!inputShape.hasRank()) {
832     outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
833     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
834     return success();
835   }
836 
837   // Rank-0 means no permutations matter.
838   if (inputShape.getRank() == 0) {
839     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
840     return success();
841   }
842 
843   // Check whether the input dimensions are all the same.
844   bool allTheSame = true;
845   for (int i = 1, s = inputShape.getRank(); i < s; i++) {
846     if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
847       allTheSame = false;
848       break;
849     }
850   }
851 
852   // If all of the input dimensions are the same we don't care about the
853   // permutation.
854   if (allTheSame) {
855     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
856     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
857     return success();
858   }
859 
860   outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
861   // If the permuations are a constant we can directly determine the output
862   // shape.
863   if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
864     outputShape.reserve(inputShape.getRank());
865     for (int i = 0, s = inputShape.getRank(); i < s; i++) {
866       outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
867     }
868   }
869 
870   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
871   return success();
872 }
873 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)874 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
875     MLIRContext *context, ::llvm::Optional<Location> location,
876     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
877     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
878   llvm::SmallVector<int64_t> outputShape;
879   outputShape.resize(3, ShapedType::kDynamicSize);
880 
881   ShapeAdaptor valuesShape = operands.getShape(0);
882   if (valuesShape.hasRank()) {
883     outputShape[0] = valuesShape.getDimSize(0);
884     outputShape[2] = valuesShape.getDimSize(2);
885   }
886 
887   ShapeAdaptor indicesShape = operands.getShape(1);
888   if (indicesShape.hasRank()) {
889     if (outputShape[0] == ShapedType::kDynamicSize)
890       outputShape[0] = indicesShape.getDimSize(0);
891     if (outputShape[1] == ShapedType::kDynamicSize)
892       outputShape[1] = indicesShape.getDimSize(1);
893   }
894 
895   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
896   return success();
897 }
898 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)899 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
900     MLIRContext *context, ::llvm::Optional<Location> location,
901     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
902     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
903   ResizeOpAdaptor adaptor(operands, attributes);
904   llvm::SmallVector<int64_t, 4> outputShape;
905   outputShape.resize(4, ShapedType::kDynamicSize);
906 
907   int32_t inHeight = ShapedType::kDynamicSize;
908   int32_t inWidth = ShapedType::kDynamicSize;
909 
910   ShapeAdaptor inputShape = operands.getShape(adaptor.input());
911   if (inputShape.hasRank()) {
912     outputShape[0] = inputShape.getDimSize(0);
913     outputShape[3] = inputShape.getDimSize(3);
914 
915     inHeight = inputShape.getDimSize(1);
916     inWidth = inputShape.getDimSize(2);
917   }
918 
919   int32_t shift = adaptor.shift().getValue().getSExtValue();
920   llvm::SmallVector<int64_t> newShape;
921   getI64Values(adaptor.output_size(), newShape);
922   outputShape[1] = newShape[0];
923   outputShape[2] = newShape[1];
924 
925   llvm::SmallVector<int64_t> strideInt;
926   llvm::SmallVector<int64_t> offsetInt;
927   llvm::SmallVector<double> strideFp;
928   llvm::SmallVector<double> offsetFp;
929   getI64Values(adaptor.offset(), offsetInt);
930   getF64Values(adaptor.offset_fp(), offsetFp);
931   getI64Values(adaptor.stride(), strideInt);
932   getF64Values(adaptor.stride_fp(), strideFp);
933 
934   // If we have a 0 zero in integers we know that the resize indexing needs to
935   // be performed in floating point. Use the floating point varient to compute
936   // the resize shape.
937   bool fpMode = strideInt[0] == 0;
938 
939   // We can compute the output shape if attribute specifies unknown dimensions
940   // based on the offset and stride. If we perfectly line up to the last index
941   // we need to round up the size to include it.
942   if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
943     float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
944     float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
945     outputShape[1] = std::ceil(sizeFp) + round;
946   }
947 
948   if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
949     float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
950     float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
951     outputShape[2] = std::ceil(sizeFp) + round;
952   }
953 
954   if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
955     int64_t size = (inHeight - 1);
956     size = ((size << shift) - offsetInt[0]) / strideInt[0];
957     outputShape[1] = size + 1;
958   }
959 
960   if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
961     int64_t size = (inWidth - 1);
962     size = ((size << shift) - offsetInt[1]) / strideInt[1];
963     outputShape[2] = size + 1;
964   }
965 
966   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
967   return success();
968 }
969 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)970 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
971     MLIRContext *context, ::llvm::Optional<Location> location,
972     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
973     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
974   llvm::SmallVector<int64_t> outputShape;
975   outputShape.resize(3, ShapedType::kDynamicSize);
976 
977   ShapeAdaptor valuesInShape = operands.getShape(0);
978   if (valuesInShape.hasRank()) {
979     outputShape[0] = valuesInShape.getDimSize(0);
980     outputShape[1] = valuesInShape.getDimSize(1);
981     outputShape[2] = valuesInShape.getDimSize(2);
982   }
983 
984   ShapeAdaptor indicesShape = operands.getShape(1);
985   if (indicesShape.hasRank()) {
986     if (outputShape[0] == ShapedType::kDynamicSize)
987       outputShape[0] = indicesShape.getDimSize(0);
988   }
989 
990   ShapeAdaptor inputShape = operands.getShape(2);
991   if (inputShape.hasRank()) {
992     if (outputShape[0] == ShapedType::kDynamicSize)
993       outputShape[0] = inputShape.getDimSize(0);
994     if (outputShape[2] == ShapedType::kDynamicSize)
995       outputShape[2] = inputShape.getDimSize(2);
996   }
997 
998   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
999   return success();
1000 }
1001 
ReduceInferReturnTypes(ShapeAdaptor operandShape,IntegerAttr axis,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1002 static LogicalResult ReduceInferReturnTypes(
1003     ShapeAdaptor operandShape, IntegerAttr axis,
1004     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1005   if (!operandShape.hasRank()) {
1006     inferredReturnShapes.push_back(ShapedTypeComponents());
1007     return success();
1008   }
1009 
1010   SmallVector<int64_t> outputShape;
1011   operandShape.getDims(outputShape);
1012   int64_t axisVal = axis.getValue().getSExtValue();
1013   outputShape[axisVal] = 1;
1014   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1015   return success();
1016 }
1017 
1018 #define REDUCE_SHAPE_INFER(OP)                                                 \
1019   LogicalResult OP::inferReturnTypeComponents(                                 \
1020       MLIRContext *context, ::llvm::Optional<Location> location,               \
1021       ValueShapeRange operands, DictionaryAttr attributes,                     \
1022       RegionRange regions,                                                     \
1023       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
1024     return ReduceInferReturnTypes(operands.getShape(0),                        \
1025                                   attributes.get("axis").cast<IntegerAttr>(),  \
1026                                   inferredReturnShapes);                       \
1027   }
1028 
1029 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)1030 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1031 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1032 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1033 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1034 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1035 #undef REDUCE_SHAPE_INFER
1036 
1037 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1038                                            SmallVector<int64_t> &outShape) {
1039   int64_t outRank = 0;
1040   for (int i = 0, e = operands.size(); i != e; ++i) {
1041     auto shape = operands.getShape(i);
1042     if (!shape.hasRank()) {
1043       return failure();
1044     }
1045     outRank = std::max<int64_t>(outRank, shape.getRank());
1046   }
1047 
1048   outShape.resize(outRank, 1);
1049 
1050   for (int i = 0, e = operands.size(); i != e; ++i) {
1051     auto shape = operands.getShape(i);
1052     auto rankDiff = outShape.size() - shape.getRank();
1053 
1054     for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1055       auto dim1 = outShape[i + rankDiff];
1056       auto dim2 = shape.getDimSize(i);
1057       auto resolvedDim = dim1;
1058 
1059       if (dim1 == 1) {
1060         resolvedDim = dim2;
1061       } else if (dim2 == 1) {
1062         resolvedDim = dim1;
1063       } else if (dim1 != dim2) {
1064         return failure();
1065       }
1066       outShape[i + rankDiff] = resolvedDim;
1067     }
1068   }
1069 
1070   return success();
1071 }
1072 
NAryInferReturnTypes(const ValueShapeRange & operands,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1073 static LogicalResult NAryInferReturnTypes(
1074     const ValueShapeRange &operands,
1075     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1076   llvm::SmallVector<int64_t> outShape;
1077   if (resolveBroadcastShape(operands, outShape).failed()) {
1078     inferredReturnShapes.push_back(ShapedTypeComponents());
1079   } else {
1080     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1081   }
1082   return success();
1083 }
1084 
1085 #define NARY_SHAPE_INFER(OP)                                                   \
1086   LogicalResult OP::inferReturnTypeComponents(                                 \
1087       MLIRContext *context, ::llvm::Optional<Location> location,               \
1088       ValueShapeRange operands, DictionaryAttr attributes,                     \
1089       RegionRange regions,                                                     \
1090       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
1091     return NAryInferReturnTypes(operands, inferredReturnShapes);               \
1092   }
1093 
1094 NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)1095 NARY_SHAPE_INFER(tosa::AddOp)
1096 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1097 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1098 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1099 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1100 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1101 NARY_SHAPE_INFER(tosa::CastOp)
1102 NARY_SHAPE_INFER(tosa::CeilOp)
1103 NARY_SHAPE_INFER(tosa::ClampOp)
1104 NARY_SHAPE_INFER(tosa::ClzOp)
1105 NARY_SHAPE_INFER(tosa::DivOp)
1106 NARY_SHAPE_INFER(tosa::EqualOp)
1107 NARY_SHAPE_INFER(tosa::ExpOp)
1108 NARY_SHAPE_INFER(tosa::FloorOp)
1109 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1110 NARY_SHAPE_INFER(tosa::GreaterOp)
1111 NARY_SHAPE_INFER(tosa::IdentityOp)
1112 NARY_SHAPE_INFER(tosa::LogOp)
1113 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1114 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1115 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1116 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1117 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1118 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1119 NARY_SHAPE_INFER(tosa::MaximumOp)
1120 NARY_SHAPE_INFER(tosa::MinimumOp)
1121 NARY_SHAPE_INFER(tosa::MulOp)
1122 NARY_SHAPE_INFER(tosa::NegateOp)
1123 NARY_SHAPE_INFER(tosa::PowOp)
1124 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1125 NARY_SHAPE_INFER(tosa::ReluNOp)
1126 NARY_SHAPE_INFER(tosa::RescaleOp)
1127 NARY_SHAPE_INFER(tosa::ReverseOp)
1128 NARY_SHAPE_INFER(tosa::RsqrtOp)
1129 NARY_SHAPE_INFER(tosa::SelectOp)
1130 NARY_SHAPE_INFER(tosa::SubOp)
1131 NARY_SHAPE_INFER(tosa::TanhOp)
1132 NARY_SHAPE_INFER(tosa::SigmoidOp)
1133 #undef PRED_SHAPE_INFER
1134 
1135 static LogicalResult poolingInferReturnTypes(
1136     const ValueShapeRange &operands, DictionaryAttr attributes,
1137     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1138   ShapeAdaptor inputShape = operands.getShape(0);
1139   llvm::SmallVector<int64_t> outputShape;
1140   outputShape.resize(4, -1);
1141 
1142   // We only know the rank if the input type is unranked.
1143   if (!inputShape) {
1144     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1145     return success();
1146   }
1147 
1148   // Batch and number of channels are identical for pooling layer.
1149   outputShape[0] = inputShape.getDimSize(0);
1150   outputShape[3] = inputShape.getDimSize(3);
1151 
1152   int32_t height = inputShape.getDimSize(1);
1153   int32_t width = inputShape.getDimSize(2);
1154 
1155   llvm::SmallVector<int64_t> kernel;
1156   llvm::SmallVector<int64_t> stride;
1157   llvm::SmallVector<int64_t> pad;
1158 
1159   getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
1160   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1161   getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
1162 
1163   if (height != -1) {
1164     int32_t padded = height + pad[0] + pad[1] - kernel[0];
1165     outputShape[1] = padded / stride[0] + 1;
1166   }
1167 
1168   if (width != -1) {
1169     int32_t padded = width + pad[2] + pad[3] - kernel[1];
1170     outputShape[2] = padded / stride[1] + 1;
1171   }
1172 
1173   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1174   return success();
1175 }
1176 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1177 LogicalResult Conv2DOp::inferReturnTypeComponents(
1178     MLIRContext *context, ::llvm::Optional<Location> location,
1179     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1180     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1181   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1182   Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1183 
1184   int32_t inputWidth = ShapedType::kDynamicSize;
1185   int32_t inputHeight = ShapedType::kDynamicSize;
1186   int32_t weightWidth = ShapedType::kDynamicSize;
1187   int32_t weightHeight = ShapedType::kDynamicSize;
1188 
1189   // Input shape describes input width/height and batch.
1190 
1191   ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1192   if (inputShape.hasRank()) {
1193     outputShape[0] = inputShape.getDimSize(0);
1194     inputHeight = inputShape.getDimSize(1);
1195     inputWidth = inputShape.getDimSize(2);
1196   }
1197 
1198   // Weight shapes describes the filter width/height and the output channels.
1199   ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1200   if (weightShape.hasRank()) {
1201     outputShape[3] = weightShape.getDimSize(0);
1202     weightHeight = weightShape.getDimSize(1);
1203     weightWidth = weightShape.getDimSize(2);
1204   }
1205 
1206   // Bias shape can describe the output channels.
1207   ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1208   if (biasShape.hasRank()) {
1209     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1210                          ? biasShape.getDimSize(0)
1211                          : outputShape[3];
1212   }
1213 
1214   llvm::SmallVector<int64_t> dilation;
1215   llvm::SmallVector<int64_t> padding;
1216   llvm::SmallVector<int64_t> stride;
1217 
1218   getI64Values(adaptor.dilation(), dilation);
1219   getI64Values(adaptor.pad(), padding);
1220   getI64Values(adaptor.stride(), stride);
1221 
1222   if (!ShapedType::isDynamic(inputHeight) &&
1223       !ShapedType::isDynamic(weightHeight)) {
1224     int32_t inputSize = inputHeight + padding[0] + padding[1];
1225     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1226     int32_t unstridedResult = inputSize - filterSize + 1;
1227     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1228   }
1229 
1230   if (!ShapedType::isDynamic(inputWidth) &&
1231       !ShapedType::isDynamic(weightWidth)) {
1232     int32_t inputSize = inputWidth + padding[2] + padding[3];
1233     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1234     int32_t unstridedResult = inputSize - filterSize + 1;
1235     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1236   }
1237 
1238   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1239   return success();
1240 }
1241 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1242 LogicalResult Conv3DOp::inferReturnTypeComponents(
1243     MLIRContext *context, ::llvm::Optional<Location> location,
1244     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1245     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1246   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1247   Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1248 
1249   int32_t inputWidth = ShapedType::kDynamicSize;
1250   int32_t inputHeight = ShapedType::kDynamicSize;
1251   int32_t inputDepth = ShapedType::kDynamicSize;
1252 
1253   int32_t weightWidth = ShapedType::kDynamicSize;
1254   int32_t weightHeight = ShapedType::kDynamicSize;
1255   int32_t weightDepth = ShapedType::kDynamicSize;
1256 
1257   // Input shape describes input width/height and batch.
1258   ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1259   if (inputShape.hasRank()) {
1260     outputShape[0] = inputShape.getDimSize(0);
1261     inputHeight = inputShape.getDimSize(1);
1262     inputWidth = inputShape.getDimSize(2);
1263     inputDepth = inputShape.getDimSize(3);
1264   }
1265 
1266   // Weight shapes describes the filter width/height and the output channels.
1267   ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1268   if (weightShape.hasRank()) {
1269     outputShape[4] = weightShape.getDimSize(0);
1270     weightHeight = weightShape.getDimSize(1);
1271     weightWidth = weightShape.getDimSize(2);
1272     weightDepth = weightShape.getDimSize(3);
1273   }
1274 
1275   // Bias shape can describe the output channels.
1276   ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1277   if (biasShape.hasRank()) {
1278     outputShape[4] =
1279         (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1280   }
1281 
1282   llvm::SmallVector<int64_t> dilation;
1283   llvm::SmallVector<int64_t> padding;
1284   llvm::SmallVector<int64_t> stride;
1285 
1286   getI64Values(adaptor.dilation(), dilation);
1287   getI64Values(adaptor.pad(), padding);
1288   getI64Values(adaptor.stride(), stride);
1289 
1290   if (!ShapedType::isDynamic(inputHeight) &&
1291       !ShapedType::isDynamic(weightHeight)) {
1292     int32_t inputSize = inputHeight + padding[0] + padding[1];
1293     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1294     int32_t unstridedResult = inputSize - filterSize + 1;
1295     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1296   }
1297 
1298   if (!ShapedType::isDynamic(inputWidth) &&
1299       !ShapedType::isDynamic(weightWidth)) {
1300     int32_t inputSize = inputWidth + padding[2] + padding[3];
1301     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1302     int32_t unstridedResult = inputSize - filterSize + 1;
1303     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1304   }
1305 
1306   if (!ShapedType::isDynamic(inputDepth) &&
1307       !ShapedType::isDynamic(weightDepth)) {
1308     int32_t inputSize = inputDepth + padding[4] + padding[5];
1309     int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1310     int32_t unstridedResult = inputSize - filterSize + 1;
1311     outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1312   }
1313 
1314   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1315   return success();
1316 }
1317 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1318 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1319     MLIRContext *context, ::llvm::Optional<Location> location,
1320     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1321     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1322   return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1323 }
1324 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1325 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1326     MLIRContext *context, ::llvm::Optional<Location> location,
1327     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1328     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1329   return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1330 }
1331 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1332 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1333     MLIRContext *context, ::llvm::Optional<Location> location,
1334     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1335     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1336   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1337   DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1338 
1339   int32_t inputWidth = ShapedType::kDynamicSize;
1340   int32_t inputHeight = ShapedType::kDynamicSize;
1341   int32_t inputChannels = ShapedType::kDynamicSize;
1342 
1343   int32_t weightWidth = ShapedType::kDynamicSize;
1344   int32_t weightHeight = ShapedType::kDynamicSize;
1345   int32_t depthChannels = ShapedType::kDynamicSize;
1346 
1347   // Input shape describes input width/height and batch.
1348   ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1349   if (inputShape.hasRank()) {
1350     outputShape[0] = inputShape.getDimSize(0);
1351     inputHeight = inputShape.getDimSize(1);
1352     inputWidth = inputShape.getDimSize(2);
1353     inputChannels = inputShape.getDimSize(3);
1354   }
1355 
1356   // Weight shapes describes the filter width/height and the output channels.
1357   ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1358   if (weightShape.hasRank()) {
1359     weightHeight = weightShape.getDimSize(0);
1360     weightWidth = weightShape.getDimSize(1);
1361     inputChannels = ShapedType::isDynamic(inputChannels)
1362                         ? weightShape.getDimSize(2)
1363                         : inputChannels;
1364     depthChannels = weightShape.getDimSize(3);
1365   }
1366 
1367   // If both inputChannels and depthChannels are available we can determine
1368   // the output channels.
1369   if (!ShapedType::isDynamic(inputChannels) &&
1370       !ShapedType::isDynamic(depthChannels)) {
1371     outputShape[3] = inputChannels * depthChannels;
1372   }
1373 
1374   // Bias shape can describe the output channels.
1375   ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1376   if (biasShape.hasRank()) {
1377     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1378                          ? biasShape.getDimSize(0)
1379                          : outputShape[3];
1380   }
1381 
1382   llvm::SmallVector<int64_t> dilation;
1383   llvm::SmallVector<int64_t> padding;
1384   llvm::SmallVector<int64_t> stride;
1385 
1386   getI64Values(adaptor.dilation(), dilation);
1387   getI64Values(adaptor.pad(), padding);
1388   getI64Values(adaptor.stride(), stride);
1389 
1390   if (!ShapedType::isDynamic(inputHeight) &&
1391       !ShapedType::isDynamic(weightHeight)) {
1392     int32_t inputSize = inputHeight + padding[0] + padding[1];
1393     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1394     int32_t unstridedResult = inputSize - filterSize + 1;
1395     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1396   }
1397 
1398   if (!ShapedType::isDynamic(inputWidth) &&
1399       !ShapedType::isDynamic(weightWidth)) {
1400     int32_t inputSize = inputWidth + padding[2] + padding[3];
1401     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1402     int32_t unstridedResult = inputSize - filterSize + 1;
1403     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1404   }
1405 
1406   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1407   return success();
1408 }
1409 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1410 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1411     MLIRContext *context, ::llvm::Optional<Location> location,
1412     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1413     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1414   TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1415   llvm::SmallVector<int64_t> outputShape;
1416   getI64Values(adaptor.out_shape(), outputShape);
1417 
1418   int32_t inputWidth = ShapedType::kDynamicSize;
1419   int32_t inputHeight = ShapedType::kDynamicSize;
1420   int32_t weightWidth = ShapedType::kDynamicSize;
1421   int32_t weightHeight = ShapedType::kDynamicSize;
1422 
1423   // Input shape describes input width/height and batch.
1424   ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1425   if (inputShape.hasRank()) {
1426     outputShape[0] = ShapedType::isDynamic(outputShape[0])
1427                          ? inputShape.getDimSize(0)
1428                          : outputShape[0];
1429     inputHeight = inputShape.getDimSize(1);
1430     inputWidth = inputShape.getDimSize(2);
1431   }
1432 
1433   // Weight shapes describes the filter width/height and the output channels.
1434   ShapeAdaptor weightShape = operands.getShape(adaptor.input());
1435   if (weightShape.hasRank()) {
1436     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1437                          ? weightShape.getDimSize(0)
1438                          : outputShape[3];
1439     weightHeight = weightShape.getDimSize(1);
1440     weightWidth = weightShape.getDimSize(2);
1441   }
1442 
1443   // Bias shape can describe the output channels.
1444   ShapeAdaptor biasShape = operands.getShape(adaptor.input());
1445   if (biasShape.hasRank()) {
1446     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1447                          ? biasShape.getDimSize(0)
1448                          : outputShape[3];
1449   }
1450 
1451   llvm::SmallVector<int64_t> dilation;
1452   llvm::SmallVector<int64_t> padding;
1453   llvm::SmallVector<int64_t> stride;
1454 
1455   getI64Values(adaptor.dilation(), dilation);
1456   getI64Values(adaptor.out_pad(), padding);
1457   getI64Values(adaptor.stride(), stride);
1458 
1459   if (!ShapedType::isDynamic(inputHeight) &&
1460       !ShapedType::isDynamic(weightHeight)) {
1461     int32_t dilated = (weightHeight - 1) * dilation[0] + 1;
1462     int32_t calculateSize =
1463         (inputHeight - 1) * stride[0] - padding[0] + dilated;
1464     outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1465   }
1466 
1467   if (!ShapedType::isDynamic(inputWidth) &&
1468       !ShapedType::isDynamic(weightWidth)) {
1469     int32_t dilated = (weightWidth - 1) * dilation[1] + 1;
1470     int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated;
1471     outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1472   }
1473 
1474   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1475   return success();
1476 }
1477 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1478 LogicalResult IfOp::inferReturnTypeComponents(
1479     MLIRContext *context, ::llvm::Optional<Location> location,
1480     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1481     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1482   llvm::SmallVector<tosa::YieldOp> yieldOps;
1483   for (Region *region : regions) {
1484     for (auto &block : *region)
1485       if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1486         yieldOps.push_back(returnOp);
1487   }
1488 
1489   if (yieldOps.empty())
1490     return failure();
1491 
1492   // Get the initial type information for the yield op.
1493   llvm::SmallVector<ValueKnowledge> resultKnowledge;
1494   resultKnowledge.reserve(yieldOps.front().getNumOperands());
1495   for (auto operand : yieldOps.front().getOperands()) {
1496     resultKnowledge.push_back(
1497         ValueKnowledge::getKnowledgeFromType(operand.getType()));
1498   }
1499 
1500   for (auto yieldOp : yieldOps) {
1501     if (resultKnowledge.size() != yieldOp.getNumOperands())
1502       return failure();
1503 
1504     for (auto it : llvm::enumerate(yieldOp.getOperands())) {
1505       int32_t index = it.index();
1506       auto meet = ValueKnowledge::meet(
1507           resultKnowledge[index],
1508           ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1509       if (!meet)
1510         continue;
1511       resultKnowledge[index] = meet;
1512     }
1513   }
1514 
1515   for (const ValueKnowledge &result : resultKnowledge) {
1516     inferredReturnShapes.push_back(result.getShapedTypeComponents());
1517   }
1518 
1519   return success();
1520 }
1521 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1522 LogicalResult WhileOp::inferReturnTypeComponents(
1523     MLIRContext *context, ::llvm::Optional<Location> location,
1524     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1525     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1526   llvm::SmallVector<tosa::YieldOp> yieldOps;
1527   for (auto &block : *regions[1])
1528     if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1529       yieldOps.push_back(returnOp);
1530 
1531   // TOSA's while must have a tosa.yield as its terminator. If not found this
1532   // tosa.while is invalid.
1533   if (yieldOps.empty())
1534     return failure();
1535 
1536   // Get the initial type information from the operand types.
1537   llvm::SmallVector<ValueKnowledge> resultKnowledge;
1538   resultKnowledge.reserve(yieldOps.front().getNumOperands());
1539   for (auto operand : yieldOps.front().getOperands()) {
1540     resultKnowledge.push_back(
1541         ValueKnowledge::getKnowledgeFromType(operand.getType()));
1542   }
1543 
1544   for (auto yieldOp : yieldOps) {
1545     if (resultKnowledge.size() != yieldOp.getNumOperands())
1546       return failure();
1547 
1548     for (auto it : llvm::enumerate(yieldOp.getOperands())) {
1549       int32_t index = it.index();
1550       if (auto meet = ValueKnowledge::meet(
1551               resultKnowledge[index],
1552               ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1553         resultKnowledge[index] = meet;
1554       };
1555     }
1556   }
1557 
1558   for (const ValueKnowledge &result : resultKnowledge) {
1559     inferredReturnShapes.push_back(result.getShapedTypeComponents());
1560   }
1561 
1562   return success();
1563 }
1564 
1565 //===----------------------------------------------------------------------===//
1566 // TOSA Operator Definitions.
1567 //===----------------------------------------------------------------------===//
1568 
1569 #define GET_OP_CLASSES
1570 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1571