1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10 
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/Function.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/StandardTypes.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 // Pull in all enum type definitions and utility function declarations.
30 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
31 
32 using namespace mlir;
33 
34 //===----------------------------------------------------------------------===//
35 // StandardOpsDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 namespace {
38 /// This class defines the interface for handling inlining with standard
39 /// operations.
40 struct StdInlinerInterface : public DialectInlinerInterface {
41   using DialectInlinerInterface::DialectInlinerInterface;
42 
43   //===--------------------------------------------------------------------===//
44   // Analysis Hooks
45   //===--------------------------------------------------------------------===//
46 
47   /// All operations within standard ops can be inlined.
isLegalToInline__anone4b94cca0111::StdInlinerInterface48   bool isLegalToInline(Operation *, Region *,
49                        BlockAndValueMapping &) const final {
50     return true;
51   }
52 
53   //===--------------------------------------------------------------------===//
54   // Transformation Hooks
55   //===--------------------------------------------------------------------===//
56 
57   /// Handle the given inlined terminator by replacing it with a new operation
58   /// as necessary.
handleTerminator__anone4b94cca0111::StdInlinerInterface59   void handleTerminator(Operation *op, Block *newDest) const final {
60     // Only "std.return" needs to be handled here.
61     auto returnOp = dyn_cast<ReturnOp>(op);
62     if (!returnOp)
63       return;
64 
65     // Replace the return with a branch to the dest.
66     OpBuilder builder(op);
67     builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
68     op->erase();
69   }
70 
71   /// Handle the given inlined terminator by replacing it with a new operation
72   /// as necessary.
handleTerminator__anone4b94cca0111::StdInlinerInterface73   void handleTerminator(Operation *op,
74                         ArrayRef<Value> valuesToRepl) const final {
75     // Only "std.return" needs to be handled here.
76     auto returnOp = cast<ReturnOp>(op);
77 
78     // Replace the values directly with the return operands.
79     assert(returnOp.getNumOperands() == valuesToRepl.size());
80     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
81       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
82   }
83 };
84 } // end anonymous namespace
85 
86 //===----------------------------------------------------------------------===//
87 // StandardOpsDialect
88 //===----------------------------------------------------------------------===//
89 
90 /// A custom unary operation printer that omits the "std." prefix from the
91 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)92 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
93   assert(op->getNumOperands() == 1 && "unary op should have one operand");
94   assert(op->getNumResults() == 1 && "unary op should have one result");
95 
96   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
97   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
98     << op->getOperand(0);
99   p.printOptionalAttrDict(op->getAttrs());
100   p << " : " << op->getOperand(0).getType();
101 }
102 
103 /// A custom binary operation printer that omits the "std." prefix from the
104 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)105 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
106   assert(op->getNumOperands() == 2 && "binary op should have two operands");
107   assert(op->getNumResults() == 1 && "binary op should have one result");
108 
109   // If not all the operand and result types are the same, just use the
110   // generic assembly form to avoid omitting information in printing.
111   auto resultType = op->getResult(0).getType();
112   if (op->getOperand(0).getType() != resultType ||
113       op->getOperand(1).getType() != resultType) {
114     p.printGenericOp(op);
115     return;
116   }
117 
118   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
119   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
120     << op->getOperand(0) << ", " << op->getOperand(1);
121   p.printOptionalAttrDict(op->getAttrs());
122 
123   // Now we can output only one type for all operands and the result.
124   p << " : " << op->getResult(0).getType();
125 }
126 
127 /// A custom cast operation printer that omits the "std." prefix from the
128 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)129 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
130   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
131   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
132     << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
133     << op->getResult(0).getType();
134 }
135 
136 /// A custom cast operation verifier.
137 template <typename T>
verifyCastOp(T op)138 static LogicalResult verifyCastOp(T op) {
139   auto opType = op.getOperand().getType();
140   auto resType = op.getType();
141   if (!T::areCastCompatible(opType, resType))
142     return op.emitError("operand type ") << opType << " and result type "
143                                          << resType << " are cast incompatible";
144 
145   return success();
146 }
147 
StandardOpsDialect(MLIRContext * context)148 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
149     : Dialect(getDialectNamespace(), context) {
150   addOperations<DmaStartOp, DmaWaitOp,
151 #define GET_OP_LIST
152 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
153                 >();
154   addInterfaces<StdInlinerInterface>();
155 }
156 
157 /// Materialize a single constant operation from a given attribute value with
158 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)159 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
160                                                    Attribute value, Type type,
161                                                    Location loc) {
162   return builder.create<ConstantOp>(loc, type, value);
163 }
164 
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & p)165 void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
166                                  Operation::operand_iterator end,
167                                  unsigned numDims, OpAsmPrinter &p) {
168   Operation::operand_range operands(begin, end);
169   p << '(' << operands.take_front(numDims) << ')';
170   if (operands.size() != numDims)
171     p << '[' << operands.drop_front(numDims) << ']';
172 }
173 
174 // Parses dimension and symbol list, and sets 'numDims' to the number of
175 // dimension operands parsed.
176 // Returns 'false' on success and 'true' on error.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)177 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
178                                         SmallVectorImpl<Value> &operands,
179                                         unsigned &numDims) {
180   SmallVector<OpAsmParser::OperandType, 8> opInfos;
181   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
182     return failure();
183   // Store number of dimensions for validation by caller.
184   numDims = opInfos.size();
185 
186   // Parse the optional symbol operands.
187   auto indexTy = parser.getBuilder().getIndexType();
188   if (parser.parseOperandList(opInfos,
189                               OpAsmParser::Delimiter::OptionalSquare) ||
190       parser.resolveOperands(opInfos, indexTy, operands))
191     return failure();
192   return success();
193 }
194 
195 /// Matches a ConstantIndexOp.
196 /// TODO: This should probably just be a general matcher that uses m_Constant
197 /// and checks the operation for an index type.
m_ConstantIndex()198 static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
199   return detail::op_matcher<ConstantIndexOp>();
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Common canonicalization pattern support logic
204 //===----------------------------------------------------------------------===//
205 
206 /// This is a common class used for patterns of the form
207 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
208 /// into the root operation directly.
foldMemRefCast(Operation * op)209 static LogicalResult foldMemRefCast(Operation *op) {
210   bool folded = false;
211   for (OpOperand &operand : op->getOpOperands()) {
212     auto cast = operand.get().getDefiningOp<MemRefCastOp>();
213     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
214       operand.set(cast.getOperand());
215       folded = true;
216     }
217   }
218   return success(folded);
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // AddFOp
223 //===----------------------------------------------------------------------===//
224 
fold(ArrayRef<Attribute> operands)225 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
226   return constFoldBinaryOp<FloatAttr>(
227       operands, [](APFloat a, APFloat b) { return a + b; });
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // AddIOp
232 //===----------------------------------------------------------------------===//
233 
fold(ArrayRef<Attribute> operands)234 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
235   /// addi(x, 0) -> x
236   if (matchPattern(rhs(), m_Zero()))
237     return lhs();
238 
239   return constFoldBinaryOp<IntegerAttr>(operands,
240                                         [](APInt a, APInt b) { return a + b; });
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // AllocOp / AllocaOp
245 //===----------------------------------------------------------------------===//
246 
247 template <typename AllocLikeOp>
printAllocLikeOp(OpAsmPrinter & p,AllocLikeOp op,StringRef name)248 static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) {
249   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
250                 "applies to only alloc or alloca");
251   p << name;
252 
253   // Print dynamic dimension operands.
254   MemRefType type = op.getType();
255   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
256                         type.getNumDynamicDims(), p);
257   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
258   p << " : " << type;
259 }
260 
print(OpAsmPrinter & p,AllocOp op)261 static void print(OpAsmPrinter &p, AllocOp op) {
262   printAllocLikeOp(p, op, "alloc");
263 }
264 
print(OpAsmPrinter & p,AllocaOp op)265 static void print(OpAsmPrinter &p, AllocaOp op) {
266   printAllocLikeOp(p, op, "alloca");
267 }
268 
parseAllocLikeOp(OpAsmParser & parser,OperationState & result)269 static ParseResult parseAllocLikeOp(OpAsmParser &parser,
270                                     OperationState &result) {
271   MemRefType type;
272 
273   // Parse the dimension operands and optional symbol operands, followed by a
274   // memref type.
275   unsigned numDimOperands;
276   if (parseDimAndSymbolList(parser, result.operands, numDimOperands) ||
277       parser.parseOptionalAttrDict(result.attributes) ||
278       parser.parseColonType(type))
279     return failure();
280 
281   // Check numDynamicDims against number of question marks in memref type.
282   // Note: this check remains here (instead of in verify()), because the
283   // partition between dim operands and symbol operands is lost after parsing.
284   // Verification still checks that the total number of operands matches
285   // the number of symbols in the affine map, plus the number of dynamic
286   // dimensions in the memref.
287   if (numDimOperands != type.getNumDynamicDims())
288     return parser.emitError(parser.getNameLoc())
289            << "dimension operand count does not equal memref dynamic dimension "
290               "count";
291   result.types.push_back(type);
292   return success();
293 }
294 
295 template <typename AllocLikeOp>
verify(AllocLikeOp op)296 static LogicalResult verify(AllocLikeOp op) {
297   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
298                 "applies to only alloc or alloca");
299   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
300   if (!memRefType)
301     return op.emitOpError("result must be a memref");
302 
303   unsigned numSymbols = 0;
304   if (!memRefType.getAffineMaps().empty()) {
305     // Store number of symbols used in affine map (used in subsequent check).
306     AffineMap affineMap = memRefType.getAffineMaps()[0];
307     numSymbols = affineMap.getNumSymbols();
308   }
309 
310   // Check that the total number of operands matches the number of symbols in
311   // the affine map, plus the number of dynamic dimensions specified in the
312   // memref type.
313   unsigned numDynamicDims = memRefType.getNumDynamicDims();
314   if (op.getNumOperands() != numDynamicDims + numSymbols)
315     return op.emitOpError(
316         "operand count does not equal dimension plus symbol operand count");
317 
318   // Verify that all operands are of type Index.
319   for (auto operandType : op.getOperandTypes())
320     if (!operandType.isIndex())
321       return op.emitOpError("requires operands to be of type Index");
322 
323   if (std::is_same<AllocLikeOp, AllocOp>::value)
324     return success();
325 
326   // An alloca op needs to have an ancestor with an allocation scope trait.
327   if (!op.template getParentWithTrait<OpTrait::AutomaticAllocationScope>())
328     return op.emitOpError(
329         "requires an ancestor op with AutomaticAllocationScope trait");
330 
331   return success();
332 }
333 
334 namespace {
335 /// Fold constant dimensions into an alloc like operation.
336 template <typename AllocLikeOp>
337 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
338   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
339 
matchAndRewrite__anone4b94cca0411::SimplifyAllocConst340   LogicalResult matchAndRewrite(AllocLikeOp alloc,
341                                 PatternRewriter &rewriter) const override {
342     // Check to see if any dimensions operands are constants.  If so, we can
343     // substitute and drop them.
344     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
345           return matchPattern(operand, m_ConstantIndex());
346         }))
347       return failure();
348 
349     auto memrefType = alloc.getType();
350 
351     // Ok, we have one or more constant operands.  Collect the non-constant ones
352     // and keep track of the resultant memref type to build.
353     SmallVector<int64_t, 4> newShapeConstants;
354     newShapeConstants.reserve(memrefType.getRank());
355     SmallVector<Value, 4> newOperands;
356 
357     unsigned dynamicDimPos = 0;
358     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
359       int64_t dimSize = memrefType.getDimSize(dim);
360       // If this is already static dimension, keep it.
361       if (dimSize != -1) {
362         newShapeConstants.push_back(dimSize);
363         continue;
364       }
365       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
366       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
367         // Dynamic shape dimension will be folded.
368         newShapeConstants.push_back(constantIndexOp.getValue());
369       } else {
370         // Dynamic shape dimension not folded; copy operand from old memref.
371         newShapeConstants.push_back(-1);
372         newOperands.push_back(alloc.getOperand(dynamicDimPos));
373       }
374       dynamicDimPos++;
375     }
376 
377     // Create new memref type (which will have fewer dynamic dimensions).
378     MemRefType newMemRefType =
379         MemRefType::Builder(memrefType).setShape(newShapeConstants);
380     assert(static_cast<int64_t>(newOperands.size()) ==
381            newMemRefType.getNumDynamicDims());
382 
383     // Create and insert the alloc op for the new memref.
384     auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
385                                                  newOperands, IntegerAttr());
386     // Insert a cast so we have the same type as the old alloc.
387     auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
388                                                     alloc.getType());
389 
390     rewriter.replaceOp(alloc, {resultCast});
391     return success();
392   }
393 };
394 
395 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
396 /// but can still be deleted if it has zero uses.
397 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
398   using OpRewritePattern<AllocOp>::OpRewritePattern;
399 
matchAndRewrite__anone4b94cca0411::SimplifyDeadAlloc400   LogicalResult matchAndRewrite(AllocOp alloc,
401                                 PatternRewriter &rewriter) const override {
402     if (alloc.use_empty()) {
403       rewriter.eraseOp(alloc);
404       return success();
405     }
406     return failure();
407   }
408 };
409 } // end anonymous namespace.
410 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)411 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
412                                           MLIRContext *context) {
413   results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
414 }
415 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)416 void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
417                                            MLIRContext *context) {
418   results.insert<SimplifyAllocConst<AllocaOp>>(context);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // AndOp
423 //===----------------------------------------------------------------------===//
424 
fold(ArrayRef<Attribute> operands)425 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
426   /// and(x, 0) -> 0
427   if (matchPattern(rhs(), m_Zero()))
428     return rhs();
429   /// and(x, allOnes) -> x
430   APInt intValue;
431   if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
432       intValue.isAllOnesValue())
433     return lhs();
434   /// and(x,x) -> x
435   if (lhs() == rhs())
436     return rhs();
437 
438   return constFoldBinaryOp<IntegerAttr>(operands,
439                                         [](APInt a, APInt b) { return a & b; });
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // AssertOp
444 //===----------------------------------------------------------------------===//
445 
446 namespace {
447 struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
448   using OpRewritePattern<AssertOp>::OpRewritePattern;
449 
matchAndRewrite__anone4b94cca0711::EraseRedundantAssertions450   LogicalResult matchAndRewrite(AssertOp op,
451                                 PatternRewriter &rewriter) const override {
452     // Erase assertion if argument is constant true.
453     if (matchPattern(op.arg(), m_One())) {
454       rewriter.eraseOp(op);
455       return success();
456     }
457     return failure();
458   }
459 };
460 } // namespace
461 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)462 void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
463                                            MLIRContext *context) {
464   patterns.insert<EraseRedundantAssertions>(context);
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // AssumeAlignmentOp
469 //===----------------------------------------------------------------------===//
470 
verify(AssumeAlignmentOp op)471 static LogicalResult verify(AssumeAlignmentOp op) {
472   unsigned alignment = op.alignment().getZExtValue();
473   if (!llvm::isPowerOf2_32(alignment))
474     return op.emitOpError("alignment must be power of 2");
475   return success();
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // AtomicRMWOp
480 //===----------------------------------------------------------------------===//
481 
verify(AtomicRMWOp op)482 static LogicalResult verify(AtomicRMWOp op) {
483   if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
484     return op.emitOpError(
485         "expects the number of subscripts to be equal to memref rank");
486   switch (op.kind()) {
487   case AtomicRMWKind::addf:
488   case AtomicRMWKind::maxf:
489   case AtomicRMWKind::minf:
490   case AtomicRMWKind::mulf:
491     if (!op.value().getType().isa<FloatType>())
492       return op.emitOpError()
493              << "with kind '" << stringifyAtomicRMWKind(op.kind())
494              << "' expects a floating-point type";
495     break;
496   case AtomicRMWKind::addi:
497   case AtomicRMWKind::maxs:
498   case AtomicRMWKind::maxu:
499   case AtomicRMWKind::mins:
500   case AtomicRMWKind::minu:
501   case AtomicRMWKind::muli:
502     if (!op.value().getType().isa<IntegerType>())
503       return op.emitOpError()
504              << "with kind '" << stringifyAtomicRMWKind(op.kind())
505              << "' expects an integer type";
506     break;
507   default:
508     break;
509   }
510   return success();
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // GenericAtomicRMWOp
515 //===----------------------------------------------------------------------===//
516 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)517 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
518                                Value memref, ValueRange ivs) {
519   result.addOperands(memref);
520   result.addOperands(ivs);
521 
522   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
523     Type elementType = memrefType.getElementType();
524     result.addTypes(elementType);
525 
526     Region *bodyRegion = result.addRegion();
527     bodyRegion->push_back(new Block());
528     bodyRegion->addArgument(elementType);
529   }
530 }
531 
verify(GenericAtomicRMWOp op)532 static LogicalResult verify(GenericAtomicRMWOp op) {
533   auto &body = op.body();
534   if (body.getNumArguments() != 1)
535     return op.emitOpError("expected single number of entry block arguments");
536 
537   if (op.getResult().getType() != body.getArgument(0).getType())
538     return op.emitOpError(
539         "expected block argument of the same type result type");
540 
541   bool hasSideEffects =
542       body.walk([&](Operation *nestedOp) {
543             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
544               return WalkResult::advance();
545             nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
546                                 "only operations with no side effects");
547             return WalkResult::interrupt();
548           })
549           .wasInterrupted();
550   return hasSideEffects ? failure() : success();
551 }
552 
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)553 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
554                                            OperationState &result) {
555   OpAsmParser::OperandType memref;
556   Type memrefType;
557   SmallVector<OpAsmParser::OperandType, 4> ivs;
558 
559   Type indexType = parser.getBuilder().getIndexType();
560   if (parser.parseOperand(memref) ||
561       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
562       parser.parseColonType(memrefType) ||
563       parser.resolveOperand(memref, memrefType, result.operands) ||
564       parser.resolveOperands(ivs, indexType, result.operands))
565     return failure();
566 
567   Region *body = result.addRegion();
568   if (parser.parseRegion(*body, llvm::None, llvm::None))
569     return failure();
570   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
571   return success();
572 }
573 
print(OpAsmPrinter & p,GenericAtomicRMWOp op)574 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
575   p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
576     << "] : " << op.memref().getType();
577   p.printRegion(op.body());
578   p.printOptionalAttrDict(op.getAttrs());
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // AtomicYieldOp
583 //===----------------------------------------------------------------------===//
584 
verify(AtomicYieldOp op)585 static LogicalResult verify(AtomicYieldOp op) {
586   Type parentType = op.getParentOp()->getResultTypes().front();
587   Type resultType = op.result().getType();
588   if (parentType != resultType)
589     return op.emitOpError() << "types mismatch between yield op: " << resultType
590                             << " and its parent: " << parentType;
591   return success();
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // BranchOp
596 //===----------------------------------------------------------------------===//
597 
598 /// Given a successor, try to collapse it to a new destination if it only
599 /// contains a passthrough unconditional branch. If the successor is
600 /// collapsable, `successor` and `successorOperands` are updated to reference
601 /// the new destination and values. `argStorage` is an optional storage to use
602 /// if operands to the collapsed successor need to be remapped.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)603 static LogicalResult collapseBranch(Block *&successor,
604                                     ValueRange &successorOperands,
605                                     SmallVectorImpl<Value> &argStorage) {
606   // Check that the successor only contains a unconditional branch.
607   if (std::next(successor->begin()) != successor->end())
608     return failure();
609   // Check that the terminator is an unconditional branch.
610   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
611   if (!successorBranch)
612     return failure();
613   // Check that the arguments are only used within the terminator.
614   for (BlockArgument arg : successor->getArguments()) {
615     for (Operation *user : arg.getUsers())
616       if (user != successorBranch)
617         return failure();
618   }
619   // Don't try to collapse branches to infinite loops.
620   Block *successorDest = successorBranch.getDest();
621   if (successorDest == successor)
622     return failure();
623 
624   // Update the operands to the successor. If the branch parent has no
625   // arguments, we can use the branch operands directly.
626   OperandRange operands = successorBranch.getOperands();
627   if (successor->args_empty()) {
628     successor = successorDest;
629     successorOperands = operands;
630     return success();
631   }
632 
633   // Otherwise, we need to remap any argument operands.
634   for (Value operand : operands) {
635     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
636     if (argOperand && argOperand.getOwner() == successor)
637       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
638     else
639       argStorage.push_back(operand);
640   }
641   successor = successorDest;
642   successorOperands = argStorage;
643   return success();
644 }
645 
646 namespace {
647 /// Simplify a branch to a block that has a single predecessor. This effectively
648 /// merges the two blocks.
649 struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
650   using OpRewritePattern<BranchOp>::OpRewritePattern;
651 
matchAndRewrite__anone4b94cca0911::SimplifyBrToBlockWithSinglePred652   LogicalResult matchAndRewrite(BranchOp op,
653                                 PatternRewriter &rewriter) const override {
654     // Check that the successor block has a single predecessor.
655     Block *succ = op.getDest();
656     Block *opParent = op.getOperation()->getBlock();
657     if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
658       return failure();
659 
660     // Merge the successor into the current block and erase the branch.
661     rewriter.mergeBlocks(succ, opParent, op.getOperands());
662     rewriter.eraseOp(op);
663     return success();
664   }
665 };
666 
667 ///   br ^bb1
668 /// ^bb1
669 ///   br ^bbN(...)
670 ///
671 ///  -> br ^bbN(...)
672 ///
673 struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
674   using OpRewritePattern<BranchOp>::OpRewritePattern;
675 
matchAndRewrite__anone4b94cca0911::SimplifyPassThroughBr676   LogicalResult matchAndRewrite(BranchOp op,
677                                 PatternRewriter &rewriter) const override {
678     Block *dest = op.getDest();
679     ValueRange destOperands = op.getOperands();
680     SmallVector<Value, 4> destOperandStorage;
681 
682     // Try to collapse the successor if it points somewhere other than this
683     // block.
684     if (dest == op.getOperation()->getBlock() ||
685         failed(collapseBranch(dest, destOperands, destOperandStorage)))
686       return failure();
687 
688     // Create a new branch with the collapsed successor.
689     rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
690     return success();
691   }
692 };
693 } // end anonymous namespace.
694 
getDest()695 Block *BranchOp::getDest() { return getSuccessor(); }
696 
setDest(Block * block)697 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
698 
eraseOperand(unsigned index)699 void BranchOp::eraseOperand(unsigned index) {
700   getOperation()->eraseOperand(index);
701 }
702 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)703 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
704                                            MLIRContext *context) {
705   results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
706       context);
707 }
708 
709 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)710 BranchOp::getMutableSuccessorOperands(unsigned index) {
711   assert(index == 0 && "invalid successor index");
712   return destOperandsMutable();
713 }
714 
getSuccessorForOperands(ArrayRef<Attribute>)715 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
716 
717 //===----------------------------------------------------------------------===//
718 // CallOp
719 //===----------------------------------------------------------------------===//
720 
verify(CallOp op)721 static LogicalResult verify(CallOp op) {
722   // Check that the callee attribute was specified.
723   auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
724   if (!fnAttr)
725     return op.emitOpError("requires a 'callee' symbol reference attribute");
726   auto fn =
727       op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
728   if (!fn)
729     return op.emitOpError() << "'" << fnAttr.getValue()
730                             << "' does not reference a valid function";
731 
732   // Verify that the operand and result types match the callee.
733   auto fnType = fn.getType();
734   if (fnType.getNumInputs() != op.getNumOperands())
735     return op.emitOpError("incorrect number of operands for callee");
736 
737   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
738     if (op.getOperand(i).getType() != fnType.getInput(i))
739       return op.emitOpError("operand type mismatch");
740 
741   if (fnType.getNumResults() != op.getNumResults())
742     return op.emitOpError("incorrect number of results for callee");
743 
744   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
745     if (op.getResult(i).getType() != fnType.getResult(i))
746       return op.emitOpError("result type mismatch");
747 
748   return success();
749 }
750 
getCalleeType()751 FunctionType CallOp::getCalleeType() {
752   SmallVector<Type, 8> argTypes(getOperandTypes());
753   return FunctionType::get(argTypes, getResultTypes(), getContext());
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // CallIndirectOp
758 //===----------------------------------------------------------------------===//
759 namespace {
760 /// Fold indirect calls that have a constant function as the callee operand.
761 struct SimplifyIndirectCallWithKnownCallee
762     : public OpRewritePattern<CallIndirectOp> {
763   using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
764 
matchAndRewrite__anone4b94cca0a11::SimplifyIndirectCallWithKnownCallee765   LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
766                                 PatternRewriter &rewriter) const override {
767     // Check that the callee is a constant callee.
768     SymbolRefAttr calledFn;
769     if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
770       return failure();
771 
772     // Replace with a direct call.
773     rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
774                                         indirectCall.getResultTypes(),
775                                         indirectCall.getArgOperands());
776     return success();
777   }
778 };
779 } // end anonymous namespace.
780 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)781 void CallIndirectOp::getCanonicalizationPatterns(
782     OwningRewritePatternList &results, MLIRContext *context) {
783   results.insert<SimplifyIndirectCallWithKnownCallee>(context);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // General helpers for comparison ops
788 //===----------------------------------------------------------------------===//
789 
790 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)791 static Type getI1SameShape(Type type) {
792   auto i1Type = IntegerType::get(1, type.getContext());
793   if (auto tensorType = type.dyn_cast<RankedTensorType>())
794     return RankedTensorType::get(tensorType.getShape(), i1Type);
795   if (type.isa<UnrankedTensorType>())
796     return UnrankedTensorType::get(i1Type);
797   if (auto vectorType = type.dyn_cast<VectorType>())
798     return VectorType::get(vectorType.getShape(), i1Type);
799   return i1Type;
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // CmpIOp
804 //===----------------------------------------------------------------------===//
805 
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)806 static void buildCmpIOp(OpBuilder &build, OperationState &result,
807                         CmpIPredicate predicate, Value lhs, Value rhs) {
808   result.addOperands({lhs, rhs});
809   result.types.push_back(getI1SameShape(lhs.getType()));
810   result.addAttribute(CmpIOp::getPredicateAttrName(),
811                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
812 }
813 
814 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
815 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)816 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
817                              const APInt &rhs) {
818   switch (predicate) {
819   case CmpIPredicate::eq:
820     return lhs.eq(rhs);
821   case CmpIPredicate::ne:
822     return lhs.ne(rhs);
823   case CmpIPredicate::slt:
824     return lhs.slt(rhs);
825   case CmpIPredicate::sle:
826     return lhs.sle(rhs);
827   case CmpIPredicate::sgt:
828     return lhs.sgt(rhs);
829   case CmpIPredicate::sge:
830     return lhs.sge(rhs);
831   case CmpIPredicate::ult:
832     return lhs.ult(rhs);
833   case CmpIPredicate::ule:
834     return lhs.ule(rhs);
835   case CmpIPredicate::ugt:
836     return lhs.ugt(rhs);
837   case CmpIPredicate::uge:
838     return lhs.uge(rhs);
839   }
840   llvm_unreachable("unknown comparison predicate");
841 }
842 
843 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)844 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
845   assert(operands.size() == 2 && "cmpi takes two arguments");
846 
847   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
848   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
849   if (!lhs || !rhs)
850     return {};
851 
852   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
853   return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // CmpFOp
858 //===----------------------------------------------------------------------===//
859 
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)860 static void buildCmpFOp(OpBuilder &build, OperationState &result,
861                         CmpFPredicate predicate, Value lhs, Value rhs) {
862   result.addOperands({lhs, rhs});
863   result.types.push_back(getI1SameShape(lhs.getType()));
864   result.addAttribute(CmpFOp::getPredicateAttrName(),
865                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
866 }
867 
868 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
869 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)870 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
871                              const APFloat &rhs) {
872   auto cmpResult = lhs.compare(rhs);
873   switch (predicate) {
874   case CmpFPredicate::AlwaysFalse:
875     return false;
876   case CmpFPredicate::OEQ:
877     return cmpResult == APFloat::cmpEqual;
878   case CmpFPredicate::OGT:
879     return cmpResult == APFloat::cmpGreaterThan;
880   case CmpFPredicate::OGE:
881     return cmpResult == APFloat::cmpGreaterThan ||
882            cmpResult == APFloat::cmpEqual;
883   case CmpFPredicate::OLT:
884     return cmpResult == APFloat::cmpLessThan;
885   case CmpFPredicate::OLE:
886     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
887   case CmpFPredicate::ONE:
888     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
889   case CmpFPredicate::ORD:
890     return cmpResult != APFloat::cmpUnordered;
891   case CmpFPredicate::UEQ:
892     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
893   case CmpFPredicate::UGT:
894     return cmpResult == APFloat::cmpUnordered ||
895            cmpResult == APFloat::cmpGreaterThan;
896   case CmpFPredicate::UGE:
897     return cmpResult == APFloat::cmpUnordered ||
898            cmpResult == APFloat::cmpGreaterThan ||
899            cmpResult == APFloat::cmpEqual;
900   case CmpFPredicate::ULT:
901     return cmpResult == APFloat::cmpUnordered ||
902            cmpResult == APFloat::cmpLessThan;
903   case CmpFPredicate::ULE:
904     return cmpResult == APFloat::cmpUnordered ||
905            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
906   case CmpFPredicate::UNE:
907     return cmpResult != APFloat::cmpEqual;
908   case CmpFPredicate::UNO:
909     return cmpResult == APFloat::cmpUnordered;
910   case CmpFPredicate::AlwaysTrue:
911     return true;
912   }
913   llvm_unreachable("unknown comparison predicate");
914 }
915 
916 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)917 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
918   assert(operands.size() == 2 && "cmpf takes two arguments");
919 
920   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
921   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
922 
923   // TODO: We could actually do some intelligent things if we know only one
924   // of the operands, but it's inf or nan.
925   if (!lhs || !rhs)
926     return {};
927 
928   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
929   return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // CondBranchOp
934 //===----------------------------------------------------------------------===//
935 
936 namespace {
937 /// cond_br true, ^bb1, ^bb2
938 ///  -> br ^bb1
939 /// cond_br false, ^bb1, ^bb2
940 ///  -> br ^bb2
941 ///
942 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
943   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
944 
matchAndRewrite__anone4b94cca0b11::SimplifyConstCondBranchPred945   LogicalResult matchAndRewrite(CondBranchOp condbr,
946                                 PatternRewriter &rewriter) const override {
947     if (matchPattern(condbr.getCondition(), m_NonZero())) {
948       // True branch taken.
949       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
950                                             condbr.getTrueOperands());
951       return success();
952     } else if (matchPattern(condbr.getCondition(), m_Zero())) {
953       // False branch taken.
954       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
955                                             condbr.getFalseOperands());
956       return success();
957     }
958     return failure();
959   }
960 };
961 
962 ///   cond_br %cond, ^bb1, ^bb2
963 /// ^bb1
964 ///   br ^bbN(...)
965 /// ^bb2
966 ///   br ^bbK(...)
967 ///
968 ///  -> cond_br %cond, ^bbN(...), ^bbK(...)
969 ///
970 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
971   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
972 
matchAndRewrite__anone4b94cca0b11::SimplifyPassThroughCondBranch973   LogicalResult matchAndRewrite(CondBranchOp condbr,
974                                 PatternRewriter &rewriter) const override {
975     Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
976     ValueRange trueDestOperands = condbr.getTrueOperands();
977     ValueRange falseDestOperands = condbr.getFalseOperands();
978     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
979 
980     // Try to collapse one of the current successors.
981     LogicalResult collapsedTrue =
982         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
983     LogicalResult collapsedFalse =
984         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
985     if (failed(collapsedTrue) && failed(collapsedFalse))
986       return failure();
987 
988     // Create a new branch with the collapsed successors.
989     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
990                                               trueDest, trueDestOperands,
991                                               falseDest, falseDestOperands);
992     return success();
993   }
994 };
995 
996 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
997 ///  -> br ^bb1(A, ..., N)
998 ///
999 /// cond_br %cond, ^bb1(A), ^bb1(B)
1000 ///  -> %select = select %cond, A, B
1001 ///     br ^bb1(%select)
1002 ///
1003 struct SimplifyCondBranchIdenticalSuccessors
1004     : public OpRewritePattern<CondBranchOp> {
1005   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1006 
matchAndRewrite__anone4b94cca0b11::SimplifyCondBranchIdenticalSuccessors1007   LogicalResult matchAndRewrite(CondBranchOp condbr,
1008                                 PatternRewriter &rewriter) const override {
1009     // Check that the true and false destinations are the same and have the same
1010     // operands.
1011     Block *trueDest = condbr.trueDest();
1012     if (trueDest != condbr.falseDest())
1013       return failure();
1014 
1015     // If all of the operands match, no selects need to be generated.
1016     OperandRange trueOperands = condbr.getTrueOperands();
1017     OperandRange falseOperands = condbr.getFalseOperands();
1018     if (trueOperands == falseOperands) {
1019       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
1020       return success();
1021     }
1022 
1023     // Otherwise, if the current block is the only predecessor insert selects
1024     // for any mismatched branch operands.
1025     if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
1026       return failure();
1027 
1028     // Generate a select for any operands that differ between the two.
1029     SmallVector<Value, 8> mergedOperands;
1030     mergedOperands.reserve(trueOperands.size());
1031     Value condition = condbr.getCondition();
1032     for (auto it : llvm::zip(trueOperands, falseOperands)) {
1033       if (std::get<0>(it) == std::get<1>(it))
1034         mergedOperands.push_back(std::get<0>(it));
1035       else
1036         mergedOperands.push_back(rewriter.create<SelectOp>(
1037             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1038     }
1039 
1040     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1041     return success();
1042   }
1043 };
1044 } // end anonymous namespace
1045 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1046 void CondBranchOp::getCanonicalizationPatterns(
1047     OwningRewritePatternList &results, MLIRContext *context) {
1048   results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1049                  SimplifyCondBranchIdenticalSuccessors>(context);
1050 }
1051 
1052 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1053 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1054   assert(index < getNumSuccessors() && "invalid successor index");
1055   return index == trueIndex ? trueDestOperandsMutable()
1056                             : falseDestOperandsMutable();
1057 }
1058 
getSuccessorForOperands(ArrayRef<Attribute> operands)1059 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1060   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1061     return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1062   return nullptr;
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // Constant*Op
1067 //===----------------------------------------------------------------------===//
1068 
print(OpAsmPrinter & p,ConstantOp & op)1069 static void print(OpAsmPrinter &p, ConstantOp &op) {
1070   p << "constant ";
1071   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
1072 
1073   if (op.getAttrs().size() > 1)
1074     p << ' ';
1075   p << op.getValue();
1076 
1077   // If the value is a symbol reference, print a trailing type.
1078   if (op.getValue().isa<SymbolRefAttr>())
1079     p << " : " << op.getType();
1080 }
1081 
parseConstantOp(OpAsmParser & parser,OperationState & result)1082 static ParseResult parseConstantOp(OpAsmParser &parser,
1083                                    OperationState &result) {
1084   Attribute valueAttr;
1085   if (parser.parseOptionalAttrDict(result.attributes) ||
1086       parser.parseAttribute(valueAttr, "value", result.attributes))
1087     return failure();
1088 
1089   // If the attribute is a symbol reference, then we expect a trailing type.
1090   Type type;
1091   if (!valueAttr.isa<SymbolRefAttr>())
1092     type = valueAttr.getType();
1093   else if (parser.parseColonType(type))
1094     return failure();
1095 
1096   // Add the attribute type to the list.
1097   return parser.addTypeToList(type, result.types);
1098 }
1099 
1100 /// The constant op requires an attribute, and furthermore requires that it
1101 /// matches the return type.
verify(ConstantOp & op)1102 static LogicalResult verify(ConstantOp &op) {
1103   auto value = op.getValue();
1104   if (!value)
1105     return op.emitOpError("requires a 'value' attribute");
1106 
1107   auto type = op.getType();
1108   if (!value.getType().isa<NoneType>() && type != value.getType())
1109     return op.emitOpError() << "requires attribute's type (" << value.getType()
1110                             << ") to match op's return type (" << type << ")";
1111 
1112   if (type.isa<IndexType>() || value.isa<BoolAttr>())
1113     return success();
1114 
1115   if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1116     // If the type has a known bitwidth we verify that the value can be
1117     // represented with the given bitwidth.
1118     auto bitwidth = type.cast<IntegerType>().getWidth();
1119     auto intVal = intAttr.getValue();
1120     if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1121       return op.emitOpError("requires 'value' to be an integer within the "
1122                             "range of the integer result type");
1123     return success();
1124   }
1125 
1126   if (type.isa<FloatType>()) {
1127     if (!value.isa<FloatAttr>())
1128       return op.emitOpError("requires 'value' to be a floating point constant");
1129     return success();
1130   }
1131 
1132   if (type.isa<ShapedType>()) {
1133     if (!value.isa<ElementsAttr>())
1134       return op.emitOpError("requires 'value' to be a shaped constant");
1135     return success();
1136   }
1137 
1138   if (type.isa<FunctionType>()) {
1139     auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1140     if (!fnAttr)
1141       return op.emitOpError("requires 'value' to be a function reference");
1142 
1143     // Try to find the referenced function.
1144     auto fn =
1145         op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1146     if (!fn)
1147       return op.emitOpError()
1148              << "reference to undefined function '" << fnAttr.getValue() << "'";
1149 
1150     // Check that the referenced function has the correct type.
1151     if (fn.getType() != type)
1152       return op.emitOpError("reference to function with mismatched type");
1153 
1154     return success();
1155   }
1156 
1157   if (type.isa<NoneType>() && value.isa<UnitAttr>())
1158     return success();
1159 
1160   return op.emitOpError("unsupported 'value' attribute: ") << value;
1161 }
1162 
fold(ArrayRef<Attribute> operands)1163 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1164   assert(operands.empty() && "constant has no operands");
1165   return getValue();
1166 }
1167 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1168 void ConstantOp::getAsmResultNames(
1169     function_ref<void(Value, StringRef)> setNameFn) {
1170   Type type = getType();
1171   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1172     IntegerType intTy = type.dyn_cast<IntegerType>();
1173 
1174     // Sugar i1 constants with 'true' and 'false'.
1175     if (intTy && intTy.getWidth() == 1)
1176       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1177 
1178     // Otherwise, build a complex name with the value and type.
1179     SmallString<32> specialNameBuffer;
1180     llvm::raw_svector_ostream specialName(specialNameBuffer);
1181     specialName << 'c' << intCst.getInt();
1182     if (intTy)
1183       specialName << '_' << type;
1184     setNameFn(getResult(), specialName.str());
1185 
1186   } else if (type.isa<FunctionType>()) {
1187     setNameFn(getResult(), "f");
1188   } else {
1189     setNameFn(getResult(), "cst");
1190   }
1191 }
1192 
1193 /// Returns true if a constant operation can be built with the given value and
1194 /// result type.
isBuildableWith(Attribute value,Type type)1195 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1196   // SymbolRefAttr can only be used with a function type.
1197   if (value.isa<SymbolRefAttr>())
1198     return type.isa<FunctionType>();
1199   // Otherwise, the attribute must have the same type as 'type'.
1200   if (value.getType() != type)
1201     return false;
1202   // Finally, check that the attribute kind is handled.
1203   return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1204 }
1205 
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1206 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1207                             const APFloat &value, FloatType type) {
1208   ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1209 }
1210 
classof(Operation * op)1211 bool ConstantFloatOp::classof(Operation *op) {
1212   return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1213 }
1214 
1215 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1216 bool ConstantIntOp::classof(Operation *op) {
1217   return ConstantOp::classof(op) &&
1218          op->getResult(0).getType().isSignlessInteger();
1219 }
1220 
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1221 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1222                           int64_t value, unsigned width) {
1223   Type type = builder.getIntegerType(width);
1224   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1225 }
1226 
1227 /// Build a constant int op producing an integer with the specified type,
1228 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1229 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1230                           int64_t value, Type type) {
1231   assert(type.isSignlessInteger() &&
1232          "ConstantIntOp can only have signless integer type");
1233   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1234 }
1235 
1236 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1237 bool ConstantIndexOp::classof(Operation *op) {
1238   return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1239 }
1240 
build(OpBuilder & builder,OperationState & result,int64_t value)1241 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1242                             int64_t value) {
1243   Type type = builder.getIndexType();
1244   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1245 }
1246 
1247 //===----------------------------------------------------------------------===//
1248 // DeallocOp
1249 //===----------------------------------------------------------------------===//
1250 namespace {
1251 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
1252 /// by other Dealloc operations.
1253 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
1254   using OpRewritePattern<DeallocOp>::OpRewritePattern;
1255 
matchAndRewrite__anone4b94cca0c11::SimplifyDeadDealloc1256   LogicalResult matchAndRewrite(DeallocOp dealloc,
1257                                 PatternRewriter &rewriter) const override {
1258     // Check that the memref operand's defining operation is an AllocOp.
1259     Value memref = dealloc.memref();
1260     if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
1261       return failure();
1262 
1263     // Check that all of the uses of the AllocOp are other DeallocOps.
1264     for (auto *user : memref.getUsers())
1265       if (!isa<DeallocOp>(user))
1266         return failure();
1267 
1268     // Erase the dealloc operation.
1269     rewriter.eraseOp(dealloc);
1270     return success();
1271   }
1272 };
1273 } // end anonymous namespace.
1274 
verify(DeallocOp op)1275 static LogicalResult verify(DeallocOp op) {
1276   if (!op.memref().getType().isa<MemRefType>())
1277     return op.emitOpError("operand must be a memref");
1278   return success();
1279 }
1280 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1281 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1282                                             MLIRContext *context) {
1283   results.insert<SimplifyDeadDealloc>(context);
1284 }
1285 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1286 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
1287                               SmallVectorImpl<OpFoldResult> &results) {
1288   /// dealloc(memrefcast) -> dealloc
1289   return foldMemRefCast(*this);
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // DimOp
1294 //===----------------------------------------------------------------------===//
1295 
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,int64_t index)1296 void DimOp::build(OpBuilder &builder, OperationState &result,
1297                   Value memrefOrTensor, int64_t index) {
1298   auto loc = result.location;
1299   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
1300   build(builder, result, memrefOrTensor, indexValue);
1301 }
1302 
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,Value index)1303 void DimOp::build(OpBuilder &builder, OperationState &result,
1304                   Value memrefOrTensor, Value index) {
1305   auto indexTy = builder.getIndexType();
1306   build(builder, result, indexTy, memrefOrTensor, index);
1307 }
1308 
getConstantIndex()1309 Optional<int64_t> DimOp::getConstantIndex() {
1310   if (auto constantOp = index().getDefiningOp<ConstantOp>())
1311     return constantOp.getValue().cast<IntegerAttr>().getInt();
1312   return {};
1313 }
1314 
verify(DimOp op)1315 static LogicalResult verify(DimOp op) {
1316 
1317   // Assume unknown index to be in range.
1318   Optional<int64_t> index = op.getConstantIndex();
1319   if (!index.hasValue())
1320     return success();
1321 
1322   // Check that constant index is not knowingly out of range.
1323   auto type = op.memrefOrTensor().getType();
1324   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
1325     if (index.getValue() >= tensorType.getRank())
1326       return op.emitOpError("index is out of range");
1327   } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
1328     if (index.getValue() >= memrefType.getRank())
1329       return op.emitOpError("index is out of range");
1330   } else if (type.isa<UnrankedTensorType>()) {
1331     // Assume index to be in range.
1332   } else {
1333     llvm_unreachable("expected operand with tensor or memref type");
1334   }
1335 
1336   return success();
1337 }
1338 
fold(ArrayRef<Attribute> operands)1339 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1340   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
1341 
1342   // All forms of folding require a known index.
1343   if (!index)
1344     return {};
1345 
1346   // Fold if the shape extent along the given index is known.
1347   auto argTy = memrefOrTensor().getType();
1348   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
1349     if (!shapedTy.isDynamicDim(index.getInt())) {
1350       Builder builder(getContext());
1351       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
1352     }
1353   }
1354 
1355   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1356   auto memrefType = argTy.dyn_cast<MemRefType>();
1357   if (!memrefType)
1358     return {};
1359 
1360   // The size at the given index is now known to be a dynamic size of a memref.
1361   auto memref = memrefOrTensor().getDefiningOp();
1362   unsigned unsignedIndex = index.getValue().getZExtValue();
1363   if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
1364     return *(alloc.getDynamicSizes().begin() +
1365              memrefType.getDynamicDimIndex(unsignedIndex));
1366 
1367   if (auto view = dyn_cast_or_null<ViewOp>(memref))
1368     return *(view.getDynamicSizes().begin() +
1369              memrefType.getDynamicDimIndex(unsignedIndex));
1370 
1371   if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
1372     assert(subview.isDynamicSize(unsignedIndex) &&
1373            "Expected dynamic subview size");
1374     return subview.getDynamicSize(unsignedIndex);
1375   }
1376 
1377   // dim(memrefcast) -> dim
1378   if (succeeded(foldMemRefCast(*this)))
1379     return getResult();
1380 
1381   return {};
1382 }
1383 
1384 // ---------------------------------------------------------------------------
1385 // DmaStartOp
1386 // ---------------------------------------------------------------------------
1387 
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)1388 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1389                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1390                        ValueRange destIndices, Value numElements,
1391                        Value tagMemRef, ValueRange tagIndices, Value stride,
1392                        Value elementsPerStride) {
1393   result.addOperands(srcMemRef);
1394   result.addOperands(srcIndices);
1395   result.addOperands(destMemRef);
1396   result.addOperands(destIndices);
1397   result.addOperands({numElements, tagMemRef});
1398   result.addOperands(tagIndices);
1399   if (stride)
1400     result.addOperands({stride, elementsPerStride});
1401 }
1402 
print(OpAsmPrinter & p)1403 void DmaStartOp::print(OpAsmPrinter &p) {
1404   p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1405     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1406     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1407   if (isStrided())
1408     p << ", " << getStride() << ", " << getNumElementsPerStride();
1409 
1410   p.printOptionalAttrDict(getAttrs());
1411   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1412     << ", " << getTagMemRef().getType();
1413 }
1414 
1415 // Parse DmaStartOp.
1416 // Ex:
1417 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1418 //                       %tag[%index], %stride, %num_elt_per_stride :
1419 //                     : memref<3076 x f32, 0>,
1420 //                       memref<1024 x f32, 2>,
1421 //                       memref<1 x i32>
1422 //
parse(OpAsmParser & parser,OperationState & result)1423 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1424   OpAsmParser::OperandType srcMemRefInfo;
1425   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
1426   OpAsmParser::OperandType dstMemRefInfo;
1427   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
1428   OpAsmParser::OperandType numElementsInfo;
1429   OpAsmParser::OperandType tagMemrefInfo;
1430   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
1431   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1432 
1433   SmallVector<Type, 3> types;
1434   auto indexType = parser.getBuilder().getIndexType();
1435 
1436   // Parse and resolve the following list of operands:
1437   // *) source memref followed by its indices (in square brackets).
1438   // *) destination memref followed by its indices (in square brackets).
1439   // *) dma size in KiB.
1440   if (parser.parseOperand(srcMemRefInfo) ||
1441       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1442       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1443       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1444       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1445       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1446       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1447     return failure();
1448 
1449   // Parse optional stride and elements per stride.
1450   if (parser.parseTrailingOperandList(strideInfo))
1451     return failure();
1452 
1453   bool isStrided = strideInfo.size() == 2;
1454   if (!strideInfo.empty() && !isStrided) {
1455     return parser.emitError(parser.getNameLoc(),
1456                             "expected two stride related operands");
1457   }
1458 
1459   if (parser.parseColonTypeList(types))
1460     return failure();
1461   if (types.size() != 3)
1462     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1463 
1464   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1465       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1466       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1467       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1468       // size should be an index.
1469       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1470       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1471       // tag indices should be index.
1472       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1473     return failure();
1474 
1475   if (isStrided) {
1476     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1477       return failure();
1478   }
1479 
1480   return success();
1481 }
1482 
verify()1483 LogicalResult DmaStartOp::verify() {
1484   unsigned numOperands = getNumOperands();
1485 
1486   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1487   // the number of elements.
1488   if (numOperands < 4)
1489     return emitOpError("expected at least 4 operands");
1490 
1491   // Check types of operands. The order of these calls is important: the later
1492   // calls rely on some type properties to compute the operand position.
1493   // 1. Source memref.
1494   if (!getSrcMemRef().getType().isa<MemRefType>())
1495     return emitOpError("expected source to be of memref type");
1496   if (numOperands < getSrcMemRefRank() + 4)
1497     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1498                          << " operands";
1499   if (!getSrcIndices().empty() &&
1500       !llvm::all_of(getSrcIndices().getTypes(),
1501                     [](Type t) { return t.isIndex(); }))
1502     return emitOpError("expected source indices to be of index type");
1503 
1504   // 2. Destination memref.
1505   if (!getDstMemRef().getType().isa<MemRefType>())
1506     return emitOpError("expected destination to be of memref type");
1507   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1508   if (numOperands < numExpectedOperands)
1509     return emitOpError() << "expected at least " << numExpectedOperands
1510                          << " operands";
1511   if (!getDstIndices().empty() &&
1512       !llvm::all_of(getDstIndices().getTypes(),
1513                     [](Type t) { return t.isIndex(); }))
1514     return emitOpError("expected destination indices to be of index type");
1515 
1516   // 3. Number of elements.
1517   if (!getNumElements().getType().isIndex())
1518     return emitOpError("expected num elements to be of index type");
1519 
1520   // 4. Tag memref.
1521   if (!getTagMemRef().getType().isa<MemRefType>())
1522     return emitOpError("expected tag to be of memref type");
1523   numExpectedOperands += getTagMemRefRank();
1524   if (numOperands < numExpectedOperands)
1525     return emitOpError() << "expected at least " << numExpectedOperands
1526                          << " operands";
1527   if (!getTagIndices().empty() &&
1528       !llvm::all_of(getTagIndices().getTypes(),
1529                     [](Type t) { return t.isIndex(); }))
1530     return emitOpError("expected tag indices to be of index type");
1531 
1532   // DMAs from different memory spaces supported.
1533   if (getSrcMemorySpace() == getDstMemorySpace())
1534     return emitOpError("DMA should be between different memory spaces");
1535 
1536   // Optional stride-related operands must be either both present or both
1537   // absent.
1538   if (numOperands != numExpectedOperands &&
1539       numOperands != numExpectedOperands + 2)
1540     return emitOpError("incorrect number of operands");
1541 
1542   // 5. Strides.
1543   if (isStrided()) {
1544     if (!getStride().getType().isIndex() ||
1545         !getNumElementsPerStride().getType().isIndex())
1546       return emitOpError(
1547           "expected stride and num elements per stride to be of type index");
1548   }
1549 
1550   return success();
1551 }
1552 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1553 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1554                                SmallVectorImpl<OpFoldResult> &results) {
1555   /// dma_start(memrefcast) -> dma_start
1556   return foldMemRefCast(*this);
1557 }
1558 
1559 // ---------------------------------------------------------------------------
1560 // DmaWaitOp
1561 // ---------------------------------------------------------------------------
1562 
build(OpBuilder & builder,OperationState & result,Value tagMemRef,ValueRange tagIndices,Value numElements)1563 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1564                       Value tagMemRef, ValueRange tagIndices,
1565                       Value numElements) {
1566   result.addOperands(tagMemRef);
1567   result.addOperands(tagIndices);
1568   result.addOperands(numElements);
1569 }
1570 
print(OpAsmPrinter & p)1571 void DmaWaitOp::print(OpAsmPrinter &p) {
1572   p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
1573     << getNumElements();
1574   p.printOptionalAttrDict(getAttrs());
1575   p << " : " << getTagMemRef().getType();
1576 }
1577 
1578 // Parse DmaWaitOp.
1579 // Eg:
1580 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1581 //
parse(OpAsmParser & parser,OperationState & result)1582 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1583   OpAsmParser::OperandType tagMemrefInfo;
1584   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1585   Type type;
1586   auto indexType = parser.getBuilder().getIndexType();
1587   OpAsmParser::OperandType numElementsInfo;
1588 
1589   // Parse tag memref, its indices, and dma size.
1590   if (parser.parseOperand(tagMemrefInfo) ||
1591       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1592       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1593       parser.parseColonType(type) ||
1594       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1595       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1596       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1597     return failure();
1598 
1599   return success();
1600 }
1601 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1602 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1603                               SmallVectorImpl<OpFoldResult> &results) {
1604   /// dma_wait(memrefcast) -> dma_wait
1605   return foldMemRefCast(*this);
1606 }
1607 
verify()1608 LogicalResult DmaWaitOp::verify() {
1609   // Mandatory non-variadic operands are tag and the number of elements.
1610   if (getNumOperands() < 2)
1611     return emitOpError() << "expected at least 2 operands";
1612 
1613   // Check types of operands. The order of these calls is important: the later
1614   // calls rely on some type properties to compute the operand position.
1615   if (!getTagMemRef().getType().isa<MemRefType>())
1616     return emitOpError() << "expected tag to be of memref type";
1617 
1618   if (getNumOperands() != 2 + getTagMemRefRank())
1619     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1620                          << " operands";
1621 
1622   if (!getTagIndices().empty() &&
1623       !llvm::all_of(getTagIndices().getTypes(),
1624                     [](Type t) { return t.isIndex(); }))
1625     return emitOpError() << "expected tag indices to be of index type";
1626 
1627   if (!getNumElements().getType().isIndex())
1628     return emitOpError()
1629            << "expected the number of elements to be of index type";
1630 
1631   return success();
1632 }
1633 
1634 //===----------------------------------------------------------------------===//
1635 // ExtractElementOp
1636 //===----------------------------------------------------------------------===//
1637 
verify(ExtractElementOp op)1638 static LogicalResult verify(ExtractElementOp op) {
1639   // Verify the # indices match if we have a ranked type.
1640   auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
1641   if (aggregateType.hasRank() &&
1642       aggregateType.getRank() != op.getNumOperands() - 1)
1643     return op.emitOpError("incorrect number of indices for extract_element");
1644 
1645   return success();
1646 }
1647 
fold(ArrayRef<Attribute> operands)1648 OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1649   assert(!operands.empty() && "extract_element takes at least one operand");
1650 
1651   // The aggregate operand must be a known constant.
1652   Attribute aggregate = operands.front();
1653   if (!aggregate)
1654     return {};
1655 
1656   // If this is a splat elements attribute, simply return the value. All of the
1657   // elements of a splat attribute are the same.
1658   if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
1659     return splatAggregate.getSplatValue();
1660 
1661   // Otherwise, collect the constant indices into the aggregate.
1662   SmallVector<uint64_t, 8> indices;
1663   for (Attribute indice : llvm::drop_begin(operands, 1)) {
1664     if (!indice || !indice.isa<IntegerAttr>())
1665       return {};
1666     indices.push_back(indice.cast<IntegerAttr>().getInt());
1667   }
1668 
1669   // If this is an elements attribute, query the value at the given indices.
1670   auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
1671   if (elementsAttr && elementsAttr.isValidIndex(indices))
1672     return elementsAttr.getValue(indices);
1673   return {};
1674 }
1675 
1676 //===----------------------------------------------------------------------===//
1677 // TensorFromElementsOp
1678 //===----------------------------------------------------------------------===//
1679 
parseTensorFromElementsOp(OpAsmParser & parser,OperationState & result)1680 static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
1681                                              OperationState &result) {
1682   SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
1683   Type resultType;
1684   if (parser.parseLParen() || parser.parseOperandList(elementsOperands) ||
1685       parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
1686       parser.parseColon() || parser.parseType(resultType))
1687     return failure();
1688 
1689   if (parser.resolveOperands(elementsOperands,
1690                              resultType.cast<ShapedType>().getElementType(),
1691                              result.operands))
1692     return failure();
1693 
1694   result.addTypes(resultType);
1695   return success();
1696 }
1697 
print(OpAsmPrinter & p,TensorFromElementsOp op)1698 static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
1699   p << "tensor_from_elements(" << op.elements() << ')';
1700   p.printOptionalAttrDict(op.getAttrs());
1701   p << " : " << op.result().getType();
1702 }
1703 
verify(TensorFromElementsOp op)1704 static LogicalResult verify(TensorFromElementsOp op) {
1705   auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
1706   if (!resultTensorType)
1707     return op.emitOpError("expected result type to be a ranked tensor");
1708 
1709   int64_t elementsCount = static_cast<int64_t>(op.elements().size());
1710   if (resultTensorType.getRank() != 1 ||
1711       resultTensorType.getShape().front() != elementsCount)
1712     return op.emitOpError()
1713            << "expected result type to be a 1D tensor with " << elementsCount
1714            << (elementsCount == 1 ? " element" : " elements");
1715   return success();
1716 }
1717 
1718 namespace {
1719 
1720 // Canonicalizes the pattern of the form
1721 //
1722 // %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
1723 // %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
1724 //
1725 // to just %element.
1726 struct ExtractElementFromTensorFromElements
1727     : public OpRewritePattern<ExtractElementOp> {
1728   using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1729 
matchAndRewrite__anone4b94cca1111::ExtractElementFromTensorFromElements1730   LogicalResult matchAndRewrite(ExtractElementOp extract,
1731                                 PatternRewriter &rewriter) const final {
1732     if (extract.indices().size() != 1)
1733       return failure();
1734 
1735     auto tensor_from_elements = dyn_cast_or_null<TensorFromElementsOp>(
1736         extract.aggregate().getDefiningOp());
1737     if (tensor_from_elements == nullptr)
1738       return failure();
1739 
1740     APInt index;
1741     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
1742       return failure();
1743     rewriter.replaceOp(extract,
1744                        tensor_from_elements.getOperand(index.getZExtValue()));
1745     return success();
1746   }
1747 };
1748 
1749 } // namespace
1750 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1751 void TensorFromElementsOp::getCanonicalizationPatterns(
1752     OwningRewritePatternList &results, MLIRContext *context) {
1753   results.insert<ExtractElementFromTensorFromElements>(context);
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // FPExtOp
1758 //===----------------------------------------------------------------------===//
1759 
areCastCompatible(Type a,Type b)1760 bool FPExtOp::areCastCompatible(Type a, Type b) {
1761   if (auto fa = a.dyn_cast<FloatType>())
1762     if (auto fb = b.dyn_cast<FloatType>())
1763       return fa.getWidth() < fb.getWidth();
1764   if (auto va = a.dyn_cast<VectorType>())
1765     if (auto vb = b.dyn_cast<VectorType>())
1766       return va.getShape().equals(vb.getShape()) &&
1767              areCastCompatible(va.getElementType(), vb.getElementType());
1768   return false;
1769 }
1770 
1771 //===----------------------------------------------------------------------===//
1772 // FPToSIOp
1773 //===----------------------------------------------------------------------===//
1774 
areCastCompatible(Type a,Type b)1775 bool FPToSIOp::areCastCompatible(Type a, Type b) {
1776   return a.isa<FloatType>() && b.isSignlessInteger();
1777 }
1778 
1779 //===----------------------------------------------------------------------===//
1780 // FPTruncOp
1781 //===----------------------------------------------------------------------===//
1782 
areCastCompatible(Type a,Type b)1783 bool FPTruncOp::areCastCompatible(Type a, Type b) {
1784   if (auto fa = a.dyn_cast<FloatType>())
1785     if (auto fb = b.dyn_cast<FloatType>())
1786       return fa.getWidth() > fb.getWidth();
1787   if (auto va = a.dyn_cast<VectorType>())
1788     if (auto vb = b.dyn_cast<VectorType>())
1789       return va.getShape().equals(vb.getShape()) &&
1790              areCastCompatible(va.getElementType(), vb.getElementType());
1791   return false;
1792 }
1793 
1794 //===----------------------------------------------------------------------===//
1795 // IndexCastOp
1796 //===----------------------------------------------------------------------===//
1797 
1798 // Index cast is applicable from index to integer and backwards.
areCastCompatible(Type a,Type b)1799 bool IndexCastOp::areCastCompatible(Type a, Type b) {
1800   if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
1801     auto aShaped = a.cast<ShapedType>();
1802     auto bShaped = b.cast<ShapedType>();
1803 
1804     return (aShaped.getShape() == bShaped.getShape()) &&
1805            areCastCompatible(aShaped.getElementType(),
1806                              bShaped.getElementType());
1807   }
1808 
1809   return (a.isIndex() && b.isSignlessInteger()) ||
1810          (a.isSignlessInteger() && b.isIndex());
1811 }
1812 
fold(ArrayRef<Attribute> cstOperands)1813 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1814   // Fold IndexCast(IndexCast(x)) -> x
1815   auto cast = getOperand().getDefiningOp<IndexCastOp>();
1816   if (cast && cast.getOperand().getType() == getType())
1817     return cast.getOperand();
1818 
1819   // Fold IndexCast(constant) -> constant
1820   // A little hack because we go through int.  Otherwise, the size
1821   // of the constant might need to change.
1822   if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1823     return IntegerAttr::get(getType(), value.getInt());
1824 
1825   return {};
1826 }
1827 
1828 //===----------------------------------------------------------------------===//
1829 // LoadOp
1830 //===----------------------------------------------------------------------===//
1831 
verify(LoadOp op)1832 static LogicalResult verify(LoadOp op) {
1833   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1834     return op.emitOpError("incorrect number of indices for load");
1835   return success();
1836 }
1837 
fold(ArrayRef<Attribute> cstOperands)1838 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1839   /// load(memrefcast) -> load
1840   if (succeeded(foldMemRefCast(*this)))
1841     return getResult();
1842   return OpFoldResult();
1843 }
1844 
1845 //===----------------------------------------------------------------------===//
1846 // MemRefCastOp
1847 //===----------------------------------------------------------------------===//
1848 
getViewSource()1849 Value MemRefCastOp::getViewSource() { return source(); }
1850 
areCastCompatible(Type a,Type b)1851 bool MemRefCastOp::areCastCompatible(Type a, Type b) {
1852   auto aT = a.dyn_cast<MemRefType>();
1853   auto bT = b.dyn_cast<MemRefType>();
1854 
1855   auto uaT = a.dyn_cast<UnrankedMemRefType>();
1856   auto ubT = b.dyn_cast<UnrankedMemRefType>();
1857 
1858   if (aT && bT) {
1859     if (aT.getElementType() != bT.getElementType())
1860       return false;
1861     if (aT.getAffineMaps() != bT.getAffineMaps()) {
1862       int64_t aOffset, bOffset;
1863       SmallVector<int64_t, 4> aStrides, bStrides;
1864       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
1865           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
1866           aStrides.size() != bStrides.size())
1867         return false;
1868 
1869       // Strides along a dimension/offset are compatible if the value in the
1870       // source memref is static and the value in the target memref is the
1871       // same. They are also compatible if either one is dynamic (see
1872       // description of MemRefCastOp for details).
1873       auto checkCompatible = [](int64_t a, int64_t b) {
1874         return (a == MemRefType::getDynamicStrideOrOffset() ||
1875                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
1876       };
1877       if (!checkCompatible(aOffset, bOffset))
1878         return false;
1879       for (auto aStride : enumerate(aStrides))
1880         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
1881           return false;
1882     }
1883     if (aT.getMemorySpace() != bT.getMemorySpace())
1884       return false;
1885 
1886     // They must have the same rank, and any specified dimensions must match.
1887     if (aT.getRank() != bT.getRank())
1888       return false;
1889 
1890     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
1891       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
1892       if (aDim != -1 && bDim != -1 && aDim != bDim)
1893         return false;
1894     }
1895     return true;
1896   } else {
1897     if (!aT && !uaT)
1898       return false;
1899     if (!bT && !ubT)
1900       return false;
1901     // Unranked to unranked casting is unsupported
1902     if (uaT && ubT)
1903       return false;
1904 
1905     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
1906     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
1907     if (aEltType != bEltType)
1908       return false;
1909 
1910     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
1911     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
1912     if (aMemSpace != bMemSpace)
1913       return false;
1914 
1915     return true;
1916   }
1917 
1918   return false;
1919 }
1920 
fold(ArrayRef<Attribute> operands)1921 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
1922   return impl::foldCastOp(*this);
1923 }
1924 
1925 //===----------------------------------------------------------------------===//
1926 // MulFOp
1927 //===----------------------------------------------------------------------===//
1928 
fold(ArrayRef<Attribute> operands)1929 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1930   return constFoldBinaryOp<FloatAttr>(
1931       operands, [](APFloat a, APFloat b) { return a * b; });
1932 }
1933 
1934 //===----------------------------------------------------------------------===//
1935 // MulIOp
1936 //===----------------------------------------------------------------------===//
1937 
fold(ArrayRef<Attribute> operands)1938 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1939   /// muli(x, 0) -> 0
1940   if (matchPattern(rhs(), m_Zero()))
1941     return rhs();
1942   /// muli(x, 1) -> x
1943   if (matchPattern(rhs(), m_One()))
1944     return getOperand(0);
1945 
1946   // TODO: Handle the overflow case.
1947   return constFoldBinaryOp<IntegerAttr>(operands,
1948                                         [](APInt a, APInt b) { return a * b; });
1949 }
1950 
1951 //===----------------------------------------------------------------------===//
1952 // OrOp
1953 //===----------------------------------------------------------------------===//
1954 
fold(ArrayRef<Attribute> operands)1955 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1956   /// or(x, 0) -> x
1957   if (matchPattern(rhs(), m_Zero()))
1958     return lhs();
1959   /// or(x,x) -> x
1960   if (lhs() == rhs())
1961     return rhs();
1962 
1963   return constFoldBinaryOp<IntegerAttr>(operands,
1964                                         [](APInt a, APInt b) { return a | b; });
1965 }
1966 
1967 //===----------------------------------------------------------------------===//
1968 // PrefetchOp
1969 //===----------------------------------------------------------------------===//
1970 
print(OpAsmPrinter & p,PrefetchOp op)1971 static void print(OpAsmPrinter &p, PrefetchOp op) {
1972   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1973   p.printOperands(op.indices());
1974   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1975   p << ", locality<" << op.localityHint();
1976   p << ">, " << (op.isDataCache() ? "data" : "instr");
1977   p.printOptionalAttrDict(
1978       op.getAttrs(),
1979       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1980   p << " : " << op.getMemRefType();
1981 }
1982 
parsePrefetchOp(OpAsmParser & parser,OperationState & result)1983 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1984                                    OperationState &result) {
1985   OpAsmParser::OperandType memrefInfo;
1986   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1987   IntegerAttr localityHint;
1988   MemRefType type;
1989   StringRef readOrWrite, cacheType;
1990 
1991   auto indexTy = parser.getBuilder().getIndexType();
1992   auto i32Type = parser.getBuilder().getIntegerType(32);
1993   if (parser.parseOperand(memrefInfo) ||
1994       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1995       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1996       parser.parseComma() || parser.parseKeyword("locality") ||
1997       parser.parseLess() ||
1998       parser.parseAttribute(localityHint, i32Type, "localityHint",
1999                             result.attributes) ||
2000       parser.parseGreater() || parser.parseComma() ||
2001       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
2002       parser.resolveOperand(memrefInfo, type, result.operands) ||
2003       parser.resolveOperands(indexInfo, indexTy, result.operands))
2004     return failure();
2005 
2006   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2007     return parser.emitError(parser.getNameLoc(),
2008                             "rw specifier has to be 'read' or 'write'");
2009   result.addAttribute(
2010       PrefetchOp::getIsWriteAttrName(),
2011       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2012 
2013   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2014     return parser.emitError(parser.getNameLoc(),
2015                             "cache type has to be 'data' or 'instr'");
2016 
2017   result.addAttribute(
2018       PrefetchOp::getIsDataCacheAttrName(),
2019       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2020 
2021   return success();
2022 }
2023 
verify(PrefetchOp op)2024 static LogicalResult verify(PrefetchOp op) {
2025   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
2026     return op.emitOpError("too few indices");
2027 
2028   return success();
2029 }
2030 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2031 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2032                                SmallVectorImpl<OpFoldResult> &results) {
2033   // prefetch(memrefcast) -> prefetch
2034   return foldMemRefCast(*this);
2035 }
2036 
2037 //===----------------------------------------------------------------------===//
2038 // RankOp
2039 //===----------------------------------------------------------------------===//
2040 
fold(ArrayRef<Attribute> operands)2041 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2042   // Constant fold rank when the rank of the tensor is known.
2043   auto type = getOperand().getType();
2044   if (auto tensorType = type.dyn_cast<RankedTensorType>())
2045     return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
2046   return IntegerAttr();
2047 }
2048 
2049 //===----------------------------------------------------------------------===//
2050 // ReturnOp
2051 //===----------------------------------------------------------------------===//
2052 
verify(ReturnOp op)2053 static LogicalResult verify(ReturnOp op) {
2054   auto function = cast<FuncOp>(op.getParentOp());
2055 
2056   // The operand number and types must match the function signature.
2057   const auto &results = function.getType().getResults();
2058   if (op.getNumOperands() != results.size())
2059     return op.emitOpError("has ")
2060            << op.getNumOperands() << " operands, but enclosing function (@"
2061            << function.getName() << ") returns " << results.size();
2062 
2063   for (unsigned i = 0, e = results.size(); i != e; ++i)
2064     if (op.getOperand(i).getType() != results[i])
2065       return op.emitError()
2066              << "type of return operand " << i << " ("
2067              << op.getOperand(i).getType()
2068              << ") doesn't match function result type (" << results[i] << ")"
2069              << " in function @" << function.getName();
2070 
2071   return success();
2072 }
2073 
2074 //===----------------------------------------------------------------------===//
2075 // SelectOp
2076 //===----------------------------------------------------------------------===//
2077 
fold(ArrayRef<Attribute> operands)2078 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2079   auto condition = getCondition();
2080 
2081   // select true, %0, %1 => %0
2082   if (matchPattern(condition, m_One()))
2083     return getTrueValue();
2084 
2085   // select false, %0, %1 => %1
2086   if (matchPattern(condition, m_Zero()))
2087     return getFalseValue();
2088   return nullptr;
2089 }
2090 
print(OpAsmPrinter & p,SelectOp op)2091 static void print(OpAsmPrinter &p, SelectOp op) {
2092   p << "select " << op.getOperands();
2093   p.printOptionalAttrDict(op.getAttrs());
2094   p << " : ";
2095   if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
2096     p << condType << ", ";
2097   p << op.getType();
2098 }
2099 
parseSelectOp(OpAsmParser & parser,OperationState & result)2100 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
2101   Type conditionType, resultType;
2102   SmallVector<OpAsmParser::OperandType, 3> operands;
2103   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2104       parser.parseOptionalAttrDict(result.attributes) ||
2105       parser.parseColonType(resultType))
2106     return failure();
2107 
2108   // Check for the explicit condition type if this is a masked tensor or vector.
2109   if (succeeded(parser.parseOptionalComma())) {
2110     conditionType = resultType;
2111     if (parser.parseType(resultType))
2112       return failure();
2113   } else {
2114     conditionType = parser.getBuilder().getI1Type();
2115   }
2116 
2117   result.addTypes(resultType);
2118   return parser.resolveOperands(operands,
2119                                 {conditionType, resultType, resultType},
2120                                 parser.getNameLoc(), result.operands);
2121 }
2122 
verify(SelectOp op)2123 static LogicalResult verify(SelectOp op) {
2124   Type conditionType = op.getCondition().getType();
2125   if (conditionType.isSignlessInteger(1))
2126     return success();
2127 
2128   // If the result type is a vector or tensor, the type can be a mask with the
2129   // same elements.
2130   Type resultType = op.getType();
2131   if (!resultType.isa<TensorType, VectorType>())
2132     return op.emitOpError()
2133            << "expected condition to be a signless i1, but got "
2134            << conditionType;
2135   Type shapedConditionType = getI1SameShape(resultType);
2136   if (conditionType != shapedConditionType)
2137     return op.emitOpError()
2138            << "expected condition type to have the same shape "
2139               "as the result type, expected "
2140            << shapedConditionType << ", but got " << conditionType;
2141   return success();
2142 }
2143 
2144 //===----------------------------------------------------------------------===//
2145 // SignExtendIOp
2146 //===----------------------------------------------------------------------===//
2147 
verify(SignExtendIOp op)2148 static LogicalResult verify(SignExtendIOp op) {
2149   // Get the scalar type (which is either directly the type of the operand
2150   // or the vector's/tensor's element type.
2151   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2152   auto dstType = getElementTypeOrSelf(op.getType());
2153 
2154   // For now, index is forbidden for the source and the destination type.
2155   if (srcType.isa<IndexType>())
2156     return op.emitError() << srcType << " is not a valid operand type";
2157   if (dstType.isa<IndexType>())
2158     return op.emitError() << dstType << " is not a valid result type";
2159 
2160   if (srcType.cast<IntegerType>().getWidth() >=
2161       dstType.cast<IntegerType>().getWidth())
2162     return op.emitError("result type ")
2163            << dstType << " must be wider than operand type " << srcType;
2164 
2165   return success();
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // SignedDivIOp
2170 //===----------------------------------------------------------------------===//
2171 
fold(ArrayRef<Attribute> operands)2172 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
2173   assert(operands.size() == 2 && "binary operation takes two operands");
2174 
2175   // Don't fold if it would overflow or if it requires a division by zero.
2176   bool overflowOrDiv0 = false;
2177   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2178     if (overflowOrDiv0 || !b) {
2179       overflowOrDiv0 = true;
2180       return a;
2181     }
2182     return a.sdiv_ov(b, overflowOrDiv0);
2183   });
2184 
2185   // Fold out division by one. Assumes all tensors of all ones are splats.
2186   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2187     if (rhs.getValue() == 1)
2188       return lhs();
2189   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2190     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2191       return lhs();
2192   }
2193 
2194   return overflowOrDiv0 ? Attribute() : result;
2195 }
2196 
2197 //===----------------------------------------------------------------------===//
2198 // SignedRemIOp
2199 //===----------------------------------------------------------------------===//
2200 
fold(ArrayRef<Attribute> operands)2201 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
2202   assert(operands.size() == 2 && "remi_signed takes two operands");
2203 
2204   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2205   if (!rhs)
2206     return {};
2207   auto rhsValue = rhs.getValue();
2208 
2209   // x % 1 = 0
2210   if (rhsValue.isOneValue())
2211     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2212 
2213   // Don't fold if it requires division by zero.
2214   if (rhsValue.isNullValue())
2215     return {};
2216 
2217   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2218   if (!lhs)
2219     return {};
2220   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
2221 }
2222 
2223 //===----------------------------------------------------------------------===//
2224 // SIToFPOp
2225 //===----------------------------------------------------------------------===//
2226 
2227 // sitofp is applicable from integer types to float types.
areCastCompatible(Type a,Type b)2228 bool SIToFPOp::areCastCompatible(Type a, Type b) {
2229   return a.isSignlessInteger() && b.isa<FloatType>();
2230 }
2231 
2232 //===----------------------------------------------------------------------===//
2233 // SplatOp
2234 //===----------------------------------------------------------------------===//
2235 
verify(SplatOp op)2236 static LogicalResult verify(SplatOp op) {
2237   // TODO: we could replace this by a trait.
2238   if (op.getOperand().getType() !=
2239       op.getType().cast<ShapedType>().getElementType())
2240     return op.emitError("operand should be of elemental type of result type");
2241 
2242   return success();
2243 }
2244 
2245 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)2246 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2247   assert(operands.size() == 1 && "splat takes one operand");
2248 
2249   auto constOperand = operands.front();
2250   if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
2251     return {};
2252 
2253   auto shapedType = getType().cast<ShapedType>();
2254   assert(shapedType.getElementType() == constOperand.getType() &&
2255          "incorrect input attribute type for folding");
2256 
2257   // SplatElementsAttr::get treats single value for second arg as being a splat.
2258   return SplatElementsAttr::get(shapedType, {constOperand});
2259 }
2260 
2261 //===----------------------------------------------------------------------===//
2262 // StoreOp
2263 //===----------------------------------------------------------------------===//
2264 
verify(StoreOp op)2265 static LogicalResult verify(StoreOp op) {
2266   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
2267     return op.emitOpError("store index operand count not equal to memref rank");
2268 
2269   return success();
2270 }
2271 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2272 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2273                             SmallVectorImpl<OpFoldResult> &results) {
2274   /// store(memrefcast) -> store
2275   return foldMemRefCast(*this);
2276 }
2277 
2278 //===----------------------------------------------------------------------===//
2279 // SubFOp
2280 //===----------------------------------------------------------------------===//
2281 
fold(ArrayRef<Attribute> operands)2282 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
2283   return constFoldBinaryOp<FloatAttr>(
2284       operands, [](APFloat a, APFloat b) { return a - b; });
2285 }
2286 
2287 //===----------------------------------------------------------------------===//
2288 // SubIOp
2289 //===----------------------------------------------------------------------===//
2290 
fold(ArrayRef<Attribute> operands)2291 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
2292   // subi(x,x) -> 0
2293   if (getOperand(0) == getOperand(1))
2294     return Builder(getContext()).getZeroAttr(getType());
2295   // subi(x,0) -> x
2296   if (matchPattern(rhs(), m_Zero()))
2297     return lhs();
2298 
2299   return constFoldBinaryOp<IntegerAttr>(operands,
2300                                         [](APInt a, APInt b) { return a - b; });
2301 }
2302 
2303 //===----------------------------------------------------------------------===//
2304 // SubViewOp
2305 //===----------------------------------------------------------------------===//
2306 
2307 /// Print a list with either (1) the static integer value in `arrayAttr` if
2308 /// `isDynamic` evaluates to false or (2) the next value otherwise.
2309 /// This allows idiomatic printing of mixed value and integer attributes in a
2310 /// list. E.g. `[%arg0, 7, 42, %arg42]`.
printSubViewListOfOperandsOrIntegers(OpAsmPrinter & p,ValueRange values,ArrayAttr arrayAttr,llvm::function_ref<bool (int64_t)> isDynamic)2311 static void printSubViewListOfOperandsOrIntegers(
2312     OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
2313     llvm::function_ref<bool(int64_t)> isDynamic) {
2314   p << "[";
2315   unsigned idx = 0;
2316   llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
2317     int64_t val = a.cast<IntegerAttr>().getInt();
2318     if (isDynamic(val))
2319       p << values[idx++];
2320     else
2321       p << val;
2322   });
2323   p << "] ";
2324 }
2325 
2326 /// Parse a mixed list with either (1) static integer values or (2) SSA values.
2327 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
2328 /// encode the position of SSA values. Add the parsed SSA values to `ssa`
2329 /// in-order.
2330 //
2331 /// E.g. after parsing "[%arg0, 7, 42, %arg42]":
2332 ///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
2333 ///   2. `ssa` is filled with "[%arg0, %arg1]".
2334 static ParseResult
parseListOfOperandsOrIntegers(OpAsmParser & parser,OperationState & result,StringRef attrName,int64_t dynVal,SmallVectorImpl<OpAsmParser::OperandType> & ssa)2335 parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
2336                               StringRef attrName, int64_t dynVal,
2337                               SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
2338   if (failed(parser.parseLSquare()))
2339     return failure();
2340   // 0-D.
2341   if (succeeded(parser.parseOptionalRSquare()))
2342     return success();
2343 
2344   SmallVector<int64_t, 4> attrVals;
2345   while (true) {
2346     OpAsmParser::OperandType operand;
2347     auto res = parser.parseOptionalOperand(operand);
2348     if (res.hasValue() && succeeded(res.getValue())) {
2349       ssa.push_back(operand);
2350       attrVals.push_back(dynVal);
2351     } else {
2352       Attribute attr;
2353       NamedAttrList placeholder;
2354       if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
2355           !attr.isa<IntegerAttr>())
2356         return parser.emitError(parser.getNameLoc())
2357                << "expected SSA value or integer";
2358       attrVals.push_back(attr.cast<IntegerAttr>().getInt());
2359     }
2360 
2361     if (succeeded(parser.parseOptionalComma()))
2362       continue;
2363     if (failed(parser.parseRSquare()))
2364       return failure();
2365     else
2366       break;
2367   }
2368 
2369   auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
2370   result.addAttribute(attrName, arrayAttr);
2371   return success();
2372 }
2373 
2374 namespace {
2375 /// Helpers to write more idiomatic operations.
2376 namespace saturated_arith {
2377 struct Wrapper {
Wrapper__anone4b94cca1a11::saturated_arith::Wrapper2378   explicit Wrapper(int64_t v) : v(v) {}
operator int64_t__anone4b94cca1a11::saturated_arith::Wrapper2379   operator int64_t() { return v; }
2380   int64_t v;
2381 };
operator +(Wrapper a,int64_t b)2382 Wrapper operator+(Wrapper a, int64_t b) {
2383   if (ShapedType::isDynamicStrideOrOffset(a) ||
2384       ShapedType::isDynamicStrideOrOffset(b))
2385     return Wrapper(ShapedType::kDynamicStrideOrOffset);
2386   return Wrapper(a.v + b);
2387 }
operator *(Wrapper a,int64_t b)2388 Wrapper operator*(Wrapper a, int64_t b) {
2389   if (ShapedType::isDynamicStrideOrOffset(a) ||
2390       ShapedType::isDynamicStrideOrOffset(b))
2391     return Wrapper(ShapedType::kDynamicStrideOrOffset);
2392   return Wrapper(a.v * b);
2393 }
2394 } // end namespace saturated_arith
2395 } // end namespace
2396 
2397 /// A subview result type can be fully inferred from the source type and the
2398 /// static representation of offsets, sizes and strides. Special sentinels
2399 /// encode the dynamic case.
inferSubViewResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)2400 Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
2401                                        ArrayRef<int64_t> staticOffsets,
2402                                        ArrayRef<int64_t> staticSizes,
2403                                        ArrayRef<int64_t> staticStrides) {
2404   unsigned rank = sourceMemRefType.getRank();
2405   (void)rank;
2406   assert(staticOffsets.size() == rank &&
2407          "unexpected staticOffsets size mismatch");
2408   assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
2409   assert(staticStrides.size() == rank &&
2410          "unexpected staticStrides size mismatch");
2411 
2412   // Extract source offset and strides.
2413   int64_t sourceOffset;
2414   SmallVector<int64_t, 4> sourceStrides;
2415   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2416   assert(succeeded(res) && "SubViewOp expected strided memref type");
2417   (void)res;
2418 
2419   // Compute target offset whose value is:
2420   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2421   int64_t targetOffset = sourceOffset;
2422   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2423     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2424     using namespace saturated_arith;
2425     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
2426   }
2427 
2428   // Compute target stride whose value is:
2429   //   `sourceStrides_i * staticStrides_i`.
2430   SmallVector<int64_t, 4> targetStrides;
2431   targetStrides.reserve(staticOffsets.size());
2432   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2433     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2434     using namespace saturated_arith;
2435     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
2436   }
2437 
2438   // The type is now known.
2439   return MemRefType::get(
2440       staticSizes, sourceMemRefType.getElementType(),
2441       makeStridedLinearLayoutMap(targetStrides, targetOffset,
2442                                  sourceMemRefType.getContext()),
2443       sourceMemRefType.getMemorySpace());
2444 }
2445 
2446 /// Print SubViewOp in the form:
2447 /// ```
2448 ///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2449 ///     `:` strided-memref-type `to` strided-memref-type
2450 /// ```
print(OpAsmPrinter & p,SubViewOp op)2451 static void print(OpAsmPrinter &p, SubViewOp op) {
2452   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
2453   p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
2454   p << op.getOperand(0);
2455   printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
2456                                        ShapedType::isDynamicStrideOrOffset);
2457   printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
2458                                        ShapedType::isDynamic);
2459   printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
2460                                        ShapedType::isDynamicStrideOrOffset);
2461   p.printOptionalAttrDict(op.getAttrs(),
2462                           /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()});
2463   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2464 }
2465 
2466 /// Parse SubViewOp of the form:
2467 /// ```
2468 ///   subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2469 ///     `:` strided-memref-type `to` strided-memref-type
2470 /// ```
parseSubViewOp(OpAsmParser & parser,OperationState & result)2471 static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
2472   OpAsmParser::OperandType srcInfo;
2473   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
2474   auto indexType = parser.getBuilder().getIndexType();
2475   Type srcType, dstType;
2476   if (parser.parseOperand(srcInfo))
2477     return failure();
2478   if (parseListOfOperandsOrIntegers(
2479           parser, result, SubViewOp::getStaticOffsetsAttrName(),
2480           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
2481       parseListOfOperandsOrIntegers(parser, result,
2482                                     SubViewOp::getStaticSizesAttrName(),
2483                                     ShapedType::kDynamicSize, sizesInfo) ||
2484       parseListOfOperandsOrIntegers(
2485           parser, result, SubViewOp::getStaticStridesAttrName(),
2486           ShapedType::kDynamicStrideOrOffset, stridesInfo))
2487     return failure();
2488 
2489   auto b = parser.getBuilder();
2490   SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
2491                                    static_cast<int>(sizesInfo.size()),
2492                                    static_cast<int>(stridesInfo.size())};
2493   result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
2494                       b.getI32VectorAttr(segmentSizes));
2495 
2496   return failure(
2497       parser.parseOptionalAttrDict(result.attributes) ||
2498       parser.parseColonType(srcType) ||
2499       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2500       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
2501       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2502       parser.resolveOperands(stridesInfo, indexType, result.operands) ||
2503       parser.parseKeywordType("to", dstType) ||
2504       parser.addTypeToList(dstType, result.types));
2505 }
2506 
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2507 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2508                             ArrayRef<int64_t> staticOffsets,
2509                             ArrayRef<int64_t> staticSizes,
2510                             ArrayRef<int64_t> staticStrides, ValueRange offsets,
2511                             ValueRange sizes, ValueRange strides,
2512                             ArrayRef<NamedAttribute> attrs) {
2513   auto sourceMemRefType = source.getType().cast<MemRefType>();
2514   auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
2515                                            staticSizes, staticStrides);
2516   build(b, result, resultType, source, offsets, sizes, strides,
2517         b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
2518         b.getI64ArrayAttr(staticStrides));
2519   result.addAttributes(attrs);
2520 }
2521 
2522 /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
2523 /// and `staticStrides` are  automatically filled with source-memref-rank
2524 /// sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2525 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2526                             ValueRange offsets, ValueRange sizes,
2527                             ValueRange strides,
2528                             ArrayRef<NamedAttribute> attrs) {
2529   auto sourceMemRefType = source.getType().cast<MemRefType>();
2530   unsigned rank = sourceMemRefType.getRank();
2531   SmallVector<int64_t, 4> staticOffsetsVector;
2532   staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
2533   SmallVector<int64_t, 4> staticSizesVector;
2534   staticSizesVector.assign(rank, ShapedType::kDynamicSize);
2535   SmallVector<int64_t, 4> staticStridesVector;
2536   staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
2537   build(b, result, source, staticOffsetsVector, staticSizesVector,
2538         staticStridesVector, offsets, sizes, strides, attrs);
2539 }
2540 
2541 /// Verify that a particular offset/size/stride static attribute is well-formed.
2542 static LogicalResult
verifySubViewOpPart(SubViewOp op,StringRef name,StringRef attrName,ArrayAttr attr,llvm::function_ref<bool (int64_t)> isDynamic,ValueRange values)2543 verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
2544                     ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
2545                     ValueRange values) {
2546   /// Check static and dynamic offsets/sizes/strides breakdown.
2547   if (attr.size() != op.getRank())
2548     return op.emitError("expected ")
2549            << op.getRank() << " " << name << " values";
2550   unsigned expectedNumDynamicEntries =
2551       llvm::count_if(attr.getValue(), [&](Attribute attr) {
2552         return isDynamic(attr.cast<IntegerAttr>().getInt());
2553       });
2554   if (values.size() != expectedNumDynamicEntries)
2555     return op.emitError("expected ")
2556            << expectedNumDynamicEntries << " dynamic " << name << " values";
2557   return success();
2558 }
2559 
2560 /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)2561 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
2562   return llvm::to_vector<4>(
2563       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
2564         return a.cast<IntegerAttr>().getInt();
2565       }));
2566 }
2567 
2568 /// Verifier for SubViewOp.
verify(SubViewOp op)2569 static LogicalResult verify(SubViewOp op) {
2570   auto baseType = op.getBaseMemRefType().cast<MemRefType>();
2571   auto subViewType = op.getType();
2572 
2573   // The base memref and the view memref should be in the same memory space.
2574   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2575     return op.emitError("different memory spaces specified for base memref "
2576                         "type ")
2577            << baseType << " and subview memref type " << subViewType;
2578 
2579   // Verify that the base memref type has a strided layout map.
2580   if (!isStrided(baseType))
2581     return op.emitError("base type ") << baseType << " is not strided";
2582 
2583   // Verify static attributes offsets/sizes/strides.
2584   if (failed(verifySubViewOpPart(
2585           op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
2586           ShapedType::isDynamicStrideOrOffset, op.offsets())))
2587     return failure();
2588 
2589   if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
2590                                  op.static_sizes(), ShapedType::isDynamic,
2591                                  op.sizes())))
2592     return failure();
2593   if (failed(verifySubViewOpPart(
2594           op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
2595           ShapedType::isDynamicStrideOrOffset, op.strides())))
2596     return failure();
2597 
2598   // Verify result type against inferred type.
2599   auto expectedType = SubViewOp::inferSubViewResultType(
2600       op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
2601       extractFromI64ArrayAttr(op.static_sizes()),
2602       extractFromI64ArrayAttr(op.static_strides()));
2603   if (op.getType() != expectedType)
2604     return op.emitError("expected result type to be ") << expectedType;
2605 
2606   return success();
2607 }
2608 
operator <<(raw_ostream & os,SubViewOp::Range & range)2609 raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
2610   return os << "range " << range.offset << ":" << range.size << ":"
2611             << range.stride;
2612 }
2613 
getNumDynamicEntriesUpToIdx(ArrayAttr attr,llvm::function_ref<bool (int64_t)> isDynamic,unsigned idx)2614 static unsigned getNumDynamicEntriesUpToIdx(
2615     ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
2616   return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
2617                        [&](Attribute attr) {
2618                          return isDynamic(attr.cast<IntegerAttr>().getInt());
2619                        });
2620 }
2621 
isDynamicOffset(unsigned idx)2622 bool SubViewOp::isDynamicOffset(unsigned idx) {
2623   return ShapedType::isDynamicStrideOrOffset(
2624       extractFromI64ArrayAttr(static_offsets())[idx]);
2625 }
isDynamicSize(unsigned idx)2626 bool SubViewOp::isDynamicSize(unsigned idx) {
2627   return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]);
2628 }
isDynamicStride(unsigned idx)2629 bool SubViewOp::isDynamicStride(unsigned idx) {
2630   return ShapedType::isDynamicStrideOrOffset(
2631       extractFromI64ArrayAttr(static_strides())[idx]);
2632 }
2633 
getIndexOfDynamicOffset(unsigned idx)2634 unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) {
2635   assert(isDynamicOffset(idx) && "expected static offset");
2636   auto numDynamic =
2637       getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
2638                                   ShapedType::isDynamicStrideOrOffset, idx);
2639   return 1 + numDynamic;
2640 }
getIndexOfDynamicSize(unsigned idx)2641 unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) {
2642   assert(isDynamicSize(idx) && "expected static size");
2643   auto numDynamic = getNumDynamicEntriesUpToIdx(
2644       static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
2645   return 1 + offsets().size() + numDynamic;
2646 }
getIndexOfDynamicStride(unsigned idx)2647 unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
2648   assert(isDynamicStride(idx) && "expected static stride");
2649   auto numDynamic =
2650       getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
2651                                   ShapedType::isDynamicStrideOrOffset, idx);
2652   return 1 + offsets().size() + sizes().size() + numDynamic;
2653 }
2654 
2655 /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
2656 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2657 /// with `b` at location `loc`.
getOrCreateRanges(OpBuilder & b,Location loc)2658 SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
2659                                                               Location loc) {
2660   SmallVector<Range, 8> res;
2661   unsigned rank = getType().getRank();
2662   res.reserve(rank);
2663   for (unsigned idx = 0; idx < rank; ++idx) {
2664     auto offset = isDynamicOffset(idx)
2665                       ? getDynamicOffset(idx)
2666                       : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
2667     auto size = isDynamicSize(idx)
2668                     ? getDynamicSize(idx)
2669                     : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
2670     auto stride = isDynamicStride(idx)
2671                       ? getDynamicStride(idx)
2672                       : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
2673     res.emplace_back(Range{offset, size, stride});
2674   }
2675   return res;
2676 }
2677 
getOrCreateOffsets(OpBuilder & b,Location loc)2678 SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
2679                                                     Location loc) {
2680   unsigned dynamicIdx = 1;
2681   return llvm::to_vector<4>(llvm::map_range(
2682       static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2683         int64_t staticOffset = a.cast<IntegerAttr>().getInt();
2684         if (ShapedType::isDynamicStrideOrOffset(staticOffset))
2685           return getOperand(dynamicIdx++);
2686         else
2687           return b.create<ConstantIndexOp>(loc, staticOffset);
2688       }));
2689 }
2690 
getOrCreateSizes(OpBuilder & b,Location loc)2691 SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
2692   unsigned dynamicIdx = 1 + offsets().size();
2693   return llvm::to_vector<4>(llvm::map_range(
2694       static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2695         int64_t staticSize = a.cast<IntegerAttr>().getInt();
2696         if (ShapedType::isDynamic(staticSize))
2697           return getOperand(dynamicIdx++);
2698         else
2699           return b.create<ConstantIndexOp>(loc, staticSize);
2700       }));
2701 }
2702 
getOrCreateStrides(OpBuilder & b,Location loc)2703 SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
2704                                                     Location loc) {
2705   unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
2706   return llvm::to_vector<4>(llvm::map_range(
2707       static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2708         int64_t staticStride = a.cast<IntegerAttr>().getInt();
2709         if (ShapedType::isDynamicStrideOrOffset(staticStride))
2710           return getOperand(dynamicIdx++);
2711         else
2712           return b.create<ConstantIndexOp>(loc, staticStride);
2713       }));
2714 }
2715 
2716 LogicalResult
getStaticStrides(SmallVectorImpl<int64_t> & staticStrides)2717 SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
2718   if (!strides().empty())
2719     return failure();
2720   staticStrides = extractFromI64ArrayAttr(static_strides());
2721   return success();
2722 }
2723 
getViewSource()2724 Value SubViewOp::getViewSource() { return source(); }
2725 
2726 namespace {
2727 
2728 /// Take a list of `values` with potential new constant to extract and a list
2729 /// of `constantValues` with`values.size()` sentinel that evaluate to true by
2730 /// applying `isDynamic`.
2731 /// Detects the `values` produced by a ConstantIndexOp and places the new
2732 /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<Value> & values,SmallVectorImpl<int64_t> & constantValues,llvm::function_ref<bool (int64_t)> isDynamic)2733 void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
2734                              SmallVectorImpl<int64_t> &constantValues,
2735                              llvm::function_ref<bool(int64_t)> isDynamic) {
2736   bool hasNewStaticValue = llvm::any_of(
2737       values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
2738   if (hasNewStaticValue) {
2739     for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
2740          cstIdx != e; ++cstIdx) {
2741       // Was already static, skip.
2742       if (!isDynamic(constantValues[cstIdx]))
2743         continue;
2744       // Newly static, move from Value to constant.
2745       if (matchPattern(values[valIdx], m_ConstantIndex())) {
2746         constantValues[cstIdx] =
2747             cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
2748         // Erase for impl. simplicity. Reverse iterator if we really must.
2749         values.erase(std::next(values.begin(), valIdx));
2750         continue;
2751       }
2752       // Remains dynamic move to next value.
2753       ++valIdx;
2754     }
2755   }
2756 }
2757 
2758 /// Pattern to rewrite a subview op with constant arguments.
2759 class SubViewOpConstantArgumentFolder final
2760     : public OpRewritePattern<SubViewOp> {
2761 public:
2762   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2763 
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2764   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2765                                 PatternRewriter &rewriter) const override {
2766     // No constant operand, just return;
2767     if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
2768           return matchPattern(operand, m_ConstantIndex());
2769         }))
2770       return failure();
2771 
2772     // At least one of offsets/sizes/strides is a new constant.
2773     // Form the new list of operands and constant attributes from the existing.
2774     SmallVector<Value, 8> newOffsets(subViewOp.offsets());
2775     SmallVector<int64_t, 8> newStaticOffsets =
2776         extractFromI64ArrayAttr(subViewOp.static_offsets());
2777     assert(newStaticOffsets.size() == subViewOp.getRank());
2778     canonicalizeSubViewPart(newOffsets, newStaticOffsets,
2779                             ShapedType::isDynamicStrideOrOffset);
2780 
2781     SmallVector<Value, 8> newSizes(subViewOp.sizes());
2782     SmallVector<int64_t, 8> newStaticSizes =
2783         extractFromI64ArrayAttr(subViewOp.static_sizes());
2784     assert(newStaticOffsets.size() == subViewOp.getRank());
2785     canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
2786 
2787     SmallVector<Value, 8> newStrides(subViewOp.strides());
2788     SmallVector<int64_t, 8> newStaticStrides =
2789         extractFromI64ArrayAttr(subViewOp.static_strides());
2790     assert(newStaticOffsets.size() == subViewOp.getRank());
2791     canonicalizeSubViewPart(newStrides, newStaticStrides,
2792                             ShapedType::isDynamicStrideOrOffset);
2793 
2794     // Create the new op in canonical form.
2795     auto newSubViewOp = rewriter.create<SubViewOp>(
2796         subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
2797         newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
2798 
2799     // Insert a memref_cast for compatibility of the uses of the op.
2800     rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
2801                                               subViewOp.getType());
2802 
2803     return success();
2804   }
2805 };
2806 
2807 } // end anonymous namespace
2808 
2809 /// Determines whether MemRefCastOp casts to a more dynamic version of the
2810 /// source memref. This is useful to to fold a memref_cast into a consuming op
2811 /// and implement canonicalization patterns for ops in different dialects that
2812 /// may consume the results of memref_cast operations. Such foldable memref_cast
2813 /// operations are typically inserted as `view` and `subview` ops are
2814 /// canonicalized, to preserve the type compatibility of their uses.
2815 ///
2816 /// Returns true when all conditions are met:
2817 /// 1. source and result are ranked memrefs with strided semantics and same
2818 /// element type and rank.
2819 /// 2. each of the source's size, offset or stride has more static information
2820 /// than the corresponding result's size, offset or stride.
2821 ///
2822 /// Example 1:
2823 /// ```mlir
2824 ///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
2825 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
2826 /// ```
2827 ///
2828 /// may fold into:
2829 ///
2830 /// ```mlir
2831 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
2832 /// ```
2833 ///
2834 /// Example 2:
2835 /// ```
2836 ///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
2837 ///          to memref<?x?xf32>
2838 ///   consumer %1 : memref<?x?xf32> ...
2839 /// ```
2840 ///
2841 /// may fold into:
2842 ///
2843 /// ```
2844 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
2845 /// ```
canFoldIntoConsumerOp(MemRefCastOp castOp)2846 bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
2847   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
2848   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
2849 
2850   // Requires ranked MemRefType.
2851   if (!sourceType || !resultType)
2852     return false;
2853 
2854   // Requires same elemental type.
2855   if (sourceType.getElementType() != resultType.getElementType())
2856     return false;
2857 
2858   // Requires same rank.
2859   if (sourceType.getRank() != resultType.getRank())
2860     return false;
2861 
2862   // Only fold casts between strided memref forms.
2863   int64_t sourceOffset, resultOffset;
2864   SmallVector<int64_t, 4> sourceStrides, resultStrides;
2865   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
2866       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
2867     return false;
2868 
2869   // If cast is towards more static sizes along any dimension, don't fold.
2870   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
2871     auto ss = std::get<0>(it), st = std::get<1>(it);
2872     if (ss != st)
2873       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
2874         return false;
2875   }
2876 
2877   // If cast is towards more static offset along any dimension, don't fold.
2878   if (sourceOffset != resultOffset)
2879     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
2880         !MemRefType::isDynamicStrideOrOffset(resultOffset))
2881       return false;
2882 
2883   // If cast is towards more static strides along any dimension, don't fold.
2884   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
2885     auto ss = std::get<0>(it), st = std::get<1>(it);
2886     if (ss != st)
2887       if (MemRefType::isDynamicStrideOrOffset(ss) &&
2888           !MemRefType::isDynamicStrideOrOffset(st))
2889         return false;
2890   }
2891 
2892   return true;
2893 }
2894 
2895 namespace {
2896 /// Pattern to rewrite a subview op with MemRefCast arguments.
2897 /// This essentially pushes memref_cast past its consuming subview when
2898 /// `canFoldIntoConsumerOp` is true.
2899 ///
2900 /// Example:
2901 /// ```
2902 ///   %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
2903 ///   %1 = subview %0[0, 0][3, 4][1, 1] :
2904 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2905 /// ```
2906 /// is rewritten into:
2907 /// ```
2908 ///   %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2909 ///   %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2910 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2911 /// ```
2912 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2913 public:
2914   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2915 
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2916   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2917                                 PatternRewriter &rewriter) const override {
2918     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2919     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2920           return matchPattern(operand, m_ConstantIndex());
2921         }))
2922       return failure();
2923 
2924     auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
2925     if (!castOp)
2926       return failure();
2927 
2928     if (!canFoldIntoConsumerOp(castOp))
2929       return failure();
2930 
2931     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2932     /// the cast source operand type and the SubViewOp static information. This
2933     /// is the resulting type if the MemRefCastOp were folded.
2934     Type resultType = SubViewOp::inferSubViewResultType(
2935         castOp.source().getType().cast<MemRefType>(),
2936         extractFromI64ArrayAttr(subViewOp.static_offsets()),
2937         extractFromI64ArrayAttr(subViewOp.static_sizes()),
2938         extractFromI64ArrayAttr(subViewOp.static_strides()));
2939     Value newSubView = rewriter.create<SubViewOp>(
2940         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2941         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2942         subViewOp.static_sizes(), subViewOp.static_strides());
2943     rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
2944                                               newSubView);
2945     return success();
2946   }
2947 };
2948 } // namespace
2949 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2950 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2951                                             MLIRContext *context) {
2952   results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
2953       context);
2954 }
2955 
2956 //===----------------------------------------------------------------------===//
2957 // TensorCastOp
2958 //===----------------------------------------------------------------------===//
2959 
areCastCompatible(Type a,Type b)2960 bool TensorCastOp::areCastCompatible(Type a, Type b) {
2961   auto aT = a.dyn_cast<TensorType>();
2962   auto bT = b.dyn_cast<TensorType>();
2963   if (!aT || !bT)
2964     return false;
2965 
2966   if (aT.getElementType() != bT.getElementType())
2967     return false;
2968 
2969   return succeeded(verifyCompatibleShape(aT, bT));
2970 }
2971 
fold(ArrayRef<Attribute> operands)2972 OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
2973   return impl::foldCastOp(*this);
2974 }
2975 
2976 //===----------------------------------------------------------------------===//
2977 // Helpers for Tensor[Load|Store]Op
2978 //===----------------------------------------------------------------------===//
2979 
getTensorTypeFromMemRefType(Type type)2980 static Type getTensorTypeFromMemRefType(Type type) {
2981   if (auto memref = type.dyn_cast<MemRefType>())
2982     return RankedTensorType::get(memref.getShape(), memref.getElementType());
2983   return NoneType::get(type.getContext());
2984 }
2985 
2986 //===----------------------------------------------------------------------===//
2987 // TruncateIOp
2988 //===----------------------------------------------------------------------===//
2989 
verify(TruncateIOp op)2990 static LogicalResult verify(TruncateIOp op) {
2991   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2992   auto dstType = getElementTypeOrSelf(op.getType());
2993 
2994   if (srcType.isa<IndexType>())
2995     return op.emitError() << srcType << " is not a valid operand type";
2996   if (dstType.isa<IndexType>())
2997     return op.emitError() << dstType << " is not a valid result type";
2998 
2999   if (srcType.cast<IntegerType>().getWidth() <=
3000       dstType.cast<IntegerType>().getWidth())
3001     return op.emitError("operand type ")
3002            << srcType << " must be wider than result type " << dstType;
3003 
3004   return success();
3005 }
3006 
3007 //===----------------------------------------------------------------------===//
3008 // UnsignedDivIOp
3009 //===----------------------------------------------------------------------===//
3010 
fold(ArrayRef<Attribute> operands)3011 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
3012   assert(operands.size() == 2 && "binary operation takes two operands");
3013 
3014   // Don't fold if it would require a division by zero.
3015   bool div0 = false;
3016   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
3017     if (div0 || !b) {
3018       div0 = true;
3019       return a;
3020     }
3021     return a.udiv(b);
3022   });
3023 
3024   // Fold out division by one. Assumes all tensors of all ones are splats.
3025   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
3026     if (rhs.getValue() == 1)
3027       return lhs();
3028   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
3029     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
3030       return lhs();
3031   }
3032 
3033   return div0 ? Attribute() : result;
3034 }
3035 
3036 //===----------------------------------------------------------------------===//
3037 // UnsignedRemIOp
3038 //===----------------------------------------------------------------------===//
3039 
fold(ArrayRef<Attribute> operands)3040 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
3041   assert(operands.size() == 2 && "remi_unsigned takes two operands");
3042 
3043   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
3044   if (!rhs)
3045     return {};
3046   auto rhsValue = rhs.getValue();
3047 
3048   // x % 1 = 0
3049   if (rhsValue.isOneValue())
3050     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
3051 
3052   // Don't fold if it requires division by zero.
3053   if (rhsValue.isNullValue())
3054     return {};
3055 
3056   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
3057   if (!lhs)
3058     return {};
3059   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
3060 }
3061 
3062 //===----------------------------------------------------------------------===//
3063 // ViewOp
3064 //===----------------------------------------------------------------------===//
3065 
parseViewOp(OpAsmParser & parser,OperationState & result)3066 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
3067   OpAsmParser::OperandType srcInfo;
3068   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
3069   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
3070   auto indexType = parser.getBuilder().getIndexType();
3071   Type srcType, dstType;
3072   llvm::SMLoc offsetLoc;
3073   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
3074       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
3075     return failure();
3076 
3077   if (offsetInfo.size() != 1)
3078     return parser.emitError(offsetLoc) << "expects 1 offset operand";
3079 
3080   return failure(
3081       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
3082       parser.parseOptionalAttrDict(result.attributes) ||
3083       parser.parseColonType(srcType) ||
3084       parser.resolveOperand(srcInfo, srcType, result.operands) ||
3085       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
3086       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
3087       parser.parseKeywordType("to", dstType) ||
3088       parser.addTypeToList(dstType, result.types));
3089 }
3090 
print(OpAsmPrinter & p,ViewOp op)3091 static void print(OpAsmPrinter &p, ViewOp op) {
3092   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
3093   p.printOperand(op.byte_shift());
3094   p << "][" << op.sizes() << ']';
3095   p.printOptionalAttrDict(op.getAttrs());
3096   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
3097 }
3098 
verify(ViewOp op)3099 static LogicalResult verify(ViewOp op) {
3100   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
3101   auto viewType = op.getType();
3102 
3103   // The base memref should have identity layout map (or none).
3104   if (baseType.getAffineMaps().size() > 1 ||
3105       (baseType.getAffineMaps().size() == 1 &&
3106        !baseType.getAffineMaps()[0].isIdentity()))
3107     return op.emitError("unsupported map for base memref type ") << baseType;
3108 
3109   // The result memref should have identity layout map (or none).
3110   if (viewType.getAffineMaps().size() > 1 ||
3111       (viewType.getAffineMaps().size() == 1 &&
3112        !viewType.getAffineMaps()[0].isIdentity()))
3113     return op.emitError("unsupported map for result memref type ") << viewType;
3114 
3115   // The base memref and the view memref should be in the same memory space.
3116   if (baseType.getMemorySpace() != viewType.getMemorySpace())
3117     return op.emitError("different memory spaces specified for base memref "
3118                         "type ")
3119            << baseType << " and view memref type " << viewType;
3120 
3121   // Verify that we have the correct number of sizes for the result type.
3122   unsigned numDynamicDims = viewType.getNumDynamicDims();
3123   if (op.sizes().size() != numDynamicDims)
3124     return op.emitError("incorrect number of size operands for type ")
3125            << viewType;
3126 
3127   return success();
3128 }
3129 
getViewSource()3130 Value ViewOp::getViewSource() { return source(); }
3131 
3132 namespace {
3133 
3134 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3135   using OpRewritePattern<ViewOp>::OpRewritePattern;
3136 
matchAndRewrite__anone4b94cca2711::ViewOpShapeFolder3137   LogicalResult matchAndRewrite(ViewOp viewOp,
3138                                 PatternRewriter &rewriter) const override {
3139     // Return if none of the operands are constants.
3140     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3141           return matchPattern(operand, m_ConstantIndex());
3142         }))
3143       return failure();
3144 
3145     // Get result memref type.
3146     auto memrefType = viewOp.getType();
3147 
3148     // Get offset from old memref view type 'memRefType'.
3149     int64_t oldOffset;
3150     SmallVector<int64_t, 4> oldStrides;
3151     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3152       return failure();
3153     assert(oldOffset == 0 && "Expected 0 offset");
3154 
3155     SmallVector<Value, 4> newOperands;
3156 
3157     // Offset cannot be folded into result type.
3158 
3159     // Fold any dynamic dim operands which are produced by a constant.
3160     SmallVector<int64_t, 4> newShapeConstants;
3161     newShapeConstants.reserve(memrefType.getRank());
3162 
3163     unsigned dynamicDimPos = 0;
3164     unsigned rank = memrefType.getRank();
3165     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3166       int64_t dimSize = memrefType.getDimSize(dim);
3167       // If this is already static dimension, keep it.
3168       if (!ShapedType::isDynamic(dimSize)) {
3169         newShapeConstants.push_back(dimSize);
3170         continue;
3171       }
3172       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
3173       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
3174         // Dynamic shape dimension will be folded.
3175         newShapeConstants.push_back(constantIndexOp.getValue());
3176       } else {
3177         // Dynamic shape dimension not folded; copy operand from old memref.
3178         newShapeConstants.push_back(dimSize);
3179         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
3180       }
3181       dynamicDimPos++;
3182     }
3183 
3184     // Create new memref type with constant folded dims.
3185     MemRefType newMemRefType =
3186         MemRefType::Builder(memrefType).setShape(newShapeConstants);
3187     // Nothing new, don't fold.
3188     if (newMemRefType == memrefType)
3189       return failure();
3190 
3191     // Create new ViewOp.
3192     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
3193                                              viewOp.getOperand(0),
3194                                              viewOp.byte_shift(), newOperands);
3195     // Insert a cast so we have the same type as the old memref type.
3196     rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
3197                                               viewOp.getType());
3198     return success();
3199   }
3200 };
3201 
3202 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3203   using OpRewritePattern<ViewOp>::OpRewritePattern;
3204 
matchAndRewrite__anone4b94cca2711::ViewOpMemrefCastFolder3205   LogicalResult matchAndRewrite(ViewOp viewOp,
3206                                 PatternRewriter &rewriter) const override {
3207     Value memrefOperand = viewOp.getOperand(0);
3208     MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
3209     if (!memrefCastOp)
3210       return failure();
3211     Value allocOperand = memrefCastOp.getOperand();
3212     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3213     if (!allocOp)
3214       return failure();
3215     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3216                                         viewOp.byte_shift(), viewOp.sizes());
3217     return success();
3218   }
3219 };
3220 
3221 } // end anonymous namespace
3222 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3223 void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3224                                          MLIRContext *context) {
3225   results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3226 }
3227 
3228 //===----------------------------------------------------------------------===//
3229 // XOrOp
3230 //===----------------------------------------------------------------------===//
3231 
fold(ArrayRef<Attribute> operands)3232 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
3233   /// xor(x, 0) -> x
3234   if (matchPattern(rhs(), m_Zero()))
3235     return lhs();
3236   /// xor(x,x) -> 0
3237   if (lhs() == rhs())
3238     return Builder(getContext()).getZeroAttr(getType());
3239 
3240   return constFoldBinaryOp<IntegerAttr>(operands,
3241                                         [](APInt a, APInt b) { return a ^ b; });
3242 }
3243 
3244 //===----------------------------------------------------------------------===//
3245 // ZeroExtendIOp
3246 //===----------------------------------------------------------------------===//
3247 
verify(ZeroExtendIOp op)3248 static LogicalResult verify(ZeroExtendIOp op) {
3249   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
3250   auto dstType = getElementTypeOrSelf(op.getType());
3251 
3252   if (srcType.isa<IndexType>())
3253     return op.emitError() << srcType << " is not a valid operand type";
3254   if (dstType.isa<IndexType>())
3255     return op.emitError() << dstType << " is not a valid result type";
3256 
3257   if (srcType.cast<IntegerType>().getWidth() >=
3258       dstType.cast<IntegerType>().getWidth())
3259     return op.emitError("result type ")
3260            << dstType << " must be wider than operand type " << srcType;
3261 
3262   return success();
3263 }
3264 
3265 //===----------------------------------------------------------------------===//
3266 // TableGen'd op method definitions
3267 //===----------------------------------------------------------------------===//
3268 
3269 #define GET_OP_CLASSES
3270 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
3271