1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 using namespace mlir::shape;
21 
22 namespace {
23 #include "ShapeCanonicalization.inc"
24 }
25 
ShapeDialect(MLIRContext * context)26 ShapeDialect::ShapeDialect(MLIRContext *context)
27     : Dialect(getDialectNamespace(), context) {
28   addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
31       >();
32   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
33            WitnessType>();
34   // Allow unknown operations during prototyping and testing. As the dialect is
35   // still evolving it makes it simple to start with an unregistered ops and
36   // try different variants before actually defining the op.
37   allowUnknownOperations();
38 }
39 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)40 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
41                                              Attribute value, Type type,
42                                              Location loc) {
43   if (auto shapeType = type.dyn_cast<ShapeType>())
44     return builder.create<ConstShapeOp>(loc, type,
45                                         value.cast<DenseIntElementsAttr>());
46   if (auto sizeType = type.dyn_cast<SizeType>())
47     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
48   if (auto witnessType = type.dyn_cast<WitnessType>())
49     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
50   return nullptr;
51 }
52 
53 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const54 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
55   StringRef keyword;
56   if (parser.parseKeyword(&keyword))
57     return Type();
58 
59   if (keyword == "component")
60     return ComponentType::get(getContext());
61   if (keyword == "element")
62     return ElementType::get(getContext());
63   if (keyword == "shape")
64     return ShapeType::get(getContext());
65   if (keyword == "size")
66     return SizeType::get(getContext());
67   if (keyword == "value_shape")
68     return ValueShapeType::get(getContext());
69   if (keyword == "witness")
70     return WitnessType::get(getContext());
71 
72   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
73   return Type();
74 }
75 
76 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const77 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
78   switch (type.getKind()) {
79   case ShapeTypes::Component:
80     os << "component";
81     return;
82   case ShapeTypes::Element:
83     os << "element";
84     return;
85   case ShapeTypes::Size:
86     os << "size";
87     return;
88   case ShapeTypes::Shape:
89     os << "shape";
90     return;
91   case ShapeTypes::ValueShape:
92     os << "value_shape";
93     return;
94   case ShapeTypes::Witness:
95     os << "witness";
96     return;
97   default:
98     llvm_unreachable("unexpected 'shape' type kind");
99   }
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // AnyOp
104 //===----------------------------------------------------------------------===//
105 
106 // TODO: Canonicalization should be implemented for shapes that can be
107 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)108 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
109   // Only the last operand is checked because AnyOp is commutative.
110   if (operands.back())
111     return operands.back();
112 
113   return nullptr;
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // AssumingOp
118 //===----------------------------------------------------------------------===//
119 
parseAssumingOp(OpAsmParser & parser,OperationState & result)120 static ParseResult parseAssumingOp(OpAsmParser &parser,
121                                    OperationState &result) {
122   result.regions.reserve(1);
123   Region *doRegion = result.addRegion();
124 
125   auto &builder = parser.getBuilder();
126   OpAsmParser::OperandType cond;
127   if (parser.parseOperand(cond) ||
128       parser.resolveOperand(cond, builder.getType<WitnessType>(),
129                             result.operands))
130     return failure();
131 
132   // Parse optional results type list.
133   if (parser.parseOptionalArrowTypeList(result.types))
134     return failure();
135 
136   // Parse the region and add a terminator if elided.
137   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
138     return failure();
139   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
140 
141   // Parse the optional attribute list.
142   if (parser.parseOptionalAttrDict(result.attributes))
143     return failure();
144   return success();
145 }
146 
print(OpAsmPrinter & p,AssumingOp op)147 static void print(OpAsmPrinter &p, AssumingOp op) {
148   bool yieldsResults = !op.results().empty();
149 
150   p << AssumingOp::getOperationName() << " " << op.witness();
151   if (yieldsResults) {
152     p << " -> (" << op.getResultTypes() << ")";
153   }
154   p.printRegion(op.doRegion(),
155                 /*printEntryBlockArgs=*/false,
156                 /*printBlockTerminators=*/yieldsResults);
157   p.printOptionalAttrDict(op.getAttrs());
158 }
159 
160 namespace {
161 // Removes AssumingOp with a passing witness and inlines the region.
162 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
163   using OpRewritePattern<AssumingOp>::OpRewritePattern;
164 
matchAndRewrite__anonb0061b370211::AssumingWithTrue165   LogicalResult matchAndRewrite(AssumingOp op,
166                                 PatternRewriter &rewriter) const override {
167     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
168     if (!witness || !witness.passingAttr())
169       return failure();
170 
171     AssumingOp::inlineRegionIntoParent(op, rewriter);
172     return success();
173   }
174 };
175 } // namespace
176 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)177 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
178                                              MLIRContext *context) {
179   // If taking a passing witness, inline region.
180   patterns.insert<AssumingWithTrue>(context);
181 }
182 
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)183 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
184                                         PatternRewriter &rewriter) {
185   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
186   auto *assumingBlock = op.getBody();
187   auto initPosition = rewriter.getInsertionPoint();
188   auto *blockAfterAssuming =
189       rewriter.splitBlock(blockBeforeAssuming, initPosition);
190 
191   // Remove the AssumingOp and AssumingYieldOp.
192   auto &yieldOp = assumingBlock->back();
193   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
194   rewriter.replaceOp(op, yieldOp.getOperands());
195   rewriter.eraseOp(&yieldOp);
196 
197   // Merge blocks together as there was no branching behavior from the
198   // AssumingOp.
199   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
200   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // AssumingAllOp
205 //===----------------------------------------------------------------------===//
fold(ArrayRef<Attribute> operands)206 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
207   // Iterate in reverse to first handle all constant operands. They are
208   // guaranteed to be the tail of the inputs because this is commutative.
209   for (int idx = operands.size() - 1; idx >= 0; idx--) {
210     Attribute a = operands[idx];
211     // Cannot fold if any inputs are not constant;
212     if (!a)
213       return nullptr;
214 
215     // We do not need to keep statically known values after handling them in
216     // this method.
217     getOperation()->eraseOperand(idx);
218 
219     // Always false if any input is statically known false
220     if (!a.cast<BoolAttr>().getValue())
221       return a;
222   }
223   // If this is reached, all inputs were statically known passing.
224   return BoolAttr::get(true, getContext());
225 }
226 
verify(AssumingAllOp op)227 static LogicalResult verify(AssumingAllOp op) {
228   // Ensure that AssumingAllOp contains at least one operand
229   if (op.getNumOperands() == 0)
230     return op.emitOpError("no operands specified");
231 
232   return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // BroadcastOp
237 //===----------------------------------------------------------------------===//
238 
fold(ArrayRef<Attribute> operands)239 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
240   if (!operands[0] || !operands[1])
241     return nullptr;
242   auto lhsShape = llvm::to_vector<6>(
243       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
244   auto rhsShape = llvm::to_vector<6>(
245       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
246   SmallVector<int64_t, 6> resultShape;
247   // If the shapes are not compatible, we can't fold it.
248   // TODO: Fold to an "error".
249   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
250     return nullptr;
251   Builder builder(getContext());
252   return builder.getIndexTensorAttr(resultShape);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // ConcatOp
257 //===----------------------------------------------------------------------===//
258 
fold(ArrayRef<Attribute> operands)259 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
260   if (!operands[0] || !operands[1])
261     return nullptr;
262   auto lhsShape = llvm::to_vector<6>(
263       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
264   auto rhsShape = llvm::to_vector<6>(
265       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
266   SmallVector<int64_t, 6> resultShape;
267   resultShape.append(lhsShape.begin(), lhsShape.end());
268   resultShape.append(rhsShape.begin(), rhsShape.end());
269   Builder builder(getContext());
270   return builder.getIndexTensorAttr(resultShape);
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // ConstShapeOp
275 //===----------------------------------------------------------------------===//
276 
print(OpAsmPrinter & p,ConstShapeOp & op)277 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
278   p << "shape.const_shape ";
279   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
280   p << "[";
281   interleaveComma(op.shape().getValues<int64_t>(), p,
282                   [&](int64_t i) { p << i; });
283   p << "]";
284 }
285 
parseConstShapeOp(OpAsmParser & parser,OperationState & result)286 static ParseResult parseConstShapeOp(OpAsmParser &parser,
287                                      OperationState &result) {
288   if (parser.parseOptionalAttrDict(result.attributes))
289     return failure();
290   // We piggy-back on ArrayAttr parsing, though we don't internally store the
291   // shape as an ArrayAttr.
292   // TODO: Implement custom parser and maybe make syntax a bit more concise.
293   Attribute extentsRaw;
294   NamedAttrList dummy;
295   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
296     return failure();
297   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
298   if (!extentsArray)
299     return failure();
300   SmallVector<int64_t, 6> ints;
301   for (Attribute extent : extentsArray) {
302     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
303     if (!attr)
304       return failure();
305     ints.push_back(attr.getInt());
306   }
307   Builder &builder = parser.getBuilder();
308   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
309 
310   result.types.push_back(ShapeType::get(builder.getContext()));
311   return success();
312 }
313 
fold(ArrayRef<Attribute>)314 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
315 
316 //===----------------------------------------------------------------------===//
317 // CstrBroadcastableOp
318 //===----------------------------------------------------------------------===//
319 
320 namespace {
321 // Given an input shape Value, try to obtain the shape's values.
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)322 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
323   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
324     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
325     if (!type.hasRank())
326       return failure();
327     shapeValues = llvm::to_vector<6>(type.getShape());
328     return success();
329   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
330     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
331     return success();
332   } else {
333     return failure();
334   }
335 }
336 
337 // For shapes that were created by some operations, we can obtain partial
338 // information on the shapes and sometimes determine if they will be
339 // broadcastable with that.
340 struct CstrBroadcastablePartialInfo
341     : public OpRewritePattern<CstrBroadcastableOp> {
342   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
343 
matchAndRewrite__anonb0061b370411::CstrBroadcastablePartialInfo344   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
345                                 PatternRewriter &rewriter) const override {
346     SmallVector<int64_t, 6> lhsShape, rhsShape;
347     if (failed(getShapeVec(op.lhs(), lhsShape)))
348       return failure();
349     if (failed(getShapeVec(op.rhs(), rhsShape)))
350       return failure();
351     if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
352       return failure();
353 
354     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
355     return success();
356   }
357 };
358 
359 // Scalars are always broadcastable.
360 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
361   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
362 
matchAndRewrite__anonb0061b370411::CstrBroadcastableScalar363   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
364                                 PatternRewriter &rewriter) const override {
365     SmallVector<int64_t, 6> shape;
366     if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
367       return failure();
368     if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
369       return failure();
370 
371     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
372     return success();
373   }
374 };
375 
376 } // namespace
377 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)378 void CstrBroadcastableOp::getCanonicalizationPatterns(
379     OwningRewritePatternList &patterns, MLIRContext *context) {
380   // Canonicalization patterns have overlap with the considerations during
381   // folding in case additional shape information is inferred at some point that
382   // does not result in folding.
383   patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
384                   CstrBroadcastableScalar>(context);
385 }
386 
fold(ArrayRef<Attribute> operands)387 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
388   // Both operands are not needed if one is a scalar.
389   if (operands[0] &&
390       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
391     return BoolAttr::get(true, getContext());
392   if (operands[1] &&
393       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
394     return BoolAttr::get(true, getContext());
395 
396   if (operands[0] && operands[1]) {
397     auto lhsShape = llvm::to_vector<6>(
398         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
399     auto rhsShape = llvm::to_vector<6>(
400         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
401     SmallVector<int64_t, 6> resultShape;
402     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
403       return BoolAttr::get(true, getContext());
404   }
405 
406   // Lastly, see if folding can be completed based on what constraints are known
407   // on the input shapes.
408   SmallVector<int64_t, 6> lhsShape, rhsShape;
409   if (failed(getShapeVec(lhs(), lhsShape)))
410     return nullptr;
411   if (failed(getShapeVec(rhs(), rhsShape)))
412     return nullptr;
413 
414   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
415     return BoolAttr::get(true, getContext());
416 
417   // Because a failing witness result here represents an eventual assertion
418   // failure, we do not replace it with a constant witness.
419   return nullptr;
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // CstrEqOp
424 //===----------------------------------------------------------------------===//
425 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)426 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
427                                            MLIRContext *context) {
428   // If inputs are equal, return passing witness
429   patterns.insert<CstrEqEqOps>(context);
430 }
431 
fold(ArrayRef<Attribute> operands)432 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
433   if (llvm::all_of(operands,
434                    [&](Attribute a) { return a && a == operands[0]; }))
435     return BoolAttr::get(true, getContext());
436 
437   // Because a failing witness result here represents an eventual assertion
438   // failure, we do not try to replace it with a constant witness. Similarly, we
439   // cannot if there are any non-const inputs.
440   return nullptr;
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // ConstSizeOp
445 //===----------------------------------------------------------------------===//
446 
build(OpBuilder & builder,OperationState & result,int64_t value)447 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
448                         int64_t value) {
449   build(builder, result, builder.getIndexAttr(value));
450 }
451 
fold(ArrayRef<Attribute>)452 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
453 
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)454 void ConstSizeOp::getAsmResultNames(
455     llvm::function_ref<void(Value, StringRef)> setNameFn) {
456   SmallString<4> buffer;
457   llvm::raw_svector_ostream os(buffer);
458   os << "c" << value();
459   setNameFn(getResult(), os.str());
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // ConstWitnessOp
464 //===----------------------------------------------------------------------===//
465 
fold(ArrayRef<Attribute>)466 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
467 
468 //===----------------------------------------------------------------------===//
469 // IndexToSizeOp
470 //===----------------------------------------------------------------------===//
471 
fold(ArrayRef<Attribute> operands)472 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
473   // Constant values of both types, `shape.size` and `index`, are represented as
474   // `IntegerAttr`s which makes constant folding simple.
475   if (Attribute arg = operands[0])
476     return arg;
477   return {};
478 }
479 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)480 void IndexToSizeOp::getCanonicalizationPatterns(
481     OwningRewritePatternList &patterns, MLIRContext *context) {
482   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // FromExtentsOp
487 //===----------------------------------------------------------------------===//
488 
fold(ArrayRef<Attribute> operands)489 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
490   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
491     return nullptr;
492   SmallVector<int64_t, 6> extents;
493   for (auto attr : operands)
494     extents.push_back(attr.cast<IntegerAttr>().getInt());
495   Builder builder(getContext());
496   return builder.getIndexTensorAttr(extents);
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // GetExtentOp
501 //===----------------------------------------------------------------------===//
502 
getConstantDim()503 Optional<int64_t> GetExtentOp::getConstantDim() {
504   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
505     return constSizeOp.value().getLimitedValue();
506   }
507   return llvm::None;
508 }
509 
fold(ArrayRef<Attribute> operands)510 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
511   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
512   if (!elements)
513     return nullptr;
514   Optional<int64_t> dim = getConstantDim();
515   if (!dim.hasValue())
516     return nullptr;
517   if (dim.getValue() >= elements.getNumElements())
518     return nullptr;
519   return elements.getValue({(uint64_t)dim.getValue()});
520 }
521 
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)522 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
523                         int64_t dim) {
524   auto loc = result.location;
525   auto dimAttr = builder.getIndexAttr(dim);
526   Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
527   build(builder, result, shape, dimValue);
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // RankOp
532 //===----------------------------------------------------------------------===//
533 
fold(ArrayRef<Attribute> operands)534 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
535   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
536   if (!shape)
537     return {};
538   int64_t rank = shape.getNumElements();
539   Builder builder(getContext());
540   return builder.getIndexAttr(rank);
541 }
542 
543 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
544 /// Constant folding fails in cases where only the rank is constant, not the
545 /// shape itself.
546 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
547 ///
548 /// Example:
549 ///
550 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
551 /// %rank = shape.rank %shape
552 ///
553 /// becomes
554 ///
555 /// %rank = shape.const_size 3
556 
557 namespace {
558 struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
559   using OpRewritePattern<RankOp>::OpRewritePattern;
560 
matchAndRewrite__anonb0061b370711::RankShapeOfCanonicalizationPattern561   LogicalResult matchAndRewrite(RankOp op,
562                                 PatternRewriter &rewriter) const override {
563     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
564     if (!shapeOfOp)
565       return failure();
566     auto rankedTensorType =
567         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
568     if (!rankedTensorType)
569       return failure();
570     int64_t rank = rankedTensorType.getRank();
571     rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
572     return success();
573   }
574 };
575 } // namespace
576 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)577 void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
578                                          MLIRContext *context) {
579   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // NumElementsOp
584 //===----------------------------------------------------------------------===//
585 
fold(ArrayRef<Attribute> operands)586 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
587 
588   // Fold only when argument constant.
589   Attribute shape = operands[0];
590   if (!shape)
591     return {};
592 
593   APInt product(64, 1);
594   for (auto value : shape.cast<DenseIntElementsAttr>())
595     product *= value;
596   Builder builder(getContext());
597   return builder.getIndexAttr(product.getLimitedValue());
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // ShapeOfOp
602 //===----------------------------------------------------------------------===//
603 
fold(ArrayRef<Attribute>)604 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
605   auto type = getOperand().getType().dyn_cast<ShapedType>();
606   if (!type || !type.hasStaticShape())
607     return nullptr;
608   Builder builder(getContext());
609   return builder.getIndexTensorAttr(type.getShape());
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // SizeToIndexOp
614 //===----------------------------------------------------------------------===//
615 
fold(ArrayRef<Attribute> operands)616 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
617   // Constant values of both types, `shape.size` and `index`, are represented as
618   // `IntegerAttr`s which makes constant folding simple.
619   if (Attribute arg = operands[0])
620     return arg;
621   return {};
622 }
623 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)624 void SizeToIndexOp::getCanonicalizationPatterns(
625     OwningRewritePatternList &patterns, MLIRContext *context) {
626   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
627 }
628 
629 //===----------------------------------------------------------------------===//
630 // YieldOp
631 //===----------------------------------------------------------------------===//
632 
verify(YieldOp op)633 static LogicalResult verify(YieldOp op) {
634   auto *parentOp = op.getParentOp();
635   auto results = parentOp->getResults();
636   auto operands = op.getOperands();
637 
638   if (parentOp->getNumResults() != op.getNumOperands())
639     return op.emitOpError() << "number of operands does not match number of "
640                                "results of its parent";
641   for (auto e : llvm::zip(results, operands))
642     if (std::get<0>(e).getType() != std::get<1>(e).getType())
643       return op.emitOpError()
644              << "types mismatch between yield op and its parent";
645 
646   return success();
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // SplitAtOp
651 //===----------------------------------------------------------------------===//
652 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)653 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
654                               SmallVectorImpl<OpFoldResult> &results) {
655   if (!operands[0] || !operands[1])
656     return failure();
657   auto shapeVec = llvm::to_vector<6>(
658       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
659   auto shape = llvm::makeArrayRef(shapeVec);
660   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
661   // Verify that the split point is in the correct range.
662   // TODO: Constant fold to an "error".
663   int64_t rank = shape.size();
664   if (!(-rank <= splitPoint && splitPoint <= rank))
665     return failure();
666   if (splitPoint < 0)
667     splitPoint += shape.size();
668   Builder builder(operands[0].getContext());
669   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
670   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
671   return success();
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // ToExtentTensorOp
676 //===----------------------------------------------------------------------===//
677 
fold(ArrayRef<Attribute> operands)678 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
679   if (!operands[0])
680     return nullptr;
681   Builder builder(getContext());
682   auto shape = llvm::to_vector<6>(
683       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
684   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
685                                     builder.getIndexType());
686   return DenseIntElementsAttr::get(type, shape);
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // ReduceOp
691 //===----------------------------------------------------------------------===//
692 
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)693 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
694                      ValueRange initVals) {
695   result.addOperands(shape);
696   result.addOperands(initVals);
697 
698   Region *bodyRegion = result.addRegion();
699   bodyRegion->push_back(new Block);
700   Block &bodyBlock = bodyRegion->front();
701   bodyBlock.addArgument(builder.getIndexType());
702   bodyBlock.addArgument(SizeType::get(builder.getContext()));
703 
704   for (Type initValType : initVals.getTypes()) {
705     bodyBlock.addArgument(initValType);
706     result.addTypes(initValType);
707   }
708 }
709 
verify(ReduceOp op)710 static LogicalResult verify(ReduceOp op) {
711   // Verify block arg types.
712   Block &block = op.region().front();
713 
714   auto blockArgsCount = op.initVals().size() + 2;
715   if (block.getNumArguments() != blockArgsCount)
716     return op.emitOpError() << "ReduceOp body is expected to have "
717                             << blockArgsCount << " arguments";
718 
719   if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
720     return op.emitOpError(
721         "argument 0 of ReduceOp body is expected to be of IndexType");
722 
723   if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
724     return op.emitOpError(
725         "argument 1 of ReduceOp body is expected to be of SizeType");
726 
727   for (auto type : llvm::enumerate(op.initVals()))
728     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
729       return op.emitOpError()
730              << "type mismatch between argument " << type.index() + 2
731              << " of ReduceOp body and initial value " << type.index();
732   return success();
733 }
734 
parseReduceOp(OpAsmParser & parser,OperationState & result)735 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
736   auto *ctx = parser.getBuilder().getContext();
737   // Parse operands.
738   SmallVector<OpAsmParser::OperandType, 3> operands;
739   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
740                               OpAsmParser::Delimiter::Paren) ||
741       parser.parseOptionalArrowTypeList(result.types))
742     return failure();
743 
744   // Resolve operands.
745   auto initVals = llvm::makeArrayRef(operands).drop_front();
746   if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
747                             result.operands) ||
748       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
749                              result.operands))
750     return failure();
751 
752   // Parse the body.
753   Region *body = result.addRegion();
754   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
755     return failure();
756 
757   // Parse attributes.
758   if (parser.parseOptionalAttrDict(result.attributes))
759     return failure();
760 
761   return success();
762 }
763 
print(OpAsmPrinter & p,ReduceOp op)764 static void print(OpAsmPrinter &p, ReduceOp op) {
765   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
766     << ") ";
767   p.printOptionalArrowTypeList(op.getResultTypes());
768   p.printRegion(op.region());
769   p.printOptionalAttrDict(op.getAttrs());
770 }
771 
772 namespace mlir {
773 namespace shape {
774 
775 #define GET_OP_CLASSES
776 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
777 
778 } // namespace shape
779 } // namespace mlir
780