1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23
24 using namespace mlir;
25 using namespace mlir::shape;
26
27 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
28
29 namespace {
30 #include "ShapeCanonicalization.inc"
31 }
32
getExtentTensorType(MLIRContext * ctx,int64_t rank)33 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
34 return RankedTensorType::get({rank}, IndexType::get(ctx));
35 }
36
isExtentTensorType(Type type)37 bool shape::isExtentTensorType(Type type) {
38 auto ranked = type.dyn_cast<RankedTensorType>();
39 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
40 }
41
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)42 LogicalResult shape::getShapeVec(Value input,
43 SmallVectorImpl<int64_t> &shapeValues) {
44 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
45 auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
46 if (!type.hasRank())
47 return failure();
48 shapeValues = llvm::to_vector<6>(type.getShape());
49 return success();
50 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
51 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
52 return success();
53 } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
54 shapeValues = llvm::to_vector<6>(
55 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
56 return success();
57 } else {
58 return failure();
59 }
60 }
61
isErrorPropagationPossible(TypeRange operandTypes)62 static bool isErrorPropagationPossible(TypeRange operandTypes) {
63 return llvm::any_of(operandTypes, [](Type ty) {
64 return ty.isa<SizeType, ShapeType, ValueShapeType>();
65 });
66 }
67
verifySizeOrIndexOp(Operation * op)68 static LogicalResult verifySizeOrIndexOp(Operation *op) {
69 assert(op != nullptr && op->getNumResults() == 1);
70 Type resultTy = op->getResultTypes().front();
71 if (isErrorPropagationPossible(op->getOperandTypes())) {
72 if (!resultTy.isa<SizeType>())
73 return op->emitOpError()
74 << "if at least one of the operands can hold error values then "
75 "the result must be of type `size` to propagate them";
76 }
77 return success();
78 }
79
verifyShapeOrExtentTensorOp(Operation * op)80 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
81 assert(op != nullptr && op->getNumResults() == 1);
82 Type resultTy = op->getResultTypes().front();
83 if (isErrorPropagationPossible(op->getOperandTypes())) {
84 if (!resultTy.isa<ShapeType>())
85 return op->emitOpError()
86 << "if at least one of the operands can hold error values then "
87 "the result must be of type `shape` to propagate them";
88 }
89 return success();
90 }
91
92 //===----------------------------------------------------------------------===//
93 // InlinerInterface
94 //===----------------------------------------------------------------------===//
95
96 namespace {
97 /// This class defines the interface for inlining shape dialect ops.
98 struct ShapeInlinerInterface : public DialectInlinerInterface {
99 using DialectInlinerInterface::DialectInlinerInterface;
100
101 // Returns true if the given region 'src' can be inlined into the region
102 // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonced9166f0311::ShapeInlinerInterface103 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
104 BlockAndValueMapping &) const final {
105 return true;
106 }
107
108 // Returns true if the given operation 'op', that is registered to this
109 // dialect, can be inlined into the region 'dest' that is attached to an
110 // operation registered to the current dialect.
isLegalToInline__anonced9166f0311::ShapeInlinerInterface111 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
112 BlockAndValueMapping &) const final {
113 return true;
114 }
115 };
116 } // namespace
117
initialize()118 void ShapeDialect::initialize() {
119 addOperations<
120 #define GET_OP_LIST
121 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
122 >();
123 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
124 addInterfaces<ShapeInlinerInterface>();
125 // Allow unknown operations during prototyping and testing. As the dialect is
126 // still evolving it makes it simple to start with an unregistered ops and
127 // try different variants before actually defining the op.
128 allowUnknownOperations();
129 }
130
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)131 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
132 Attribute value, Type type,
133 Location loc) {
134 if (type.isa<ShapeType>() || isExtentTensorType(type))
135 return builder.create<ConstShapeOp>(loc, type,
136 value.cast<DenseIntElementsAttr>());
137 if (type.isa<SizeType>())
138 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
139 if (type.isa<WitnessType>())
140 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
141 if (ConstantOp::isBuildableWith(value, type))
142 return builder.create<ConstantOp>(loc, type, value);
143 return nullptr;
144 }
145
146 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const147 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
148 StringRef keyword;
149 if (parser.parseKeyword(&keyword))
150 return Type();
151
152 if (keyword == "shape")
153 return ShapeType::get(getContext());
154 if (keyword == "size")
155 return SizeType::get(getContext());
156 if (keyword == "value_shape")
157 return ValueShapeType::get(getContext());
158 if (keyword == "witness")
159 return WitnessType::get(getContext());
160
161 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
162 return Type();
163 }
164
165 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const166 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
167 TypeSwitch<Type>(type)
168 .Case<ShapeType>([&](Type) { os << "shape"; })
169 .Case<SizeType>([&](Type) { os << "size"; })
170 .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
171 .Case<WitnessType>([&](Type) { os << "witness"; })
172 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
173 }
174
verifyOperationAttribute(Operation * op,NamedAttribute attribute)175 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
176 NamedAttribute attribute) {
177 // Verify shape.lib attribute.
178 if (attribute.first == "shape.lib") {
179 if (!op->hasTrait<OpTrait::SymbolTable>())
180 return op->emitError(
181 "shape.lib attribute may only be on op implementing SymbolTable");
182
183 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
184 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
185 if (!symbol)
186 return op->emitError("shape function library ")
187 << symbolRef << " not found";
188 return isa<shape::FunctionLibraryOp>(symbol)
189 ? success()
190 : op->emitError()
191 << symbolRef << " required to be shape function library";
192 }
193
194 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
195 // Verify all entries are function libraries and mappings in libraries
196 // refer to unique ops.
197 DenseSet<Identifier> key;
198 for (auto it : arr) {
199 if (!it.isa<SymbolRefAttr>())
200 return op->emitError(
201 "only SymbolRefAttr allowed in shape.lib attribute array");
202
203 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
204 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
205 if (!shapeFnLib)
206 return op->emitError()
207 << it << " does not refer to FunctionLibraryOp";
208 for (auto mapping : shapeFnLib.mapping()) {
209 if (!key.insert(mapping.first).second) {
210 return op->emitError("only one op to shape mapping allowed, found "
211 "multiple for `")
212 << mapping.first << "`";
213 }
214 }
215 }
216 return success();
217 }
218
219 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
220 "allowed as shape.lib attribute");
221 }
222 return success();
223 }
224
225 //===----------------------------------------------------------------------===//
226 // AnyOp
227 //===----------------------------------------------------------------------===//
228
229 // TODO: Canonicalization should be implemented for shapes that can be
230 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)231 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
232 // Only the last operand is checked because AnyOp is commutative.
233 if (operands.back())
234 return operands.back();
235
236 return nullptr;
237 }
238
239 //===----------------------------------------------------------------------===//
240 // AssumingOp
241 //===----------------------------------------------------------------------===//
242
parseAssumingOp(OpAsmParser & parser,OperationState & result)243 static ParseResult parseAssumingOp(OpAsmParser &parser,
244 OperationState &result) {
245 result.regions.reserve(1);
246 Region *doRegion = result.addRegion();
247
248 auto &builder = parser.getBuilder();
249 OpAsmParser::OperandType cond;
250 if (parser.parseOperand(cond) ||
251 parser.resolveOperand(cond, builder.getType<WitnessType>(),
252 result.operands))
253 return failure();
254
255 // Parse optional results type list.
256 if (parser.parseOptionalArrowTypeList(result.types))
257 return failure();
258
259 // Parse the region and add a terminator if elided.
260 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
261 return failure();
262 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
263
264 // Parse the optional attribute list.
265 if (parser.parseOptionalAttrDict(result.attributes))
266 return failure();
267 return success();
268 }
269
print(OpAsmPrinter & p,AssumingOp op)270 static void print(OpAsmPrinter &p, AssumingOp op) {
271 bool yieldsResults = !op.results().empty();
272
273 p << AssumingOp::getOperationName() << " " << op.witness();
274 if (yieldsResults) {
275 p << " -> (" << op.getResultTypes() << ")";
276 }
277 p.printRegion(op.doRegion(),
278 /*printEntryBlockArgs=*/false,
279 /*printBlockTerminators=*/yieldsResults);
280 p.printOptionalAttrDict(op->getAttrs());
281 }
282
283 namespace {
284 // Removes AssumingOp with a passing witness and inlines the region.
285 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
286 using OpRewritePattern<AssumingOp>::OpRewritePattern;
287
matchAndRewrite__anonced9166f0911::AssumingWithTrue288 LogicalResult matchAndRewrite(AssumingOp op,
289 PatternRewriter &rewriter) const override {
290 auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
291 if (!witness || !witness.passingAttr())
292 return failure();
293
294 AssumingOp::inlineRegionIntoParent(op, rewriter);
295 return success();
296 }
297 };
298
299 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
300 using OpRewritePattern<AssumingOp>::OpRewritePattern;
301
matchAndRewrite__anonced9166f0911::AssumingOpRemoveUnusedResults302 LogicalResult matchAndRewrite(AssumingOp op,
303 PatternRewriter &rewriter) const override {
304 Block *body = op.getBody();
305 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
306
307 // Find used values.
308 SmallVector<Value, 4> newYieldOperands;
309 Value opResult, yieldOperand;
310 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
311 std::tie(opResult, yieldOperand) = it;
312 if (!opResult.getUses().empty()) {
313 newYieldOperands.push_back(yieldOperand);
314 }
315 }
316
317 // Rewrite only if redundant results exist.
318 if (newYieldOperands.size() == yieldOp->getNumOperands())
319 return failure();
320
321 // Replace yield op in the old assuming op's body and move the entire region
322 // to the new assuming op.
323 rewriter.setInsertionPointToEnd(body);
324 auto newYieldOp =
325 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
326 rewriter.setInsertionPoint(op);
327 auto newOp = rewriter.create<AssumingOp>(
328 op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
329 newOp.doRegion().takeBody(op.doRegion());
330
331 // Use the new results to replace the previously used ones.
332 SmallVector<Value, 4> replacementValues;
333 auto src = newOp.getResults().begin();
334 for (auto it : op.getResults()) {
335 if (it.getUses().empty())
336 replacementValues.push_back(nullptr);
337 else
338 replacementValues.push_back(*src++);
339 }
340 rewriter.replaceOp(op, replacementValues);
341 return success();
342 }
343 };
344 } // namespace
345
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)346 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
347 MLIRContext *context) {
348 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
349 }
350
351 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)352 void AssumingOp::getSuccessorRegions(
353 Optional<unsigned> index, ArrayRef<Attribute> operands,
354 SmallVectorImpl<RegionSuccessor> ®ions) {
355 // AssumingOp has unconditional control flow into the region and back to the
356 // parent, so return the correct RegionSuccessor purely based on the index
357 // being None or 0.
358 if (index.hasValue()) {
359 regions.push_back(RegionSuccessor(getResults()));
360 return;
361 }
362
363 regions.push_back(RegionSuccessor(&doRegion()));
364 }
365
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)366 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
367 PatternRewriter &rewriter) {
368 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
369 auto *assumingBlock = op.getBody();
370 auto initPosition = rewriter.getInsertionPoint();
371 auto *blockAfterAssuming =
372 rewriter.splitBlock(blockBeforeAssuming, initPosition);
373
374 // Remove the AssumingOp and AssumingYieldOp.
375 auto &yieldOp = assumingBlock->back();
376 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
377 rewriter.replaceOp(op, yieldOp.getOperands());
378 rewriter.eraseOp(&yieldOp);
379
380 // Merge blocks together as there was no branching behavior from the
381 // AssumingOp.
382 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
383 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
384 }
385
build(OpBuilder & builder,OperationState & result,Value witness,function_ref<SmallVector<Value,2> (OpBuilder &,Location)> bodyBuilder)386 void AssumingOp::build(
387 OpBuilder &builder, OperationState &result, Value witness,
388 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
389
390 result.addOperands(witness);
391 Region *bodyRegion = result.addRegion();
392 bodyRegion->push_back(new Block);
393 Block &bodyBlock = bodyRegion->front();
394
395 // Build body.
396 OpBuilder::InsertionGuard guard(builder);
397 builder.setInsertionPointToStart(&bodyBlock);
398 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
399 builder.create<AssumingYieldOp>(result.location, yieldValues);
400
401 SmallVector<Type, 2> assumingTypes;
402 for (Value v : yieldValues)
403 assumingTypes.push_back(v.getType());
404 result.addTypes(assumingTypes);
405 }
406
407 //===----------------------------------------------------------------------===//
408 // AssumingAllOp
409 //===----------------------------------------------------------------------===//
410
411 namespace {
412 struct AssumingAllToCstrEqCanonicalization
413 : public OpRewritePattern<AssumingAllOp> {
414 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
415
matchAndRewrite__anonced9166f0a11::AssumingAllToCstrEqCanonicalization416 LogicalResult matchAndRewrite(AssumingAllOp op,
417 PatternRewriter &rewriter) const override {
418 SmallVector<Value, 8> shapes;
419 for (Value w : op.inputs()) {
420 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
421 if (!cstrEqOp)
422 return failure();
423 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
424 return llvm::is_contained(shapes, s);
425 });
426 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
427 return failure();
428 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
429 }
430 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
431 return success();
432 }
433 };
434
435 template <typename OpTy>
436 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
437 using OpRewritePattern<OpTy>::OpRewritePattern;
438
matchAndRewrite__anonced9166f0a11::RemoveDuplicateOperandsPattern439 LogicalResult matchAndRewrite(OpTy op,
440 PatternRewriter &rewriter) const override {
441 // Find unique operands.
442 SmallVector<Value, 2> unique;
443 for (Value v : op.getOperands()) {
444 if (!llvm::is_contained(unique, v))
445 unique.push_back(v);
446 }
447
448 // Reduce op to equivalent with unique operands.
449 if (unique.size() < op.getNumOperands()) {
450 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
451 op->getAttrs());
452 return success();
453 }
454
455 return failure();
456 }
457 };
458 } // namespace
459
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)460 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
461 MLIRContext *context) {
462 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
463 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
464 }
465
fold(ArrayRef<Attribute> operands)466 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
467 // Iterate in reverse to first handle all constant operands. They are
468 // guaranteed to be the tail of the inputs because this is commutative.
469 for (int idx = operands.size() - 1; idx >= 0; idx--) {
470 Attribute a = operands[idx];
471 // Cannot fold if any inputs are not constant;
472 if (!a)
473 return nullptr;
474
475 // We do not need to keep statically known values after handling them in
476 // this method.
477 getOperation()->eraseOperand(idx);
478
479 // Always false if any input is statically known false
480 if (!a.cast<BoolAttr>().getValue())
481 return a;
482 }
483 // If this is reached, all inputs were statically known passing.
484 return BoolAttr::get(getContext(), true);
485 }
486
verify(AssumingAllOp op)487 static LogicalResult verify(AssumingAllOp op) {
488 // Ensure that AssumingAllOp contains at least one operand
489 if (op.getNumOperands() == 0)
490 return op.emitOpError("no operands specified");
491
492 return success();
493 }
494
build(OpBuilder & b,OperationState & state,ValueRange inputs)495 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
496 ValueRange inputs) {
497 build(b, state, b.getType<WitnessType>(), inputs);
498 }
499
500 //===----------------------------------------------------------------------===//
501 // BroadcastOp
502 //===----------------------------------------------------------------------===//
503
fold(ArrayRef<Attribute> operands)504 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
505 if (shapes().size() == 1) {
506 // Otherwise, we need a cast which would be a canonicalization, not folding.
507 if (shapes().front().getType() != getType())
508 return nullptr;
509 return shapes().front();
510 }
511
512 // TODO: Support folding with more than 2 input shapes
513 if (shapes().size() > 2)
514 return nullptr;
515
516 if (!operands[0] || !operands[1])
517 return nullptr;
518 auto lhsShape = llvm::to_vector<6>(
519 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
520 auto rhsShape = llvm::to_vector<6>(
521 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
522 SmallVector<int64_t, 6> resultShape;
523
524 // If the shapes are not compatible, we can't fold it.
525 // TODO: Fold to an "error".
526 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
527 return nullptr;
528
529 Builder builder(getContext());
530 return builder.getIndexTensorAttr(resultShape);
531 }
532
verify(BroadcastOp op)533 static LogicalResult verify(BroadcastOp op) {
534 return verifyShapeOrExtentTensorOp(op);
535 }
536
537 namespace {
538 template <typename OpTy>
539 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
540 using OpRewritePattern<OpTy>::OpRewritePattern;
541
matchAndRewrite__anonced9166f0c11::RemoveEmptyShapeOperandsPattern542 LogicalResult matchAndRewrite(OpTy op,
543 PatternRewriter &rewriter) const override {
544 auto isPotentiallyNonEmptyShape = [](Value shape) {
545 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
546 if (extentTensorTy.getDimSize(0) == 0)
547 return false;
548 }
549 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
550 if (constShape.shape().empty())
551 return false;
552 }
553 return true;
554 };
555 auto newOperands = llvm::to_vector<8>(
556 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
557
558 // Reduce op to equivalent without empty shape operands.
559 if (newOperands.size() < op.getNumOperands()) {
560 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
561 op->getAttrs());
562 return success();
563 }
564
565 return failure();
566 }
567 };
568
569 struct BroadcastForwardSingleOperandPattern
570 : public OpRewritePattern<BroadcastOp> {
571 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
572
matchAndRewrite__anonced9166f0c11::BroadcastForwardSingleOperandPattern573 LogicalResult matchAndRewrite(BroadcastOp op,
574 PatternRewriter &rewriter) const override {
575 if (op.getNumOperands() != 1)
576 return failure();
577 Value replacement = op.shapes().front();
578
579 // Insert cast if needed.
580 if (replacement.getType() != op.getType()) {
581 auto loc = op.getLoc();
582 if (op.getType().isa<ShapeType>()) {
583 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
584 } else {
585 assert(!op.getType().isa<ShapeType>() &&
586 !replacement.getType().isa<ShapeType>() &&
587 "expect extent tensor cast");
588 replacement =
589 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
590 }
591 }
592
593 rewriter.replaceOp(op, replacement);
594 return success();
595 }
596 };
597
598 struct BroadcastFoldConstantOperandsPattern
599 : public OpRewritePattern<BroadcastOp> {
600 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
601
matchAndRewrite__anonced9166f0c11::BroadcastFoldConstantOperandsPattern602 LogicalResult matchAndRewrite(BroadcastOp op,
603 PatternRewriter &rewriter) const override {
604 SmallVector<int64_t, 8> foldedConstantShape;
605 SmallVector<Value, 8> newShapeOperands;
606 for (Value shape : op.shapes()) {
607 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
608 SmallVector<int64_t, 8> newFoldedConstantShape;
609 if (OpTrait::util::getBroadcastedShape(
610 foldedConstantShape,
611 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
612 newFoldedConstantShape)) {
613 foldedConstantShape = newFoldedConstantShape;
614 continue;
615 }
616 }
617 newShapeOperands.push_back(shape);
618 }
619
620 // Need at least two constant operands to fold anything.
621 if (op.getNumOperands() - newShapeOperands.size() < 2)
622 return failure();
623
624 auto foldedConstantOperandsTy = RankedTensorType::get(
625 {static_cast<int64_t>(foldedConstantShape.size())},
626 rewriter.getIndexType());
627 newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
628 op.getLoc(), foldedConstantOperandsTy,
629 rewriter.getIndexTensorAttr(foldedConstantShape)));
630 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
631 newShapeOperands);
632 return success();
633 }
634 };
635
636 template <typename OpTy>
637 struct CanonicalizeCastExtentTensorOperandsPattern
638 : public OpRewritePattern<OpTy> {
639 using OpRewritePattern<OpTy>::OpRewritePattern;
640
matchAndRewrite__anonced9166f0c11::CanonicalizeCastExtentTensorOperandsPattern641 LogicalResult matchAndRewrite(OpTy op,
642 PatternRewriter &rewriter) const override {
643 // Canonicalize operands.
644 bool anyChange = false;
645 auto canonicalizeOperand = [&](Value operand) {
646 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
647 // Only eliminate the cast if it holds no shape information.
648 bool isInformationLoosingCast =
649 castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
650 if (isInformationLoosingCast) {
651 anyChange = true;
652 return castOp.source();
653 }
654 }
655 return operand;
656 };
657 auto newOperands = llvm::to_vector<8>(
658 llvm::map_range(op.getOperands(), canonicalizeOperand));
659
660 // Rewrite op if any change required.
661 if (!anyChange)
662 return failure();
663 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
664 return success();
665 }
666 };
667
668 struct BroadcastConcretizeResultTypePattern
669 : public OpRewritePattern<BroadcastOp> {
670 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
671
matchAndRewrite__anonced9166f0c11::BroadcastConcretizeResultTypePattern672 LogicalResult matchAndRewrite(BroadcastOp op,
673 PatternRewriter &rewriter) const override {
674 // Only concretize dynamic extent tensor result types.
675 auto resultTy = op.getType().dyn_cast<RankedTensorType>();
676 if (!resultTy || !resultTy.isDynamicDim(0))
677 return failure();
678
679 // Infer resulting shape rank if possible.
680 int64_t maxRank = 0;
681 for (Value shape : op.shapes()) {
682 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
683 // Cannot infer resulting shape rank if any operand is dynamically
684 // ranked.
685 if (extentTensorTy.isDynamicDim(0))
686 return failure();
687 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
688 }
689 }
690
691 auto newOp = rewriter.create<BroadcastOp>(
692 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
693 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
694 return success();
695 }
696 };
697 } // namespace
698
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)699 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
700 MLIRContext *context) {
701 patterns.add<BroadcastConcretizeResultTypePattern,
702 BroadcastFoldConstantOperandsPattern,
703 BroadcastForwardSingleOperandPattern,
704 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
705 RemoveDuplicateOperandsPattern<BroadcastOp>,
706 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
707 }
708
709 //===----------------------------------------------------------------------===//
710 // ConcatOp
711 //===----------------------------------------------------------------------===//
712
fold(ArrayRef<Attribute> operands)713 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
714 if (!operands[0] || !operands[1])
715 return nullptr;
716 auto lhsShape = llvm::to_vector<6>(
717 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
718 auto rhsShape = llvm::to_vector<6>(
719 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
720 SmallVector<int64_t, 6> resultShape;
721 resultShape.append(lhsShape.begin(), lhsShape.end());
722 resultShape.append(rhsShape.begin(), rhsShape.end());
723 Builder builder(getContext());
724 return builder.getIndexTensorAttr(resultShape);
725 }
726
727 //===----------------------------------------------------------------------===//
728 // ConstShapeOp
729 //===----------------------------------------------------------------------===//
730
print(OpAsmPrinter & p,ConstShapeOp & op)731 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
732 p << "shape.const_shape ";
733 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
734 p << "[";
735 interleaveComma(op.shape().getValues<int64_t>(), p,
736 [&](int64_t i) { p << i; });
737 p << "] : ";
738 p.printType(op.getType());
739 }
740
parseConstShapeOp(OpAsmParser & parser,OperationState & result)741 static ParseResult parseConstShapeOp(OpAsmParser &parser,
742 OperationState &result) {
743 if (parser.parseOptionalAttrDict(result.attributes))
744 return failure();
745 // We piggy-back on ArrayAttr parsing, though we don't internally store the
746 // shape as an ArrayAttr.
747 // TODO: Implement custom parser and maybe make syntax a bit more concise.
748 Attribute extentsRaw;
749 NamedAttrList dummy;
750 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
751 return failure();
752 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
753 if (!extentsArray)
754 return failure();
755 SmallVector<int64_t, 6> ints;
756 for (Attribute extent : extentsArray) {
757 IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
758 if (!attr)
759 return failure();
760 ints.push_back(attr.getInt());
761 }
762 Builder &builder = parser.getBuilder();
763 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
764 Type resultTy;
765 if (parser.parseColonType(resultTy))
766 return failure();
767 result.types.push_back(resultTy);
768 return success();
769 }
770
fold(ArrayRef<Attribute>)771 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
772
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)773 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
774 MLIRContext *context) {
775 patterns.add<TensorCastConstShape>(context);
776 }
777
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)778 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
779 MLIRContext *context, Optional<Location> location, ValueRange operands,
780 DictionaryAttr attributes, RegionRange regions,
781 SmallVectorImpl<Type> &inferredReturnTypes) {
782 Builder b(context);
783 auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
784 if (!shape)
785 return emitOptionalError(location, "missing shape attribute");
786 inferredReturnTypes.assign({RankedTensorType::get(
787 {static_cast<int64_t>(shape.size())}, b.getIndexType())});
788 return success();
789 }
790
isCompatibleReturnTypes(TypeRange l,TypeRange r)791 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
792 TypeRange r) {
793 if (l.size() != 1 || r.size() != 1)
794 return false;
795
796 Type lhs = l.front();
797 Type rhs = r.front();
798
799 if (lhs == rhs)
800 return true;
801
802 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
803 // Shape type is compatible with all other valid return types.
804 return true;
805
806 return succeeded(verifyCompatibleShapes(lhs, rhs));
807 }
808
809 //===----------------------------------------------------------------------===//
810 // CstrBroadcastableOp
811 //===----------------------------------------------------------------------===//
812
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)813 void CstrBroadcastableOp::getCanonicalizationPatterns(
814 RewritePatternSet &patterns, MLIRContext *context) {
815 // Canonicalization patterns have overlap with the considerations during
816 // folding in case additional shape information is inferred at some point that
817 // does not result in folding.
818 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
819 CstrBroadcastableEqOps,
820 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
821 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
822 }
823
824 // Return true if there is exactly one attribute not representing a scalar
825 // broadcast.
hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes)826 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
827 bool nonScalarSeen = false;
828 for (Attribute a : attributes) {
829 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
830 if (nonScalarSeen)
831 return false;
832 nonScalarSeen = true;
833 }
834 }
835 return true;
836 }
837
fold(ArrayRef<Attribute> operands)838 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
839 // No broadcasting is needed if all operands but one are scalar.
840 if (hasAtMostSingleNonScalar(operands))
841 return BoolAttr::get(getContext(), true);
842
843 if ([&] {
844 SmallVector<SmallVector<int64_t, 6>, 6> extents;
845 for (const auto &operand : operands) {
846 if (!operand)
847 return false;
848 extents.push_back(llvm::to_vector<6>(
849 operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
850 }
851 return OpTrait::util::staticallyKnownBroadcastable(extents);
852 }())
853 return BoolAttr::get(getContext(), true);
854
855 // Lastly, see if folding can be completed based on what constraints are known
856 // on the input shapes.
857 if ([&] {
858 SmallVector<SmallVector<int64_t, 6>, 6> extents;
859 for (auto shapeValue : shapes()) {
860 extents.emplace_back();
861 if (failed(getShapeVec(shapeValue, extents.back())))
862 return false;
863 }
864 return OpTrait::util::staticallyKnownBroadcastable(extents);
865 }())
866 return BoolAttr::get(getContext(), true);
867
868 // Because a failing witness result here represents an eventual assertion
869 // failure, we do not replace it with a constant witness.
870 return nullptr;
871 }
872
verify(CstrBroadcastableOp op)873 static LogicalResult verify(CstrBroadcastableOp op) {
874 // Ensure that AssumingAllOp contains at least one operand
875 if (op.getNumOperands() < 2)
876 return op.emitOpError("required at least 2 input shapes");
877 return success();
878 }
879
880 //===----------------------------------------------------------------------===//
881 // CstrEqOp
882 //===----------------------------------------------------------------------===//
883
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)884 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
885 MLIRContext *context) {
886 // If inputs are equal, return passing witness
887 patterns.add<CstrEqEqOps>(context);
888 }
889
fold(ArrayRef<Attribute> operands)890 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
891 if (llvm::all_of(operands,
892 [&](Attribute a) { return a && a == operands[0]; }))
893 return BoolAttr::get(getContext(), true);
894
895 // Because a failing witness result here represents an eventual assertion
896 // failure, we do not try to replace it with a constant witness. Similarly, we
897 // cannot if there are any non-const inputs.
898 return nullptr;
899 }
900
901 //===----------------------------------------------------------------------===//
902 // ConstSizeOp
903 //===----------------------------------------------------------------------===//
904
build(OpBuilder & builder,OperationState & result,int64_t value)905 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
906 int64_t value) {
907 build(builder, result, builder.getIndexAttr(value));
908 }
909
fold(ArrayRef<Attribute>)910 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
911
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)912 void ConstSizeOp::getAsmResultNames(
913 llvm::function_ref<void(Value, StringRef)> setNameFn) {
914 SmallString<4> buffer;
915 llvm::raw_svector_ostream os(buffer);
916 os << "c" << value();
917 setNameFn(getResult(), os.str());
918 }
919
920 //===----------------------------------------------------------------------===//
921 // ConstWitnessOp
922 //===----------------------------------------------------------------------===//
923
fold(ArrayRef<Attribute>)924 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
925
926 //===----------------------------------------------------------------------===//
927 // CstrRequireOp
928 //===----------------------------------------------------------------------===//
929
fold(ArrayRef<Attribute> operands)930 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
931 return operands[0];
932 }
933
934 //===----------------------------------------------------------------------===//
935 // DivOp
936 //===----------------------------------------------------------------------===//
937
fold(ArrayRef<Attribute> operands)938 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
939 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
940 if (!lhs)
941 return nullptr;
942 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
943 if (!rhs)
944 return nullptr;
945
946 // Division in APInt does not follow floor(lhs, rhs) when the result is
947 // negative. Rather, APInt rounds toward zero.
948 APInt quotient, remainder;
949 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
950 if (quotient.isNegative() && !remainder.isNullValue()) {
951 quotient -= 1;
952 }
953
954 Type indexTy = IndexType::get(getContext());
955 return IntegerAttr::get(indexTy, quotient);
956 }
957
958 //===----------------------------------------------------------------------===//
959 // ShapeEqOp
960 //===----------------------------------------------------------------------===//
961
fold(ArrayRef<Attribute> operands)962 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
963 bool allSame = true;
964 if (!operands.empty() && !operands[0])
965 return {};
966 for (Attribute operand : operands.drop_front(1)) {
967 if (!operand)
968 return {};
969 allSame = allSame && operand == operands[0];
970 }
971 return BoolAttr::get(getContext(), allSame);
972 }
973
974 //===----------------------------------------------------------------------===//
975 // IndexToSizeOp
976 //===----------------------------------------------------------------------===//
977
fold(ArrayRef<Attribute> operands)978 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
979 // Constant values of both types, `shape.size` and `index`, are represented as
980 // `IntegerAttr`s which makes constant folding simple.
981 if (Attribute arg = operands[0])
982 return arg;
983 return {};
984 }
985
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)986 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
987 MLIRContext *context) {
988 patterns.add<SizeToIndexToSizeCanonicalization>(context);
989 }
990
991 //===----------------------------------------------------------------------===//
992 // FromExtentsOp
993 //===----------------------------------------------------------------------===//
994
fold(ArrayRef<Attribute> operands)995 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
996 if (llvm::any_of(operands, [](Attribute a) { return !a; }))
997 return nullptr;
998 SmallVector<int64_t, 6> extents;
999 for (auto attr : operands)
1000 extents.push_back(attr.cast<IntegerAttr>().getInt());
1001 Builder builder(getContext());
1002 return builder.getIndexTensorAttr(extents);
1003 }
1004
1005 //===----------------------------------------------------------------------===//
1006 // FunctionLibraryOp
1007 //===----------------------------------------------------------------------===//
1008
build(OpBuilder & builder,OperationState & result,StringRef name)1009 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1010 StringRef name) {
1011 result.attributes.push_back(builder.getNamedAttr(
1012 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1013 }
1014
getShapeFunction(Operation * op)1015 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1016 auto attr = mapping()
1017 .get(op->getName().getIdentifier())
1018 .dyn_cast_or_null<FlatSymbolRefAttr>();
1019 if (!attr)
1020 return nullptr;
1021 return lookupSymbol<FuncOp>(attr);
1022 }
1023
parseFunctionLibraryOp(OpAsmParser & parser,OperationState & result)1024 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1025 OperationState &result) {
1026 // Parse the op name.
1027 StringAttr nameAttr;
1028 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1029 result.attributes))
1030 return failure();
1031
1032 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1033 return failure();
1034
1035 auto *bodyRegion = result.addRegion();
1036 if (parser.parseRegion(*bodyRegion))
1037 return failure();
1038
1039 if (parser.parseKeyword("mapping"))
1040 return failure();
1041
1042 DictionaryAttr mappingAttr;
1043 if (parser.parseAttribute(mappingAttr,
1044 parser.getBuilder().getType<NoneType>(), "mapping",
1045 result.attributes))
1046 return failure();
1047 return success();
1048 }
1049
print(OpAsmPrinter & p,FunctionLibraryOp op)1050 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1051 p << op.getOperationName() << ' ';
1052 p.printSymbolName(op.getName());
1053 p.printOptionalAttrDictWithKeyword(
1054 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1055 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1056 /*printBlockTerminators=*/false);
1057 p << " mapping ";
1058 p.printAttributeWithoutType(op.mappingAttr());
1059 }
1060
1061 //===----------------------------------------------------------------------===//
1062 // GetExtentOp
1063 //===----------------------------------------------------------------------===//
1064
getConstantDim()1065 Optional<int64_t> GetExtentOp::getConstantDim() {
1066 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1067 return constSizeOp.value().getLimitedValue();
1068 if (auto constantOp = dim().getDefiningOp<ConstantOp>())
1069 return constantOp.value().cast<IntegerAttr>().getInt();
1070 return llvm::None;
1071 }
1072
fold(ArrayRef<Attribute> operands)1073 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1074 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1075 if (!elements)
1076 return nullptr;
1077 Optional<int64_t> dim = getConstantDim();
1078 if (!dim.hasValue())
1079 return nullptr;
1080 if (dim.getValue() >= elements.getNumElements())
1081 return nullptr;
1082 return elements.getValue({(uint64_t)dim.getValue()});
1083 }
1084
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)1085 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1086 int64_t dim) {
1087 auto loc = result.location;
1088 auto dimAttr = builder.getIndexAttr(dim);
1089 if (shape.getType().isa<ShapeType>()) {
1090 Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1091 build(builder, result, builder.getType<SizeType>(), shape, dim);
1092 } else {
1093 Value dim =
1094 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
1095 build(builder, result, builder.getIndexType(), shape, dim);
1096 }
1097 }
1098
1099 //===----------------------------------------------------------------------===//
1100 // IsBroadcastableOp
1101 //===----------------------------------------------------------------------===//
1102
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1103 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1104 MLIRContext *context) {
1105 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1106 }
1107
fold(ArrayRef<Attribute> operands)1108 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1109 // Can always broadcast fewer than two shapes.
1110 if (operands.size() < 2) {
1111 return BoolAttr::get(getContext(), true);
1112 }
1113
1114 return nullptr;
1115 }
1116
1117 //===----------------------------------------------------------------------===//
1118 // RankOp
1119 //===----------------------------------------------------------------------===//
1120
fold(ArrayRef<Attribute> operands)1121 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1122 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1123 if (!shape)
1124 return {};
1125 int64_t rank = shape.getNumElements();
1126 Builder builder(getContext());
1127 return builder.getIndexAttr(rank);
1128 }
1129
1130 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1131 /// Constant folding fails in cases where only the rank is constant, not the
1132 /// shape itself.
1133 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1134 ///
1135 /// Example:
1136 ///
1137 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1138 /// %rank = shape.rank %shape
1139 ///
1140 /// becomes
1141 ///
1142 /// %rank = shape.const_size 3
1143
1144 namespace {
1145 struct RankShapeOfCanonicalizationPattern
1146 : public OpRewritePattern<shape::RankOp> {
1147 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1148
matchAndRewrite__anonced9166f1411::RankShapeOfCanonicalizationPattern1149 LogicalResult matchAndRewrite(shape::RankOp op,
1150 PatternRewriter &rewriter) const override {
1151 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1152 if (!shapeOfOp)
1153 return failure();
1154 auto rankedTensorType =
1155 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1156 if (!rankedTensorType)
1157 return failure();
1158 int64_t rank = rankedTensorType.getRank();
1159 if (op.getType().isa<IndexType>()) {
1160 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
1161 } else if (op.getType().isa<shape::SizeType>()) {
1162 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1163 } else {
1164 return failure();
1165 }
1166 return success();
1167 }
1168 };
1169 } // namespace
1170
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1171 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1172 MLIRContext *context) {
1173 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1174 }
1175
1176 //===----------------------------------------------------------------------===//
1177 // NumElementsOp
1178 //===----------------------------------------------------------------------===//
1179
fold(ArrayRef<Attribute> operands)1180 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1181
1182 // Fold only when argument constant.
1183 Attribute shape = operands[0];
1184 if (!shape)
1185 return {};
1186
1187 APInt product(64, 1);
1188 for (auto value : shape.cast<DenseIntElementsAttr>())
1189 product *= value;
1190 Builder builder(getContext());
1191 return builder.getIndexAttr(product.getLimitedValue());
1192 }
1193
build(OpBuilder & builder,OperationState & result,Value shape)1194 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
1195 Value shape) {
1196 if (shape.getType().isa<ShapedType>()) {
1197 auto type = builder.getIndexType();
1198 return build(builder, result, type, shape);
1199 }
1200 auto type = SizeType::get(builder.getContext());
1201 return build(builder, result, type, shape);
1202 }
1203
1204 //===----------------------------------------------------------------------===//
1205 // MaxOp
1206 //===----------------------------------------------------------------------===//
1207
fold(llvm::ArrayRef<mlir::Attribute> operands)1208 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1209 // If operands are equal, just propagate one.
1210 if (lhs() == rhs())
1211 return lhs();
1212 return nullptr;
1213 }
1214
1215 //===----------------------------------------------------------------------===//
1216 // MinOp
1217 //===----------------------------------------------------------------------===//
1218
fold(llvm::ArrayRef<mlir::Attribute> operands)1219 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1220 // If operands are equal, just propagate one.
1221 if (lhs() == rhs())
1222 return lhs();
1223 return nullptr;
1224 }
1225
1226 //===----------------------------------------------------------------------===//
1227 // MulOp
1228 //===----------------------------------------------------------------------===//
1229
fold(ArrayRef<Attribute> operands)1230 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1231 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1232 if (!lhs)
1233 return nullptr;
1234 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1235 if (!rhs)
1236 return nullptr;
1237 APInt folded = lhs.getValue() * rhs.getValue();
1238 Type indexTy = IndexType::get(getContext());
1239 return IntegerAttr::get(indexTy, folded);
1240 }
1241
1242 //===----------------------------------------------------------------------===//
1243 // ShapeOfOp
1244 //===----------------------------------------------------------------------===//
1245
fold(ArrayRef<Attribute>)1246 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1247 auto type = getOperand().getType().dyn_cast<ShapedType>();
1248 if (!type || !type.hasStaticShape())
1249 return nullptr;
1250 Builder builder(getContext());
1251 return builder.getIndexTensorAttr(type.getShape());
1252 }
1253
build(OpBuilder & builder,OperationState & result,Value arg)1254 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
1255 if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
1256 int64_t rank =
1257 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1258 Type indexTy = builder.getIndexType();
1259 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1260 return ShapeOfOp::build(builder, result, extentTensorTy, arg);
1261 }
1262 Type shapeTy = builder.getType<ShapeType>();
1263 return ShapeOfOp::build(builder, result, shapeTy, arg);
1264 }
1265
1266 namespace {
1267 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1268 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1269
matchAndRewrite__anonced9166f1511::ShapeOfWithTensor1270 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1271 PatternRewriter &rewriter) const override {
1272 if (!op.arg().getType().isa<ShapedType>())
1273 return failure();
1274 if (op.getType().isa<ShapedType>())
1275 return failure();
1276
1277 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1278 return success();
1279 }
1280 };
1281
1282 // Canonicalize
1283 // ```
1284 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1285 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1286 // ```
1287 // to
1288 // ```
1289 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1290 // ```
1291 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1292 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1293
matchAndRewrite__anonced9166f1511::ShapeOfCastExtentTensor1294 LogicalResult matchAndRewrite(tensor::CastOp op,
1295 PatternRewriter &rewriter) const override {
1296 auto ty = op.getType().dyn_cast<RankedTensorType>();
1297 if (!ty || ty.getRank() != 1)
1298 return failure();
1299
1300 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1301 if (!shapeOfOp)
1302 return failure();
1303
1304 // Argument type must be ranked and must not conflict.
1305 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1306 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1307 return failure();
1308
1309 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1310 return success();
1311 }
1312 };
1313 } // namespace
1314
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1315 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1316 MLIRContext *context) {
1317 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
1318 }
1319
1320 //===----------------------------------------------------------------------===//
1321 // SizeToIndexOp
1322 //===----------------------------------------------------------------------===//
1323
fold(ArrayRef<Attribute> operands)1324 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1325 // Constant values of both types, `shape.size` and `index`, are represented as
1326 // `IntegerAttr`s which makes constant folding simple.
1327 if (Attribute arg = operands[0])
1328 return arg;
1329 return impl::foldCastOp(*this);
1330 }
1331
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1332 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1333 MLIRContext *context) {
1334 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1335 }
1336
1337 //===----------------------------------------------------------------------===//
1338 // YieldOp
1339 //===----------------------------------------------------------------------===//
1340
verify(shape::YieldOp op)1341 static LogicalResult verify(shape::YieldOp op) {
1342 auto *parentOp = op->getParentOp();
1343 auto results = parentOp->getResults();
1344 auto operands = op.getOperands();
1345
1346 if (parentOp->getNumResults() != op.getNumOperands())
1347 return op.emitOpError() << "number of operands does not match number of "
1348 "results of its parent";
1349 for (auto e : llvm::zip(results, operands))
1350 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1351 return op.emitOpError()
1352 << "types mismatch between yield op and its parent";
1353
1354 return success();
1355 }
1356
1357 //===----------------------------------------------------------------------===//
1358 // SplitAtOp
1359 //===----------------------------------------------------------------------===//
1360
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1361 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1362 SmallVectorImpl<OpFoldResult> &results) {
1363 if (!operands[0] || !operands[1])
1364 return failure();
1365 auto shapeVec = llvm::to_vector<6>(
1366 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1367 auto shape = llvm::makeArrayRef(shapeVec);
1368 auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1369 // Verify that the split point is in the correct range.
1370 // TODO: Constant fold to an "error".
1371 int64_t rank = shape.size();
1372 if (!(-rank <= splitPoint && splitPoint <= rank))
1373 return failure();
1374 if (splitPoint < 0)
1375 splitPoint += shape.size();
1376 Builder builder(operands[0].getContext());
1377 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1378 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1379 return success();
1380 }
1381
1382 //===----------------------------------------------------------------------===//
1383 // ToExtentTensorOp
1384 //===----------------------------------------------------------------------===//
1385
fold(ArrayRef<Attribute> operands)1386 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1387 if (!operands[0])
1388 return impl::foldCastOp(*this);
1389 Builder builder(getContext());
1390 auto shape = llvm::to_vector<6>(
1391 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1392 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1393 builder.getIndexType());
1394 return DenseIntElementsAttr::get(type, shape);
1395 }
1396
1397 //===----------------------------------------------------------------------===//
1398 // ReduceOp
1399 //===----------------------------------------------------------------------===//
1400
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)1401 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1402 ValueRange initVals) {
1403 result.addOperands(shape);
1404 result.addOperands(initVals);
1405
1406 Region *bodyRegion = result.addRegion();
1407 bodyRegion->push_back(new Block);
1408 Block &bodyBlock = bodyRegion->front();
1409 bodyBlock.addArgument(builder.getIndexType());
1410
1411 Type elementType;
1412 if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1413 elementType = tensorType.getElementType();
1414 else
1415 elementType = SizeType::get(builder.getContext());
1416 bodyBlock.addArgument(elementType);
1417
1418 for (Type initValType : initVals.getTypes()) {
1419 bodyBlock.addArgument(initValType);
1420 result.addTypes(initValType);
1421 }
1422 }
1423
verify(ReduceOp op)1424 static LogicalResult verify(ReduceOp op) {
1425 // Verify block arg types.
1426 Block &block = op.region().front();
1427
1428 // The block takes index, extent, and aggregated values as arguments.
1429 auto blockArgsCount = op.initVals().size() + 2;
1430 if (block.getNumArguments() != blockArgsCount)
1431 return op.emitOpError() << "ReduceOp body is expected to have "
1432 << blockArgsCount << " arguments";
1433
1434 // The first block argument is the index and must always be of type `index`.
1435 if (!block.getArgument(0).getType().isa<IndexType>())
1436 return op.emitOpError(
1437 "argument 0 of ReduceOp body is expected to be of IndexType");
1438
1439 // The second block argument is the extent and must be of type `size` or
1440 // `index`, depending on whether the reduce operation is applied to a shape or
1441 // to an extent tensor.
1442 Type extentTy = block.getArgument(1).getType();
1443 if (op.shape().getType().isa<ShapeType>()) {
1444 if (!extentTy.isa<SizeType>())
1445 return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1446 "SizeType if the ReduceOp operates on a ShapeType");
1447 } else {
1448 if (!extentTy.isa<IndexType>())
1449 return op.emitOpError(
1450 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1451 "ReduceOp operates on an extent tensor");
1452 }
1453
1454 for (auto type : llvm::enumerate(op.initVals()))
1455 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1456 return op.emitOpError()
1457 << "type mismatch between argument " << type.index() + 2
1458 << " of ReduceOp body and initial value " << type.index();
1459 return success();
1460 }
1461
parseReduceOp(OpAsmParser & parser,OperationState & result)1462 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1463 // Parse operands.
1464 SmallVector<OpAsmParser::OperandType, 3> operands;
1465 Type shapeOrExtentTensorType;
1466 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1467 OpAsmParser::Delimiter::Paren) ||
1468 parser.parseColonType(shapeOrExtentTensorType) ||
1469 parser.parseOptionalArrowTypeList(result.types))
1470 return failure();
1471
1472 // Resolve operands.
1473 auto initVals = llvm::makeArrayRef(operands).drop_front();
1474 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1475 result.operands) ||
1476 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1477 result.operands))
1478 return failure();
1479
1480 // Parse the body.
1481 Region *body = result.addRegion();
1482 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1483 return failure();
1484
1485 // Parse attributes.
1486 if (parser.parseOptionalAttrDict(result.attributes))
1487 return failure();
1488
1489 return success();
1490 }
1491
print(OpAsmPrinter & p,ReduceOp op)1492 static void print(OpAsmPrinter &p, ReduceOp op) {
1493 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1494 << ") : " << op.shape().getType();
1495 p.printOptionalArrowTypeList(op.getResultTypes());
1496 p.printRegion(op.region());
1497 p.printOptionalAttrDict(op->getAttrs());
1498 }
1499
1500 #define GET_OP_CLASSES
1501 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1502