1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 using namespace mlir;
25 using namespace mlir::shape;
26 
27 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
28 
29 namespace {
30 #include "ShapeCanonicalization.inc"
31 }
32 
getExtentTensorType(MLIRContext * ctx,int64_t rank)33 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
34   return RankedTensorType::get({rank}, IndexType::get(ctx));
35 }
36 
isExtentTensorType(Type type)37 bool shape::isExtentTensorType(Type type) {
38   auto ranked = type.dyn_cast<RankedTensorType>();
39   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
40 }
41 
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)42 LogicalResult shape::getShapeVec(Value input,
43                                  SmallVectorImpl<int64_t> &shapeValues) {
44   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
45     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
46     if (!type.hasRank())
47       return failure();
48     shapeValues = llvm::to_vector<6>(type.getShape());
49     return success();
50   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
51     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
52     return success();
53   } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
54     shapeValues = llvm::to_vector<6>(
55         inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
56     return success();
57   } else {
58     return failure();
59   }
60 }
61 
isErrorPropagationPossible(TypeRange operandTypes)62 static bool isErrorPropagationPossible(TypeRange operandTypes) {
63   return llvm::any_of(operandTypes, [](Type ty) {
64     return ty.isa<SizeType, ShapeType, ValueShapeType>();
65   });
66 }
67 
verifySizeOrIndexOp(Operation * op)68 static LogicalResult verifySizeOrIndexOp(Operation *op) {
69   assert(op != nullptr && op->getNumResults() == 1);
70   Type resultTy = op->getResultTypes().front();
71   if (isErrorPropagationPossible(op->getOperandTypes())) {
72     if (!resultTy.isa<SizeType>())
73       return op->emitOpError()
74              << "if at least one of the operands can hold error values then "
75                 "the result must be of type `size` to propagate them";
76   }
77   return success();
78 }
79 
verifyShapeOrExtentTensorOp(Operation * op)80 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
81   assert(op != nullptr && op->getNumResults() == 1);
82   Type resultTy = op->getResultTypes().front();
83   if (isErrorPropagationPossible(op->getOperandTypes())) {
84     if (!resultTy.isa<ShapeType>())
85       return op->emitOpError()
86              << "if at least one of the operands can hold error values then "
87                 "the result must be of type `shape` to propagate them";
88   }
89   return success();
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // InlinerInterface
94 //===----------------------------------------------------------------------===//
95 
96 namespace {
97 /// This class defines the interface for inlining shape dialect ops.
98 struct ShapeInlinerInterface : public DialectInlinerInterface {
99   using DialectInlinerInterface::DialectInlinerInterface;
100 
101   // Returns true if the given region 'src' can be inlined into the region
102   // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonf25128a30311::ShapeInlinerInterface103   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
104                        BlockAndValueMapping &) const final {
105     return true;
106   }
107 
108   // Returns true if the given operation 'op', that is registered to this
109   // dialect, can be inlined into the region 'dest' that is attached to an
110   // operation registered to the current dialect.
isLegalToInline__anonf25128a30311::ShapeInlinerInterface111   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
112                        BlockAndValueMapping &) const final {
113     return true;
114   }
115 };
116 } // namespace
117 
initialize()118 void ShapeDialect::initialize() {
119   addOperations<
120 #define GET_OP_LIST
121 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
122       >();
123   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
124   addInterfaces<ShapeInlinerInterface>();
125   // Allow unknown operations during prototyping and testing. As the dialect is
126   // still evolving it makes it simple to start with an unregistered ops and
127   // try different variants before actually defining the op.
128   allowUnknownOperations();
129 }
130 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)131 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
132                                              Attribute value, Type type,
133                                              Location loc) {
134   if (type.isa<ShapeType>() || isExtentTensorType(type))
135     return builder.create<ConstShapeOp>(loc, type,
136                                         value.cast<DenseIntElementsAttr>());
137   if (type.isa<SizeType>())
138     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
139   if (type.isa<WitnessType>())
140     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
141   if (ConstantOp::isBuildableWith(value, type))
142     return builder.create<ConstantOp>(loc, type, value);
143   return nullptr;
144 }
145 
146 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const147 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
148   StringRef keyword;
149   if (parser.parseKeyword(&keyword))
150     return Type();
151 
152   if (keyword == "shape")
153     return ShapeType::get(getContext());
154   if (keyword == "size")
155     return SizeType::get(getContext());
156   if (keyword == "value_shape")
157     return ValueShapeType::get(getContext());
158   if (keyword == "witness")
159     return WitnessType::get(getContext());
160 
161   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
162   return Type();
163 }
164 
165 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const166 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
167   TypeSwitch<Type>(type)
168       .Case<ShapeType>([&](Type) { os << "shape"; })
169       .Case<SizeType>([&](Type) { os << "size"; })
170       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
171       .Case<WitnessType>([&](Type) { os << "witness"; })
172       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
173 }
174 
verifyOperationAttribute(Operation * op,NamedAttribute attribute)175 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
176                                                      NamedAttribute attribute) {
177   // Verify shape.lib attribute.
178   if (attribute.first == "shape.lib") {
179     if (!op->hasTrait<OpTrait::SymbolTable>())
180       return op->emitError(
181           "shape.lib attribute may only be on op implementing SymbolTable");
182 
183     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
184       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
185       if (!symbol)
186         return op->emitError("shape function library ")
187                << symbolRef << " not found";
188       return isa<shape::FunctionLibraryOp>(symbol)
189                  ? success()
190                  : op->emitError()
191                        << symbolRef << " required to be shape function library";
192     }
193 
194     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
195       // Verify all entries are function libraries and mappings in libraries
196       // refer to unique ops.
197       DenseSet<Identifier> key;
198       for (auto it : arr) {
199         if (!it.isa<SymbolRefAttr>())
200           return op->emitError(
201               "only SymbolRefAttr allowed in shape.lib attribute array");
202 
203         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
204             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
205         if (!shapeFnLib)
206           return op->emitError()
207                  << it << " does not refer to FunctionLibraryOp";
208         for (auto mapping : shapeFnLib.mapping()) {
209           if (!key.insert(mapping.first).second) {
210             return op->emitError("only one op to shape mapping allowed, found "
211                                  "multiple for `")
212                    << mapping.first << "`";
213           }
214         }
215       }
216       return success();
217     }
218 
219     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
220                          "allowed as shape.lib attribute");
221   }
222   return success();
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // AnyOp
227 //===----------------------------------------------------------------------===//
228 
229 // TODO: Canonicalization should be implemented for shapes that can be
230 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)231 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
232   // Only the last operand is checked because AnyOp is commutative.
233   if (operands.back())
234     return operands.back();
235 
236   return nullptr;
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // AssumingOp
241 //===----------------------------------------------------------------------===//
242 
parseAssumingOp(OpAsmParser & parser,OperationState & result)243 static ParseResult parseAssumingOp(OpAsmParser &parser,
244                                    OperationState &result) {
245   result.regions.reserve(1);
246   Region *doRegion = result.addRegion();
247 
248   auto &builder = parser.getBuilder();
249   OpAsmParser::OperandType cond;
250   if (parser.parseOperand(cond) ||
251       parser.resolveOperand(cond, builder.getType<WitnessType>(),
252                             result.operands))
253     return failure();
254 
255   // Parse optional results type list.
256   if (parser.parseOptionalArrowTypeList(result.types))
257     return failure();
258 
259   // Parse the region and add a terminator if elided.
260   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
261     return failure();
262   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
263 
264   // Parse the optional attribute list.
265   if (parser.parseOptionalAttrDict(result.attributes))
266     return failure();
267   return success();
268 }
269 
print(OpAsmPrinter & p,AssumingOp op)270 static void print(OpAsmPrinter &p, AssumingOp op) {
271   bool yieldsResults = !op.results().empty();
272 
273   p << AssumingOp::getOperationName() << " " << op.witness();
274   if (yieldsResults) {
275     p << " -> (" << op.getResultTypes() << ")";
276   }
277   p.printRegion(op.doRegion(),
278                 /*printEntryBlockArgs=*/false,
279                 /*printBlockTerminators=*/yieldsResults);
280   p.printOptionalAttrDict(op->getAttrs());
281 }
282 
283 namespace {
284 // Removes AssumingOp with a passing witness and inlines the region.
285 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
286   using OpRewritePattern<AssumingOp>::OpRewritePattern;
287 
matchAndRewrite__anonf25128a30911::AssumingWithTrue288   LogicalResult matchAndRewrite(AssumingOp op,
289                                 PatternRewriter &rewriter) const override {
290     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
291     if (!witness || !witness.passingAttr())
292       return failure();
293 
294     AssumingOp::inlineRegionIntoParent(op, rewriter);
295     return success();
296   }
297 };
298 
299 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
300   using OpRewritePattern<AssumingOp>::OpRewritePattern;
301 
matchAndRewrite__anonf25128a30911::AssumingOpRemoveUnusedResults302   LogicalResult matchAndRewrite(AssumingOp op,
303                                 PatternRewriter &rewriter) const override {
304     Block *body = op.getBody();
305     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
306 
307     // Find used values.
308     SmallVector<Value, 4> newYieldOperands;
309     Value opResult, yieldOperand;
310     for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
311       std::tie(opResult, yieldOperand) = it;
312       if (!opResult.getUses().empty()) {
313         newYieldOperands.push_back(yieldOperand);
314       }
315     }
316 
317     // Rewrite only if redundant results exist.
318     if (newYieldOperands.size() == yieldOp->getNumOperands())
319       return failure();
320 
321     // Replace yield op in the old assuming op's body and move the entire region
322     // to the new assuming op.
323     rewriter.setInsertionPointToEnd(body);
324     auto newYieldOp =
325         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
326     rewriter.setInsertionPoint(op);
327     auto newOp = rewriter.create<AssumingOp>(
328         op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
329     newOp.doRegion().takeBody(op.doRegion());
330 
331     // Use the new results to replace the previously used ones.
332     SmallVector<Value, 4> replacementValues;
333     auto src = newOp.getResults().begin();
334     for (auto it : op.getResults()) {
335       if (it.getUses().empty())
336         replacementValues.push_back(nullptr);
337       else
338         replacementValues.push_back(*src++);
339     }
340     rewriter.replaceOp(op, replacementValues);
341     return success();
342   }
343 };
344 } // namespace
345 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)346 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
347                                              MLIRContext *context) {
348   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
349 }
350 
351 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)352 void AssumingOp::getSuccessorRegions(
353     Optional<unsigned> index, ArrayRef<Attribute> operands,
354     SmallVectorImpl<RegionSuccessor> &regions) {
355   // AssumingOp has unconditional control flow into the region and back to the
356   // parent, so return the correct RegionSuccessor purely based on the index
357   // being None or 0.
358   if (index.hasValue()) {
359     regions.push_back(RegionSuccessor(getResults()));
360     return;
361   }
362 
363   regions.push_back(RegionSuccessor(&doRegion()));
364 }
365 
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)366 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
367                                         PatternRewriter &rewriter) {
368   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
369   auto *assumingBlock = op.getBody();
370   auto initPosition = rewriter.getInsertionPoint();
371   auto *blockAfterAssuming =
372       rewriter.splitBlock(blockBeforeAssuming, initPosition);
373 
374   // Remove the AssumingOp and AssumingYieldOp.
375   auto &yieldOp = assumingBlock->back();
376   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
377   rewriter.replaceOp(op, yieldOp.getOperands());
378   rewriter.eraseOp(&yieldOp);
379 
380   // Merge blocks together as there was no branching behavior from the
381   // AssumingOp.
382   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
383   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
384 }
385 
build(OpBuilder & builder,OperationState & result,Value witness,function_ref<SmallVector<Value,2> (OpBuilder &,Location)> bodyBuilder)386 void AssumingOp::build(
387     OpBuilder &builder, OperationState &result, Value witness,
388     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
389 
390   result.addOperands(witness);
391   Region *bodyRegion = result.addRegion();
392   bodyRegion->push_back(new Block);
393   Block &bodyBlock = bodyRegion->front();
394 
395   // Build body.
396   OpBuilder::InsertionGuard guard(builder);
397   builder.setInsertionPointToStart(&bodyBlock);
398   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
399   builder.create<AssumingYieldOp>(result.location, yieldValues);
400 
401   SmallVector<Type, 2> assumingTypes;
402   for (Value v : yieldValues)
403     assumingTypes.push_back(v.getType());
404   result.addTypes(assumingTypes);
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // AssumingAllOp
409 //===----------------------------------------------------------------------===//
410 
411 namespace {
412 struct AssumingAllToCstrEqCanonicalization
413     : public OpRewritePattern<AssumingAllOp> {
414   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
415 
matchAndRewrite__anonf25128a30a11::AssumingAllToCstrEqCanonicalization416   LogicalResult matchAndRewrite(AssumingAllOp op,
417                                 PatternRewriter &rewriter) const override {
418     SmallVector<Value, 8> shapes;
419     for (Value w : op.inputs()) {
420       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
421       if (!cstrEqOp)
422         return failure();
423       bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
424         return llvm::is_contained(shapes, s);
425       });
426       if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
427         return failure();
428       shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
429     }
430     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
431     return success();
432   }
433 };
434 
435 template <typename OpTy>
436 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
437   using OpRewritePattern<OpTy>::OpRewritePattern;
438 
matchAndRewrite__anonf25128a30a11::RemoveDuplicateOperandsPattern439   LogicalResult matchAndRewrite(OpTy op,
440                                 PatternRewriter &rewriter) const override {
441     // Find unique operands.
442     SmallVector<Value, 2> unique;
443     for (Value v : op.getOperands()) {
444       if (!llvm::is_contained(unique, v))
445         unique.push_back(v);
446     }
447 
448     // Reduce op to equivalent with unique operands.
449     if (unique.size() < op.getNumOperands()) {
450       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
451                                         op->getAttrs());
452       return success();
453     }
454 
455     return failure();
456   }
457 };
458 } // namespace
459 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)460 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
461                                                 MLIRContext *context) {
462   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
463                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
464 }
465 
fold(ArrayRef<Attribute> operands)466 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
467   // Iterate in reverse to first handle all constant operands. They are
468   // guaranteed to be the tail of the inputs because this is commutative.
469   for (int idx = operands.size() - 1; idx >= 0; idx--) {
470     Attribute a = operands[idx];
471     // Cannot fold if any inputs are not constant;
472     if (!a)
473       return nullptr;
474 
475     // We do not need to keep statically known values after handling them in
476     // this method.
477     getOperation()->eraseOperand(idx);
478 
479     // Always false if any input is statically known false
480     if (!a.cast<BoolAttr>().getValue())
481       return a;
482   }
483   // If this is reached, all inputs were statically known passing.
484   return BoolAttr::get(getContext(), true);
485 }
486 
verify(AssumingAllOp op)487 static LogicalResult verify(AssumingAllOp op) {
488   // Ensure that AssumingAllOp contains at least one operand
489   if (op.getNumOperands() == 0)
490     return op.emitOpError("no operands specified");
491 
492   return success();
493 }
494 
build(OpBuilder & b,OperationState & state,ValueRange inputs)495 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
496                           ValueRange inputs) {
497   build(b, state, b.getType<WitnessType>(), inputs);
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // BroadcastOp
502 //===----------------------------------------------------------------------===//
503 
fold(ArrayRef<Attribute> operands)504 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
505   if (shapes().size() == 1) {
506     // Otherwise, we need a cast which would be a canonicalization, not folding.
507     if (shapes().front().getType() != getType())
508       return nullptr;
509     return shapes().front();
510   }
511 
512   // TODO: Support folding with more than 2 input shapes
513   if (shapes().size() > 2)
514     return nullptr;
515 
516   if (!operands[0] || !operands[1])
517     return nullptr;
518   auto lhsShape = llvm::to_vector<6>(
519       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
520   auto rhsShape = llvm::to_vector<6>(
521       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
522   SmallVector<int64_t, 6> resultShape;
523 
524   // If the shapes are not compatible, we can't fold it.
525   // TODO: Fold to an "error".
526   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
527     return nullptr;
528 
529   Builder builder(getContext());
530   return builder.getIndexTensorAttr(resultShape);
531 }
532 
verify(BroadcastOp op)533 static LogicalResult verify(BroadcastOp op) {
534   return verifyShapeOrExtentTensorOp(op);
535 }
536 
537 namespace {
538 template <typename OpTy>
539 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
540   using OpRewritePattern<OpTy>::OpRewritePattern;
541 
matchAndRewrite__anonf25128a30c11::RemoveEmptyShapeOperandsPattern542   LogicalResult matchAndRewrite(OpTy op,
543                                 PatternRewriter &rewriter) const override {
544     auto isPotentiallyNonEmptyShape = [](Value shape) {
545       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
546         if (extentTensorTy.getDimSize(0) == 0)
547           return false;
548       }
549       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
550         if (constShape.shape().empty())
551           return false;
552       }
553       return true;
554     };
555     auto newOperands = llvm::to_vector<8>(
556         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
557 
558     // Reduce op to equivalent without empty shape operands.
559     if (newOperands.size() < op.getNumOperands()) {
560       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
561                                         op->getAttrs());
562       return success();
563     }
564 
565     return failure();
566   }
567 };
568 
569 struct BroadcastForwardSingleOperandPattern
570     : public OpRewritePattern<BroadcastOp> {
571   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
572 
matchAndRewrite__anonf25128a30c11::BroadcastForwardSingleOperandPattern573   LogicalResult matchAndRewrite(BroadcastOp op,
574                                 PatternRewriter &rewriter) const override {
575     if (op.getNumOperands() != 1)
576       return failure();
577     Value replacement = op.shapes().front();
578 
579     // Insert cast if needed.
580     if (replacement.getType() != op.getType()) {
581       auto loc = op.getLoc();
582       if (op.getType().isa<ShapeType>()) {
583         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
584       } else {
585         assert(!op.getType().isa<ShapeType>() &&
586                !replacement.getType().isa<ShapeType>() &&
587                "expect extent tensor cast");
588         replacement =
589             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
590       }
591     }
592 
593     rewriter.replaceOp(op, replacement);
594     return success();
595   }
596 };
597 
598 struct BroadcastFoldConstantOperandsPattern
599     : public OpRewritePattern<BroadcastOp> {
600   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
601 
matchAndRewrite__anonf25128a30c11::BroadcastFoldConstantOperandsPattern602   LogicalResult matchAndRewrite(BroadcastOp op,
603                                 PatternRewriter &rewriter) const override {
604     SmallVector<int64_t, 8> foldedConstantShape;
605     SmallVector<Value, 8> newShapeOperands;
606     for (Value shape : op.shapes()) {
607       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
608         SmallVector<int64_t, 8> newFoldedConstantShape;
609         if (OpTrait::util::getBroadcastedShape(
610                 foldedConstantShape,
611                 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
612                 newFoldedConstantShape)) {
613           foldedConstantShape = newFoldedConstantShape;
614           continue;
615         }
616       }
617       newShapeOperands.push_back(shape);
618     }
619 
620     // Need at least two constant operands to fold anything.
621     if (op.getNumOperands() - newShapeOperands.size() < 2)
622       return failure();
623 
624     auto foldedConstantOperandsTy = RankedTensorType::get(
625         {static_cast<int64_t>(foldedConstantShape.size())},
626         rewriter.getIndexType());
627     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
628         op.getLoc(), foldedConstantOperandsTy,
629         rewriter.getIndexTensorAttr(foldedConstantShape)));
630     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
631                                              newShapeOperands);
632     return success();
633   }
634 };
635 
636 template <typename OpTy>
637 struct CanonicalizeCastExtentTensorOperandsPattern
638     : public OpRewritePattern<OpTy> {
639   using OpRewritePattern<OpTy>::OpRewritePattern;
640 
matchAndRewrite__anonf25128a30c11::CanonicalizeCastExtentTensorOperandsPattern641   LogicalResult matchAndRewrite(OpTy op,
642                                 PatternRewriter &rewriter) const override {
643     // Canonicalize operands.
644     bool anyChange = false;
645     auto canonicalizeOperand = [&](Value operand) {
646       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
647         // Only eliminate the cast if it holds no shape information.
648         bool isInformationLoosingCast =
649             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
650         if (isInformationLoosingCast) {
651           anyChange = true;
652           return castOp.source();
653         }
654       }
655       return operand;
656     };
657     auto newOperands = llvm::to_vector<8>(
658         llvm::map_range(op.getOperands(), canonicalizeOperand));
659 
660     // Rewrite op if any change required.
661     if (!anyChange)
662       return failure();
663     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
664     return success();
665   }
666 };
667 
668 struct BroadcastConcretizeResultTypePattern
669     : public OpRewritePattern<BroadcastOp> {
670   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
671 
matchAndRewrite__anonf25128a30c11::BroadcastConcretizeResultTypePattern672   LogicalResult matchAndRewrite(BroadcastOp op,
673                                 PatternRewriter &rewriter) const override {
674     // Only concretize dynamic extent tensor result types.
675     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
676     if (!resultTy || !resultTy.isDynamicDim(0))
677       return failure();
678 
679     // Infer resulting shape rank if possible.
680     int64_t maxRank = 0;
681     for (Value shape : op.shapes()) {
682       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
683         // Cannot infer resulting shape rank if any operand is dynamically
684         // ranked.
685         if (extentTensorTy.isDynamicDim(0))
686           return failure();
687         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
688       }
689     }
690 
691     auto newOp = rewriter.create<BroadcastOp>(
692         op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
693     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
694     return success();
695   }
696 };
697 } // namespace
698 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)699 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
700                                               MLIRContext *context) {
701   patterns.add<BroadcastConcretizeResultTypePattern,
702                BroadcastFoldConstantOperandsPattern,
703                BroadcastForwardSingleOperandPattern,
704                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
705                RemoveDuplicateOperandsPattern<BroadcastOp>,
706                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // ConcatOp
711 //===----------------------------------------------------------------------===//
712 
fold(ArrayRef<Attribute> operands)713 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
714   if (!operands[0] || !operands[1])
715     return nullptr;
716   auto lhsShape = llvm::to_vector<6>(
717       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
718   auto rhsShape = llvm::to_vector<6>(
719       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
720   SmallVector<int64_t, 6> resultShape;
721   resultShape.append(lhsShape.begin(), lhsShape.end());
722   resultShape.append(rhsShape.begin(), rhsShape.end());
723   Builder builder(getContext());
724   return builder.getIndexTensorAttr(resultShape);
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // ConstShapeOp
729 //===----------------------------------------------------------------------===//
730 
print(OpAsmPrinter & p,ConstShapeOp & op)731 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
732   p << "shape.const_shape ";
733   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
734   p << "[";
735   interleaveComma(op.shape().getValues<int64_t>(), p,
736                   [&](int64_t i) { p << i; });
737   p << "] : ";
738   p.printType(op.getType());
739 }
740 
parseConstShapeOp(OpAsmParser & parser,OperationState & result)741 static ParseResult parseConstShapeOp(OpAsmParser &parser,
742                                      OperationState &result) {
743   if (parser.parseOptionalAttrDict(result.attributes))
744     return failure();
745   // We piggy-back on ArrayAttr parsing, though we don't internally store the
746   // shape as an ArrayAttr.
747   // TODO: Implement custom parser and maybe make syntax a bit more concise.
748   Attribute extentsRaw;
749   NamedAttrList dummy;
750   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
751     return failure();
752   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
753   if (!extentsArray)
754     return failure();
755   SmallVector<int64_t, 6> ints;
756   for (Attribute extent : extentsArray) {
757     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
758     if (!attr)
759       return failure();
760     ints.push_back(attr.getInt());
761   }
762   Builder &builder = parser.getBuilder();
763   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
764   Type resultTy;
765   if (parser.parseColonType(resultTy))
766     return failure();
767   result.types.push_back(resultTy);
768   return success();
769 }
770 
fold(ArrayRef<Attribute>)771 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
772 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)773 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
774                                                MLIRContext *context) {
775   patterns.add<TensorCastConstShape>(context);
776 }
777 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)778 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
779     MLIRContext *context, Optional<Location> location, ValueRange operands,
780     DictionaryAttr attributes, RegionRange regions,
781     SmallVectorImpl<Type> &inferredReturnTypes) {
782   Builder b(context);
783   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
784   if (!shape)
785     return emitOptionalError(location, "missing shape attribute");
786   inferredReturnTypes.assign({RankedTensorType::get(
787       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
788   return success();
789 }
790 
isCompatibleReturnTypes(TypeRange l,TypeRange r)791 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
792                                                         TypeRange r) {
793   if (l.size() != 1 || r.size() != 1)
794     return false;
795 
796   Type lhs = l.front();
797   Type rhs = r.front();
798 
799   if (lhs == rhs)
800     return true;
801 
802   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
803     // Shape type is compatible with all other valid return types.
804     return true;
805 
806   return succeeded(verifyCompatibleShapes(lhs, rhs));
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // CstrBroadcastableOp
811 //===----------------------------------------------------------------------===//
812 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)813 void CstrBroadcastableOp::getCanonicalizationPatterns(
814     RewritePatternSet &patterns, MLIRContext *context) {
815   // Canonicalization patterns have overlap with the considerations during
816   // folding in case additional shape information is inferred at some point that
817   // does not result in folding.
818   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
819                CstrBroadcastableEqOps,
820                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
821                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
822 }
823 
824 // Return true if there is exactly one attribute not representing a scalar
825 // broadcast.
hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes)826 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
827   bool nonScalarSeen = false;
828   for (Attribute a : attributes) {
829     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
830       if (nonScalarSeen)
831         return false;
832       nonScalarSeen = true;
833     }
834   }
835   return true;
836 }
837 
fold(ArrayRef<Attribute> operands)838 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
839   // No broadcasting is needed if all operands but one are scalar.
840   if (hasAtMostSingleNonScalar(operands))
841     return BoolAttr::get(getContext(), true);
842 
843   if ([&] {
844         SmallVector<SmallVector<int64_t, 6>, 6> extents;
845         for (const auto &operand : operands) {
846           if (!operand)
847             return false;
848           extents.push_back(llvm::to_vector<6>(
849               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
850         }
851         return OpTrait::util::staticallyKnownBroadcastable(extents);
852       }())
853     return BoolAttr::get(getContext(), true);
854 
855   // Lastly, see if folding can be completed based on what constraints are known
856   // on the input shapes.
857   if ([&] {
858         SmallVector<SmallVector<int64_t, 6>, 6> extents;
859         for (auto shapeValue : shapes()) {
860           extents.emplace_back();
861           if (failed(getShapeVec(shapeValue, extents.back())))
862             return false;
863         }
864         return OpTrait::util::staticallyKnownBroadcastable(extents);
865       }())
866     return BoolAttr::get(getContext(), true);
867 
868   // Because a failing witness result here represents an eventual assertion
869   // failure, we do not replace it with a constant witness.
870   return nullptr;
871 }
872 
verify(CstrBroadcastableOp op)873 static LogicalResult verify(CstrBroadcastableOp op) {
874   // Ensure that AssumingAllOp contains at least one operand
875   if (op.getNumOperands() < 2)
876     return op.emitOpError("required at least 2 input shapes");
877   return success();
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // CstrEqOp
882 //===----------------------------------------------------------------------===//
883 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)884 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
885                                            MLIRContext *context) {
886   // If inputs are equal, return passing witness
887   patterns.add<CstrEqEqOps>(context);
888 }
889 
fold(ArrayRef<Attribute> operands)890 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
891   if (llvm::all_of(operands,
892                    [&](Attribute a) { return a && a == operands[0]; }))
893     return BoolAttr::get(getContext(), true);
894 
895   // Because a failing witness result here represents an eventual assertion
896   // failure, we do not try to replace it with a constant witness. Similarly, we
897   // cannot if there are any non-const inputs.
898   return nullptr;
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // ConstSizeOp
903 //===----------------------------------------------------------------------===//
904 
build(OpBuilder & builder,OperationState & result,int64_t value)905 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
906                         int64_t value) {
907   build(builder, result, builder.getIndexAttr(value));
908 }
909 
fold(ArrayRef<Attribute>)910 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
911 
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)912 void ConstSizeOp::getAsmResultNames(
913     llvm::function_ref<void(Value, StringRef)> setNameFn) {
914   SmallString<4> buffer;
915   llvm::raw_svector_ostream os(buffer);
916   os << "c" << value();
917   setNameFn(getResult(), os.str());
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // ConstWitnessOp
922 //===----------------------------------------------------------------------===//
923 
fold(ArrayRef<Attribute>)924 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
925 
926 //===----------------------------------------------------------------------===//
927 // CstrRequireOp
928 //===----------------------------------------------------------------------===//
929 
fold(ArrayRef<Attribute> operands)930 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
931   return operands[0];
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // DivOp
936 //===----------------------------------------------------------------------===//
937 
fold(ArrayRef<Attribute> operands)938 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
939   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
940   if (!lhs)
941     return nullptr;
942   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
943   if (!rhs)
944     return nullptr;
945 
946   // Division in APInt does not follow floor(lhs, rhs) when the result is
947   // negative. Rather, APInt rounds toward zero.
948   APInt quotient, remainder;
949   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
950   if (quotient.isNegative() && !remainder.isNullValue()) {
951     quotient -= 1;
952   }
953 
954   Type indexTy = IndexType::get(getContext());
955   return IntegerAttr::get(indexTy, quotient);
956 }
957 
958 //===----------------------------------------------------------------------===//
959 // ShapeEqOp
960 //===----------------------------------------------------------------------===//
961 
fold(ArrayRef<Attribute> operands)962 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
963   bool allSame = true;
964   if (!operands.empty() && !operands[0])
965     return {};
966   for (Attribute operand : operands.drop_front(1)) {
967     if (!operand)
968       return {};
969     allSame = allSame && operand == operands[0];
970   }
971   return BoolAttr::get(getContext(), allSame);
972 }
973 
974 //===----------------------------------------------------------------------===//
975 // IndexToSizeOp
976 //===----------------------------------------------------------------------===//
977 
fold(ArrayRef<Attribute> operands)978 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
979   // Constant values of both types, `shape.size` and `index`, are represented as
980   // `IntegerAttr`s which makes constant folding simple.
981   if (Attribute arg = operands[0])
982     return arg;
983   return {};
984 }
985 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)986 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
987                                                 MLIRContext *context) {
988   patterns.add<SizeToIndexToSizeCanonicalization>(context);
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // FromExtentsOp
993 //===----------------------------------------------------------------------===//
994 
fold(ArrayRef<Attribute> operands)995 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
996   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
997     return nullptr;
998   SmallVector<int64_t, 6> extents;
999   for (auto attr : operands)
1000     extents.push_back(attr.cast<IntegerAttr>().getInt());
1001   Builder builder(getContext());
1002   return builder.getIndexTensorAttr(extents);
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // FunctionLibraryOp
1007 //===----------------------------------------------------------------------===//
1008 
build(OpBuilder & builder,OperationState & result,StringRef name)1009 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1010                               StringRef name) {
1011   result.attributes.push_back(builder.getNamedAttr(
1012       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1013 }
1014 
getShapeFunction(Operation * op)1015 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1016   auto attr = mapping()
1017                   .get(op->getName().getIdentifier())
1018                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1019   if (!attr)
1020     return nullptr;
1021   return lookupSymbol<FuncOp>(attr);
1022 }
1023 
parseFunctionLibraryOp(OpAsmParser & parser,OperationState & result)1024 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1025                                    OperationState &result) {
1026   // Parse the op name.
1027   StringAttr nameAttr;
1028   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1029                              result.attributes))
1030     return failure();
1031 
1032   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1033     return failure();
1034 
1035   auto *bodyRegion = result.addRegion();
1036   if (parser.parseRegion(*bodyRegion))
1037     return failure();
1038 
1039   if (parser.parseKeyword("mapping"))
1040     return failure();
1041 
1042   DictionaryAttr mappingAttr;
1043   if (parser.parseAttribute(mappingAttr,
1044                             parser.getBuilder().getType<NoneType>(), "mapping",
1045                             result.attributes))
1046     return failure();
1047   return success();
1048 }
1049 
print(OpAsmPrinter & p,FunctionLibraryOp op)1050 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1051   p << op.getOperationName() << ' ';
1052   p.printSymbolName(op.getName());
1053   p.printOptionalAttrDictWithKeyword(
1054       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1055   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1056                 /*printBlockTerminators=*/false);
1057   p << " mapping ";
1058   p.printAttributeWithoutType(op.mappingAttr());
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // GetExtentOp
1063 //===----------------------------------------------------------------------===//
1064 
getConstantDim()1065 Optional<int64_t> GetExtentOp::getConstantDim() {
1066   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1067     return constSizeOp.value().getLimitedValue();
1068   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
1069     return constantOp.value().cast<IntegerAttr>().getInt();
1070   return llvm::None;
1071 }
1072 
fold(ArrayRef<Attribute> operands)1073 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1074   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1075   if (!elements)
1076     return nullptr;
1077   Optional<int64_t> dim = getConstantDim();
1078   if (!dim.hasValue())
1079     return nullptr;
1080   if (dim.getValue() >= elements.getNumElements())
1081     return nullptr;
1082   return elements.getValue({(uint64_t)dim.getValue()});
1083 }
1084 
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)1085 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1086                         int64_t dim) {
1087   auto loc = result.location;
1088   auto dimAttr = builder.getIndexAttr(dim);
1089   if (shape.getType().isa<ShapeType>()) {
1090     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1091     build(builder, result, builder.getType<SizeType>(), shape, dim);
1092   } else {
1093     Value dim =
1094         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
1095     build(builder, result, builder.getIndexType(), shape, dim);
1096   }
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // IsBroadcastableOp
1101 //===----------------------------------------------------------------------===//
1102 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1103 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1104                                                     MLIRContext *context) {
1105   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1106 }
1107 
fold(ArrayRef<Attribute> operands)1108 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1109   // Can always broadcast fewer than two shapes.
1110   if (operands.size() < 2) {
1111     return BoolAttr::get(getContext(), true);
1112   }
1113 
1114   return nullptr;
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // RankOp
1119 //===----------------------------------------------------------------------===//
1120 
fold(ArrayRef<Attribute> operands)1121 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1122   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1123   if (!shape)
1124     return {};
1125   int64_t rank = shape.getNumElements();
1126   Builder builder(getContext());
1127   return builder.getIndexAttr(rank);
1128 }
1129 
1130 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1131 /// Constant folding fails in cases where only the rank is constant, not the
1132 /// shape itself.
1133 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1134 ///
1135 /// Example:
1136 ///
1137 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1138 /// %rank = shape.rank %shape
1139 ///
1140 /// becomes
1141 ///
1142 /// %rank = shape.const_size 3
1143 
1144 namespace {
1145 struct RankShapeOfCanonicalizationPattern
1146     : public OpRewritePattern<shape::RankOp> {
1147   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1148 
matchAndRewrite__anonf25128a31411::RankShapeOfCanonicalizationPattern1149   LogicalResult matchAndRewrite(shape::RankOp op,
1150                                 PatternRewriter &rewriter) const override {
1151     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1152     if (!shapeOfOp)
1153       return failure();
1154     auto rankedTensorType =
1155         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1156     if (!rankedTensorType)
1157       return failure();
1158     int64_t rank = rankedTensorType.getRank();
1159     if (op.getType().isa<IndexType>()) {
1160       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
1161     } else if (op.getType().isa<shape::SizeType>()) {
1162       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1163     } else {
1164       return failure();
1165     }
1166     return success();
1167   }
1168 };
1169 } // namespace
1170 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1171 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1172                                                 MLIRContext *context) {
1173   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // NumElementsOp
1178 //===----------------------------------------------------------------------===//
1179 
fold(ArrayRef<Attribute> operands)1180 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1181 
1182   // Fold only when argument constant.
1183   Attribute shape = operands[0];
1184   if (!shape)
1185     return {};
1186 
1187   APInt product(64, 1);
1188   for (auto value : shape.cast<DenseIntElementsAttr>())
1189     product *= value;
1190   Builder builder(getContext());
1191   return builder.getIndexAttr(product.getLimitedValue());
1192 }
1193 
build(OpBuilder & builder,OperationState & result,Value shape)1194 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
1195                           Value shape) {
1196   if (shape.getType().isa<ShapedType>()) {
1197     auto type = builder.getIndexType();
1198     return build(builder, result, type, shape);
1199   }
1200   auto type = SizeType::get(builder.getContext());
1201   return build(builder, result, type, shape);
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // MaxOp
1206 //===----------------------------------------------------------------------===//
1207 
fold(llvm::ArrayRef<mlir::Attribute> operands)1208 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1209   // If operands are equal, just propagate one.
1210   if (lhs() == rhs())
1211     return lhs();
1212   return nullptr;
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // MinOp
1217 //===----------------------------------------------------------------------===//
1218 
fold(llvm::ArrayRef<mlir::Attribute> operands)1219 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1220   // If operands are equal, just propagate one.
1221   if (lhs() == rhs())
1222     return lhs();
1223   return nullptr;
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // MulOp
1228 //===----------------------------------------------------------------------===//
1229 
fold(ArrayRef<Attribute> operands)1230 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1231   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1232   if (!lhs)
1233     return nullptr;
1234   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1235   if (!rhs)
1236     return nullptr;
1237   APInt folded = lhs.getValue() * rhs.getValue();
1238   Type indexTy = IndexType::get(getContext());
1239   return IntegerAttr::get(indexTy, folded);
1240 }
1241 
1242 //===----------------------------------------------------------------------===//
1243 // ShapeOfOp
1244 //===----------------------------------------------------------------------===//
1245 
fold(ArrayRef<Attribute>)1246 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1247   auto type = getOperand().getType().dyn_cast<ShapedType>();
1248   if (!type || !type.hasStaticShape())
1249     return nullptr;
1250   Builder builder(getContext());
1251   return builder.getIndexTensorAttr(type.getShape());
1252 }
1253 
build(OpBuilder & builder,OperationState & result,Value arg)1254 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
1255   if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
1256     int64_t rank =
1257         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1258     Type indexTy = builder.getIndexType();
1259     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1260     return ShapeOfOp::build(builder, result, extentTensorTy, arg);
1261   }
1262   Type shapeTy = builder.getType<ShapeType>();
1263   return ShapeOfOp::build(builder, result, shapeTy, arg);
1264 }
1265 
1266 namespace {
1267 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1268   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1269 
matchAndRewrite__anonf25128a31511::ShapeOfWithTensor1270   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1271                                 PatternRewriter &rewriter) const override {
1272     if (!op.arg().getType().isa<ShapedType>())
1273       return failure();
1274     if (op.getType().isa<ShapedType>())
1275       return failure();
1276 
1277     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1278     return success();
1279   }
1280 };
1281 
1282 // Canonicalize
1283 // ```
1284 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1285 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1286 // ```
1287 // to
1288 // ```
1289 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1290 // ```
1291 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1292   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1293 
matchAndRewrite__anonf25128a31511::ShapeOfCastExtentTensor1294   LogicalResult matchAndRewrite(tensor::CastOp op,
1295                                 PatternRewriter &rewriter) const override {
1296     auto ty = op.getType().dyn_cast<RankedTensorType>();
1297     if (!ty || ty.getRank() != 1)
1298       return failure();
1299 
1300     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1301     if (!shapeOfOp)
1302       return failure();
1303 
1304     // Argument type must be ranked and must not conflict.
1305     auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1306     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1307       return failure();
1308 
1309     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1310     return success();
1311   }
1312 };
1313 } // namespace
1314 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1315 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1316                                             MLIRContext *context) {
1317   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // SizeToIndexOp
1322 //===----------------------------------------------------------------------===//
1323 
fold(ArrayRef<Attribute> operands)1324 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1325   // Constant values of both types, `shape.size` and `index`, are represented as
1326   // `IntegerAttr`s which makes constant folding simple.
1327   if (Attribute arg = operands[0])
1328     return arg;
1329   return impl::foldCastOp(*this);
1330 }
1331 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1332 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1333                                                 MLIRContext *context) {
1334   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // YieldOp
1339 //===----------------------------------------------------------------------===//
1340 
verify(shape::YieldOp op)1341 static LogicalResult verify(shape::YieldOp op) {
1342   auto *parentOp = op->getParentOp();
1343   auto results = parentOp->getResults();
1344   auto operands = op.getOperands();
1345 
1346   if (parentOp->getNumResults() != op.getNumOperands())
1347     return op.emitOpError() << "number of operands does not match number of "
1348                                "results of its parent";
1349   for (auto e : llvm::zip(results, operands))
1350     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1351       return op.emitOpError()
1352              << "types mismatch between yield op and its parent";
1353 
1354   return success();
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // SplitAtOp
1359 //===----------------------------------------------------------------------===//
1360 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1361 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1362                               SmallVectorImpl<OpFoldResult> &results) {
1363   if (!operands[0] || !operands[1])
1364     return failure();
1365   auto shapeVec = llvm::to_vector<6>(
1366       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1367   auto shape = llvm::makeArrayRef(shapeVec);
1368   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1369   // Verify that the split point is in the correct range.
1370   // TODO: Constant fold to an "error".
1371   int64_t rank = shape.size();
1372   if (!(-rank <= splitPoint && splitPoint <= rank))
1373     return failure();
1374   if (splitPoint < 0)
1375     splitPoint += shape.size();
1376   Builder builder(operands[0].getContext());
1377   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1378   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1379   return success();
1380 }
1381 
1382 //===----------------------------------------------------------------------===//
1383 // ToExtentTensorOp
1384 //===----------------------------------------------------------------------===//
1385 
fold(ArrayRef<Attribute> operands)1386 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1387   if (!operands[0])
1388     return impl::foldCastOp(*this);
1389   Builder builder(getContext());
1390   auto shape = llvm::to_vector<6>(
1391       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1392   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1393                                     builder.getIndexType());
1394   return DenseIntElementsAttr::get(type, shape);
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // ReduceOp
1399 //===----------------------------------------------------------------------===//
1400 
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)1401 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1402                      ValueRange initVals) {
1403   result.addOperands(shape);
1404   result.addOperands(initVals);
1405 
1406   Region *bodyRegion = result.addRegion();
1407   bodyRegion->push_back(new Block);
1408   Block &bodyBlock = bodyRegion->front();
1409   bodyBlock.addArgument(builder.getIndexType());
1410 
1411   Type elementType;
1412   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1413     elementType = tensorType.getElementType();
1414   else
1415     elementType = SizeType::get(builder.getContext());
1416   bodyBlock.addArgument(elementType);
1417 
1418   for (Type initValType : initVals.getTypes()) {
1419     bodyBlock.addArgument(initValType);
1420     result.addTypes(initValType);
1421   }
1422 }
1423 
verify(ReduceOp op)1424 static LogicalResult verify(ReduceOp op) {
1425   // Verify block arg types.
1426   Block &block = op.region().front();
1427 
1428   // The block takes index, extent, and aggregated values as arguments.
1429   auto blockArgsCount = op.initVals().size() + 2;
1430   if (block.getNumArguments() != blockArgsCount)
1431     return op.emitOpError() << "ReduceOp body is expected to have "
1432                             << blockArgsCount << " arguments";
1433 
1434   // The first block argument is the index and must always be of type `index`.
1435   if (!block.getArgument(0).getType().isa<IndexType>())
1436     return op.emitOpError(
1437         "argument 0 of ReduceOp body is expected to be of IndexType");
1438 
1439   // The second block argument is the extent and must be of type `size` or
1440   // `index`, depending on whether the reduce operation is applied to a shape or
1441   // to an extent tensor.
1442   Type extentTy = block.getArgument(1).getType();
1443   if (op.shape().getType().isa<ShapeType>()) {
1444     if (!extentTy.isa<SizeType>())
1445       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1446                             "SizeType if the ReduceOp operates on a ShapeType");
1447   } else {
1448     if (!extentTy.isa<IndexType>())
1449       return op.emitOpError(
1450           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1451           "ReduceOp operates on an extent tensor");
1452   }
1453 
1454   for (auto type : llvm::enumerate(op.initVals()))
1455     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1456       return op.emitOpError()
1457              << "type mismatch between argument " << type.index() + 2
1458              << " of ReduceOp body and initial value " << type.index();
1459   return success();
1460 }
1461 
parseReduceOp(OpAsmParser & parser,OperationState & result)1462 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1463   // Parse operands.
1464   SmallVector<OpAsmParser::OperandType, 3> operands;
1465   Type shapeOrExtentTensorType;
1466   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1467                               OpAsmParser::Delimiter::Paren) ||
1468       parser.parseColonType(shapeOrExtentTensorType) ||
1469       parser.parseOptionalArrowTypeList(result.types))
1470     return failure();
1471 
1472   // Resolve operands.
1473   auto initVals = llvm::makeArrayRef(operands).drop_front();
1474   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1475                             result.operands) ||
1476       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1477                              result.operands))
1478     return failure();
1479 
1480   // Parse the body.
1481   Region *body = result.addRegion();
1482   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1483     return failure();
1484 
1485   // Parse attributes.
1486   if (parser.parseOptionalAttrDict(result.attributes))
1487     return failure();
1488 
1489   return success();
1490 }
1491 
print(OpAsmPrinter & p,ReduceOp op)1492 static void print(OpAsmPrinter &p, ReduceOp op) {
1493   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1494     << ") : " << op.shape().getType();
1495   p.printOptionalArrowTypeList(op.getResultTypes());
1496   p.printRegion(op.region());
1497   p.printOptionalAttrDict(op->getAttrs());
1498 }
1499 
1500 #define GET_OP_CLASSES
1501 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1502