1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/Support/raw_ostream.h"
18
19 using namespace mlir;
20 using namespace mlir::shape;
21
22 namespace {
23 #include "ShapeCanonicalization.inc"
24 }
25
ShapeDialect(MLIRContext * context)26 ShapeDialect::ShapeDialect(MLIRContext *context)
27 : Dialect(getDialectNamespace(), context) {
28 addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
31 >();
32 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
33 WitnessType>();
34 // Allow unknown operations during prototyping and testing. As the dialect is
35 // still evolving it makes it simple to start with an unregistered ops and
36 // try different variants before actually defining the op.
37 allowUnknownOperations();
38 }
39
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)40 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
41 Attribute value, Type type,
42 Location loc) {
43 if (auto shapeType = type.dyn_cast<ShapeType>())
44 return builder.create<ConstShapeOp>(loc, type,
45 value.cast<DenseIntElementsAttr>());
46 if (auto sizeType = type.dyn_cast<SizeType>())
47 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
48 if (auto witnessType = type.dyn_cast<WitnessType>())
49 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
50 return nullptr;
51 }
52
53 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const54 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
55 StringRef keyword;
56 if (parser.parseKeyword(&keyword))
57 return Type();
58
59 if (keyword == "component")
60 return ComponentType::get(getContext());
61 if (keyword == "element")
62 return ElementType::get(getContext());
63 if (keyword == "shape")
64 return ShapeType::get(getContext());
65 if (keyword == "size")
66 return SizeType::get(getContext());
67 if (keyword == "value_shape")
68 return ValueShapeType::get(getContext());
69 if (keyword == "witness")
70 return WitnessType::get(getContext());
71
72 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
73 return Type();
74 }
75
76 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const77 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
78 switch (type.getKind()) {
79 case ShapeTypes::Component:
80 os << "component";
81 return;
82 case ShapeTypes::Element:
83 os << "element";
84 return;
85 case ShapeTypes::Size:
86 os << "size";
87 return;
88 case ShapeTypes::Shape:
89 os << "shape";
90 return;
91 case ShapeTypes::ValueShape:
92 os << "value_shape";
93 return;
94 case ShapeTypes::Witness:
95 os << "witness";
96 return;
97 default:
98 llvm_unreachable("unexpected 'shape' type kind");
99 }
100 }
101
102 //===----------------------------------------------------------------------===//
103 // AnyOp
104 //===----------------------------------------------------------------------===//
105
106 // TODO: Canonicalization should be implemented for shapes that can be
107 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)108 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
109 // Only the last operand is checked because AnyOp is commutative.
110 if (operands.back())
111 return operands.back();
112
113 return nullptr;
114 }
115
116 //===----------------------------------------------------------------------===//
117 // AssumingOp
118 //===----------------------------------------------------------------------===//
119
parseAssumingOp(OpAsmParser & parser,OperationState & result)120 static ParseResult parseAssumingOp(OpAsmParser &parser,
121 OperationState &result) {
122 result.regions.reserve(1);
123 Region *doRegion = result.addRegion();
124
125 auto &builder = parser.getBuilder();
126 OpAsmParser::OperandType cond;
127 if (parser.parseOperand(cond) ||
128 parser.resolveOperand(cond, builder.getType<WitnessType>(),
129 result.operands))
130 return failure();
131
132 // Parse optional results type list.
133 if (parser.parseOptionalArrowTypeList(result.types))
134 return failure();
135
136 // Parse the region and add a terminator if elided.
137 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
138 return failure();
139 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
140
141 // Parse the optional attribute list.
142 if (parser.parseOptionalAttrDict(result.attributes))
143 return failure();
144 return success();
145 }
146
print(OpAsmPrinter & p,AssumingOp op)147 static void print(OpAsmPrinter &p, AssumingOp op) {
148 bool yieldsResults = !op.results().empty();
149
150 p << AssumingOp::getOperationName() << " " << op.witness();
151 if (yieldsResults) {
152 p << " -> (" << op.getResultTypes() << ")";
153 }
154 p.printRegion(op.doRegion(),
155 /*printEntryBlockArgs=*/false,
156 /*printBlockTerminators=*/yieldsResults);
157 p.printOptionalAttrDict(op.getAttrs());
158 }
159
160 namespace {
161 // Removes AssumingOp with a passing witness and inlines the region.
162 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
163 using OpRewritePattern<AssumingOp>::OpRewritePattern;
164
matchAndRewrite__anonb0061b370211::AssumingWithTrue165 LogicalResult matchAndRewrite(AssumingOp op,
166 PatternRewriter &rewriter) const override {
167 auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
168 if (!witness || !witness.passingAttr())
169 return failure();
170
171 AssumingOp::inlineRegionIntoParent(op, rewriter);
172 return success();
173 }
174 };
175 } // namespace
176
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)177 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
178 MLIRContext *context) {
179 // If taking a passing witness, inline region.
180 patterns.insert<AssumingWithTrue>(context);
181 }
182
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)183 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
184 PatternRewriter &rewriter) {
185 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
186 auto *assumingBlock = op.getBody();
187 auto initPosition = rewriter.getInsertionPoint();
188 auto *blockAfterAssuming =
189 rewriter.splitBlock(blockBeforeAssuming, initPosition);
190
191 // Remove the AssumingOp and AssumingYieldOp.
192 auto &yieldOp = assumingBlock->back();
193 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
194 rewriter.replaceOp(op, yieldOp.getOperands());
195 rewriter.eraseOp(&yieldOp);
196
197 // Merge blocks together as there was no branching behavior from the
198 // AssumingOp.
199 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
200 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
201 }
202
203 //===----------------------------------------------------------------------===//
204 // AssumingAllOp
205 //===----------------------------------------------------------------------===//
fold(ArrayRef<Attribute> operands)206 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
207 // Iterate in reverse to first handle all constant operands. They are
208 // guaranteed to be the tail of the inputs because this is commutative.
209 for (int idx = operands.size() - 1; idx >= 0; idx--) {
210 Attribute a = operands[idx];
211 // Cannot fold if any inputs are not constant;
212 if (!a)
213 return nullptr;
214
215 // We do not need to keep statically known values after handling them in
216 // this method.
217 getOperation()->eraseOperand(idx);
218
219 // Always false if any input is statically known false
220 if (!a.cast<BoolAttr>().getValue())
221 return a;
222 }
223 // If this is reached, all inputs were statically known passing.
224 return BoolAttr::get(true, getContext());
225 }
226
verify(AssumingAllOp op)227 static LogicalResult verify(AssumingAllOp op) {
228 // Ensure that AssumingAllOp contains at least one operand
229 if (op.getNumOperands() == 0)
230 return op.emitOpError("no operands specified");
231
232 return success();
233 }
234
235 //===----------------------------------------------------------------------===//
236 // BroadcastOp
237 //===----------------------------------------------------------------------===//
238
fold(ArrayRef<Attribute> operands)239 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
240 if (!operands[0] || !operands[1])
241 return nullptr;
242 auto lhsShape = llvm::to_vector<6>(
243 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
244 auto rhsShape = llvm::to_vector<6>(
245 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
246 SmallVector<int64_t, 6> resultShape;
247 // If the shapes are not compatible, we can't fold it.
248 // TODO: Fold to an "error".
249 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
250 return nullptr;
251 Builder builder(getContext());
252 return builder.getIndexTensorAttr(resultShape);
253 }
254
255 //===----------------------------------------------------------------------===//
256 // ConcatOp
257 //===----------------------------------------------------------------------===//
258
fold(ArrayRef<Attribute> operands)259 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
260 if (!operands[0] || !operands[1])
261 return nullptr;
262 auto lhsShape = llvm::to_vector<6>(
263 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
264 auto rhsShape = llvm::to_vector<6>(
265 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
266 SmallVector<int64_t, 6> resultShape;
267 resultShape.append(lhsShape.begin(), lhsShape.end());
268 resultShape.append(rhsShape.begin(), rhsShape.end());
269 Builder builder(getContext());
270 return builder.getIndexTensorAttr(resultShape);
271 }
272
273 //===----------------------------------------------------------------------===//
274 // ConstShapeOp
275 //===----------------------------------------------------------------------===//
276
print(OpAsmPrinter & p,ConstShapeOp & op)277 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
278 p << "shape.const_shape ";
279 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
280 p << "[";
281 interleaveComma(op.shape().getValues<int64_t>(), p,
282 [&](int64_t i) { p << i; });
283 p << "]";
284 }
285
parseConstShapeOp(OpAsmParser & parser,OperationState & result)286 static ParseResult parseConstShapeOp(OpAsmParser &parser,
287 OperationState &result) {
288 if (parser.parseOptionalAttrDict(result.attributes))
289 return failure();
290 // We piggy-back on ArrayAttr parsing, though we don't internally store the
291 // shape as an ArrayAttr.
292 // TODO: Implement custom parser and maybe make syntax a bit more concise.
293 Attribute extentsRaw;
294 NamedAttrList dummy;
295 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
296 return failure();
297 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
298 if (!extentsArray)
299 return failure();
300 SmallVector<int64_t, 6> ints;
301 for (Attribute extent : extentsArray) {
302 IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
303 if (!attr)
304 return failure();
305 ints.push_back(attr.getInt());
306 }
307 Builder &builder = parser.getBuilder();
308 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
309
310 result.types.push_back(ShapeType::get(builder.getContext()));
311 return success();
312 }
313
fold(ArrayRef<Attribute>)314 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
315
316 //===----------------------------------------------------------------------===//
317 // CstrBroadcastableOp
318 //===----------------------------------------------------------------------===//
319
320 namespace {
321 // Given an input shape Value, try to obtain the shape's values.
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)322 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
323 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
324 auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
325 if (!type.hasRank())
326 return failure();
327 shapeValues = llvm::to_vector<6>(type.getShape());
328 return success();
329 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
330 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
331 return success();
332 } else {
333 return failure();
334 }
335 }
336
337 // For shapes that were created by some operations, we can obtain partial
338 // information on the shapes and sometimes determine if they will be
339 // broadcastable with that.
340 struct CstrBroadcastablePartialInfo
341 : public OpRewritePattern<CstrBroadcastableOp> {
342 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
343
matchAndRewrite__anonb0061b370411::CstrBroadcastablePartialInfo344 LogicalResult matchAndRewrite(CstrBroadcastableOp op,
345 PatternRewriter &rewriter) const override {
346 SmallVector<int64_t, 6> lhsShape, rhsShape;
347 if (failed(getShapeVec(op.lhs(), lhsShape)))
348 return failure();
349 if (failed(getShapeVec(op.rhs(), rhsShape)))
350 return failure();
351 if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
352 return failure();
353
354 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
355 return success();
356 }
357 };
358
359 // Scalars are always broadcastable.
360 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
361 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
362
matchAndRewrite__anonb0061b370411::CstrBroadcastableScalar363 LogicalResult matchAndRewrite(CstrBroadcastableOp op,
364 PatternRewriter &rewriter) const override {
365 SmallVector<int64_t, 6> shape;
366 if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
367 return failure();
368 if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
369 return failure();
370
371 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
372 return success();
373 }
374 };
375
376 } // namespace
377
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)378 void CstrBroadcastableOp::getCanonicalizationPatterns(
379 OwningRewritePatternList &patterns, MLIRContext *context) {
380 // Canonicalization patterns have overlap with the considerations during
381 // folding in case additional shape information is inferred at some point that
382 // does not result in folding.
383 patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
384 CstrBroadcastableScalar>(context);
385 }
386
fold(ArrayRef<Attribute> operands)387 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
388 // Both operands are not needed if one is a scalar.
389 if (operands[0] &&
390 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
391 return BoolAttr::get(true, getContext());
392 if (operands[1] &&
393 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
394 return BoolAttr::get(true, getContext());
395
396 if (operands[0] && operands[1]) {
397 auto lhsShape = llvm::to_vector<6>(
398 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
399 auto rhsShape = llvm::to_vector<6>(
400 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
401 SmallVector<int64_t, 6> resultShape;
402 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
403 return BoolAttr::get(true, getContext());
404 }
405
406 // Lastly, see if folding can be completed based on what constraints are known
407 // on the input shapes.
408 SmallVector<int64_t, 6> lhsShape, rhsShape;
409 if (failed(getShapeVec(lhs(), lhsShape)))
410 return nullptr;
411 if (failed(getShapeVec(rhs(), rhsShape)))
412 return nullptr;
413
414 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
415 return BoolAttr::get(true, getContext());
416
417 // Because a failing witness result here represents an eventual assertion
418 // failure, we do not replace it with a constant witness.
419 return nullptr;
420 }
421
422 //===----------------------------------------------------------------------===//
423 // CstrEqOp
424 //===----------------------------------------------------------------------===//
425
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)426 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
427 MLIRContext *context) {
428 // If inputs are equal, return passing witness
429 patterns.insert<CstrEqEqOps>(context);
430 }
431
fold(ArrayRef<Attribute> operands)432 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
433 if (llvm::all_of(operands,
434 [&](Attribute a) { return a && a == operands[0]; }))
435 return BoolAttr::get(true, getContext());
436
437 // Because a failing witness result here represents an eventual assertion
438 // failure, we do not try to replace it with a constant witness. Similarly, we
439 // cannot if there are any non-const inputs.
440 return nullptr;
441 }
442
443 //===----------------------------------------------------------------------===//
444 // ConstSizeOp
445 //===----------------------------------------------------------------------===//
446
build(OpBuilder & builder,OperationState & result,int64_t value)447 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
448 int64_t value) {
449 build(builder, result, builder.getIndexAttr(value));
450 }
451
fold(ArrayRef<Attribute>)452 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
453
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)454 void ConstSizeOp::getAsmResultNames(
455 llvm::function_ref<void(Value, StringRef)> setNameFn) {
456 SmallString<4> buffer;
457 llvm::raw_svector_ostream os(buffer);
458 os << "c" << value();
459 setNameFn(getResult(), os.str());
460 }
461
462 //===----------------------------------------------------------------------===//
463 // ConstWitnessOp
464 //===----------------------------------------------------------------------===//
465
fold(ArrayRef<Attribute>)466 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
467
468 //===----------------------------------------------------------------------===//
469 // IndexToSizeOp
470 //===----------------------------------------------------------------------===//
471
fold(ArrayRef<Attribute> operands)472 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
473 // Constant values of both types, `shape.size` and `index`, are represented as
474 // `IntegerAttr`s which makes constant folding simple.
475 if (Attribute arg = operands[0])
476 return arg;
477 return {};
478 }
479
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)480 void IndexToSizeOp::getCanonicalizationPatterns(
481 OwningRewritePatternList &patterns, MLIRContext *context) {
482 patterns.insert<SizeToIndexToSizeCanonicalization>(context);
483 }
484
485 //===----------------------------------------------------------------------===//
486 // FromExtentsOp
487 //===----------------------------------------------------------------------===//
488
fold(ArrayRef<Attribute> operands)489 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
490 if (llvm::any_of(operands, [](Attribute a) { return !a; }))
491 return nullptr;
492 SmallVector<int64_t, 6> extents;
493 for (auto attr : operands)
494 extents.push_back(attr.cast<IntegerAttr>().getInt());
495 Builder builder(getContext());
496 return builder.getIndexTensorAttr(extents);
497 }
498
499 //===----------------------------------------------------------------------===//
500 // GetExtentOp
501 //===----------------------------------------------------------------------===//
502
getConstantDim()503 Optional<int64_t> GetExtentOp::getConstantDim() {
504 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
505 return constSizeOp.value().getLimitedValue();
506 }
507 return llvm::None;
508 }
509
fold(ArrayRef<Attribute> operands)510 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
511 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
512 if (!elements)
513 return nullptr;
514 Optional<int64_t> dim = getConstantDim();
515 if (!dim.hasValue())
516 return nullptr;
517 if (dim.getValue() >= elements.getNumElements())
518 return nullptr;
519 return elements.getValue({(uint64_t)dim.getValue()});
520 }
521
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)522 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
523 int64_t dim) {
524 auto loc = result.location;
525 auto dimAttr = builder.getIndexAttr(dim);
526 Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
527 build(builder, result, shape, dimValue);
528 }
529
530 //===----------------------------------------------------------------------===//
531 // RankOp
532 //===----------------------------------------------------------------------===//
533
fold(ArrayRef<Attribute> operands)534 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
535 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
536 if (!shape)
537 return {};
538 int64_t rank = shape.getNumElements();
539 Builder builder(getContext());
540 return builder.getIndexAttr(rank);
541 }
542
543 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
544 /// Constant folding fails in cases where only the rank is constant, not the
545 /// shape itself.
546 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
547 ///
548 /// Example:
549 ///
550 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
551 /// %rank = shape.rank %shape
552 ///
553 /// becomes
554 ///
555 /// %rank = shape.const_size 3
556
557 namespace {
558 struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
559 using OpRewritePattern<RankOp>::OpRewritePattern;
560
matchAndRewrite__anonb0061b370711::RankShapeOfCanonicalizationPattern561 LogicalResult matchAndRewrite(RankOp op,
562 PatternRewriter &rewriter) const override {
563 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
564 if (!shapeOfOp)
565 return failure();
566 auto rankedTensorType =
567 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
568 if (!rankedTensorType)
569 return failure();
570 int64_t rank = rankedTensorType.getRank();
571 rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
572 return success();
573 }
574 };
575 } // namespace
576
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)577 void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
578 MLIRContext *context) {
579 patterns.insert<RankShapeOfCanonicalizationPattern>(context);
580 }
581
582 //===----------------------------------------------------------------------===//
583 // NumElementsOp
584 //===----------------------------------------------------------------------===//
585
fold(ArrayRef<Attribute> operands)586 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
587
588 // Fold only when argument constant.
589 Attribute shape = operands[0];
590 if (!shape)
591 return {};
592
593 APInt product(64, 1);
594 for (auto value : shape.cast<DenseIntElementsAttr>())
595 product *= value;
596 Builder builder(getContext());
597 return builder.getIndexAttr(product.getLimitedValue());
598 }
599
600 //===----------------------------------------------------------------------===//
601 // ShapeOfOp
602 //===----------------------------------------------------------------------===//
603
fold(ArrayRef<Attribute>)604 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
605 auto type = getOperand().getType().dyn_cast<ShapedType>();
606 if (!type || !type.hasStaticShape())
607 return nullptr;
608 Builder builder(getContext());
609 return builder.getIndexTensorAttr(type.getShape());
610 }
611
612 //===----------------------------------------------------------------------===//
613 // SizeToIndexOp
614 //===----------------------------------------------------------------------===//
615
fold(ArrayRef<Attribute> operands)616 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
617 // Constant values of both types, `shape.size` and `index`, are represented as
618 // `IntegerAttr`s which makes constant folding simple.
619 if (Attribute arg = operands[0])
620 return arg;
621 return {};
622 }
623
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)624 void SizeToIndexOp::getCanonicalizationPatterns(
625 OwningRewritePatternList &patterns, MLIRContext *context) {
626 patterns.insert<IndexToSizeToIndexCanonicalization>(context);
627 }
628
629 //===----------------------------------------------------------------------===//
630 // YieldOp
631 //===----------------------------------------------------------------------===//
632
verify(YieldOp op)633 static LogicalResult verify(YieldOp op) {
634 auto *parentOp = op.getParentOp();
635 auto results = parentOp->getResults();
636 auto operands = op.getOperands();
637
638 if (parentOp->getNumResults() != op.getNumOperands())
639 return op.emitOpError() << "number of operands does not match number of "
640 "results of its parent";
641 for (auto e : llvm::zip(results, operands))
642 if (std::get<0>(e).getType() != std::get<1>(e).getType())
643 return op.emitOpError()
644 << "types mismatch between yield op and its parent";
645
646 return success();
647 }
648
649 //===----------------------------------------------------------------------===//
650 // SplitAtOp
651 //===----------------------------------------------------------------------===//
652
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)653 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
654 SmallVectorImpl<OpFoldResult> &results) {
655 if (!operands[0] || !operands[1])
656 return failure();
657 auto shapeVec = llvm::to_vector<6>(
658 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
659 auto shape = llvm::makeArrayRef(shapeVec);
660 auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
661 // Verify that the split point is in the correct range.
662 // TODO: Constant fold to an "error".
663 int64_t rank = shape.size();
664 if (!(-rank <= splitPoint && splitPoint <= rank))
665 return failure();
666 if (splitPoint < 0)
667 splitPoint += shape.size();
668 Builder builder(operands[0].getContext());
669 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
670 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
671 return success();
672 }
673
674 //===----------------------------------------------------------------------===//
675 // ToExtentTensorOp
676 //===----------------------------------------------------------------------===//
677
fold(ArrayRef<Attribute> operands)678 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
679 if (!operands[0])
680 return nullptr;
681 Builder builder(getContext());
682 auto shape = llvm::to_vector<6>(
683 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
684 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
685 builder.getIndexType());
686 return DenseIntElementsAttr::get(type, shape);
687 }
688
689 //===----------------------------------------------------------------------===//
690 // ReduceOp
691 //===----------------------------------------------------------------------===//
692
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)693 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
694 ValueRange initVals) {
695 result.addOperands(shape);
696 result.addOperands(initVals);
697
698 Region *bodyRegion = result.addRegion();
699 bodyRegion->push_back(new Block);
700 Block &bodyBlock = bodyRegion->front();
701 bodyBlock.addArgument(builder.getIndexType());
702 bodyBlock.addArgument(SizeType::get(builder.getContext()));
703
704 for (Type initValType : initVals.getTypes()) {
705 bodyBlock.addArgument(initValType);
706 result.addTypes(initValType);
707 }
708 }
709
verify(ReduceOp op)710 static LogicalResult verify(ReduceOp op) {
711 // Verify block arg types.
712 Block &block = op.region().front();
713
714 auto blockArgsCount = op.initVals().size() + 2;
715 if (block.getNumArguments() != blockArgsCount)
716 return op.emitOpError() << "ReduceOp body is expected to have "
717 << blockArgsCount << " arguments";
718
719 if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
720 return op.emitOpError(
721 "argument 0 of ReduceOp body is expected to be of IndexType");
722
723 if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
724 return op.emitOpError(
725 "argument 1 of ReduceOp body is expected to be of SizeType");
726
727 for (auto type : llvm::enumerate(op.initVals()))
728 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
729 return op.emitOpError()
730 << "type mismatch between argument " << type.index() + 2
731 << " of ReduceOp body and initial value " << type.index();
732 return success();
733 }
734
parseReduceOp(OpAsmParser & parser,OperationState & result)735 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
736 auto *ctx = parser.getBuilder().getContext();
737 // Parse operands.
738 SmallVector<OpAsmParser::OperandType, 3> operands;
739 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
740 OpAsmParser::Delimiter::Paren) ||
741 parser.parseOptionalArrowTypeList(result.types))
742 return failure();
743
744 // Resolve operands.
745 auto initVals = llvm::makeArrayRef(operands).drop_front();
746 if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
747 result.operands) ||
748 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
749 result.operands))
750 return failure();
751
752 // Parse the body.
753 Region *body = result.addRegion();
754 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
755 return failure();
756
757 // Parse attributes.
758 if (parser.parseOptionalAttrDict(result.attributes))
759 return failure();
760
761 return success();
762 }
763
print(OpAsmPrinter & p,ReduceOp op)764 static void print(OpAsmPrinter &p, ReduceOp op) {
765 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
766 << ") ";
767 p.printOptionalArrowTypeList(op.getResultTypes());
768 p.printRegion(op.region());
769 p.printOptionalAttrDict(op.getAttrs());
770 }
771
772 namespace mlir {
773 namespace shape {
774
775 #define GET_OP_CLASSES
776 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
777
778 } // namespace shape
779 } // namespace mlir
780