1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10 
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
32 
33 // Pull in all enum type definitions and utility function declarations.
34 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // StandardOpsDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with standard
43 /// operations.
44 struct StdInlinerInterface : public DialectInlinerInterface {
45   using DialectInlinerInterface::DialectInlinerInterface;
46 
47   //===--------------------------------------------------------------------===//
48   // Analysis Hooks
49   //===--------------------------------------------------------------------===//
50 
51   /// All call operations within standard ops can be inlined.
isLegalToInline__anon476530360111::StdInlinerInterface52   bool isLegalToInline(Operation *call, Operation *callable,
53                        bool wouldBeCloned) const final {
54     return true;
55   }
56 
57   /// All operations within standard ops can be inlined.
isLegalToInline__anon476530360111::StdInlinerInterface58   bool isLegalToInline(Operation *, Region *, bool,
59                        BlockAndValueMapping &) const final {
60     return true;
61   }
62 
63   //===--------------------------------------------------------------------===//
64   // Transformation Hooks
65   //===--------------------------------------------------------------------===//
66 
67   /// Handle the given inlined terminator by replacing it with a new operation
68   /// as necessary.
handleTerminator__anon476530360111::StdInlinerInterface69   void handleTerminator(Operation *op, Block *newDest) const final {
70     // Only "std.return" needs to be handled here.
71     auto returnOp = dyn_cast<ReturnOp>(op);
72     if (!returnOp)
73       return;
74 
75     // Replace the return with a branch to the dest.
76     OpBuilder builder(op);
77     builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
78     op->erase();
79   }
80 
81   /// Handle the given inlined terminator by replacing it with a new operation
82   /// as necessary.
handleTerminator__anon476530360111::StdInlinerInterface83   void handleTerminator(Operation *op,
84                         ArrayRef<Value> valuesToRepl) const final {
85     // Only "std.return" needs to be handled here.
86     auto returnOp = cast<ReturnOp>(op);
87 
88     // Replace the values directly with the return operands.
89     assert(returnOp.getNumOperands() == valuesToRepl.size());
90     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
91       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
92   }
93 };
94 } // end anonymous namespace
95 
96 //===----------------------------------------------------------------------===//
97 // StandardOpsDialect
98 //===----------------------------------------------------------------------===//
99 
100 /// A custom unary operation printer that omits the "std." prefix from the
101 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)102 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
103   assert(op->getNumOperands() == 1 && "unary op should have one operand");
104   assert(op->getNumResults() == 1 && "unary op should have one result");
105 
106   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
107   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
108     << op->getOperand(0);
109   p.printOptionalAttrDict(op->getAttrs());
110   p << " : " << op->getOperand(0).getType();
111 }
112 
113 /// A custom binary operation printer that omits the "std." prefix from the
114 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)115 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
116   assert(op->getNumOperands() == 2 && "binary op should have two operands");
117   assert(op->getNumResults() == 1 && "binary op should have one result");
118 
119   // If not all the operand and result types are the same, just use the
120   // generic assembly form to avoid omitting information in printing.
121   auto resultType = op->getResult(0).getType();
122   if (op->getOperand(0).getType() != resultType ||
123       op->getOperand(1).getType() != resultType) {
124     p.printGenericOp(op);
125     return;
126   }
127 
128   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
129   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
130     << op->getOperand(0) << ", " << op->getOperand(1);
131   p.printOptionalAttrDict(op->getAttrs());
132 
133   // Now we can output only one type for all operands and the result.
134   p << " : " << op->getResult(0).getType();
135 }
136 
137 /// A custom ternary operation printer that omits the "std." prefix from the
138 /// operation names.
printStandardTernaryOp(Operation * op,OpAsmPrinter & p)139 static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
140   assert(op->getNumOperands() == 3 && "ternary op should have three operands");
141   assert(op->getNumResults() == 1 && "ternary op should have one result");
142 
143   // If not all the operand and result types are the same, just use the
144   // generic assembly form to avoid omitting information in printing.
145   auto resultType = op->getResult(0).getType();
146   if (op->getOperand(0).getType() != resultType ||
147       op->getOperand(1).getType() != resultType ||
148       op->getOperand(2).getType() != resultType) {
149     p.printGenericOp(op);
150     return;
151   }
152 
153   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
154   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
155     << op->getOperand(0) << ", " << op->getOperand(1) << ", "
156     << op->getOperand(2);
157   p.printOptionalAttrDict(op->getAttrs());
158 
159   // Now we can output only one type for all operands and the result.
160   p << " : " << op->getResult(0).getType();
161 }
162 
163 /// A custom cast operation printer that omits the "std." prefix from the
164 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)165 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
166   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
167   p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
168     << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
169     << op->getResult(0).getType();
170 }
171 
initialize()172 void StandardOpsDialect::initialize() {
173   addOperations<
174 #define GET_OP_LIST
175 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
176       >();
177   addInterfaces<StdInlinerInterface>();
178 }
179 
180 /// Materialize a single constant operation from a given attribute value with
181 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)182 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
183                                                    Attribute value, Type type,
184                                                    Location loc) {
185   return builder.create<ConstantOp>(loc, type, value);
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // Common cast compatibility check for vector types.
190 //===----------------------------------------------------------------------===//
191 
192 /// This method checks for cast compatibility of vector types.
193 /// If 'a' and 'b' are vector types, and they are cast compatible,
194 /// it calls the 'areElementsCastCompatible' function to check for
195 /// element cast compatibility.
196 /// Returns 'true' if the vector types are cast compatible,  and 'false'
197 /// otherwise.
areVectorCastSimpleCompatible(Type a,Type b,function_ref<bool (TypeRange,TypeRange)> areElementsCastCompatible)198 static bool areVectorCastSimpleCompatible(
199     Type a, Type b,
200     function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
201   if (auto va = a.dyn_cast<VectorType>())
202     if (auto vb = b.dyn_cast<VectorType>())
203       return va.getShape().equals(vb.getShape()) &&
204              areElementsCastCompatible(va.getElementType(),
205                                        vb.getElementType());
206   return false;
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // AddFOp
211 //===----------------------------------------------------------------------===//
212 
fold(ArrayRef<Attribute> operands)213 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
214   return constFoldBinaryOp<FloatAttr>(
215       operands, [](APFloat a, APFloat b) { return a + b; });
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // AddIOp
220 //===----------------------------------------------------------------------===//
221 
fold(ArrayRef<Attribute> operands)222 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
223   /// addi(x, 0) -> x
224   if (matchPattern(rhs(), m_Zero()))
225     return lhs();
226 
227   return constFoldBinaryOp<IntegerAttr>(operands,
228                                         [](APInt a, APInt b) { return a + b; });
229 }
230 
231 /// Canonicalize a sum of a constant and (constant - something) to simply be
232 /// a sum of constants minus something. This transformation does similar
233 /// transformations for additions of a constant with a subtract/add of
234 /// a constant. This may result in some operations being reordered (but should
235 /// remain equivalent).
236 struct AddConstantReorder : public OpRewritePattern<AddIOp> {
237   using OpRewritePattern<AddIOp>::OpRewritePattern;
238 
matchAndRewriteAddConstantReorder239   LogicalResult matchAndRewrite(AddIOp addop,
240                                 PatternRewriter &rewriter) const override {
241     for (int i = 0; i < 2; i++) {
242       APInt origConst;
243       APInt midConst;
244       if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
245         if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
246           for (int j = 0; j < 2; j++) {
247             if (matchPattern(midAddOp.getOperand(j),
248                              m_ConstantInt(&midConst))) {
249               auto nextConstant = rewriter.create<ConstantOp>(
250                   addop.getLoc(), rewriter.getIntegerAttr(
251                                       addop.getType(), origConst + midConst));
252               rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
253                                                   midAddOp.getOperand(1 - j));
254               return success();
255             }
256           }
257         }
258         if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
259           if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
260             auto nextConstant = rewriter.create<ConstantOp>(
261                 addop.getLoc(),
262                 rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
263             rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
264                                                 midSubOp.getOperand(1));
265             return success();
266           }
267           if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
268             auto nextConstant = rewriter.create<ConstantOp>(
269                 addop.getLoc(),
270                 rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
271             rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
272                                                 midSubOp.getOperand(0));
273             return success();
274           }
275         }
276       }
277     }
278     return failure();
279   }
280 };
281 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)282 void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
283                                          MLIRContext *context) {
284   results.insert<AddConstantReorder>(context);
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // AndOp
289 //===----------------------------------------------------------------------===//
290 
fold(ArrayRef<Attribute> operands)291 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
292   /// and(x, 0) -> 0
293   if (matchPattern(rhs(), m_Zero()))
294     return rhs();
295   /// and(x, allOnes) -> x
296   APInt intValue;
297   if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
298       intValue.isAllOnesValue())
299     return lhs();
300   /// and(x,x) -> x
301   if (lhs() == rhs())
302     return rhs();
303 
304   return constFoldBinaryOp<IntegerAttr>(operands,
305                                         [](APInt a, APInt b) { return a & b; });
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // AssertOp
310 //===----------------------------------------------------------------------===//
311 
canonicalize(AssertOp op,PatternRewriter & rewriter)312 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
313   // Erase assertion if argument is constant true.
314   if (matchPattern(op.arg(), m_One())) {
315     rewriter.eraseOp(op);
316     return success();
317   }
318   return failure();
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // AtomicRMWOp
323 //===----------------------------------------------------------------------===//
324 
verify(AtomicRMWOp op)325 static LogicalResult verify(AtomicRMWOp op) {
326   if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
327     return op.emitOpError(
328         "expects the number of subscripts to be equal to memref rank");
329   switch (op.kind()) {
330   case AtomicRMWKind::addf:
331   case AtomicRMWKind::maxf:
332   case AtomicRMWKind::minf:
333   case AtomicRMWKind::mulf:
334     if (!op.value().getType().isa<FloatType>())
335       return op.emitOpError()
336              << "with kind '" << stringifyAtomicRMWKind(op.kind())
337              << "' expects a floating-point type";
338     break;
339   case AtomicRMWKind::addi:
340   case AtomicRMWKind::maxs:
341   case AtomicRMWKind::maxu:
342   case AtomicRMWKind::mins:
343   case AtomicRMWKind::minu:
344   case AtomicRMWKind::muli:
345     if (!op.value().getType().isa<IntegerType>())
346       return op.emitOpError()
347              << "with kind '" << stringifyAtomicRMWKind(op.kind())
348              << "' expects an integer type";
349     break;
350   default:
351     break;
352   }
353   return success();
354 }
355 
356 /// Returns the identity value attribute associated with an AtomicRMWKind op.
getIdentityValueAttr(AtomicRMWKind kind,Type resultType,OpBuilder & builder,Location loc)357 Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
358                                      OpBuilder &builder, Location loc) {
359   switch (kind) {
360   case AtomicRMWKind::addf:
361   case AtomicRMWKind::addi:
362     return builder.getZeroAttr(resultType);
363   case AtomicRMWKind::muli:
364     return builder.getIntegerAttr(resultType, 1);
365   case AtomicRMWKind::mulf:
366     return builder.getFloatAttr(resultType, 1);
367   // TODO: Add remaining reduction operations.
368   default:
369     (void)emitOptionalError(loc, "Reduction operation type not supported");
370     break;
371   }
372   return nullptr;
373 }
374 
375 /// Returns the identity value associated with an AtomicRMWKind op.
getIdentityValue(AtomicRMWKind op,Type resultType,OpBuilder & builder,Location loc)376 Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
377                              OpBuilder &builder, Location loc) {
378   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
379   return builder.create<ConstantOp>(loc, attr);
380 }
381 
382 /// Return the value obtained by applying the reduction operation kind
383 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
getReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value lhs,Value rhs)384 Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
385                            Value lhs, Value rhs) {
386   switch (op) {
387   case AtomicRMWKind::addf:
388     return builder.create<AddFOp>(loc, lhs, rhs);
389   case AtomicRMWKind::addi:
390     return builder.create<AddIOp>(loc, lhs, rhs);
391   case AtomicRMWKind::mulf:
392     return builder.create<MulFOp>(loc, lhs, rhs);
393   case AtomicRMWKind::muli:
394     return builder.create<MulIOp>(loc, lhs, rhs);
395   // TODO: Add remaining reduction operations.
396   default:
397     (void)emitOptionalError(loc, "Reduction operation type not supported");
398     break;
399   }
400   return nullptr;
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // GenericAtomicRMWOp
405 //===----------------------------------------------------------------------===//
406 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)407 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
408                                Value memref, ValueRange ivs) {
409   result.addOperands(memref);
410   result.addOperands(ivs);
411 
412   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
413     Type elementType = memrefType.getElementType();
414     result.addTypes(elementType);
415 
416     Region *bodyRegion = result.addRegion();
417     bodyRegion->push_back(new Block());
418     bodyRegion->addArgument(elementType);
419   }
420 }
421 
verify(GenericAtomicRMWOp op)422 static LogicalResult verify(GenericAtomicRMWOp op) {
423   auto &body = op.body();
424   if (body.getNumArguments() != 1)
425     return op.emitOpError("expected single number of entry block arguments");
426 
427   if (op.getResult().getType() != body.getArgument(0).getType())
428     return op.emitOpError(
429         "expected block argument of the same type result type");
430 
431   bool hasSideEffects =
432       body.walk([&](Operation *nestedOp) {
433             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
434               return WalkResult::advance();
435             nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
436                                 "only operations with no side effects");
437             return WalkResult::interrupt();
438           })
439           .wasInterrupted();
440   return hasSideEffects ? failure() : success();
441 }
442 
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)443 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
444                                            OperationState &result) {
445   OpAsmParser::OperandType memref;
446   Type memrefType;
447   SmallVector<OpAsmParser::OperandType, 4> ivs;
448 
449   Type indexType = parser.getBuilder().getIndexType();
450   if (parser.parseOperand(memref) ||
451       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
452       parser.parseColonType(memrefType) ||
453       parser.resolveOperand(memref, memrefType, result.operands) ||
454       parser.resolveOperands(ivs, indexType, result.operands))
455     return failure();
456 
457   Region *body = result.addRegion();
458   if (parser.parseRegion(*body, llvm::None, llvm::None) ||
459       parser.parseOptionalAttrDict(result.attributes))
460     return failure();
461   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
462   return success();
463 }
464 
print(OpAsmPrinter & p,GenericAtomicRMWOp op)465 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
466   p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
467     << "] : " << op.memref().getType();
468   p.printRegion(op.body());
469   p.printOptionalAttrDict(op->getAttrs());
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // AtomicYieldOp
474 //===----------------------------------------------------------------------===//
475 
verify(AtomicYieldOp op)476 static LogicalResult verify(AtomicYieldOp op) {
477   Type parentType = op->getParentOp()->getResultTypes().front();
478   Type resultType = op.result().getType();
479   if (parentType != resultType)
480     return op.emitOpError() << "types mismatch between yield op: " << resultType
481                             << " and its parent: " << parentType;
482   return success();
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // BranchOp
487 //===----------------------------------------------------------------------===//
488 
489 /// Given a successor, try to collapse it to a new destination if it only
490 /// contains a passthrough unconditional branch. If the successor is
491 /// collapsable, `successor` and `successorOperands` are updated to reference
492 /// the new destination and values. `argStorage` is used as storage if operands
493 /// to the collapsed successor need to be remapped. It must outlive uses of
494 /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)495 static LogicalResult collapseBranch(Block *&successor,
496                                     ValueRange &successorOperands,
497                                     SmallVectorImpl<Value> &argStorage) {
498   // Check that the successor only contains a unconditional branch.
499   if (std::next(successor->begin()) != successor->end())
500     return failure();
501   // Check that the terminator is an unconditional branch.
502   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
503   if (!successorBranch)
504     return failure();
505   // Check that the arguments are only used within the terminator.
506   for (BlockArgument arg : successor->getArguments()) {
507     for (Operation *user : arg.getUsers())
508       if (user != successorBranch)
509         return failure();
510   }
511   // Don't try to collapse branches to infinite loops.
512   Block *successorDest = successorBranch.getDest();
513   if (successorDest == successor)
514     return failure();
515 
516   // Update the operands to the successor. If the branch parent has no
517   // arguments, we can use the branch operands directly.
518   OperandRange operands = successorBranch.getOperands();
519   if (successor->args_empty()) {
520     successor = successorDest;
521     successorOperands = operands;
522     return success();
523   }
524 
525   // Otherwise, we need to remap any argument operands.
526   for (Value operand : operands) {
527     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
528     if (argOperand && argOperand.getOwner() == successor)
529       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
530     else
531       argStorage.push_back(operand);
532   }
533   successor = successorDest;
534   successorOperands = argStorage;
535   return success();
536 }
537 
538 /// Simplify a branch to a block that has a single predecessor. This effectively
539 /// merges the two blocks.
540 static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)541 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
542   // Check that the successor block has a single predecessor.
543   Block *succ = op.getDest();
544   Block *opParent = op->getBlock();
545   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
546     return failure();
547 
548   // Merge the successor into the current block and erase the branch.
549   rewriter.mergeBlocks(succ, opParent, op.getOperands());
550   rewriter.eraseOp(op);
551   return success();
552 }
553 
554 ///   br ^bb1
555 /// ^bb1
556 ///   br ^bbN(...)
557 ///
558 ///  -> br ^bbN(...)
559 ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)560 static LogicalResult simplifyPassThroughBr(BranchOp op,
561                                            PatternRewriter &rewriter) {
562   Block *dest = op.getDest();
563   ValueRange destOperands = op.getOperands();
564   SmallVector<Value, 4> destOperandStorage;
565 
566   // Try to collapse the successor if it points somewhere other than this
567   // block.
568   if (dest == op->getBlock() ||
569       failed(collapseBranch(dest, destOperands, destOperandStorage)))
570     return failure();
571 
572   // Create a new branch with the collapsed successor.
573   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
574   return success();
575 }
576 
canonicalize(BranchOp op,PatternRewriter & rewriter)577 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
578   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
579                  succeeded(simplifyPassThroughBr(op, rewriter)));
580 }
581 
getDest()582 Block *BranchOp::getDest() { return getSuccessor(); }
583 
setDest(Block * block)584 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
585 
eraseOperand(unsigned index)586 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
587 
588 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)589 BranchOp::getMutableSuccessorOperands(unsigned index) {
590   assert(index == 0 && "invalid successor index");
591   return destOperandsMutable();
592 }
593 
getSuccessorForOperands(ArrayRef<Attribute>)594 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
595 
596 //===----------------------------------------------------------------------===//
597 // CallOp
598 //===----------------------------------------------------------------------===//
599 
verifySymbolUses(SymbolTableCollection & symbolTable)600 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
601   // Check that the callee attribute was specified.
602   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
603   if (!fnAttr)
604     return emitOpError("requires a 'callee' symbol reference attribute");
605   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
606   if (!fn)
607     return emitOpError() << "'" << fnAttr.getValue()
608                          << "' does not reference a valid function";
609 
610   // Verify that the operand and result types match the callee.
611   auto fnType = fn.getType();
612   if (fnType.getNumInputs() != getNumOperands())
613     return emitOpError("incorrect number of operands for callee");
614 
615   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
616     if (getOperand(i).getType() != fnType.getInput(i))
617       return emitOpError("operand type mismatch: expected operand type ")
618              << fnType.getInput(i) << ", but provided "
619              << getOperand(i).getType() << " for operand number " << i;
620 
621   if (fnType.getNumResults() != getNumResults())
622     return emitOpError("incorrect number of results for callee");
623 
624   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
625     if (getResult(i).getType() != fnType.getResult(i))
626       return emitOpError("result type mismatch");
627 
628   return success();
629 }
630 
getCalleeType()631 FunctionType CallOp::getCalleeType() {
632   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // CallIndirectOp
637 //===----------------------------------------------------------------------===//
638 
639 /// Fold indirect calls that have a constant function as the callee operand.
canonicalize(CallIndirectOp indirectCall,PatternRewriter & rewriter)640 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
641                                            PatternRewriter &rewriter) {
642   // Check that the callee is a constant callee.
643   SymbolRefAttr calledFn;
644   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
645     return failure();
646 
647   // Replace with a direct call.
648   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
649                                       indirectCall.getResultTypes(),
650                                       indirectCall.getArgOperands());
651   return success();
652 }
653 
654 //===----------------------------------------------------------------------===//
655 // General helpers for comparison ops
656 //===----------------------------------------------------------------------===//
657 
658 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)659 static Type getI1SameShape(Type type) {
660   auto i1Type = IntegerType::get(type.getContext(), 1);
661   if (auto tensorType = type.dyn_cast<RankedTensorType>())
662     return RankedTensorType::get(tensorType.getShape(), i1Type);
663   if (type.isa<UnrankedTensorType>())
664     return UnrankedTensorType::get(i1Type);
665   if (auto vectorType = type.dyn_cast<VectorType>())
666     return VectorType::get(vectorType.getShape(), i1Type);
667   return i1Type;
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // CmpIOp
672 //===----------------------------------------------------------------------===//
673 
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)674 static void buildCmpIOp(OpBuilder &build, OperationState &result,
675                         CmpIPredicate predicate, Value lhs, Value rhs) {
676   result.addOperands({lhs, rhs});
677   result.types.push_back(getI1SameShape(lhs.getType()));
678   result.addAttribute(CmpIOp::getPredicateAttrName(),
679                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
680 }
681 
682 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
683 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)684 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
685                              const APInt &rhs) {
686   switch (predicate) {
687   case CmpIPredicate::eq:
688     return lhs.eq(rhs);
689   case CmpIPredicate::ne:
690     return lhs.ne(rhs);
691   case CmpIPredicate::slt:
692     return lhs.slt(rhs);
693   case CmpIPredicate::sle:
694     return lhs.sle(rhs);
695   case CmpIPredicate::sgt:
696     return lhs.sgt(rhs);
697   case CmpIPredicate::sge:
698     return lhs.sge(rhs);
699   case CmpIPredicate::ult:
700     return lhs.ult(rhs);
701   case CmpIPredicate::ule:
702     return lhs.ule(rhs);
703   case CmpIPredicate::ugt:
704     return lhs.ugt(rhs);
705   case CmpIPredicate::uge:
706     return lhs.uge(rhs);
707   }
708   llvm_unreachable("unknown comparison predicate");
709 }
710 
711 // Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(CmpIPredicate predicate)712 static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
713   switch (predicate) {
714   case CmpIPredicate::eq:
715   case CmpIPredicate::sle:
716   case CmpIPredicate::sge:
717   case CmpIPredicate::ule:
718   case CmpIPredicate::uge:
719     return true;
720   case CmpIPredicate::ne:
721   case CmpIPredicate::slt:
722   case CmpIPredicate::sgt:
723   case CmpIPredicate::ult:
724   case CmpIPredicate::ugt:
725     return false;
726   }
727   llvm_unreachable("unknown comparison predicate");
728 }
729 
730 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)731 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
732   assert(operands.size() == 2 && "cmpi takes two arguments");
733 
734   if (lhs() == rhs()) {
735     auto val = applyCmpPredicateToEqualOperands(getPredicate());
736     return BoolAttr::get(getContext(), val);
737   }
738 
739   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
740   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
741   if (!lhs || !rhs)
742     return {};
743 
744   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
745   return BoolAttr::get(getContext(), val);
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // CmpFOp
750 //===----------------------------------------------------------------------===//
751 
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)752 static void buildCmpFOp(OpBuilder &build, OperationState &result,
753                         CmpFPredicate predicate, Value lhs, Value rhs) {
754   result.addOperands({lhs, rhs});
755   result.types.push_back(getI1SameShape(lhs.getType()));
756   result.addAttribute(CmpFOp::getPredicateAttrName(),
757                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
758 }
759 
760 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
761 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)762 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
763                              const APFloat &rhs) {
764   auto cmpResult = lhs.compare(rhs);
765   switch (predicate) {
766   case CmpFPredicate::AlwaysFalse:
767     return false;
768   case CmpFPredicate::OEQ:
769     return cmpResult == APFloat::cmpEqual;
770   case CmpFPredicate::OGT:
771     return cmpResult == APFloat::cmpGreaterThan;
772   case CmpFPredicate::OGE:
773     return cmpResult == APFloat::cmpGreaterThan ||
774            cmpResult == APFloat::cmpEqual;
775   case CmpFPredicate::OLT:
776     return cmpResult == APFloat::cmpLessThan;
777   case CmpFPredicate::OLE:
778     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
779   case CmpFPredicate::ONE:
780     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
781   case CmpFPredicate::ORD:
782     return cmpResult != APFloat::cmpUnordered;
783   case CmpFPredicate::UEQ:
784     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
785   case CmpFPredicate::UGT:
786     return cmpResult == APFloat::cmpUnordered ||
787            cmpResult == APFloat::cmpGreaterThan;
788   case CmpFPredicate::UGE:
789     return cmpResult == APFloat::cmpUnordered ||
790            cmpResult == APFloat::cmpGreaterThan ||
791            cmpResult == APFloat::cmpEqual;
792   case CmpFPredicate::ULT:
793     return cmpResult == APFloat::cmpUnordered ||
794            cmpResult == APFloat::cmpLessThan;
795   case CmpFPredicate::ULE:
796     return cmpResult == APFloat::cmpUnordered ||
797            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
798   case CmpFPredicate::UNE:
799     return cmpResult != APFloat::cmpEqual;
800   case CmpFPredicate::UNO:
801     return cmpResult == APFloat::cmpUnordered;
802   case CmpFPredicate::AlwaysTrue:
803     return true;
804   }
805   llvm_unreachable("unknown comparison predicate");
806 }
807 
808 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)809 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
810   assert(operands.size() == 2 && "cmpf takes two arguments");
811 
812   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
813   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
814 
815   // TODO: We could actually do some intelligent things if we know only one
816   // of the operands, but it's inf or nan.
817   if (!lhs || !rhs)
818     return {};
819 
820   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
821   return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // CondBranchOp
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 /// cond_br true, ^bb1, ^bb2
830 ///  -> br ^bb1
831 /// cond_br false, ^bb1, ^bb2
832 ///  -> br ^bb2
833 ///
834 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
835   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
836 
matchAndRewrite__anon476530360611::SimplifyConstCondBranchPred837   LogicalResult matchAndRewrite(CondBranchOp condbr,
838                                 PatternRewriter &rewriter) const override {
839     if (matchPattern(condbr.getCondition(), m_NonZero())) {
840       // True branch taken.
841       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
842                                             condbr.getTrueOperands());
843       return success();
844     } else if (matchPattern(condbr.getCondition(), m_Zero())) {
845       // False branch taken.
846       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
847                                             condbr.getFalseOperands());
848       return success();
849     }
850     return failure();
851   }
852 };
853 
854 ///   cond_br %cond, ^bb1, ^bb2
855 /// ^bb1
856 ///   br ^bbN(...)
857 /// ^bb2
858 ///   br ^bbK(...)
859 ///
860 ///  -> cond_br %cond, ^bbN(...), ^bbK(...)
861 ///
862 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
863   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
864 
matchAndRewrite__anon476530360611::SimplifyPassThroughCondBranch865   LogicalResult matchAndRewrite(CondBranchOp condbr,
866                                 PatternRewriter &rewriter) const override {
867     Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
868     ValueRange trueDestOperands = condbr.getTrueOperands();
869     ValueRange falseDestOperands = condbr.getFalseOperands();
870     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
871 
872     // Try to collapse one of the current successors.
873     LogicalResult collapsedTrue =
874         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
875     LogicalResult collapsedFalse =
876         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
877     if (failed(collapsedTrue) && failed(collapsedFalse))
878       return failure();
879 
880     // Create a new branch with the collapsed successors.
881     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
882                                               trueDest, trueDestOperands,
883                                               falseDest, falseDestOperands);
884     return success();
885   }
886 };
887 
888 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
889 ///  -> br ^bb1(A, ..., N)
890 ///
891 /// cond_br %cond, ^bb1(A), ^bb1(B)
892 ///  -> %select = select %cond, A, B
893 ///     br ^bb1(%select)
894 ///
895 struct SimplifyCondBranchIdenticalSuccessors
896     : public OpRewritePattern<CondBranchOp> {
897   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
898 
matchAndRewrite__anon476530360611::SimplifyCondBranchIdenticalSuccessors899   LogicalResult matchAndRewrite(CondBranchOp condbr,
900                                 PatternRewriter &rewriter) const override {
901     // Check that the true and false destinations are the same and have the same
902     // operands.
903     Block *trueDest = condbr.trueDest();
904     if (trueDest != condbr.falseDest())
905       return failure();
906 
907     // If all of the operands match, no selects need to be generated.
908     OperandRange trueOperands = condbr.getTrueOperands();
909     OperandRange falseOperands = condbr.getFalseOperands();
910     if (trueOperands == falseOperands) {
911       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
912       return success();
913     }
914 
915     // Otherwise, if the current block is the only predecessor insert selects
916     // for any mismatched branch operands.
917     if (trueDest->getUniquePredecessor() != condbr->getBlock())
918       return failure();
919 
920     // Generate a select for any operands that differ between the two.
921     SmallVector<Value, 8> mergedOperands;
922     mergedOperands.reserve(trueOperands.size());
923     Value condition = condbr.getCondition();
924     for (auto it : llvm::zip(trueOperands, falseOperands)) {
925       if (std::get<0>(it) == std::get<1>(it))
926         mergedOperands.push_back(std::get<0>(it));
927       else
928         mergedOperands.push_back(rewriter.create<SelectOp>(
929             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
930     }
931 
932     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
933     return success();
934   }
935 };
936 
937 ///   ...
938 ///   cond_br %cond, ^bb1(...), ^bb2(...)
939 /// ...
940 /// ^bb1: // has single predecessor
941 ///   ...
942 ///   cond_br %cond, ^bb3(...), ^bb4(...)
943 ///
944 /// ->
945 ///
946 ///   ...
947 ///   cond_br %cond, ^bb1(...), ^bb2(...)
948 /// ...
949 /// ^bb1: // has single predecessor
950 ///   ...
951 ///   br ^bb3(...)
952 ///
953 struct SimplifyCondBranchFromCondBranchOnSameCondition
954     : public OpRewritePattern<CondBranchOp> {
955   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
956 
matchAndRewrite__anon476530360611::SimplifyCondBranchFromCondBranchOnSameCondition957   LogicalResult matchAndRewrite(CondBranchOp condbr,
958                                 PatternRewriter &rewriter) const override {
959     // Check that we have a single distinct predecessor.
960     Block *currentBlock = condbr->getBlock();
961     Block *predecessor = currentBlock->getSinglePredecessor();
962     if (!predecessor)
963       return failure();
964 
965     // Check that the predecessor terminates with a conditional branch to this
966     // block and that it branches on the same condition.
967     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
968     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
969       return failure();
970 
971     // Fold this branch to an unconditional branch.
972     if (currentBlock == predBranch.trueDest())
973       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
974                                             condbr.trueDestOperands());
975     else
976       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
977                                             condbr.falseDestOperands());
978     return success();
979   }
980 };
981 
982 ///   cond_br %arg0, ^trueB, ^falseB
983 ///
984 /// ^trueB:
985 ///   "test.consumer1"(%arg0) : (i1) -> ()
986 ///    ...
987 ///
988 /// ^falseB:
989 ///   "test.consumer2"(%arg0) : (i1) -> ()
990 ///   ...
991 ///
992 /// ->
993 ///
994 ///   cond_br %arg0, ^trueB, ^falseB
995 /// ^trueB:
996 ///   "test.consumer1"(%true) : (i1) -> ()
997 ///   ...
998 ///
999 /// ^falseB:
1000 ///   "test.consumer2"(%false) : (i1) -> ()
1001 ///   ...
1002 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
1003   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1004 
matchAndRewrite__anon476530360611::CondBranchTruthPropagation1005   LogicalResult matchAndRewrite(CondBranchOp condbr,
1006                                 PatternRewriter &rewriter) const override {
1007     // Check that we have a single distinct predecessor.
1008     bool replaced = false;
1009     Type ty = rewriter.getI1Type();
1010 
1011     // These variables serve to prevent creating duplicate constants
1012     // and hold constant true or false values.
1013     Value constantTrue = nullptr;
1014     Value constantFalse = nullptr;
1015 
1016     // TODO These checks can be expanded to encompas any use with only
1017     // either the true of false edge as a predecessor. For now, we fall
1018     // back to checking the single predecessor is given by the true/fasle
1019     // destination, thereby ensuring that only that edge can reach the
1020     // op.
1021     if (condbr.getTrueDest()->getSinglePredecessor()) {
1022       for (OpOperand &use :
1023            llvm::make_early_inc_range(condbr.condition().getUses())) {
1024         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
1025           replaced = true;
1026 
1027           if (!constantTrue)
1028             constantTrue = rewriter.create<mlir::ConstantOp>(
1029                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
1030 
1031           rewriter.updateRootInPlace(use.getOwner(),
1032                                      [&] { use.set(constantTrue); });
1033         }
1034       }
1035     }
1036     if (condbr.getFalseDest()->getSinglePredecessor()) {
1037       for (OpOperand &use :
1038            llvm::make_early_inc_range(condbr.condition().getUses())) {
1039         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
1040           replaced = true;
1041 
1042           if (!constantFalse)
1043             constantFalse = rewriter.create<mlir::ConstantOp>(
1044                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
1045 
1046           rewriter.updateRootInPlace(use.getOwner(),
1047                                      [&] { use.set(constantFalse); });
1048         }
1049       }
1050     }
1051     return success(replaced);
1052   }
1053 };
1054 } // end anonymous namespace
1055 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1056 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
1057                                                MLIRContext *context) {
1058   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1059               SimplifyCondBranchIdenticalSuccessors,
1060               SimplifyCondBranchFromCondBranchOnSameCondition,
1061               CondBranchTruthPropagation>(context);
1062 }
1063 
1064 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1065 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1066   assert(index < getNumSuccessors() && "invalid successor index");
1067   return index == trueIndex ? trueDestOperandsMutable()
1068                             : falseDestOperandsMutable();
1069 }
1070 
getSuccessorForOperands(ArrayRef<Attribute> operands)1071 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1072   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1073     return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1074   return nullptr;
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // Constant*Op
1079 //===----------------------------------------------------------------------===//
1080 
print(OpAsmPrinter & p,ConstantOp & op)1081 static void print(OpAsmPrinter &p, ConstantOp &op) {
1082   p << "constant ";
1083   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
1084 
1085   if (op->getAttrs().size() > 1)
1086     p << ' ';
1087   p << op.getValue();
1088 
1089   // If the value is a symbol reference or Array, print a trailing type.
1090   if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
1091     p << " : " << op.getType();
1092 }
1093 
parseConstantOp(OpAsmParser & parser,OperationState & result)1094 static ParseResult parseConstantOp(OpAsmParser &parser,
1095                                    OperationState &result) {
1096   Attribute valueAttr;
1097   if (parser.parseOptionalAttrDict(result.attributes) ||
1098       parser.parseAttribute(valueAttr, "value", result.attributes))
1099     return failure();
1100 
1101   // If the attribute is a symbol reference or array, then we expect a trailing
1102   // type.
1103   Type type;
1104   if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
1105     type = valueAttr.getType();
1106   else if (parser.parseColonType(type))
1107     return failure();
1108 
1109   // Add the attribute type to the list.
1110   return parser.addTypeToList(type, result.types);
1111 }
1112 
1113 /// The constant op requires an attribute, and furthermore requires that it
1114 /// matches the return type.
verify(ConstantOp & op)1115 static LogicalResult verify(ConstantOp &op) {
1116   auto value = op.getValue();
1117   if (!value)
1118     return op.emitOpError("requires a 'value' attribute");
1119 
1120   Type type = op.getType();
1121   if (!value.getType().isa<NoneType>() && type != value.getType())
1122     return op.emitOpError() << "requires attribute's type (" << value.getType()
1123                             << ") to match op's return type (" << type << ")";
1124 
1125   if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1126     if (type.isa<IndexType>() || value.isa<BoolAttr>())
1127       return success();
1128     IntegerType intType = type.cast<IntegerType>();
1129     if (!intType.isSignless())
1130       return op.emitOpError("requires integer result types to be signless");
1131 
1132     // If the type has a known bitwidth we verify that the value can be
1133     // represented with the given bitwidth.
1134     unsigned bitwidth = intType.getWidth();
1135     APInt intVal = intAttr.getValue();
1136     if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1137       return op.emitOpError("requires 'value' to be an integer within the "
1138                             "range of the integer result type");
1139     return success();
1140   }
1141 
1142   if (auto complexTy = type.dyn_cast<ComplexType>()) {
1143     auto arrayAttr = value.dyn_cast<ArrayAttr>();
1144     if (!complexTy || arrayAttr.size() != 2)
1145       return op.emitOpError(
1146           "requires 'value' to be a complex constant, represented as array of "
1147           "two values");
1148     auto complexEltTy = complexTy.getElementType();
1149     if (complexEltTy != arrayAttr[0].getType() ||
1150         complexEltTy != arrayAttr[1].getType()) {
1151       return op.emitOpError()
1152              << "requires attribute's element types (" << arrayAttr[0].getType()
1153              << ", " << arrayAttr[1].getType()
1154              << ") to match the element type of the op's return type ("
1155              << complexEltTy << ")";
1156     }
1157     return success();
1158   }
1159 
1160   if (type.isa<FloatType>()) {
1161     if (!value.isa<FloatAttr>())
1162       return op.emitOpError("requires 'value' to be a floating point constant");
1163     return success();
1164   }
1165 
1166   if (type.isa<ShapedType>()) {
1167     if (!value.isa<ElementsAttr>())
1168       return op.emitOpError("requires 'value' to be a shaped constant");
1169     return success();
1170   }
1171 
1172   if (type.isa<FunctionType>()) {
1173     auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1174     if (!fnAttr)
1175       return op.emitOpError("requires 'value' to be a function reference");
1176 
1177     // Try to find the referenced function.
1178     auto fn =
1179         op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1180     if (!fn)
1181       return op.emitOpError()
1182              << "reference to undefined function '" << fnAttr.getValue() << "'";
1183 
1184     // Check that the referenced function has the correct type.
1185     if (fn.getType() != type)
1186       return op.emitOpError("reference to function with mismatched type");
1187 
1188     return success();
1189   }
1190 
1191   if (type.isa<NoneType>() && value.isa<UnitAttr>())
1192     return success();
1193 
1194   return op.emitOpError("unsupported 'value' attribute: ") << value;
1195 }
1196 
fold(ArrayRef<Attribute> operands)1197 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1198   assert(operands.empty() && "constant has no operands");
1199   return getValue();
1200 }
1201 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1202 void ConstantOp::getAsmResultNames(
1203     function_ref<void(Value, StringRef)> setNameFn) {
1204   Type type = getType();
1205   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1206     IntegerType intTy = type.dyn_cast<IntegerType>();
1207 
1208     // Sugar i1 constants with 'true' and 'false'.
1209     if (intTy && intTy.getWidth() == 1)
1210       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1211 
1212     // Otherwise, build a complex name with the value and type.
1213     SmallString<32> specialNameBuffer;
1214     llvm::raw_svector_ostream specialName(specialNameBuffer);
1215     specialName << 'c' << intCst.getInt();
1216     if (intTy)
1217       specialName << '_' << type;
1218     setNameFn(getResult(), specialName.str());
1219 
1220   } else if (type.isa<FunctionType>()) {
1221     setNameFn(getResult(), "f");
1222   } else {
1223     setNameFn(getResult(), "cst");
1224   }
1225 }
1226 
1227 /// Returns true if a constant operation can be built with the given value and
1228 /// result type.
isBuildableWith(Attribute value,Type type)1229 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1230   // SymbolRefAttr can only be used with a function type.
1231   if (value.isa<SymbolRefAttr>())
1232     return type.isa<FunctionType>();
1233   // The attribute must have the same type as 'type'.
1234   if (!value.getType().isa<NoneType>() && value.getType() != type)
1235     return false;
1236   // If the type is an integer type, it must be signless.
1237   if (IntegerType integerTy = type.dyn_cast<IntegerType>())
1238     if (!integerTy.isSignless())
1239       return false;
1240   // Finally, check that the attribute kind is handled.
1241   if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
1242     auto complexTy = type.dyn_cast<ComplexType>();
1243     if (!complexTy)
1244       return false;
1245     auto complexEltTy = complexTy.getElementType();
1246     return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
1247            arrAttr[1].getType() == complexEltTy;
1248   }
1249   return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1250 }
1251 
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1252 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1253                             const APFloat &value, FloatType type) {
1254   ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1255 }
1256 
classof(Operation * op)1257 bool ConstantFloatOp::classof(Operation *op) {
1258   return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1259 }
1260 
1261 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1262 bool ConstantIntOp::classof(Operation *op) {
1263   return ConstantOp::classof(op) &&
1264          op->getResult(0).getType().isSignlessInteger();
1265 }
1266 
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1267 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1268                           int64_t value, unsigned width) {
1269   Type type = builder.getIntegerType(width);
1270   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1271 }
1272 
1273 /// Build a constant int op producing an integer with the specified type,
1274 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1275 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1276                           int64_t value, Type type) {
1277   assert(type.isSignlessInteger() &&
1278          "ConstantIntOp can only have signless integer type");
1279   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1280 }
1281 
1282 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1283 bool ConstantIndexOp::classof(Operation *op) {
1284   return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1285 }
1286 
build(OpBuilder & builder,OperationState & result,int64_t value)1287 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1288                             int64_t value) {
1289   Type type = builder.getIndexType();
1290   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1291 }
1292 
1293 // ---------------------------------------------------------------------------
1294 // DivFOp
1295 // ---------------------------------------------------------------------------
1296 
fold(ArrayRef<Attribute> operands)1297 OpFoldResult DivFOp::fold(ArrayRef<Attribute> operands) {
1298   return constFoldBinaryOp<FloatAttr>(
1299       operands, [](APFloat a, APFloat b) { return a / b; });
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // FPExtOp
1304 //===----------------------------------------------------------------------===//
1305 
areCastCompatible(TypeRange inputs,TypeRange outputs)1306 bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1307   if (inputs.size() != 1 || outputs.size() != 1)
1308     return false;
1309   Type a = inputs.front(), b = outputs.front();
1310   if (auto fa = a.dyn_cast<FloatType>())
1311     if (auto fb = b.dyn_cast<FloatType>())
1312       return fa.getWidth() < fb.getWidth();
1313   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // FPToSIOp
1318 //===----------------------------------------------------------------------===//
1319 
areCastCompatible(TypeRange inputs,TypeRange outputs)1320 bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1321   if (inputs.size() != 1 || outputs.size() != 1)
1322     return false;
1323   Type a = inputs.front(), b = outputs.front();
1324   if (a.isa<FloatType>() && b.isSignlessInteger())
1325     return true;
1326   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // FPToUIOp
1331 //===----------------------------------------------------------------------===//
1332 
areCastCompatible(TypeRange inputs,TypeRange outputs)1333 bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1334   if (inputs.size() != 1 || outputs.size() != 1)
1335     return false;
1336   Type a = inputs.front(), b = outputs.front();
1337   if (a.isa<FloatType>() && b.isSignlessInteger())
1338     return true;
1339   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1340 }
1341 
1342 //===----------------------------------------------------------------------===//
1343 // FPTruncOp
1344 //===----------------------------------------------------------------------===//
1345 
areCastCompatible(TypeRange inputs,TypeRange outputs)1346 bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1347   if (inputs.size() != 1 || outputs.size() != 1)
1348     return false;
1349   Type a = inputs.front(), b = outputs.front();
1350   if (auto fa = a.dyn_cast<FloatType>())
1351     if (auto fb = b.dyn_cast<FloatType>())
1352       return fa.getWidth() > fb.getWidth();
1353   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // IndexCastOp
1358 //===----------------------------------------------------------------------===//
1359 
1360 // Index cast is applicable from index to integer and backwards.
areCastCompatible(TypeRange inputs,TypeRange outputs)1361 bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1362   if (inputs.size() != 1 || outputs.size() != 1)
1363     return false;
1364   Type a = inputs.front(), b = outputs.front();
1365   if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
1366     auto aShaped = a.cast<ShapedType>();
1367     auto bShaped = b.cast<ShapedType>();
1368 
1369     return (aShaped.getShape() == bShaped.getShape()) &&
1370            areCastCompatible(aShaped.getElementType(),
1371                              bShaped.getElementType());
1372   }
1373 
1374   return (a.isIndex() && b.isSignlessInteger()) ||
1375          (a.isSignlessInteger() && b.isIndex());
1376 }
1377 
fold(ArrayRef<Attribute> cstOperands)1378 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1379   // Fold IndexCast(IndexCast(x)) -> x
1380   auto cast = getOperand().getDefiningOp<IndexCastOp>();
1381   if (cast && cast.getOperand().getType() == getType())
1382     return cast.getOperand();
1383 
1384   // Fold IndexCast(constant) -> constant
1385   // A little hack because we go through int.  Otherwise, the size
1386   // of the constant might need to change.
1387   if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1388     return IntegerAttr::get(getType(), value.getInt());
1389 
1390   return {};
1391 }
1392 
1393 namespace {
1394 ///  index_cast(sign_extend x) => index_cast(x)
1395 struct IndexCastOfSExt : public OpRewritePattern<IndexCastOp> {
1396   using OpRewritePattern<IndexCastOp>::OpRewritePattern;
1397 
matchAndRewrite__anon476530360a11::IndexCastOfSExt1398   LogicalResult matchAndRewrite(IndexCastOp op,
1399                                 PatternRewriter &rewriter) const override {
1400 
1401     if (auto extop = op.getOperand().getDefiningOp<SignExtendIOp>()) {
1402       op.setOperand(extop.getOperand());
1403       return success();
1404     }
1405     return failure();
1406   }
1407 };
1408 
1409 } // namespace
1410 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1411 void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1412                                               MLIRContext *context) {
1413   results.insert<IndexCastOfSExt>(context);
1414 }
1415 
1416 //===----------------------------------------------------------------------===//
1417 // MulFOp
1418 //===----------------------------------------------------------------------===//
1419 
fold(ArrayRef<Attribute> operands)1420 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1421   return constFoldBinaryOp<FloatAttr>(
1422       operands, [](APFloat a, APFloat b) { return a * b; });
1423 }
1424 
1425 //===----------------------------------------------------------------------===//
1426 // MulIOp
1427 //===----------------------------------------------------------------------===//
1428 
fold(ArrayRef<Attribute> operands)1429 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1430   /// muli(x, 0) -> 0
1431   if (matchPattern(rhs(), m_Zero()))
1432     return rhs();
1433   /// muli(x, 1) -> x
1434   if (matchPattern(rhs(), m_One()))
1435     return getOperand(0);
1436 
1437   // TODO: Handle the overflow case.
1438   return constFoldBinaryOp<IntegerAttr>(operands,
1439                                         [](APInt a, APInt b) { return a * b; });
1440 }
1441 
1442 //===----------------------------------------------------------------------===//
1443 // OrOp
1444 //===----------------------------------------------------------------------===//
1445 
fold(ArrayRef<Attribute> operands)1446 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1447   /// or(x, 0) -> x
1448   if (matchPattern(rhs(), m_Zero()))
1449     return lhs();
1450   /// or(x,x) -> x
1451   if (lhs() == rhs())
1452     return rhs();
1453 
1454   return constFoldBinaryOp<IntegerAttr>(operands,
1455                                         [](APInt a, APInt b) { return a | b; });
1456 }
1457 
1458 //===----------------------------------------------------------------------===//
1459 // RankOp
1460 //===----------------------------------------------------------------------===//
1461 
fold(ArrayRef<Attribute> operands)1462 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1463   // Constant fold rank when the rank of the operand is known.
1464   auto type = getOperand().getType();
1465   if (auto shapedType = type.dyn_cast<ShapedType>())
1466     if (shapedType.hasRank())
1467       return IntegerAttr::get(IndexType::get(getContext()),
1468                               shapedType.getRank());
1469   return IntegerAttr();
1470 }
1471 
1472 //===----------------------------------------------------------------------===//
1473 // ReturnOp
1474 //===----------------------------------------------------------------------===//
1475 
verify(ReturnOp op)1476 static LogicalResult verify(ReturnOp op) {
1477   auto function = cast<FuncOp>(op->getParentOp());
1478 
1479   // The operand number and types must match the function signature.
1480   const auto &results = function.getType().getResults();
1481   if (op.getNumOperands() != results.size())
1482     return op.emitOpError("has ")
1483            << op.getNumOperands() << " operands, but enclosing function (@"
1484            << function.getName() << ") returns " << results.size();
1485 
1486   for (unsigned i = 0, e = results.size(); i != e; ++i)
1487     if (op.getOperand(i).getType() != results[i])
1488       return op.emitError()
1489              << "type of return operand " << i << " ("
1490              << op.getOperand(i).getType()
1491              << ") doesn't match function result type (" << results[i] << ")"
1492              << " in function @" << function.getName();
1493 
1494   return success();
1495 }
1496 
1497 //===----------------------------------------------------------------------===//
1498 // SelectOp
1499 //===----------------------------------------------------------------------===//
1500 
1501 // Transforms a select to a not, where relevant.
1502 //
1503 //  select %arg, %false, %true
1504 //
1505 //  becomes
1506 //
1507 //  xor %arg, %true
1508 struct SelectToNot : public OpRewritePattern<SelectOp> {
1509   using OpRewritePattern<SelectOp>::OpRewritePattern;
1510 
matchAndRewriteSelectToNot1511   LogicalResult matchAndRewrite(SelectOp op,
1512                                 PatternRewriter &rewriter) const override {
1513     if (!matchPattern(op.getTrueValue(), m_Zero()))
1514       return failure();
1515 
1516     if (!matchPattern(op.getFalseValue(), m_One()))
1517       return failure();
1518 
1519     if (!op.getType().isInteger(1))
1520       return failure();
1521 
1522     rewriter.replaceOpWithNewOp<XOrOp>(op, op.condition(), op.getFalseValue());
1523     return success();
1524   }
1525 };
1526 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1527 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1528                                            MLIRContext *context) {
1529   results.insert<SelectToNot>(context);
1530 }
1531 
fold(ArrayRef<Attribute> operands)1532 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
1533   auto trueVal = getTrueValue();
1534   auto falseVal = getFalseValue();
1535   if (trueVal == falseVal)
1536     return trueVal;
1537 
1538   auto condition = getCondition();
1539 
1540   // select true, %0, %1 => %0
1541   if (matchPattern(condition, m_One()))
1542     return trueVal;
1543 
1544   // select false, %0, %1 => %1
1545   if (matchPattern(condition, m_Zero()))
1546     return falseVal;
1547 
1548   if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
1549     auto pred = cmp.predicate();
1550     if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
1551       auto cmpLhs = cmp.lhs();
1552       auto cmpRhs = cmp.rhs();
1553 
1554       // %0 = cmpi eq, %arg0, %arg1
1555       // %1 = select %0, %arg0, %arg1 => %arg1
1556 
1557       // %0 = cmpi ne, %arg0, %arg1
1558       // %1 = select %0, %arg0, %arg1 => %arg0
1559 
1560       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1561           (cmpRhs == trueVal && cmpLhs == falseVal))
1562         return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
1563     }
1564   }
1565   return nullptr;
1566 }
1567 
print(OpAsmPrinter & p,SelectOp op)1568 static void print(OpAsmPrinter &p, SelectOp op) {
1569   p << "select " << op.getOperands();
1570   p.printOptionalAttrDict(op->getAttrs());
1571   p << " : ";
1572   if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
1573     p << condType << ", ";
1574   p << op.getType();
1575 }
1576 
parseSelectOp(OpAsmParser & parser,OperationState & result)1577 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
1578   Type conditionType, resultType;
1579   SmallVector<OpAsmParser::OperandType, 3> operands;
1580   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1581       parser.parseOptionalAttrDict(result.attributes) ||
1582       parser.parseColonType(resultType))
1583     return failure();
1584 
1585   // Check for the explicit condition type if this is a masked tensor or vector.
1586   if (succeeded(parser.parseOptionalComma())) {
1587     conditionType = resultType;
1588     if (parser.parseType(resultType))
1589       return failure();
1590   } else {
1591     conditionType = parser.getBuilder().getI1Type();
1592   }
1593 
1594   result.addTypes(resultType);
1595   return parser.resolveOperands(operands,
1596                                 {conditionType, resultType, resultType},
1597                                 parser.getNameLoc(), result.operands);
1598 }
1599 
verify(SelectOp op)1600 static LogicalResult verify(SelectOp op) {
1601   Type conditionType = op.getCondition().getType();
1602   if (conditionType.isSignlessInteger(1))
1603     return success();
1604 
1605   // If the result type is a vector or tensor, the type can be a mask with the
1606   // same elements.
1607   Type resultType = op.getType();
1608   if (!resultType.isa<TensorType, VectorType>())
1609     return op.emitOpError()
1610            << "expected condition to be a signless i1, but got "
1611            << conditionType;
1612   Type shapedConditionType = getI1SameShape(resultType);
1613   if (conditionType != shapedConditionType)
1614     return op.emitOpError()
1615            << "expected condition type to have the same shape "
1616               "as the result type, expected "
1617            << shapedConditionType << ", but got " << conditionType;
1618   return success();
1619 }
1620 
1621 //===----------------------------------------------------------------------===//
1622 // SignExtendIOp
1623 //===----------------------------------------------------------------------===//
1624 
verify(SignExtendIOp op)1625 static LogicalResult verify(SignExtendIOp op) {
1626   // Get the scalar type (which is either directly the type of the operand
1627   // or the vector's/tensor's element type.
1628   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
1629   auto dstType = getElementTypeOrSelf(op.getType());
1630 
1631   // For now, index is forbidden for the source and the destination type.
1632   if (srcType.isa<IndexType>())
1633     return op.emitError() << srcType << " is not a valid operand type";
1634   if (dstType.isa<IndexType>())
1635     return op.emitError() << dstType << " is not a valid result type";
1636 
1637   if (srcType.cast<IntegerType>().getWidth() >=
1638       dstType.cast<IntegerType>().getWidth())
1639     return op.emitError("result type ")
1640            << dstType << " must be wider than operand type " << srcType;
1641 
1642   return success();
1643 }
1644 
fold(ArrayRef<Attribute> operands)1645 OpFoldResult SignExtendIOp::fold(ArrayRef<Attribute> operands) {
1646   assert(operands.size() == 1 && "unary operation takes one operand");
1647 
1648   if (!operands[0])
1649     return {};
1650 
1651   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
1652     return IntegerAttr::get(
1653         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
1654   }
1655 
1656   return {};
1657 }
1658 
1659 //===----------------------------------------------------------------------===//
1660 // SignedDivIOp
1661 //===----------------------------------------------------------------------===//
1662 
fold(ArrayRef<Attribute> operands)1663 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
1664   assert(operands.size() == 2 && "binary operation takes two operands");
1665 
1666   // Don't fold if it would overflow or if it requires a division by zero.
1667   bool overflowOrDiv0 = false;
1668   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1669     if (overflowOrDiv0 || !b) {
1670       overflowOrDiv0 = true;
1671       return a;
1672     }
1673     return a.sdiv_ov(b, overflowOrDiv0);
1674   });
1675 
1676   // Fold out division by one. Assumes all tensors of all ones are splats.
1677   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1678     if (rhs.getValue() == 1)
1679       return lhs();
1680   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1681     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1682       return lhs();
1683   }
1684 
1685   return overflowOrDiv0 ? Attribute() : result;
1686 }
1687 
1688 //===----------------------------------------------------------------------===//
1689 // SignedFloorDivIOp
1690 //===----------------------------------------------------------------------===//
1691 
signedCeilNonnegInputs(APInt a,APInt b,bool & overflow)1692 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
1693   // Returns (a-1)/b + 1
1694   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
1695   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
1696   return val.sadd_ov(one, overflow);
1697 }
1698 
fold(ArrayRef<Attribute> operands)1699 OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
1700   assert(operands.size() == 2 && "binary operation takes two operands");
1701 
1702   // Don't fold if it would overflow or if it requires a division by zero.
1703   bool overflowOrDiv0 = false;
1704   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1705     if (overflowOrDiv0 || !b) {
1706       overflowOrDiv0 = true;
1707       return a;
1708     }
1709     unsigned bits = a.getBitWidth();
1710     APInt zero = APInt::getNullValue(bits);
1711     if (a.sge(zero) && b.sgt(zero)) {
1712       // Both positive (or a is zero), return a / b.
1713       return a.sdiv_ov(b, overflowOrDiv0);
1714     } else if (a.sle(zero) && b.slt(zero)) {
1715       // Both negative (or a is zero), return -a / -b.
1716       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1717       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1718       return posA.sdiv_ov(posB, overflowOrDiv0);
1719     } else if (a.slt(zero) && b.sgt(zero)) {
1720       // A is negative, b is positive, return - ceil(-a, b).
1721       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1722       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
1723       return zero.ssub_ov(ceil, overflowOrDiv0);
1724     } else {
1725       // A is positive, b is negative, return - ceil(a, -b).
1726       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1727       APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
1728       return zero.ssub_ov(ceil, overflowOrDiv0);
1729     }
1730   });
1731 
1732   // Fold out floor division by one. Assumes all tensors of all ones are
1733   // splats.
1734   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1735     if (rhs.getValue() == 1)
1736       return lhs();
1737   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1738     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1739       return lhs();
1740   }
1741 
1742   return overflowOrDiv0 ? Attribute() : result;
1743 }
1744 
1745 //===----------------------------------------------------------------------===//
1746 // SignedCeilDivIOp
1747 //===----------------------------------------------------------------------===//
1748 
fold(ArrayRef<Attribute> operands)1749 OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
1750   assert(operands.size() == 2 && "binary operation takes two operands");
1751 
1752   // Don't fold if it would overflow or if it requires a division by zero.
1753   bool overflowOrDiv0 = false;
1754   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1755     if (overflowOrDiv0 || !b) {
1756       overflowOrDiv0 = true;
1757       return a;
1758     }
1759     unsigned bits = a.getBitWidth();
1760     APInt zero = APInt::getNullValue(bits);
1761     if (a.sgt(zero) && b.sgt(zero)) {
1762       // Both positive, return ceil(a, b).
1763       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
1764     } else if (a.slt(zero) && b.slt(zero)) {
1765       // Both negative, return ceil(-a, -b).
1766       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1767       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1768       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
1769     } else if (a.slt(zero) && b.sgt(zero)) {
1770       // A is negative, b is positive, return - ( -a / b).
1771       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1772       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
1773       return zero.ssub_ov(div, overflowOrDiv0);
1774     } else {
1775       // A is positive (or zero), b is negative, return - (a / -b).
1776       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1777       APInt div = a.sdiv_ov(posB, overflowOrDiv0);
1778       return zero.ssub_ov(div, overflowOrDiv0);
1779     }
1780   });
1781 
1782   // Fold out floor division by one. Assumes all tensors of all ones are
1783   // splats.
1784   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1785     if (rhs.getValue() == 1)
1786       return lhs();
1787   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1788     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1789       return lhs();
1790   }
1791 
1792   return overflowOrDiv0 ? Attribute() : result;
1793 }
1794 
1795 //===----------------------------------------------------------------------===//
1796 // SignedRemIOp
1797 //===----------------------------------------------------------------------===//
1798 
fold(ArrayRef<Attribute> operands)1799 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
1800   assert(operands.size() == 2 && "remi_signed takes two operands");
1801 
1802   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1803   if (!rhs)
1804     return {};
1805   auto rhsValue = rhs.getValue();
1806 
1807   // x % 1 = 0
1808   if (rhsValue.isOneValue())
1809     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
1810 
1811   // Don't fold if it requires division by zero.
1812   if (rhsValue.isNullValue())
1813     return {};
1814 
1815   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1816   if (!lhs)
1817     return {};
1818   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
1819 }
1820 
1821 //===----------------------------------------------------------------------===//
1822 // SIToFPOp
1823 //===----------------------------------------------------------------------===//
1824 
1825 // sitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)1826 bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1827   if (inputs.size() != 1 || outputs.size() != 1)
1828     return false;
1829   Type a = inputs.front(), b = outputs.front();
1830   if (a.isSignlessInteger() && b.isa<FloatType>())
1831     return true;
1832   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1833 }
1834 
1835 //===----------------------------------------------------------------------===//
1836 // SplatOp
1837 //===----------------------------------------------------------------------===//
1838 
verify(SplatOp op)1839 static LogicalResult verify(SplatOp op) {
1840   // TODO: we could replace this by a trait.
1841   if (op.getOperand().getType() !=
1842       op.getType().cast<ShapedType>().getElementType())
1843     return op.emitError("operand should be of elemental type of result type");
1844 
1845   return success();
1846 }
1847 
1848 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)1849 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
1850   assert(operands.size() == 1 && "splat takes one operand");
1851 
1852   auto constOperand = operands.front();
1853   if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
1854     return {};
1855 
1856   auto shapedType = getType().cast<ShapedType>();
1857   assert(shapedType.getElementType() == constOperand.getType() &&
1858          "incorrect input attribute type for folding");
1859 
1860   // SplatElementsAttr::get treats single value for second arg as being a splat.
1861   return SplatElementsAttr::get(shapedType, {constOperand});
1862 }
1863 
1864 //===----------------------------------------------------------------------===//
1865 // SubFOp
1866 //===----------------------------------------------------------------------===//
1867 
fold(ArrayRef<Attribute> operands)1868 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
1869   return constFoldBinaryOp<FloatAttr>(
1870       operands, [](APFloat a, APFloat b) { return a - b; });
1871 }
1872 
1873 //===----------------------------------------------------------------------===//
1874 // SubIOp
1875 //===----------------------------------------------------------------------===//
1876 
fold(ArrayRef<Attribute> operands)1877 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
1878   // subi(x,x) -> 0
1879   if (getOperand(0) == getOperand(1))
1880     return Builder(getContext()).getZeroAttr(getType());
1881   // subi(x,0) -> x
1882   if (matchPattern(rhs(), m_Zero()))
1883     return lhs();
1884 
1885   return constFoldBinaryOp<IntegerAttr>(operands,
1886                                         [](APInt a, APInt b) { return a - b; });
1887 }
1888 
1889 /// Canonicalize a sub of a constant and (constant +/- something) to simply be
1890 /// a single operation that merges the two constants.
1891 struct SubConstantReorder : public OpRewritePattern<SubIOp> {
1892   using OpRewritePattern<SubIOp>::OpRewritePattern;
1893 
matchAndRewriteSubConstantReorder1894   LogicalResult matchAndRewrite(SubIOp subOp,
1895                                 PatternRewriter &rewriter) const override {
1896     APInt origConst;
1897     APInt midConst;
1898 
1899     if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
1900       if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
1901         // origConst - (midConst + something) == (origConst - midConst) -
1902         // something
1903         for (int j = 0; j < 2; j++) {
1904           if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
1905             auto nextConstant = rewriter.create<ConstantOp>(
1906                 subOp.getLoc(),
1907                 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1908             rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1909                                                 midAddOp.getOperand(1 - j));
1910             return success();
1911           }
1912         }
1913       }
1914 
1915       if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
1916         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1917           // (midConst - something) - origConst == (midConst - origConst) -
1918           // something
1919           auto nextConstant = rewriter.create<ConstantOp>(
1920               subOp.getLoc(),
1921               rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1922           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1923                                               midSubOp.getOperand(1));
1924           return success();
1925         }
1926 
1927         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1928           // (something - midConst) - origConst == something - (origConst +
1929           // midConst)
1930           auto nextConstant = rewriter.create<ConstantOp>(
1931               subOp.getLoc(),
1932               rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
1933           rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
1934                                               nextConstant);
1935           return success();
1936         }
1937       }
1938 
1939       if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
1940         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1941           // origConst - (midConst - something) == (origConst - midConst) +
1942           // something
1943           auto nextConstant = rewriter.create<ConstantOp>(
1944               subOp.getLoc(),
1945               rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1946           rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1947                                               midSubOp.getOperand(1));
1948           return success();
1949         }
1950 
1951         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1952           // origConst - (something - midConst) == (origConst + midConst) -
1953           // something
1954           auto nextConstant = rewriter.create<ConstantOp>(
1955               subOp.getLoc(),
1956               rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
1957           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1958                                               midSubOp.getOperand(0));
1959           return success();
1960         }
1961       }
1962     }
1963 
1964     if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
1965       if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
1966         // (midConst + something) - origConst == (midConst - origConst) +
1967         // something
1968         for (int j = 0; j < 2; j++) {
1969           if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
1970             auto nextConstant = rewriter.create<ConstantOp>(
1971                 subOp.getLoc(),
1972                 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1973             rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1974                                                 midAddOp.getOperand(1 - j));
1975             return success();
1976           }
1977         }
1978       }
1979 
1980       if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
1981         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1982           // (midConst - something) - origConst == (midConst - origConst) -
1983           // something
1984           auto nextConstant = rewriter.create<ConstantOp>(
1985               subOp.getLoc(),
1986               rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1987           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1988                                               midSubOp.getOperand(1));
1989           return success();
1990         }
1991 
1992         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1993           // (something - midConst) - origConst == something - (midConst +
1994           // origConst)
1995           auto nextConstant = rewriter.create<ConstantOp>(
1996               subOp.getLoc(),
1997               rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
1998           rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
1999                                               nextConstant);
2000           return success();
2001         }
2002       }
2003 
2004       if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
2005         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2006           // origConst - (midConst - something) == (origConst - midConst) +
2007           // something
2008           auto nextConstant = rewriter.create<ConstantOp>(
2009               subOp.getLoc(),
2010               rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2011           rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2012                                               midSubOp.getOperand(1));
2013           return success();
2014         }
2015         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2016           // origConst - (something - midConst) == (origConst - midConst) -
2017           // something
2018           auto nextConstant = rewriter.create<ConstantOp>(
2019               subOp.getLoc(),
2020               rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2021           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2022                                               midSubOp.getOperand(0));
2023           return success();
2024         }
2025       }
2026     }
2027     return failure();
2028   }
2029 };
2030 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2031 void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2032                                          MLIRContext *context) {
2033   results.insert<SubConstantReorder>(context);
2034 }
2035 
2036 //===----------------------------------------------------------------------===//
2037 // UIToFPOp
2038 //===----------------------------------------------------------------------===//
2039 
2040 // uitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)2041 bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2042   if (inputs.size() != 1 || outputs.size() != 1)
2043     return false;
2044   Type a = inputs.front(), b = outputs.front();
2045   if (a.isSignlessInteger() && b.isa<FloatType>())
2046     return true;
2047   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2048 }
2049 
2050 //===----------------------------------------------------------------------===//
2051 // SwitchOp
2052 //===----------------------------------------------------------------------===//
2053 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2054 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2055                      Block *defaultDestination, ValueRange defaultOperands,
2056                      DenseIntElementsAttr caseValues,
2057                      BlockRange caseDestinations,
2058                      ArrayRef<ValueRange> caseOperands) {
2059   SmallVector<Value> flattenedCaseOperands;
2060   SmallVector<int32_t> caseOperandOffsets;
2061   int32_t offset = 0;
2062   for (ValueRange operands : caseOperands) {
2063     flattenedCaseOperands.append(operands.begin(), operands.end());
2064     caseOperandOffsets.push_back(offset);
2065     offset += operands.size();
2066   }
2067   DenseIntElementsAttr caseOperandOffsetsAttr;
2068   if (!caseOperandOffsets.empty())
2069     caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
2070 
2071   build(builder, result, value, defaultOperands, flattenedCaseOperands,
2072         caseValues, caseOperandOffsetsAttr, defaultDestination,
2073         caseDestinations);
2074 }
2075 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2076 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2077                      Block *defaultDestination, ValueRange defaultOperands,
2078                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
2079                      ArrayRef<ValueRange> caseOperands) {
2080   DenseIntElementsAttr caseValuesAttr;
2081   if (!caseValues.empty()) {
2082     ShapedType caseValueType = VectorType::get(
2083         static_cast<int64_t>(caseValues.size()), value.getType());
2084     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
2085   }
2086   build(builder, result, value, defaultDestination, defaultOperands,
2087         caseValuesAttr, caseDestinations, caseOperands);
2088 }
2089 
2090 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
2091 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
2092 static ParseResult
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::OperandType> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<OpAsmParser::OperandType> & caseOperands,SmallVectorImpl<Type> & caseOperandTypes,DenseIntElementsAttr & caseOperandOffsets)2093 parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
2094                    Block *&defaultDestination,
2095                    SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
2096                    SmallVectorImpl<Type> &defaultOperandTypes,
2097                    DenseIntElementsAttr &caseValues,
2098                    SmallVectorImpl<Block *> &caseDestinations,
2099                    SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
2100                    SmallVectorImpl<Type> &caseOperandTypes,
2101                    DenseIntElementsAttr &caseOperandOffsets) {
2102   if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
2103       failed(parser.parseSuccessor(defaultDestination)))
2104     return failure();
2105   if (succeeded(parser.parseOptionalLParen())) {
2106     if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
2107         failed(parser.parseColonTypeList(defaultOperandTypes)) ||
2108         failed(parser.parseRParen()))
2109       return failure();
2110   }
2111 
2112   SmallVector<APInt> values;
2113   SmallVector<int32_t> offsets;
2114   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
2115   int64_t offset = 0;
2116   while (succeeded(parser.parseOptionalComma())) {
2117     int64_t value = 0;
2118     if (failed(parser.parseInteger(value)))
2119       return failure();
2120     values.push_back(APInt(bitWidth, value));
2121 
2122     Block *destination;
2123     SmallVector<OpAsmParser::OperandType> operands;
2124     if (failed(parser.parseColon()) ||
2125         failed(parser.parseSuccessor(destination)))
2126       return failure();
2127     if (succeeded(parser.parseOptionalLParen())) {
2128       if (failed(parser.parseRegionArgumentList(operands)) ||
2129           failed(parser.parseColonTypeList(caseOperandTypes)) ||
2130           failed(parser.parseRParen()))
2131         return failure();
2132     }
2133     caseDestinations.push_back(destination);
2134     caseOperands.append(operands.begin(), operands.end());
2135     offsets.push_back(offset);
2136     offset += operands.size();
2137   }
2138 
2139   if (values.empty())
2140     return success();
2141 
2142   Builder &builder = parser.getBuilder();
2143   ShapedType caseValueType =
2144       VectorType::get(static_cast<int64_t>(values.size()), flagType);
2145   caseValues = DenseIntElementsAttr::get(caseValueType, values);
2146   caseOperandOffsets = builder.getI32VectorAttr(offsets);
2147 
2148   return success();
2149 }
2150 
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRange caseOperands,TypeRange caseOperandTypes,ElementsAttr caseOperandOffsets)2151 static void printSwitchOpCases(
2152     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
2153     OperandRange defaultOperands, TypeRange defaultOperandTypes,
2154     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
2155     OperandRange caseOperands, TypeRange caseOperandTypes,
2156     ElementsAttr caseOperandOffsets) {
2157   p << "  default: ";
2158   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
2159 
2160   if (!caseValues)
2161     return;
2162 
2163   for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
2164     p << ',';
2165     p.printNewline();
2166     p << "  ";
2167     p << caseValues.getValue<APInt>(i).getLimitedValue();
2168     p << ": ";
2169     p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
2170   }
2171   p.printNewline();
2172 }
2173 
verify(SwitchOp op)2174 static LogicalResult verify(SwitchOp op) {
2175   auto caseValues = op.case_values();
2176   auto caseDestinations = op.caseDestinations();
2177 
2178   if (!caseValues && caseDestinations.empty())
2179     return success();
2180 
2181   Type flagType = op.flag().getType();
2182   Type caseValueType = caseValues->getType().getElementType();
2183   if (caseValueType != flagType)
2184     return op.emitOpError()
2185            << "'flag' type (" << flagType << ") should match case value type ("
2186            << caseValueType << ")";
2187 
2188   if (caseValues &&
2189       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
2190     return op.emitOpError() << "number of case values (" << caseValues->size()
2191                             << ") should match number of "
2192                                "case destinations ("
2193                             << caseDestinations.size() << ")";
2194   return success();
2195 }
2196 
getCaseOperands(unsigned index)2197 OperandRange SwitchOp::getCaseOperands(unsigned index) {
2198   return getCaseOperandsMutable(index);
2199 }
2200 
getCaseOperandsMutable(unsigned index)2201 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
2202   MutableOperandRange caseOperands = caseOperandsMutable();
2203   if (!case_operand_offsets()) {
2204     assert(caseOperands.size() == 0 &&
2205            "non-empty case operands must have offsets");
2206     return caseOperands;
2207   }
2208 
2209   ElementsAttr offsets = case_operand_offsets().getValue();
2210   assert(index < offsets.size() && "invalid case operand offset index");
2211 
2212   int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
2213   int64_t end = index + 1 == offsets.size()
2214                     ? caseOperands.size()
2215                     : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
2216   return caseOperandsMutable().slice(begin, end - begin);
2217 }
2218 
2219 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)2220 SwitchOp::getMutableSuccessorOperands(unsigned index) {
2221   assert(index < getNumSuccessors() && "invalid successor index");
2222   return index == 0 ? defaultOperandsMutable()
2223                     : getCaseOperandsMutable(index - 1);
2224 }
2225 
getSuccessorForOperands(ArrayRef<Attribute> operands)2226 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
2227   Optional<DenseIntElementsAttr> caseValues = case_values();
2228 
2229   if (!caseValues)
2230     return defaultDestination();
2231 
2232   SuccessorRange caseDests = caseDestinations();
2233   if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
2234     for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
2235       if (value == caseValues->getValue<IntegerAttr>(i))
2236         return caseDests[i];
2237     return defaultDestination();
2238   }
2239   return nullptr;
2240 }
2241 
2242 /// switch %flag : i32, [
2243 ///   default:  ^bb1
2244 /// ]
2245 ///  -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)2246 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
2247                                                    PatternRewriter &rewriter) {
2248   if (!op.caseDestinations().empty())
2249     return failure();
2250 
2251   rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2252                                         op.defaultOperands());
2253   return success();
2254 }
2255 
2256 /// switch %flag : i32, [
2257 ///   default: ^bb1,
2258 ///   42: ^bb1,
2259 ///   43: ^bb2
2260 /// ]
2261 /// ->
2262 /// switch %flag : i32, [
2263 ///   default: ^bb1,
2264 ///   43: ^bb2
2265 /// ]
2266 static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)2267 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
2268   SmallVector<Block *> newCaseDestinations;
2269   SmallVector<ValueRange> newCaseOperands;
2270   SmallVector<APInt> newCaseValues;
2271   bool requiresChange = false;
2272   auto caseValues = op.case_values();
2273   auto caseDests = op.caseDestinations();
2274 
2275   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2276     if (caseDests[i] == op.defaultDestination() &&
2277         op.getCaseOperands(i) == op.defaultOperands()) {
2278       requiresChange = true;
2279       continue;
2280     }
2281     newCaseDestinations.push_back(caseDests[i]);
2282     newCaseOperands.push_back(op.getCaseOperands(i));
2283     newCaseValues.push_back(caseValues->getValue<APInt>(i));
2284   }
2285 
2286   if (!requiresChange)
2287     return failure();
2288 
2289   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2290                                         op.defaultOperands(), newCaseValues,
2291                                         newCaseDestinations, newCaseOperands);
2292   return success();
2293 }
2294 
2295 /// Helper for folding a switch with a constant value.
2296 /// switch %c_42 : i32, [
2297 ///   default: ^bb1 ,
2298 ///   42: ^bb2,
2299 ///   43: ^bb3
2300 /// ]
2301 /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,APInt caseValue)2302 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
2303                        APInt caseValue) {
2304   auto caseValues = op.case_values();
2305   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2306     if (caseValues->getValue<APInt>(i) == caseValue) {
2307       rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
2308                                             op.getCaseOperands(i));
2309       return;
2310     }
2311   }
2312   rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2313                                         op.defaultOperands());
2314 }
2315 
2316 /// switch %c_42 : i32, [
2317 ///   default: ^bb1,
2318 ///   42: ^bb2,
2319 ///   43: ^bb3
2320 /// ]
2321 /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)2322 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
2323                                               PatternRewriter &rewriter) {
2324   APInt caseValue;
2325   if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
2326     return failure();
2327 
2328   foldSwitch(op, rewriter, caseValue);
2329   return success();
2330 }
2331 
2332 /// switch %c_42 : i32, [
2333 ///   default: ^bb1,
2334 ///   42: ^bb2,
2335 /// ]
2336 /// ^bb2:
2337 ///   br ^bb3
2338 /// ->
2339 /// switch %c_42 : i32, [
2340 ///   default: ^bb1,
2341 ///   42: ^bb3,
2342 /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)2343 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
2344                                                PatternRewriter &rewriter) {
2345   SmallVector<Block *> newCaseDests;
2346   SmallVector<ValueRange> newCaseOperands;
2347   SmallVector<SmallVector<Value>> argStorage;
2348   auto caseValues = op.case_values();
2349   auto caseDests = op.caseDestinations();
2350   bool requiresChange = false;
2351   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2352     Block *caseDest = caseDests[i];
2353     ValueRange caseOperands = op.getCaseOperands(i);
2354     argStorage.emplace_back();
2355     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
2356       requiresChange = true;
2357 
2358     newCaseDests.push_back(caseDest);
2359     newCaseOperands.push_back(caseOperands);
2360   }
2361 
2362   Block *defaultDest = op.defaultDestination();
2363   ValueRange defaultOperands = op.defaultOperands();
2364   argStorage.emplace_back();
2365 
2366   if (succeeded(
2367           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
2368     requiresChange = true;
2369 
2370   if (!requiresChange)
2371     return failure();
2372 
2373   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
2374                                         defaultOperands, caseValues.getValue(),
2375                                         newCaseDests, newCaseOperands);
2376   return success();
2377 }
2378 
2379 /// switch %flag : i32, [
2380 ///   default: ^bb1,
2381 ///   42: ^bb2,
2382 /// ]
2383 /// ^bb2:
2384 ///   switch %flag : i32, [
2385 ///     default: ^bb3,
2386 ///     42: ^bb4
2387 ///   ]
2388 /// ->
2389 /// switch %flag : i32, [
2390 ///   default: ^bb1,
2391 ///   42: ^bb2,
2392 /// ]
2393 /// ^bb2:
2394 ///   br ^bb4
2395 ///
2396 ///  and
2397 ///
2398 /// switch %flag : i32, [
2399 ///   default: ^bb1,
2400 ///   42: ^bb2,
2401 /// ]
2402 /// ^bb2:
2403 ///   switch %flag : i32, [
2404 ///     default: ^bb3,
2405 ///     43: ^bb4
2406 ///   ]
2407 /// ->
2408 /// switch %flag : i32, [
2409 ///   default: ^bb1,
2410 ///   42: ^bb2,
2411 /// ]
2412 /// ^bb2:
2413 ///   br ^bb3
2414 static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2415 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
2416                                         PatternRewriter &rewriter) {
2417   // Check that we have a single distinct predecessor.
2418   Block *currentBlock = op->getBlock();
2419   Block *predecessor = currentBlock->getSinglePredecessor();
2420   if (!predecessor)
2421     return failure();
2422 
2423   // Check that the predecessor terminates with a switch branch to this block
2424   // and that it branches on the same condition and that this branch isn't the
2425   // default destination.
2426   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2427   if (!predSwitch || op.flag() != predSwitch.flag() ||
2428       predSwitch.defaultDestination() == currentBlock)
2429     return failure();
2430 
2431   // Fold this switch to an unconditional branch.
2432   APInt caseValue;
2433   bool isDefault = true;
2434   SuccessorRange predDests = predSwitch.caseDestinations();
2435   Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
2436   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
2437     if (currentBlock == predDests[i]) {
2438       caseValue = predCaseValues->getValue<APInt>(i);
2439       isDefault = false;
2440       break;
2441     }
2442   }
2443   if (isDefault)
2444     rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2445                                           op.defaultOperands());
2446   else
2447     foldSwitch(op, rewriter, caseValue);
2448   return success();
2449 }
2450 
2451 /// switch %flag : i32, [
2452 ///   default: ^bb1,
2453 ///   42: ^bb2
2454 /// ]
2455 /// ^bb1:
2456 ///   switch %flag : i32, [
2457 ///     default: ^bb3,
2458 ///     42: ^bb4,
2459 ///     43: ^bb5
2460 ///   ]
2461 /// ->
2462 /// switch %flag : i32, [
2463 ///   default: ^bb1,
2464 ///   42: ^bb2,
2465 /// ]
2466 /// ^bb1:
2467 ///   switch %flag : i32, [
2468 ///     default: ^bb3,
2469 ///     43: ^bb5
2470 ///   ]
2471 static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2472 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
2473                                                PatternRewriter &rewriter) {
2474   // Check that we have a single distinct predecessor.
2475   Block *currentBlock = op->getBlock();
2476   Block *predecessor = currentBlock->getSinglePredecessor();
2477   if (!predecessor)
2478     return failure();
2479 
2480   // Check that the predecessor terminates with a switch branch to this block
2481   // and that it branches on the same condition and that this branch is the
2482   // default destination.
2483   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2484   if (!predSwitch || op.flag() != predSwitch.flag() ||
2485       predSwitch.defaultDestination() != currentBlock)
2486     return failure();
2487 
2488   // Delete case values that are not possible here.
2489   DenseSet<APInt> caseValuesToRemove;
2490   auto predDests = predSwitch.caseDestinations();
2491   auto predCaseValues = predSwitch.case_values();
2492   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
2493     if (currentBlock != predDests[i])
2494       caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
2495 
2496   SmallVector<Block *> newCaseDestinations;
2497   SmallVector<ValueRange> newCaseOperands;
2498   SmallVector<APInt> newCaseValues;
2499   bool requiresChange = false;
2500 
2501   auto caseValues = op.case_values();
2502   auto caseDests = op.caseDestinations();
2503   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2504     if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
2505       requiresChange = true;
2506       continue;
2507     }
2508     newCaseDestinations.push_back(caseDests[i]);
2509     newCaseOperands.push_back(op.getCaseOperands(i));
2510     newCaseValues.push_back(caseValues->getValue<APInt>(i));
2511   }
2512 
2513   if (!requiresChange)
2514     return failure();
2515 
2516   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2517                                         op.defaultOperands(), newCaseValues,
2518                                         newCaseDestinations, newCaseOperands);
2519   return success();
2520 }
2521 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2522 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
2523                                            MLIRContext *context) {
2524   results.add(&simplifySwitchWithOnlyDefault)
2525       .add(&dropSwitchCasesThatMatchDefault)
2526       .add(&simplifyConstSwitchValue)
2527       .add(&simplifyPassThroughSwitch)
2528       .add(&simplifySwitchFromSwitchOnSameCondition)
2529       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
2530 }
2531 
2532 //===----------------------------------------------------------------------===//
2533 // TruncateIOp
2534 //===----------------------------------------------------------------------===//
2535 
verify(TruncateIOp op)2536 static LogicalResult verify(TruncateIOp op) {
2537   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2538   auto dstType = getElementTypeOrSelf(op.getType());
2539 
2540   if (srcType.isa<IndexType>())
2541     return op.emitError() << srcType << " is not a valid operand type";
2542   if (dstType.isa<IndexType>())
2543     return op.emitError() << dstType << " is not a valid result type";
2544 
2545   if (srcType.cast<IntegerType>().getWidth() <=
2546       dstType.cast<IntegerType>().getWidth())
2547     return op.emitError("operand type ")
2548            << srcType << " must be wider than result type " << dstType;
2549 
2550   return success();
2551 }
2552 
fold(ArrayRef<Attribute> operands)2553 OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
2554   // trunci(zexti(a)) -> a
2555   // trunci(sexti(a)) -> a
2556   if (matchPattern(getOperand(), m_Op<ZeroExtendIOp>()) ||
2557       matchPattern(getOperand(), m_Op<SignExtendIOp>()))
2558     return getOperand().getDefiningOp()->getOperand(0);
2559 
2560   assert(operands.size() == 1 && "unary operation takes one operand");
2561 
2562   if (!operands[0])
2563     return {};
2564 
2565   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
2566 
2567     return IntegerAttr::get(
2568         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
2569   }
2570 
2571   return {};
2572 }
2573 
2574 //===----------------------------------------------------------------------===//
2575 // UnsignedDivIOp
2576 //===----------------------------------------------------------------------===//
2577 
fold(ArrayRef<Attribute> operands)2578 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
2579   assert(operands.size() == 2 && "binary operation takes two operands");
2580 
2581   // Don't fold if it would require a division by zero.
2582   bool div0 = false;
2583   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2584     if (div0 || !b) {
2585       div0 = true;
2586       return a;
2587     }
2588     return a.udiv(b);
2589   });
2590 
2591   // Fold out division by one. Assumes all tensors of all ones are splats.
2592   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2593     if (rhs.getValue() == 1)
2594       return lhs();
2595   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2596     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2597       return lhs();
2598   }
2599 
2600   return div0 ? Attribute() : result;
2601 }
2602 
2603 //===----------------------------------------------------------------------===//
2604 // UnsignedRemIOp
2605 //===----------------------------------------------------------------------===//
2606 
fold(ArrayRef<Attribute> operands)2607 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
2608   assert(operands.size() == 2 && "remi_unsigned takes two operands");
2609 
2610   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2611   if (!rhs)
2612     return {};
2613   auto rhsValue = rhs.getValue();
2614 
2615   // x % 1 = 0
2616   if (rhsValue.isOneValue())
2617     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2618 
2619   // Don't fold if it requires division by zero.
2620   if (rhsValue.isNullValue())
2621     return {};
2622 
2623   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2624   if (!lhs)
2625     return {};
2626   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
2627 }
2628 
2629 //===----------------------------------------------------------------------===//
2630 // XOrOp
2631 //===----------------------------------------------------------------------===//
2632 
fold(ArrayRef<Attribute> operands)2633 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
2634   /// xor(x, 0) -> x
2635   if (matchPattern(rhs(), m_Zero()))
2636     return lhs();
2637   /// xor(x,x) -> 0
2638   if (lhs() == rhs())
2639     return Builder(getContext()).getZeroAttr(getType());
2640 
2641   return constFoldBinaryOp<IntegerAttr>(operands,
2642                                         [](APInt a, APInt b) { return a ^ b; });
2643 }
2644 
2645 namespace {
2646 /// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
2647 /// cmp ne A, B. Note that a logical not is implemented as xor 1, val.
2648 struct NotICmp : public OpRewritePattern<XOrOp> {
2649   using OpRewritePattern<XOrOp>::OpRewritePattern;
2650 
matchAndRewrite__anon476530361511::NotICmp2651   LogicalResult matchAndRewrite(XOrOp op,
2652                                 PatternRewriter &rewriter) const override {
2653     // Commutative ops (such as xor) have the constant appear second, which
2654     // we assume here.
2655 
2656     APInt constValue;
2657     if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
2658       return failure();
2659 
2660     if (constValue != 1)
2661       return failure();
2662 
2663     auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
2664     if (!prev)
2665       return failure();
2666 
2667     switch (prev.predicate()) {
2668     case CmpIPredicate::eq:
2669       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
2670                                           prev.rhs());
2671       return success();
2672     case CmpIPredicate::ne:
2673       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
2674                                           prev.rhs());
2675       return success();
2676 
2677     case CmpIPredicate::slt:
2678       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
2679                                           prev.rhs());
2680       return success();
2681     case CmpIPredicate::sle:
2682       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
2683                                           prev.rhs());
2684       return success();
2685     case CmpIPredicate::sgt:
2686       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
2687                                           prev.rhs());
2688       return success();
2689     case CmpIPredicate::sge:
2690       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
2691                                           prev.rhs());
2692       return success();
2693 
2694     case CmpIPredicate::ult:
2695       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
2696                                           prev.rhs());
2697       return success();
2698     case CmpIPredicate::ule:
2699       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
2700                                           prev.rhs());
2701       return success();
2702     case CmpIPredicate::ugt:
2703       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
2704                                           prev.rhs());
2705       return success();
2706     case CmpIPredicate::uge:
2707       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
2708                                           prev.rhs());
2709       return success();
2710     }
2711     return failure();
2712   }
2713 };
2714 } // namespace
2715 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2716 void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2717                                         MLIRContext *context) {
2718   results.insert<NotICmp>(context);
2719 }
2720 
2721 //===----------------------------------------------------------------------===//
2722 // ZeroExtendIOp
2723 //===----------------------------------------------------------------------===//
2724 
verify(ZeroExtendIOp op)2725 static LogicalResult verify(ZeroExtendIOp op) {
2726   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2727   auto dstType = getElementTypeOrSelf(op.getType());
2728 
2729   if (srcType.isa<IndexType>())
2730     return op.emitError() << srcType << " is not a valid operand type";
2731   if (dstType.isa<IndexType>())
2732     return op.emitError() << dstType << " is not a valid result type";
2733 
2734   if (srcType.cast<IntegerType>().getWidth() >=
2735       dstType.cast<IntegerType>().getWidth())
2736     return op.emitError("result type ")
2737            << dstType << " must be wider than operand type " << srcType;
2738 
2739   return success();
2740 }
2741 
2742 //===----------------------------------------------------------------------===//
2743 // TableGen'd op method definitions
2744 //===----------------------------------------------------------------------===//
2745 
2746 #define GET_OP_CLASSES
2747 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
2748