1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10 
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32 
33 #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
34 
35 // Pull in all enum type definitions and utility function declarations.
36 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
37 
38 using namespace mlir;
39 
40 //===----------------------------------------------------------------------===//
41 // StandardOpsDialect Interfaces
42 //===----------------------------------------------------------------------===//
43 namespace {
44 /// This class defines the interface for handling inlining with standard
45 /// operations.
46 struct StdInlinerInterface : public DialectInlinerInterface {
47   using DialectInlinerInterface::DialectInlinerInterface;
48 
49   //===--------------------------------------------------------------------===//
50   // Analysis Hooks
51   //===--------------------------------------------------------------------===//
52 
53   /// All call operations within standard ops can be inlined.
isLegalToInline__anon1787a6f50111::StdInlinerInterface54   bool isLegalToInline(Operation *call, Operation *callable,
55                        bool wouldBeCloned) const final {
56     return true;
57   }
58 
59   /// All operations within standard ops can be inlined.
isLegalToInline__anon1787a6f50111::StdInlinerInterface60   bool isLegalToInline(Operation *, Region *, bool,
61                        BlockAndValueMapping &) const final {
62     return true;
63   }
64 
65   //===--------------------------------------------------------------------===//
66   // Transformation Hooks
67   //===--------------------------------------------------------------------===//
68 
69   /// Handle the given inlined terminator by replacing it with a new operation
70   /// as necessary.
handleTerminator__anon1787a6f50111::StdInlinerInterface71   void handleTerminator(Operation *op, Block *newDest) const final {
72     // Only "std.return" needs to be handled here.
73     auto returnOp = dyn_cast<ReturnOp>(op);
74     if (!returnOp)
75       return;
76 
77     // Replace the return with a branch to the dest.
78     OpBuilder builder(op);
79     builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
80     op->erase();
81   }
82 
83   /// Handle the given inlined terminator by replacing it with a new operation
84   /// as necessary.
handleTerminator__anon1787a6f50111::StdInlinerInterface85   void handleTerminator(Operation *op,
86                         ArrayRef<Value> valuesToRepl) const final {
87     // Only "std.return" needs to be handled here.
88     auto returnOp = cast<ReturnOp>(op);
89 
90     // Replace the values directly with the return operands.
91     assert(returnOp.getNumOperands() == valuesToRepl.size());
92     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
93       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
94   }
95 };
96 } // end anonymous namespace
97 
98 //===----------------------------------------------------------------------===//
99 // StandardOpsDialect
100 //===----------------------------------------------------------------------===//
101 
102 /// A custom unary operation printer that omits the "std." prefix from the
103 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)104 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
105   assert(op->getNumOperands() == 1 && "unary op should have one operand");
106   assert(op->getNumResults() == 1 && "unary op should have one result");
107 
108   p << ' ' << op->getOperand(0);
109   p.printOptionalAttrDict(op->getAttrs());
110   p << " : " << op->getOperand(0).getType();
111 }
112 
113 /// A custom binary operation printer that omits the "std." prefix from the
114 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)115 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
116   assert(op->getNumOperands() == 2 && "binary op should have two operands");
117   assert(op->getNumResults() == 1 && "binary op should have one result");
118 
119   // If not all the operand and result types are the same, just use the
120   // generic assembly form to avoid omitting information in printing.
121   auto resultType = op->getResult(0).getType();
122   if (op->getOperand(0).getType() != resultType ||
123       op->getOperand(1).getType() != resultType) {
124     p.printGenericOp(op);
125     return;
126   }
127 
128   p << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
129   p.printOptionalAttrDict(op->getAttrs());
130 
131   // Now we can output only one type for all operands and the result.
132   p << " : " << op->getResult(0).getType();
133 }
134 
135 /// A custom ternary operation printer that omits the "std." prefix from the
136 /// operation names.
printStandardTernaryOp(Operation * op,OpAsmPrinter & p)137 static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
138   assert(op->getNumOperands() == 3 && "ternary op should have three operands");
139   assert(op->getNumResults() == 1 && "ternary op should have one result");
140 
141   // If not all the operand and result types are the same, just use the
142   // generic assembly form to avoid omitting information in printing.
143   auto resultType = op->getResult(0).getType();
144   if (op->getOperand(0).getType() != resultType ||
145       op->getOperand(1).getType() != resultType ||
146       op->getOperand(2).getType() != resultType) {
147     p.printGenericOp(op);
148     return;
149   }
150 
151   p << ' ' << op->getOperand(0) << ", " << op->getOperand(1) << ", "
152     << op->getOperand(2);
153   p.printOptionalAttrDict(op->getAttrs());
154 
155   // Now we can output only one type for all operands and the result.
156   p << " : " << op->getResult(0).getType();
157 }
158 
159 /// A custom cast operation printer that omits the "std." prefix from the
160 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)161 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
162   p << ' ' << op->getOperand(0) << " : " << op->getOperand(0).getType()
163     << " to " << op->getResult(0).getType();
164 }
165 
initialize()166 void StandardOpsDialect::initialize() {
167   addOperations<
168 #define GET_OP_LIST
169 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
170       >();
171   addInterfaces<StdInlinerInterface>();
172 }
173 
174 /// Materialize a single constant operation from a given attribute value with
175 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)176 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
177                                                    Attribute value, Type type,
178                                                    Location loc) {
179   return builder.create<ConstantOp>(loc, type, value);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Common cast compatibility check for vector types.
184 //===----------------------------------------------------------------------===//
185 
186 /// This method checks for cast compatibility of vector types.
187 /// If 'a' and 'b' are vector types, and they are cast compatible,
188 /// it calls the 'areElementsCastCompatible' function to check for
189 /// element cast compatibility.
190 /// Returns 'true' if the vector types are cast compatible,  and 'false'
191 /// otherwise.
areVectorCastSimpleCompatible(Type a,Type b,function_ref<bool (TypeRange,TypeRange)> areElementsCastCompatible)192 static bool areVectorCastSimpleCompatible(
193     Type a, Type b,
194     function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
195   if (auto va = a.dyn_cast<VectorType>())
196     if (auto vb = b.dyn_cast<VectorType>())
197       return va.getShape().equals(vb.getShape()) &&
198              areElementsCastCompatible(va.getElementType(),
199                                        vb.getElementType());
200   return false;
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // AddFOp
205 //===----------------------------------------------------------------------===//
206 
fold(ArrayRef<Attribute> operands)207 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
208   return constFoldBinaryOp<FloatAttr>(
209       operands, [](APFloat a, APFloat b) { return a + b; });
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AddIOp
214 //===----------------------------------------------------------------------===//
215 
fold(ArrayRef<Attribute> operands)216 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
217   /// addi(x, 0) -> x
218   if (matchPattern(rhs(), m_Zero()))
219     return lhs();
220 
221   return constFoldBinaryOp<IntegerAttr>(operands,
222                                         [](APInt a, APInt b) { return a + b; });
223 }
224 
225 /// Canonicalize a sum of a constant and (constant - something) to simply be
226 /// a sum of constants minus something. This transformation does similar
227 /// transformations for additions of a constant with a subtract/add of
228 /// a constant. This may result in some operations being reordered (but should
229 /// remain equivalent).
230 struct AddConstantReorder : public OpRewritePattern<AddIOp> {
231   using OpRewritePattern<AddIOp>::OpRewritePattern;
232 
matchAndRewriteAddConstantReorder233   LogicalResult matchAndRewrite(AddIOp addop,
234                                 PatternRewriter &rewriter) const override {
235     for (int i = 0; i < 2; i++) {
236       APInt origConst;
237       APInt midConst;
238       if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
239         if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
240           for (int j = 0; j < 2; j++) {
241             if (matchPattern(midAddOp.getOperand(j),
242                              m_ConstantInt(&midConst))) {
243               auto nextConstant = rewriter.create<ConstantOp>(
244                   addop.getLoc(), rewriter.getIntegerAttr(
245                                       addop.getType(), origConst + midConst));
246               rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
247                                                   midAddOp.getOperand(1 - j));
248               return success();
249             }
250           }
251         }
252         if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
253           if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
254             auto nextConstant = rewriter.create<ConstantOp>(
255                 addop.getLoc(),
256                 rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
257             rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
258                                                 midSubOp.getOperand(1));
259             return success();
260           }
261           if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
262             auto nextConstant = rewriter.create<ConstantOp>(
263                 addop.getLoc(),
264                 rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
265             rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
266                                                 midSubOp.getOperand(0));
267             return success();
268           }
269         }
270       }
271     }
272     return failure();
273   }
274 };
275 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)276 void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
277                                          MLIRContext *context) {
278   results.insert<AddConstantReorder>(context);
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // AndOp
283 //===----------------------------------------------------------------------===//
284 
fold(ArrayRef<Attribute> operands)285 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
286   /// and(x, 0) -> 0
287   if (matchPattern(rhs(), m_Zero()))
288     return rhs();
289   /// and(x, allOnes) -> x
290   APInt intValue;
291   if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
292     return lhs();
293   /// and(x,x) -> x
294   if (lhs() == rhs())
295     return rhs();
296 
297   return constFoldBinaryOp<IntegerAttr>(operands,
298                                         [](APInt a, APInt b) { return a & b; });
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // AssertOp
303 //===----------------------------------------------------------------------===//
304 
canonicalize(AssertOp op,PatternRewriter & rewriter)305 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
306   // Erase assertion if argument is constant true.
307   if (matchPattern(op.arg(), m_One())) {
308     rewriter.eraseOp(op);
309     return success();
310   }
311   return failure();
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // AtomicRMWOp
316 //===----------------------------------------------------------------------===//
317 
verify(AtomicRMWOp op)318 static LogicalResult verify(AtomicRMWOp op) {
319   if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
320     return op.emitOpError(
321         "expects the number of subscripts to be equal to memref rank");
322   switch (op.kind()) {
323   case AtomicRMWKind::addf:
324   case AtomicRMWKind::maxf:
325   case AtomicRMWKind::minf:
326   case AtomicRMWKind::mulf:
327     if (!op.value().getType().isa<FloatType>())
328       return op.emitOpError()
329              << "with kind '" << stringifyAtomicRMWKind(op.kind())
330              << "' expects a floating-point type";
331     break;
332   case AtomicRMWKind::addi:
333   case AtomicRMWKind::maxs:
334   case AtomicRMWKind::maxu:
335   case AtomicRMWKind::mins:
336   case AtomicRMWKind::minu:
337   case AtomicRMWKind::muli:
338     if (!op.value().getType().isa<IntegerType>())
339       return op.emitOpError()
340              << "with kind '" << stringifyAtomicRMWKind(op.kind())
341              << "' expects an integer type";
342     break;
343   default:
344     break;
345   }
346   return success();
347 }
348 
349 /// Returns the identity value attribute associated with an AtomicRMWKind op.
getIdentityValueAttr(AtomicRMWKind kind,Type resultType,OpBuilder & builder,Location loc)350 Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
351                                      OpBuilder &builder, Location loc) {
352   switch (kind) {
353   case AtomicRMWKind::maxf:
354     return builder.getFloatAttr(
355         resultType,
356         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
357                         /*Negative=*/true));
358   case AtomicRMWKind::addf:
359   case AtomicRMWKind::addi:
360   case AtomicRMWKind::maxu:
361     return builder.getZeroAttr(resultType);
362   case AtomicRMWKind::maxs:
363     return builder.getIntegerAttr(
364         resultType,
365         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
366   case AtomicRMWKind::minf:
367     return builder.getFloatAttr(
368         resultType,
369         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
370                         /*Negative=*/false));
371   case AtomicRMWKind::mins:
372     return builder.getIntegerAttr(
373         resultType,
374         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
375   case AtomicRMWKind::minu:
376     return builder.getIntegerAttr(
377         resultType,
378         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
379   case AtomicRMWKind::muli:
380     return builder.getIntegerAttr(resultType, 1);
381   case AtomicRMWKind::mulf:
382     return builder.getFloatAttr(resultType, 1);
383   // TODO: Add remaining reduction operations.
384   default:
385     (void)emitOptionalError(loc, "Reduction operation type not supported");
386     break;
387   }
388   return nullptr;
389 }
390 
391 /// Returns the identity value associated with an AtomicRMWKind op.
getIdentityValue(AtomicRMWKind op,Type resultType,OpBuilder & builder,Location loc)392 Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
393                              OpBuilder &builder, Location loc) {
394   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
395   return builder.create<ConstantOp>(loc, attr);
396 }
397 
398 /// Return the value obtained by applying the reduction operation kind
399 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
getReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value lhs,Value rhs)400 Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
401                            Value lhs, Value rhs) {
402   switch (op) {
403   case AtomicRMWKind::addf:
404     return builder.create<AddFOp>(loc, lhs, rhs);
405   case AtomicRMWKind::addi:
406     return builder.create<AddIOp>(loc, lhs, rhs);
407   case AtomicRMWKind::mulf:
408     return builder.create<MulFOp>(loc, lhs, rhs);
409   case AtomicRMWKind::muli:
410     return builder.create<MulIOp>(loc, lhs, rhs);
411   case AtomicRMWKind::maxf:
412     return builder.create<SelectOp>(
413         loc, builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs), lhs,
414         rhs);
415   case AtomicRMWKind::minf:
416     return builder.create<SelectOp>(
417         loc, builder.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs), lhs,
418         rhs);
419   case AtomicRMWKind::maxs:
420     return builder.create<SelectOp>(
421         loc, builder.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs), lhs,
422         rhs);
423   case AtomicRMWKind::mins:
424     return builder.create<SelectOp>(
425         loc, builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs), lhs,
426         rhs);
427   case AtomicRMWKind::maxu:
428     return builder.create<SelectOp>(
429         loc, builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhs, rhs), lhs,
430         rhs);
431   case AtomicRMWKind::minu:
432     return builder.create<SelectOp>(
433         loc, builder.create<CmpIOp>(loc, CmpIPredicate::ult, lhs, rhs), lhs,
434         rhs);
435   // TODO: Add remaining reduction operations.
436   default:
437     (void)emitOptionalError(loc, "Reduction operation type not supported");
438     break;
439   }
440   return nullptr;
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // GenericAtomicRMWOp
445 //===----------------------------------------------------------------------===//
446 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)447 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
448                                Value memref, ValueRange ivs) {
449   result.addOperands(memref);
450   result.addOperands(ivs);
451 
452   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
453     Type elementType = memrefType.getElementType();
454     result.addTypes(elementType);
455 
456     Region *bodyRegion = result.addRegion();
457     bodyRegion->push_back(new Block());
458     bodyRegion->addArgument(elementType);
459   }
460 }
461 
verify(GenericAtomicRMWOp op)462 static LogicalResult verify(GenericAtomicRMWOp op) {
463   auto &body = op.body();
464   if (body.getNumArguments() != 1)
465     return op.emitOpError("expected single number of entry block arguments");
466 
467   if (op.getResult().getType() != body.getArgument(0).getType())
468     return op.emitOpError(
469         "expected block argument of the same type result type");
470 
471   bool hasSideEffects =
472       body.walk([&](Operation *nestedOp) {
473             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
474               return WalkResult::advance();
475             nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
476                                 "only operations with no side effects");
477             return WalkResult::interrupt();
478           })
479           .wasInterrupted();
480   return hasSideEffects ? failure() : success();
481 }
482 
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)483 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
484                                            OperationState &result) {
485   OpAsmParser::OperandType memref;
486   Type memrefType;
487   SmallVector<OpAsmParser::OperandType, 4> ivs;
488 
489   Type indexType = parser.getBuilder().getIndexType();
490   if (parser.parseOperand(memref) ||
491       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
492       parser.parseColonType(memrefType) ||
493       parser.resolveOperand(memref, memrefType, result.operands) ||
494       parser.resolveOperands(ivs, indexType, result.operands))
495     return failure();
496 
497   Region *body = result.addRegion();
498   if (parser.parseRegion(*body, llvm::None, llvm::None) ||
499       parser.parseOptionalAttrDict(result.attributes))
500     return failure();
501   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
502   return success();
503 }
504 
print(OpAsmPrinter & p,GenericAtomicRMWOp op)505 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
506   p << ' ' << op.memref() << "[" << op.indices()
507     << "] : " << op.memref().getType();
508   p.printRegion(op.body());
509   p.printOptionalAttrDict(op->getAttrs());
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // AtomicYieldOp
514 //===----------------------------------------------------------------------===//
515 
verify(AtomicYieldOp op)516 static LogicalResult verify(AtomicYieldOp op) {
517   Type parentType = op->getParentOp()->getResultTypes().front();
518   Type resultType = op.result().getType();
519   if (parentType != resultType)
520     return op.emitOpError() << "types mismatch between yield op: " << resultType
521                             << " and its parent: " << parentType;
522   return success();
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // BitcastOp
527 //===----------------------------------------------------------------------===//
528 
areCastCompatible(TypeRange inputs,TypeRange outputs)529 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
530   assert(inputs.size() == 1 && outputs.size() == 1 &&
531          "bitcast op expects one operand and result");
532   Type a = inputs.front(), b = outputs.front();
533   if (a.isSignlessIntOrFloat() && b.isSignlessIntOrFloat())
534     return a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
535   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
536 }
537 
fold(ArrayRef<Attribute> operands)538 OpFoldResult BitcastOp::fold(ArrayRef<Attribute> operands) {
539   assert(operands.size() == 1 && "bitcastop expects 1 operand");
540 
541   // Bitcast of bitcast
542   auto *sourceOp = getOperand().getDefiningOp();
543   if (auto sourceBitcast = dyn_cast_or_null<BitcastOp>(sourceOp)) {
544     setOperand(sourceBitcast.getOperand());
545     return getResult();
546   }
547 
548   auto operand = operands[0];
549   if (!operand)
550     return {};
551 
552   Type resType = getResult().getType();
553 
554   if (auto denseAttr = operand.dyn_cast<DenseElementsAttr>())
555     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
556 
557   APInt bits;
558   if (auto floatAttr = operand.dyn_cast<FloatAttr>())
559     bits = floatAttr.getValue().bitcastToAPInt();
560   else if (auto intAttr = operand.dyn_cast<IntegerAttr>())
561     bits = intAttr.getValue();
562   else
563     return {};
564 
565   if (resType.isa<IntegerType>())
566     return IntegerAttr::get(resType, bits);
567   if (auto resFloatType = resType.dyn_cast<FloatType>())
568     return FloatAttr::get(resType,
569                           APFloat(resFloatType.getFloatSemantics(), bits));
570   return {};
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // BranchOp
575 //===----------------------------------------------------------------------===//
576 
577 /// Given a successor, try to collapse it to a new destination if it only
578 /// contains a passthrough unconditional branch. If the successor is
579 /// collapsable, `successor` and `successorOperands` are updated to reference
580 /// the new destination and values. `argStorage` is used as storage if operands
581 /// to the collapsed successor need to be remapped. It must outlive uses of
582 /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)583 static LogicalResult collapseBranch(Block *&successor,
584                                     ValueRange &successorOperands,
585                                     SmallVectorImpl<Value> &argStorage) {
586   // Check that the successor only contains a unconditional branch.
587   if (std::next(successor->begin()) != successor->end())
588     return failure();
589   // Check that the terminator is an unconditional branch.
590   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
591   if (!successorBranch)
592     return failure();
593   // Check that the arguments are only used within the terminator.
594   for (BlockArgument arg : successor->getArguments()) {
595     for (Operation *user : arg.getUsers())
596       if (user != successorBranch)
597         return failure();
598   }
599   // Don't try to collapse branches to infinite loops.
600   Block *successorDest = successorBranch.getDest();
601   if (successorDest == successor)
602     return failure();
603 
604   // Update the operands to the successor. If the branch parent has no
605   // arguments, we can use the branch operands directly.
606   OperandRange operands = successorBranch.getOperands();
607   if (successor->args_empty()) {
608     successor = successorDest;
609     successorOperands = operands;
610     return success();
611   }
612 
613   // Otherwise, we need to remap any argument operands.
614   for (Value operand : operands) {
615     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
616     if (argOperand && argOperand.getOwner() == successor)
617       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
618     else
619       argStorage.push_back(operand);
620   }
621   successor = successorDest;
622   successorOperands = argStorage;
623   return success();
624 }
625 
626 /// Simplify a branch to a block that has a single predecessor. This effectively
627 /// merges the two blocks.
628 static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)629 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
630   // Check that the successor block has a single predecessor.
631   Block *succ = op.getDest();
632   Block *opParent = op->getBlock();
633   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
634     return failure();
635 
636   // Merge the successor into the current block and erase the branch.
637   rewriter.mergeBlocks(succ, opParent, op.getOperands());
638   rewriter.eraseOp(op);
639   return success();
640 }
641 
642 ///   br ^bb1
643 /// ^bb1
644 ///   br ^bbN(...)
645 ///
646 ///  -> br ^bbN(...)
647 ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)648 static LogicalResult simplifyPassThroughBr(BranchOp op,
649                                            PatternRewriter &rewriter) {
650   Block *dest = op.getDest();
651   ValueRange destOperands = op.getOperands();
652   SmallVector<Value, 4> destOperandStorage;
653 
654   // Try to collapse the successor if it points somewhere other than this
655   // block.
656   if (dest == op->getBlock() ||
657       failed(collapseBranch(dest, destOperands, destOperandStorage)))
658     return failure();
659 
660   // Create a new branch with the collapsed successor.
661   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
662   return success();
663 }
664 
canonicalize(BranchOp op,PatternRewriter & rewriter)665 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
666   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
667                  succeeded(simplifyPassThroughBr(op, rewriter)));
668 }
669 
getDest()670 Block *BranchOp::getDest() { return getSuccessor(); }
671 
setDest(Block * block)672 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
673 
eraseOperand(unsigned index)674 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
675 
676 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)677 BranchOp::getMutableSuccessorOperands(unsigned index) {
678   assert(index == 0 && "invalid successor index");
679   return destOperandsMutable();
680 }
681 
getSuccessorForOperands(ArrayRef<Attribute>)682 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
683 
684 //===----------------------------------------------------------------------===//
685 // CallOp
686 //===----------------------------------------------------------------------===//
687 
verifySymbolUses(SymbolTableCollection & symbolTable)688 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
689   // Check that the callee attribute was specified.
690   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
691   if (!fnAttr)
692     return emitOpError("requires a 'callee' symbol reference attribute");
693   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
694   if (!fn)
695     return emitOpError() << "'" << fnAttr.getValue()
696                          << "' does not reference a valid function";
697 
698   // Verify that the operand and result types match the callee.
699   auto fnType = fn.getType();
700   if (fnType.getNumInputs() != getNumOperands())
701     return emitOpError("incorrect number of operands for callee");
702 
703   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
704     if (getOperand(i).getType() != fnType.getInput(i))
705       return emitOpError("operand type mismatch: expected operand type ")
706              << fnType.getInput(i) << ", but provided "
707              << getOperand(i).getType() << " for operand number " << i;
708 
709   if (fnType.getNumResults() != getNumResults())
710     return emitOpError("incorrect number of results for callee");
711 
712   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
713     if (getResult(i).getType() != fnType.getResult(i)) {
714       auto diag = emitOpError("result type mismatch at index ") << i;
715       diag.attachNote() << "      op result types: " << getResultTypes();
716       diag.attachNote() << "function result types: " << fnType.getResults();
717       return diag;
718     }
719 
720   return success();
721 }
722 
getCalleeType()723 FunctionType CallOp::getCalleeType() {
724   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // CallIndirectOp
729 //===----------------------------------------------------------------------===//
730 
731 /// Fold indirect calls that have a constant function as the callee operand.
canonicalize(CallIndirectOp indirectCall,PatternRewriter & rewriter)732 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
733                                            PatternRewriter &rewriter) {
734   // Check that the callee is a constant callee.
735   SymbolRefAttr calledFn;
736   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
737     return failure();
738 
739   // Replace with a direct call.
740   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
741                                       indirectCall.getResultTypes(),
742                                       indirectCall.getArgOperands());
743   return success();
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // General helpers for comparison ops
748 //===----------------------------------------------------------------------===//
749 
750 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)751 static Type getI1SameShape(Type type) {
752   auto i1Type = IntegerType::get(type.getContext(), 1);
753   if (auto tensorType = type.dyn_cast<RankedTensorType>())
754     return RankedTensorType::get(tensorType.getShape(), i1Type);
755   if (type.isa<UnrankedTensorType>())
756     return UnrankedTensorType::get(i1Type);
757   if (auto vectorType = type.dyn_cast<VectorType>())
758     return VectorType::get(vectorType.getShape(), i1Type);
759   return i1Type;
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // CmpIOp
764 //===----------------------------------------------------------------------===//
765 
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)766 static void buildCmpIOp(OpBuilder &build, OperationState &result,
767                         CmpIPredicate predicate, Value lhs, Value rhs) {
768   result.addOperands({lhs, rhs});
769   result.types.push_back(getI1SameShape(lhs.getType()));
770   result.addAttribute(CmpIOp::getPredicateAttrName(),
771                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
772 }
773 
774 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
775 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)776 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
777                              const APInt &rhs) {
778   switch (predicate) {
779   case CmpIPredicate::eq:
780     return lhs.eq(rhs);
781   case CmpIPredicate::ne:
782     return lhs.ne(rhs);
783   case CmpIPredicate::slt:
784     return lhs.slt(rhs);
785   case CmpIPredicate::sle:
786     return lhs.sle(rhs);
787   case CmpIPredicate::sgt:
788     return lhs.sgt(rhs);
789   case CmpIPredicate::sge:
790     return lhs.sge(rhs);
791   case CmpIPredicate::ult:
792     return lhs.ult(rhs);
793   case CmpIPredicate::ule:
794     return lhs.ule(rhs);
795   case CmpIPredicate::ugt:
796     return lhs.ugt(rhs);
797   case CmpIPredicate::uge:
798     return lhs.uge(rhs);
799   }
800   llvm_unreachable("unknown comparison predicate");
801 }
802 
803 // Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(CmpIPredicate predicate)804 static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
805   switch (predicate) {
806   case CmpIPredicate::eq:
807   case CmpIPredicate::sle:
808   case CmpIPredicate::sge:
809   case CmpIPredicate::ule:
810   case CmpIPredicate::uge:
811     return true;
812   case CmpIPredicate::ne:
813   case CmpIPredicate::slt:
814   case CmpIPredicate::sgt:
815   case CmpIPredicate::ult:
816   case CmpIPredicate::ugt:
817     return false;
818   }
819   llvm_unreachable("unknown comparison predicate");
820 }
821 
822 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)823 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
824   assert(operands.size() == 2 && "cmpi takes two arguments");
825 
826   if (lhs() == rhs()) {
827     auto val = applyCmpPredicateToEqualOperands(getPredicate());
828     return BoolAttr::get(getContext(), val);
829   }
830 
831   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
832   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
833   if (!lhs || !rhs)
834     return {};
835 
836   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
837   return BoolAttr::get(getContext(), val);
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // CmpFOp
842 //===----------------------------------------------------------------------===//
843 
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)844 static void buildCmpFOp(OpBuilder &build, OperationState &result,
845                         CmpFPredicate predicate, Value lhs, Value rhs) {
846   result.addOperands({lhs, rhs});
847   result.types.push_back(getI1SameShape(lhs.getType()));
848   result.addAttribute(CmpFOp::getPredicateAttrName(),
849                       build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
850 }
851 
852 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
853 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)854 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
855                              const APFloat &rhs) {
856   auto cmpResult = lhs.compare(rhs);
857   switch (predicate) {
858   case CmpFPredicate::AlwaysFalse:
859     return false;
860   case CmpFPredicate::OEQ:
861     return cmpResult == APFloat::cmpEqual;
862   case CmpFPredicate::OGT:
863     return cmpResult == APFloat::cmpGreaterThan;
864   case CmpFPredicate::OGE:
865     return cmpResult == APFloat::cmpGreaterThan ||
866            cmpResult == APFloat::cmpEqual;
867   case CmpFPredicate::OLT:
868     return cmpResult == APFloat::cmpLessThan;
869   case CmpFPredicate::OLE:
870     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
871   case CmpFPredicate::ONE:
872     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
873   case CmpFPredicate::ORD:
874     return cmpResult != APFloat::cmpUnordered;
875   case CmpFPredicate::UEQ:
876     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
877   case CmpFPredicate::UGT:
878     return cmpResult == APFloat::cmpUnordered ||
879            cmpResult == APFloat::cmpGreaterThan;
880   case CmpFPredicate::UGE:
881     return cmpResult == APFloat::cmpUnordered ||
882            cmpResult == APFloat::cmpGreaterThan ||
883            cmpResult == APFloat::cmpEqual;
884   case CmpFPredicate::ULT:
885     return cmpResult == APFloat::cmpUnordered ||
886            cmpResult == APFloat::cmpLessThan;
887   case CmpFPredicate::ULE:
888     return cmpResult == APFloat::cmpUnordered ||
889            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
890   case CmpFPredicate::UNE:
891     return cmpResult != APFloat::cmpEqual;
892   case CmpFPredicate::UNO:
893     return cmpResult == APFloat::cmpUnordered;
894   case CmpFPredicate::AlwaysTrue:
895     return true;
896   }
897   llvm_unreachable("unknown comparison predicate");
898 }
899 
900 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)901 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
902   assert(operands.size() == 2 && "cmpf takes two arguments");
903 
904   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
905   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
906 
907   // TODO: We could actually do some intelligent things if we know only one
908   // of the operands, but it's inf or nan.
909   if (!lhs || !rhs)
910     return {};
911 
912   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
913   return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // CondBranchOp
918 //===----------------------------------------------------------------------===//
919 
920 namespace {
921 /// cond_br true, ^bb1, ^bb2
922 ///  -> br ^bb1
923 /// cond_br false, ^bb1, ^bb2
924 ///  -> br ^bb2
925 ///
926 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
927   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
928 
matchAndRewrite__anon1787a6f50611::SimplifyConstCondBranchPred929   LogicalResult matchAndRewrite(CondBranchOp condbr,
930                                 PatternRewriter &rewriter) const override {
931     if (matchPattern(condbr.getCondition(), m_NonZero())) {
932       // True branch taken.
933       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
934                                             condbr.getTrueOperands());
935       return success();
936     } else if (matchPattern(condbr.getCondition(), m_Zero())) {
937       // False branch taken.
938       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
939                                             condbr.getFalseOperands());
940       return success();
941     }
942     return failure();
943   }
944 };
945 
946 ///   cond_br %cond, ^bb1, ^bb2
947 /// ^bb1
948 ///   br ^bbN(...)
949 /// ^bb2
950 ///   br ^bbK(...)
951 ///
952 ///  -> cond_br %cond, ^bbN(...), ^bbK(...)
953 ///
954 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
955   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
956 
matchAndRewrite__anon1787a6f50611::SimplifyPassThroughCondBranch957   LogicalResult matchAndRewrite(CondBranchOp condbr,
958                                 PatternRewriter &rewriter) const override {
959     Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
960     ValueRange trueDestOperands = condbr.getTrueOperands();
961     ValueRange falseDestOperands = condbr.getFalseOperands();
962     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
963 
964     // Try to collapse one of the current successors.
965     LogicalResult collapsedTrue =
966         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
967     LogicalResult collapsedFalse =
968         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
969     if (failed(collapsedTrue) && failed(collapsedFalse))
970       return failure();
971 
972     // Create a new branch with the collapsed successors.
973     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
974                                               trueDest, trueDestOperands,
975                                               falseDest, falseDestOperands);
976     return success();
977   }
978 };
979 
980 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
981 ///  -> br ^bb1(A, ..., N)
982 ///
983 /// cond_br %cond, ^bb1(A), ^bb1(B)
984 ///  -> %select = select %cond, A, B
985 ///     br ^bb1(%select)
986 ///
987 struct SimplifyCondBranchIdenticalSuccessors
988     : public OpRewritePattern<CondBranchOp> {
989   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
990 
matchAndRewrite__anon1787a6f50611::SimplifyCondBranchIdenticalSuccessors991   LogicalResult matchAndRewrite(CondBranchOp condbr,
992                                 PatternRewriter &rewriter) const override {
993     // Check that the true and false destinations are the same and have the same
994     // operands.
995     Block *trueDest = condbr.trueDest();
996     if (trueDest != condbr.falseDest())
997       return failure();
998 
999     // If all of the operands match, no selects need to be generated.
1000     OperandRange trueOperands = condbr.getTrueOperands();
1001     OperandRange falseOperands = condbr.getFalseOperands();
1002     if (trueOperands == falseOperands) {
1003       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
1004       return success();
1005     }
1006 
1007     // Otherwise, if the current block is the only predecessor insert selects
1008     // for any mismatched branch operands.
1009     if (trueDest->getUniquePredecessor() != condbr->getBlock())
1010       return failure();
1011 
1012     // Generate a select for any operands that differ between the two.
1013     SmallVector<Value, 8> mergedOperands;
1014     mergedOperands.reserve(trueOperands.size());
1015     Value condition = condbr.getCondition();
1016     for (auto it : llvm::zip(trueOperands, falseOperands)) {
1017       if (std::get<0>(it) == std::get<1>(it))
1018         mergedOperands.push_back(std::get<0>(it));
1019       else
1020         mergedOperands.push_back(rewriter.create<SelectOp>(
1021             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1022     }
1023 
1024     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1025     return success();
1026   }
1027 };
1028 
1029 ///   ...
1030 ///   cond_br %cond, ^bb1(...), ^bb2(...)
1031 /// ...
1032 /// ^bb1: // has single predecessor
1033 ///   ...
1034 ///   cond_br %cond, ^bb3(...), ^bb4(...)
1035 ///
1036 /// ->
1037 ///
1038 ///   ...
1039 ///   cond_br %cond, ^bb1(...), ^bb2(...)
1040 /// ...
1041 /// ^bb1: // has single predecessor
1042 ///   ...
1043 ///   br ^bb3(...)
1044 ///
1045 struct SimplifyCondBranchFromCondBranchOnSameCondition
1046     : public OpRewritePattern<CondBranchOp> {
1047   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1048 
matchAndRewrite__anon1787a6f50611::SimplifyCondBranchFromCondBranchOnSameCondition1049   LogicalResult matchAndRewrite(CondBranchOp condbr,
1050                                 PatternRewriter &rewriter) const override {
1051     // Check that we have a single distinct predecessor.
1052     Block *currentBlock = condbr->getBlock();
1053     Block *predecessor = currentBlock->getSinglePredecessor();
1054     if (!predecessor)
1055       return failure();
1056 
1057     // Check that the predecessor terminates with a conditional branch to this
1058     // block and that it branches on the same condition.
1059     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
1060     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
1061       return failure();
1062 
1063     // Fold this branch to an unconditional branch.
1064     if (currentBlock == predBranch.trueDest())
1065       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
1066                                             condbr.trueDestOperands());
1067     else
1068       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
1069                                             condbr.falseDestOperands());
1070     return success();
1071   }
1072 };
1073 
1074 ///   cond_br %arg0, ^trueB, ^falseB
1075 ///
1076 /// ^trueB:
1077 ///   "test.consumer1"(%arg0) : (i1) -> ()
1078 ///    ...
1079 ///
1080 /// ^falseB:
1081 ///   "test.consumer2"(%arg0) : (i1) -> ()
1082 ///   ...
1083 ///
1084 /// ->
1085 ///
1086 ///   cond_br %arg0, ^trueB, ^falseB
1087 /// ^trueB:
1088 ///   "test.consumer1"(%true) : (i1) -> ()
1089 ///   ...
1090 ///
1091 /// ^falseB:
1092 ///   "test.consumer2"(%false) : (i1) -> ()
1093 ///   ...
1094 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
1095   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1096 
matchAndRewrite__anon1787a6f50611::CondBranchTruthPropagation1097   LogicalResult matchAndRewrite(CondBranchOp condbr,
1098                                 PatternRewriter &rewriter) const override {
1099     // Check that we have a single distinct predecessor.
1100     bool replaced = false;
1101     Type ty = rewriter.getI1Type();
1102 
1103     // These variables serve to prevent creating duplicate constants
1104     // and hold constant true or false values.
1105     Value constantTrue = nullptr;
1106     Value constantFalse = nullptr;
1107 
1108     // TODO These checks can be expanded to encompas any use with only
1109     // either the true of false edge as a predecessor. For now, we fall
1110     // back to checking the single predecessor is given by the true/fasle
1111     // destination, thereby ensuring that only that edge can reach the
1112     // op.
1113     if (condbr.getTrueDest()->getSinglePredecessor()) {
1114       for (OpOperand &use :
1115            llvm::make_early_inc_range(condbr.condition().getUses())) {
1116         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
1117           replaced = true;
1118 
1119           if (!constantTrue)
1120             constantTrue = rewriter.create<mlir::ConstantOp>(
1121                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
1122 
1123           rewriter.updateRootInPlace(use.getOwner(),
1124                                      [&] { use.set(constantTrue); });
1125         }
1126       }
1127     }
1128     if (condbr.getFalseDest()->getSinglePredecessor()) {
1129       for (OpOperand &use :
1130            llvm::make_early_inc_range(condbr.condition().getUses())) {
1131         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
1132           replaced = true;
1133 
1134           if (!constantFalse)
1135             constantFalse = rewriter.create<mlir::ConstantOp>(
1136                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
1137 
1138           rewriter.updateRootInPlace(use.getOwner(),
1139                                      [&] { use.set(constantFalse); });
1140         }
1141       }
1142     }
1143     return success(replaced);
1144   }
1145 };
1146 } // end anonymous namespace
1147 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1148 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
1149                                                MLIRContext *context) {
1150   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1151               SimplifyCondBranchIdenticalSuccessors,
1152               SimplifyCondBranchFromCondBranchOnSameCondition,
1153               CondBranchTruthPropagation>(context);
1154 }
1155 
1156 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1157 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1158   assert(index < getNumSuccessors() && "invalid successor index");
1159   return index == trueIndex ? trueDestOperandsMutable()
1160                             : falseDestOperandsMutable();
1161 }
1162 
getSuccessorForOperands(ArrayRef<Attribute> operands)1163 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1164   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1165     return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1166   return nullptr;
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // Constant*Op
1171 //===----------------------------------------------------------------------===//
1172 
print(OpAsmPrinter & p,ConstantOp & op)1173 static void print(OpAsmPrinter &p, ConstantOp &op) {
1174   p << " ";
1175   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
1176 
1177   if (op->getAttrs().size() > 1)
1178     p << ' ';
1179   p << op.getValue();
1180 
1181   // If the value is a symbol reference or Array, print a trailing type.
1182   if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
1183     p << " : " << op.getType();
1184 }
1185 
parseConstantOp(OpAsmParser & parser,OperationState & result)1186 static ParseResult parseConstantOp(OpAsmParser &parser,
1187                                    OperationState &result) {
1188   Attribute valueAttr;
1189   if (parser.parseOptionalAttrDict(result.attributes) ||
1190       parser.parseAttribute(valueAttr, "value", result.attributes))
1191     return failure();
1192 
1193   // If the attribute is a symbol reference or array, then we expect a trailing
1194   // type.
1195   Type type;
1196   if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
1197     type = valueAttr.getType();
1198   else if (parser.parseColonType(type))
1199     return failure();
1200 
1201   // Add the attribute type to the list.
1202   return parser.addTypeToList(type, result.types);
1203 }
1204 
1205 /// The constant op requires an attribute, and furthermore requires that it
1206 /// matches the return type.
verify(ConstantOp & op)1207 static LogicalResult verify(ConstantOp &op) {
1208   auto value = op.getValue();
1209   if (!value)
1210     return op.emitOpError("requires a 'value' attribute");
1211 
1212   Type type = op.getType();
1213   if (!value.getType().isa<NoneType>() && type != value.getType())
1214     return op.emitOpError() << "requires attribute's type (" << value.getType()
1215                             << ") to match op's return type (" << type << ")";
1216 
1217   if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1218     if (type.isa<IndexType>() || value.isa<BoolAttr>())
1219       return success();
1220     IntegerType intType = type.cast<IntegerType>();
1221     if (!intType.isSignless())
1222       return op.emitOpError("requires integer result types to be signless");
1223 
1224     // If the type has a known bitwidth we verify that the value can be
1225     // represented with the given bitwidth.
1226     unsigned bitwidth = intType.getWidth();
1227     APInt intVal = intAttr.getValue();
1228     if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1229       return op.emitOpError("requires 'value' to be an integer within the "
1230                             "range of the integer result type");
1231     return success();
1232   }
1233 
1234   if (auto complexTy = type.dyn_cast<ComplexType>()) {
1235     auto arrayAttr = value.dyn_cast<ArrayAttr>();
1236     if (!complexTy || arrayAttr.size() != 2)
1237       return op.emitOpError(
1238           "requires 'value' to be a complex constant, represented as array of "
1239           "two values");
1240     auto complexEltTy = complexTy.getElementType();
1241     if (complexEltTy != arrayAttr[0].getType() ||
1242         complexEltTy != arrayAttr[1].getType()) {
1243       return op.emitOpError()
1244              << "requires attribute's element types (" << arrayAttr[0].getType()
1245              << ", " << arrayAttr[1].getType()
1246              << ") to match the element type of the op's return type ("
1247              << complexEltTy << ")";
1248     }
1249     return success();
1250   }
1251 
1252   if (type.isa<FloatType>()) {
1253     if (!value.isa<FloatAttr>())
1254       return op.emitOpError("requires 'value' to be a floating point constant");
1255     return success();
1256   }
1257 
1258   if (type.isa<ShapedType>()) {
1259     if (!value.isa<ElementsAttr>())
1260       return op.emitOpError("requires 'value' to be a shaped constant");
1261     return success();
1262   }
1263 
1264   if (type.isa<FunctionType>()) {
1265     auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1266     if (!fnAttr)
1267       return op.emitOpError("requires 'value' to be a function reference");
1268 
1269     // Try to find the referenced function.
1270     auto fn =
1271         op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1272     if (!fn)
1273       return op.emitOpError()
1274              << "reference to undefined function '" << fnAttr.getValue() << "'";
1275 
1276     // Check that the referenced function has the correct type.
1277     if (fn.getType() != type)
1278       return op.emitOpError("reference to function with mismatched type");
1279 
1280     return success();
1281   }
1282 
1283   if (type.isa<NoneType>() && value.isa<UnitAttr>())
1284     return success();
1285 
1286   return op.emitOpError("unsupported 'value' attribute: ") << value;
1287 }
1288 
fold(ArrayRef<Attribute> operands)1289 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1290   assert(operands.empty() && "constant has no operands");
1291   return getValue();
1292 }
1293 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1294 void ConstantOp::getAsmResultNames(
1295     function_ref<void(Value, StringRef)> setNameFn) {
1296   Type type = getType();
1297   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1298     IntegerType intTy = type.dyn_cast<IntegerType>();
1299 
1300     // Sugar i1 constants with 'true' and 'false'.
1301     if (intTy && intTy.getWidth() == 1)
1302       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1303 
1304     // Otherwise, build a complex name with the value and type.
1305     SmallString<32> specialNameBuffer;
1306     llvm::raw_svector_ostream specialName(specialNameBuffer);
1307     specialName << 'c' << intCst.getInt();
1308     if (intTy)
1309       specialName << '_' << type;
1310     setNameFn(getResult(), specialName.str());
1311 
1312   } else if (type.isa<FunctionType>()) {
1313     setNameFn(getResult(), "f");
1314   } else {
1315     setNameFn(getResult(), "cst");
1316   }
1317 }
1318 
1319 /// Returns true if a constant operation can be built with the given value and
1320 /// result type.
isBuildableWith(Attribute value,Type type)1321 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1322   // SymbolRefAttr can only be used with a function type.
1323   if (value.isa<SymbolRefAttr>())
1324     return type.isa<FunctionType>();
1325   // The attribute must have the same type as 'type'.
1326   if (!value.getType().isa<NoneType>() && value.getType() != type)
1327     return false;
1328   // If the type is an integer type, it must be signless.
1329   if (IntegerType integerTy = type.dyn_cast<IntegerType>())
1330     if (!integerTy.isSignless())
1331       return false;
1332   // Finally, check that the attribute kind is handled.
1333   if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
1334     auto complexTy = type.dyn_cast<ComplexType>();
1335     if (!complexTy)
1336       return false;
1337     auto complexEltTy = complexTy.getElementType();
1338     return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
1339            arrAttr[1].getType() == complexEltTy;
1340   }
1341   return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1342 }
1343 
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1344 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1345                             const APFloat &value, FloatType type) {
1346   ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1347 }
1348 
classof(Operation * op)1349 bool ConstantFloatOp::classof(Operation *op) {
1350   return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1351 }
1352 
1353 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1354 bool ConstantIntOp::classof(Operation *op) {
1355   return ConstantOp::classof(op) &&
1356          op->getResult(0).getType().isSignlessInteger();
1357 }
1358 
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1359 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1360                           int64_t value, unsigned width) {
1361   Type type = builder.getIntegerType(width);
1362   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1363 }
1364 
1365 /// Build a constant int op producing an integer with the specified type,
1366 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1367 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1368                           int64_t value, Type type) {
1369   assert(type.isSignlessInteger() &&
1370          "ConstantIntOp can only have signless integer type");
1371   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1372 }
1373 
1374 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1375 bool ConstantIndexOp::classof(Operation *op) {
1376   return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1377 }
1378 
build(OpBuilder & builder,OperationState & result,int64_t value)1379 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1380                             int64_t value) {
1381   Type type = builder.getIndexType();
1382   ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1383 }
1384 
1385 // ---------------------------------------------------------------------------
1386 // DivFOp
1387 // ---------------------------------------------------------------------------
1388 
fold(ArrayRef<Attribute> operands)1389 OpFoldResult DivFOp::fold(ArrayRef<Attribute> operands) {
1390   return constFoldBinaryOp<FloatAttr>(
1391       operands, [](APFloat a, APFloat b) { return a / b; });
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // FPExtOp
1396 //===----------------------------------------------------------------------===//
1397 
areCastCompatible(TypeRange inputs,TypeRange outputs)1398 bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399   if (inputs.size() != 1 || outputs.size() != 1)
1400     return false;
1401   Type a = inputs.front(), b = outputs.front();
1402   if (auto fa = a.dyn_cast<FloatType>())
1403     if (auto fb = b.dyn_cast<FloatType>())
1404       return fa.getWidth() < fb.getWidth();
1405   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // FPToSIOp
1410 //===----------------------------------------------------------------------===//
1411 
areCastCompatible(TypeRange inputs,TypeRange outputs)1412 bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1413   if (inputs.size() != 1 || outputs.size() != 1)
1414     return false;
1415   Type a = inputs.front(), b = outputs.front();
1416   if (a.isa<FloatType>() && b.isSignlessInteger())
1417     return true;
1418   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // FPToUIOp
1423 //===----------------------------------------------------------------------===//
1424 
areCastCompatible(TypeRange inputs,TypeRange outputs)1425 bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1426   if (inputs.size() != 1 || outputs.size() != 1)
1427     return false;
1428   Type a = inputs.front(), b = outputs.front();
1429   if (a.isa<FloatType>() && b.isSignlessInteger())
1430     return true;
1431   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1432 }
1433 
1434 //===----------------------------------------------------------------------===//
1435 // FPTruncOp
1436 //===----------------------------------------------------------------------===//
1437 
areCastCompatible(TypeRange inputs,TypeRange outputs)1438 bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1439   if (inputs.size() != 1 || outputs.size() != 1)
1440     return false;
1441   Type a = inputs.front(), b = outputs.front();
1442   if (auto fa = a.dyn_cast<FloatType>())
1443     if (auto fb = b.dyn_cast<FloatType>())
1444       return fa.getWidth() > fb.getWidth();
1445   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1446 }
1447 
1448 /// Perform safe const propagation for fptrunc, i.e. only propagate
1449 /// if FP value can be represented without precision loss or rounding.
fold(ArrayRef<Attribute> operands)1450 OpFoldResult FPTruncOp::fold(ArrayRef<Attribute> operands) {
1451   assert(operands.size() == 1 && "unary operation takes one operand");
1452 
1453   auto constOperand = operands.front();
1454   if (!constOperand || !constOperand.isa<FloatAttr>())
1455     return {};
1456 
1457   // Convert to target type via 'double'.
1458   double sourceValue =
1459       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1460   auto targetAttr = FloatAttr::get(getType(), sourceValue);
1461 
1462   // Propagate if constant's value does not change after truncation.
1463   if (sourceValue == targetAttr.getValue().convertToDouble())
1464     return targetAttr;
1465 
1466   return {};
1467 }
1468 
1469 //===----------------------------------------------------------------------===//
1470 // IndexCastOp
1471 //===----------------------------------------------------------------------===//
1472 
1473 // Index cast is applicable from index to integer and backwards.
areCastCompatible(TypeRange inputs,TypeRange outputs)1474 bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1475   if (inputs.size() != 1 || outputs.size() != 1)
1476     return false;
1477   Type a = inputs.front(), b = outputs.front();
1478   if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
1479     auto aShaped = a.cast<ShapedType>();
1480     auto bShaped = b.cast<ShapedType>();
1481 
1482     return (aShaped.getShape() == bShaped.getShape()) &&
1483            areCastCompatible(aShaped.getElementType(),
1484                              bShaped.getElementType());
1485   }
1486 
1487   return (a.isIndex() && b.isSignlessInteger()) ||
1488          (a.isSignlessInteger() && b.isIndex());
1489 }
1490 
fold(ArrayRef<Attribute> cstOperands)1491 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1492   // Fold IndexCast(IndexCast(x)) -> x
1493   auto cast = getOperand().getDefiningOp<IndexCastOp>();
1494   if (cast && cast.getOperand().getType() == getType())
1495     return cast.getOperand();
1496 
1497   // Fold IndexCast(constant) -> constant
1498   // A little hack because we go through int.  Otherwise, the size
1499   // of the constant might need to change.
1500   if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1501     return IntegerAttr::get(getType(), value.getInt());
1502 
1503   return {};
1504 }
1505 
1506 namespace {
1507 ///  index_cast(sign_extend x) => index_cast(x)
1508 struct IndexCastOfSExt : public OpRewritePattern<IndexCastOp> {
1509   using OpRewritePattern<IndexCastOp>::OpRewritePattern;
1510 
matchAndRewrite__anon1787a6f50a11::IndexCastOfSExt1511   LogicalResult matchAndRewrite(IndexCastOp op,
1512                                 PatternRewriter &rewriter) const override {
1513 
1514     if (auto extop = op.getOperand().getDefiningOp<SignExtendIOp>()) {
1515       op.setOperand(extop.getOperand());
1516       return success();
1517     }
1518     return failure();
1519   }
1520 };
1521 
1522 } // namespace
1523 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1524 void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1525                                               MLIRContext *context) {
1526   results.insert<IndexCastOfSExt>(context);
1527 }
1528 
1529 //===----------------------------------------------------------------------===//
1530 // MulFOp
1531 //===----------------------------------------------------------------------===//
1532 
fold(ArrayRef<Attribute> operands)1533 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1534   return constFoldBinaryOp<FloatAttr>(
1535       operands, [](APFloat a, APFloat b) { return a * b; });
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // MulIOp
1540 //===----------------------------------------------------------------------===//
1541 
fold(ArrayRef<Attribute> operands)1542 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1543   /// muli(x, 0) -> 0
1544   if (matchPattern(rhs(), m_Zero()))
1545     return rhs();
1546   /// muli(x, 1) -> x
1547   if (matchPattern(rhs(), m_One()))
1548     return getOperand(0);
1549 
1550   // TODO: Handle the overflow case.
1551   return constFoldBinaryOp<IntegerAttr>(operands,
1552                                         [](APInt a, APInt b) { return a * b; });
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // OrOp
1557 //===----------------------------------------------------------------------===//
1558 
fold(ArrayRef<Attribute> operands)1559 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1560   /// or(x, 0) -> x
1561   if (matchPattern(rhs(), m_Zero()))
1562     return lhs();
1563   /// or(x, x) -> x
1564   if (lhs() == rhs())
1565     return rhs();
1566   /// or(x, <all ones>) -> <all ones>
1567   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
1568     if (rhsAttr.getValue().isAllOnes())
1569       return rhsAttr;
1570 
1571   return constFoldBinaryOp<IntegerAttr>(operands,
1572                                         [](APInt a, APInt b) { return a | b; });
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // RankOp
1577 //===----------------------------------------------------------------------===//
1578 
fold(ArrayRef<Attribute> operands)1579 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1580   // Constant fold rank when the rank of the operand is known.
1581   auto type = getOperand().getType();
1582   if (auto shapedType = type.dyn_cast<ShapedType>())
1583     if (shapedType.hasRank())
1584       return IntegerAttr::get(IndexType::get(getContext()),
1585                               shapedType.getRank());
1586   return IntegerAttr();
1587 }
1588 
1589 //===----------------------------------------------------------------------===//
1590 // ReturnOp
1591 //===----------------------------------------------------------------------===//
1592 
verify(ReturnOp op)1593 static LogicalResult verify(ReturnOp op) {
1594   auto function = cast<FuncOp>(op->getParentOp());
1595 
1596   // The operand number and types must match the function signature.
1597   const auto &results = function.getType().getResults();
1598   if (op.getNumOperands() != results.size())
1599     return op.emitOpError("has ")
1600            << op.getNumOperands() << " operands, but enclosing function (@"
1601            << function.getName() << ") returns " << results.size();
1602 
1603   for (unsigned i = 0, e = results.size(); i != e; ++i)
1604     if (op.getOperand(i).getType() != results[i])
1605       return op.emitError()
1606              << "type of return operand " << i << " ("
1607              << op.getOperand(i).getType()
1608              << ") doesn't match function result type (" << results[i] << ")"
1609              << " in function @" << function.getName();
1610 
1611   return success();
1612 }
1613 
1614 //===----------------------------------------------------------------------===//
1615 // SelectOp
1616 //===----------------------------------------------------------------------===//
1617 
1618 // Transforms a select to a not, where relevant.
1619 //
1620 //  select %arg, %false, %true
1621 //
1622 //  becomes
1623 //
1624 //  xor %arg, %true
1625 struct SelectToNot : public OpRewritePattern<SelectOp> {
1626   using OpRewritePattern<SelectOp>::OpRewritePattern;
1627 
matchAndRewriteSelectToNot1628   LogicalResult matchAndRewrite(SelectOp op,
1629                                 PatternRewriter &rewriter) const override {
1630     if (!matchPattern(op.getTrueValue(), m_Zero()))
1631       return failure();
1632 
1633     if (!matchPattern(op.getFalseValue(), m_One()))
1634       return failure();
1635 
1636     if (!op.getType().isInteger(1))
1637       return failure();
1638 
1639     rewriter.replaceOpWithNewOp<XOrOp>(op, op.condition(), op.getFalseValue());
1640     return success();
1641   }
1642 };
1643 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1644 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1645                                            MLIRContext *context) {
1646   results.insert<SelectToNot>(context);
1647 }
1648 
fold(ArrayRef<Attribute> operands)1649 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
1650   auto trueVal = getTrueValue();
1651   auto falseVal = getFalseValue();
1652   if (trueVal == falseVal)
1653     return trueVal;
1654 
1655   auto condition = getCondition();
1656 
1657   // select true, %0, %1 => %0
1658   if (matchPattern(condition, m_One()))
1659     return trueVal;
1660 
1661   // select false, %0, %1 => %1
1662   if (matchPattern(condition, m_Zero()))
1663     return falseVal;
1664 
1665   if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
1666     auto pred = cmp.predicate();
1667     if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
1668       auto cmpLhs = cmp.lhs();
1669       auto cmpRhs = cmp.rhs();
1670 
1671       // %0 = cmpi eq, %arg0, %arg1
1672       // %1 = select %0, %arg0, %arg1 => %arg1
1673 
1674       // %0 = cmpi ne, %arg0, %arg1
1675       // %1 = select %0, %arg0, %arg1 => %arg0
1676 
1677       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1678           (cmpRhs == trueVal && cmpLhs == falseVal))
1679         return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
1680     }
1681   }
1682   return nullptr;
1683 }
1684 
print(OpAsmPrinter & p,SelectOp op)1685 static void print(OpAsmPrinter &p, SelectOp op) {
1686   p << " " << op.getOperands();
1687   p.printOptionalAttrDict(op->getAttrs());
1688   p << " : ";
1689   if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
1690     p << condType << ", ";
1691   p << op.getType();
1692 }
1693 
parseSelectOp(OpAsmParser & parser,OperationState & result)1694 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
1695   Type conditionType, resultType;
1696   SmallVector<OpAsmParser::OperandType, 3> operands;
1697   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1698       parser.parseOptionalAttrDict(result.attributes) ||
1699       parser.parseColonType(resultType))
1700     return failure();
1701 
1702   // Check for the explicit condition type if this is a masked tensor or vector.
1703   if (succeeded(parser.parseOptionalComma())) {
1704     conditionType = resultType;
1705     if (parser.parseType(resultType))
1706       return failure();
1707   } else {
1708     conditionType = parser.getBuilder().getI1Type();
1709   }
1710 
1711   result.addTypes(resultType);
1712   return parser.resolveOperands(operands,
1713                                 {conditionType, resultType, resultType},
1714                                 parser.getNameLoc(), result.operands);
1715 }
1716 
verify(SelectOp op)1717 static LogicalResult verify(SelectOp op) {
1718   Type conditionType = op.getCondition().getType();
1719   if (conditionType.isSignlessInteger(1))
1720     return success();
1721 
1722   // If the result type is a vector or tensor, the type can be a mask with the
1723   // same elements.
1724   Type resultType = op.getType();
1725   if (!resultType.isa<TensorType, VectorType>())
1726     return op.emitOpError()
1727            << "expected condition to be a signless i1, but got "
1728            << conditionType;
1729   Type shapedConditionType = getI1SameShape(resultType);
1730   if (conditionType != shapedConditionType)
1731     return op.emitOpError()
1732            << "expected condition type to have the same shape "
1733               "as the result type, expected "
1734            << shapedConditionType << ", but got " << conditionType;
1735   return success();
1736 }
1737 
1738 //===----------------------------------------------------------------------===//
1739 // SignExtendIOp
1740 //===----------------------------------------------------------------------===//
1741 
verify(SignExtendIOp op)1742 static LogicalResult verify(SignExtendIOp op) {
1743   // Get the scalar type (which is either directly the type of the operand
1744   // or the vector's/tensor's element type.
1745   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
1746   auto dstType = getElementTypeOrSelf(op.getType());
1747 
1748   // For now, index is forbidden for the source and the destination type.
1749   if (srcType.isa<IndexType>())
1750     return op.emitError() << srcType << " is not a valid operand type";
1751   if (dstType.isa<IndexType>())
1752     return op.emitError() << dstType << " is not a valid result type";
1753 
1754   if (srcType.cast<IntegerType>().getWidth() >=
1755       dstType.cast<IntegerType>().getWidth())
1756     return op.emitError("result type ")
1757            << dstType << " must be wider than operand type " << srcType;
1758 
1759   return success();
1760 }
1761 
fold(ArrayRef<Attribute> operands)1762 OpFoldResult SignExtendIOp::fold(ArrayRef<Attribute> operands) {
1763   assert(operands.size() == 1 && "unary operation takes one operand");
1764 
1765   if (!operands[0])
1766     return {};
1767 
1768   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
1769     return IntegerAttr::get(
1770         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
1771   }
1772 
1773   return {};
1774 }
1775 
1776 //===----------------------------------------------------------------------===//
1777 // SignedDivIOp
1778 //===----------------------------------------------------------------------===//
1779 
fold(ArrayRef<Attribute> operands)1780 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
1781   assert(operands.size() == 2 && "binary operation takes two operands");
1782 
1783   // Don't fold if it would overflow or if it requires a division by zero.
1784   bool overflowOrDiv0 = false;
1785   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1786     if (overflowOrDiv0 || !b) {
1787       overflowOrDiv0 = true;
1788       return a;
1789     }
1790     return a.sdiv_ov(b, overflowOrDiv0);
1791   });
1792 
1793   // Fold out division by one. Assumes all tensors of all ones are splats.
1794   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1795     if (rhs.getValue() == 1)
1796       return lhs();
1797   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1798     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1799       return lhs();
1800   }
1801 
1802   return overflowOrDiv0 ? Attribute() : result;
1803 }
1804 
1805 //===----------------------------------------------------------------------===//
1806 // SignedFloorDivIOp
1807 //===----------------------------------------------------------------------===//
1808 
signedCeilNonnegInputs(APInt a,APInt b,bool & overflow)1809 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
1810   // Returns (a-1)/b + 1
1811   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
1812   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
1813   return val.sadd_ov(one, overflow);
1814 }
1815 
fold(ArrayRef<Attribute> operands)1816 OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
1817   assert(operands.size() == 2 && "binary operation takes two operands");
1818 
1819   // Don't fold if it would overflow or if it requires a division by zero.
1820   bool overflowOrDiv0 = false;
1821   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1822     if (overflowOrDiv0 || !b) {
1823       overflowOrDiv0 = true;
1824       return a;
1825     }
1826     unsigned bits = a.getBitWidth();
1827     APInt zero = APInt::getZero(bits);
1828     if (a.sge(zero) && b.sgt(zero)) {
1829       // Both positive (or a is zero), return a / b.
1830       return a.sdiv_ov(b, overflowOrDiv0);
1831     } else if (a.sle(zero) && b.slt(zero)) {
1832       // Both negative (or a is zero), return -a / -b.
1833       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1834       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1835       return posA.sdiv_ov(posB, overflowOrDiv0);
1836     } else if (a.slt(zero) && b.sgt(zero)) {
1837       // A is negative, b is positive, return - ceil(-a, b).
1838       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1839       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
1840       return zero.ssub_ov(ceil, overflowOrDiv0);
1841     } else {
1842       // A is positive, b is negative, return - ceil(a, -b).
1843       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1844       APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
1845       return zero.ssub_ov(ceil, overflowOrDiv0);
1846     }
1847   });
1848 
1849   // Fold out floor division by one. Assumes all tensors of all ones are
1850   // splats.
1851   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1852     if (rhs.getValue() == 1)
1853       return lhs();
1854   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1855     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1856       return lhs();
1857   }
1858 
1859   return overflowOrDiv0 ? Attribute() : result;
1860 }
1861 
1862 //===----------------------------------------------------------------------===//
1863 // SignedCeilDivIOp
1864 //===----------------------------------------------------------------------===//
1865 
fold(ArrayRef<Attribute> operands)1866 OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
1867   assert(operands.size() == 2 && "binary operation takes two operands");
1868 
1869   // Don't fold if it would overflow or if it requires a division by zero.
1870   bool overflowOrDiv0 = false;
1871   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1872     if (overflowOrDiv0 || !b) {
1873       overflowOrDiv0 = true;
1874       return a;
1875     }
1876     unsigned bits = a.getBitWidth();
1877     APInt zero = APInt::getZero(bits);
1878     if (a.sgt(zero) && b.sgt(zero)) {
1879       // Both positive, return ceil(a, b).
1880       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
1881     } else if (a.slt(zero) && b.slt(zero)) {
1882       // Both negative, return ceil(-a, -b).
1883       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1884       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1885       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
1886     } else if (a.slt(zero) && b.sgt(zero)) {
1887       // A is negative, b is positive, return - ( -a / b).
1888       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1889       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
1890       return zero.ssub_ov(div, overflowOrDiv0);
1891     } else {
1892       // A is positive (or zero), b is negative, return - (a / -b).
1893       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1894       APInt div = a.sdiv_ov(posB, overflowOrDiv0);
1895       return zero.ssub_ov(div, overflowOrDiv0);
1896     }
1897   });
1898 
1899   // Fold out floor division by one. Assumes all tensors of all ones are
1900   // splats.
1901   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1902     if (rhs.getValue() == 1)
1903       return lhs();
1904   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1905     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1906       return lhs();
1907   }
1908 
1909   return overflowOrDiv0 ? Attribute() : result;
1910 }
1911 
1912 //===----------------------------------------------------------------------===//
1913 // SignedRemIOp
1914 //===----------------------------------------------------------------------===//
1915 
fold(ArrayRef<Attribute> operands)1916 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
1917   assert(operands.size() == 2 && "remi_signed takes two operands");
1918 
1919   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1920   if (!rhs)
1921     return {};
1922   auto rhsValue = rhs.getValue();
1923 
1924   // x % 1 = 0
1925   if (rhsValue.isOneValue())
1926     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
1927 
1928   // Don't fold if it requires division by zero.
1929   if (rhsValue.isNullValue())
1930     return {};
1931 
1932   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1933   if (!lhs)
1934     return {};
1935   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
1936 }
1937 
1938 //===----------------------------------------------------------------------===//
1939 // SIToFPOp
1940 //===----------------------------------------------------------------------===//
1941 
1942 // sitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)1943 bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1944   if (inputs.size() != 1 || outputs.size() != 1)
1945     return false;
1946   Type a = inputs.front(), b = outputs.front();
1947   if (a.isSignlessInteger() && b.isa<FloatType>())
1948     return true;
1949   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1950 }
1951 
1952 //===----------------------------------------------------------------------===//
1953 // SplatOp
1954 //===----------------------------------------------------------------------===//
1955 
verify(SplatOp op)1956 static LogicalResult verify(SplatOp op) {
1957   // TODO: we could replace this by a trait.
1958   if (op.getOperand().getType() !=
1959       op.getType().cast<ShapedType>().getElementType())
1960     return op.emitError("operand should be of elemental type of result type");
1961 
1962   return success();
1963 }
1964 
1965 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)1966 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
1967   assert(operands.size() == 1 && "splat takes one operand");
1968 
1969   auto constOperand = operands.front();
1970   if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
1971     return {};
1972 
1973   auto shapedType = getType().cast<ShapedType>();
1974   assert(shapedType.getElementType() == constOperand.getType() &&
1975          "incorrect input attribute type for folding");
1976 
1977   // SplatElementsAttr::get treats single value for second arg as being a splat.
1978   return SplatElementsAttr::get(shapedType, {constOperand});
1979 }
1980 
1981 //===----------------------------------------------------------------------===//
1982 // SubFOp
1983 //===----------------------------------------------------------------------===//
1984 
fold(ArrayRef<Attribute> operands)1985 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
1986   return constFoldBinaryOp<FloatAttr>(
1987       operands, [](APFloat a, APFloat b) { return a - b; });
1988 }
1989 
1990 //===----------------------------------------------------------------------===//
1991 // SubIOp
1992 //===----------------------------------------------------------------------===//
1993 
fold(ArrayRef<Attribute> operands)1994 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
1995   // subi(x,x) -> 0
1996   if (getOperand(0) == getOperand(1))
1997     return Builder(getContext()).getZeroAttr(getType());
1998   // subi(x,0) -> x
1999   if (matchPattern(rhs(), m_Zero()))
2000     return lhs();
2001 
2002   return constFoldBinaryOp<IntegerAttr>(operands,
2003                                         [](APInt a, APInt b) { return a - b; });
2004 }
2005 
2006 /// Canonicalize a sub of a constant and (constant +/- something) to simply be
2007 /// a single operation that merges the two constants.
2008 struct SubConstantReorder : public OpRewritePattern<SubIOp> {
2009   using OpRewritePattern<SubIOp>::OpRewritePattern;
2010 
matchAndRewriteSubConstantReorder2011   LogicalResult matchAndRewrite(SubIOp subOp,
2012                                 PatternRewriter &rewriter) const override {
2013     APInt origConst;
2014     APInt midConst;
2015 
2016     if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
2017       if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
2018         // origConst - (midConst + something) == (origConst - midConst) -
2019         // something
2020         for (int j = 0; j < 2; j++) {
2021           if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
2022             auto nextConstant = rewriter.create<ConstantOp>(
2023                 subOp.getLoc(),
2024                 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2025             rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2026                                                 midAddOp.getOperand(1 - j));
2027             return success();
2028           }
2029         }
2030       }
2031 
2032       if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
2033         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2034           // (midConst - something) - origConst == (midConst - origConst) -
2035           // something
2036           auto nextConstant = rewriter.create<ConstantOp>(
2037               subOp.getLoc(),
2038               rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
2039           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2040                                               midSubOp.getOperand(1));
2041           return success();
2042         }
2043 
2044         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2045           // (something - midConst) - origConst == something - (origConst +
2046           // midConst)
2047           auto nextConstant = rewriter.create<ConstantOp>(
2048               subOp.getLoc(),
2049               rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
2050           rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
2051                                               nextConstant);
2052           return success();
2053         }
2054       }
2055 
2056       if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
2057         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2058           // origConst - (midConst - something) == (origConst - midConst) +
2059           // something
2060           auto nextConstant = rewriter.create<ConstantOp>(
2061               subOp.getLoc(),
2062               rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2063           rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2064                                               midSubOp.getOperand(1));
2065           return success();
2066         }
2067 
2068         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2069           // origConst - (something - midConst) == (origConst + midConst) -
2070           // something
2071           auto nextConstant = rewriter.create<ConstantOp>(
2072               subOp.getLoc(),
2073               rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
2074           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2075                                               midSubOp.getOperand(0));
2076           return success();
2077         }
2078       }
2079     }
2080 
2081     if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
2082       if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
2083         // (midConst + something) - origConst == (midConst - origConst) +
2084         // something
2085         for (int j = 0; j < 2; j++) {
2086           if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
2087             auto nextConstant = rewriter.create<ConstantOp>(
2088                 subOp.getLoc(),
2089                 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
2090             rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2091                                                 midAddOp.getOperand(1 - j));
2092             return success();
2093           }
2094         }
2095       }
2096 
2097       if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
2098         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2099           // (midConst - something) - origConst == (midConst - origConst) -
2100           // something
2101           auto nextConstant = rewriter.create<ConstantOp>(
2102               subOp.getLoc(),
2103               rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
2104           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2105                                               midSubOp.getOperand(1));
2106           return success();
2107         }
2108 
2109         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2110           // (something - midConst) - origConst == something - (midConst +
2111           // origConst)
2112           auto nextConstant = rewriter.create<ConstantOp>(
2113               subOp.getLoc(),
2114               rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
2115           rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
2116                                               nextConstant);
2117           return success();
2118         }
2119       }
2120 
2121       if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
2122         if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2123           // origConst - (midConst - something) == (origConst - midConst) +
2124           // something
2125           auto nextConstant = rewriter.create<ConstantOp>(
2126               subOp.getLoc(),
2127               rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2128           rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2129                                               midSubOp.getOperand(1));
2130           return success();
2131         }
2132         if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2133           // origConst - (something - midConst) == (origConst - midConst) -
2134           // something
2135           auto nextConstant = rewriter.create<ConstantOp>(
2136               subOp.getLoc(),
2137               rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2138           rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2139                                               midSubOp.getOperand(0));
2140           return success();
2141         }
2142       }
2143     }
2144     return failure();
2145   }
2146 };
2147 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2148 void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2149                                          MLIRContext *context) {
2150   results.insert<SubConstantReorder>(context);
2151 }
2152 
2153 //===----------------------------------------------------------------------===//
2154 // UIToFPOp
2155 //===----------------------------------------------------------------------===//
2156 
2157 // uitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)2158 bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2159   if (inputs.size() != 1 || outputs.size() != 1)
2160     return false;
2161   Type a = inputs.front(), b = outputs.front();
2162   if (a.isSignlessInteger() && b.isa<FloatType>())
2163     return true;
2164   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2165 }
2166 
2167 //===----------------------------------------------------------------------===//
2168 // SwitchOp
2169 //===----------------------------------------------------------------------===//
2170 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2171 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2172                      Block *defaultDestination, ValueRange defaultOperands,
2173                      DenseIntElementsAttr caseValues,
2174                      BlockRange caseDestinations,
2175                      ArrayRef<ValueRange> caseOperands) {
2176   build(builder, result, value, defaultOperands, caseOperands, caseValues,
2177         defaultDestination, caseDestinations);
2178 }
2179 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2180 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2181                      Block *defaultDestination, ValueRange defaultOperands,
2182                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
2183                      ArrayRef<ValueRange> caseOperands) {
2184   DenseIntElementsAttr caseValuesAttr;
2185   if (!caseValues.empty()) {
2186     ShapedType caseValueType = VectorType::get(
2187         static_cast<int64_t>(caseValues.size()), value.getType());
2188     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
2189   }
2190   build(builder, result, value, defaultDestination, defaultOperands,
2191         caseValuesAttr, caseDestinations, caseOperands);
2192 }
2193 
2194 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
2195 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::OperandType> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)2196 static ParseResult parseSwitchOpCases(
2197     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
2198     SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
2199     SmallVectorImpl<Type> &defaultOperandTypes,
2200     DenseIntElementsAttr &caseValues,
2201     SmallVectorImpl<Block *> &caseDestinations,
2202     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
2203     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
2204   if (parser.parseKeyword("default") || parser.parseColon() ||
2205       parser.parseSuccessor(defaultDestination))
2206     return failure();
2207   if (succeeded(parser.parseOptionalLParen())) {
2208     if (parser.parseRegionArgumentList(defaultOperands) ||
2209         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
2210       return failure();
2211   }
2212 
2213   SmallVector<APInt> values;
2214   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
2215   while (succeeded(parser.parseOptionalComma())) {
2216     int64_t value = 0;
2217     if (failed(parser.parseInteger(value)))
2218       return failure();
2219     values.push_back(APInt(bitWidth, value));
2220 
2221     Block *destination;
2222     SmallVector<OpAsmParser::OperandType> operands;
2223     SmallVector<Type> operandTypes;
2224     if (failed(parser.parseColon()) ||
2225         failed(parser.parseSuccessor(destination)))
2226       return failure();
2227     if (succeeded(parser.parseOptionalLParen())) {
2228       if (failed(parser.parseRegionArgumentList(operands)) ||
2229           failed(parser.parseColonTypeList(operandTypes)) ||
2230           failed(parser.parseRParen()))
2231         return failure();
2232     }
2233     caseDestinations.push_back(destination);
2234     caseOperands.emplace_back(operands);
2235     caseOperandTypes.emplace_back(operandTypes);
2236   }
2237 
2238   if (!values.empty()) {
2239     ShapedType caseValueType =
2240         VectorType::get(static_cast<int64_t>(values.size()), flagType);
2241     caseValues = DenseIntElementsAttr::get(caseValueType, values);
2242   }
2243   return success();
2244 }
2245 
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,TypeRangeRange caseOperandTypes)2246 static void printSwitchOpCases(
2247     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
2248     OperandRange defaultOperands, TypeRange defaultOperandTypes,
2249     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
2250     OperandRangeRange caseOperands, TypeRangeRange caseOperandTypes) {
2251   p << "  default: ";
2252   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
2253 
2254   if (!caseValues)
2255     return;
2256 
2257   for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
2258     p << ',';
2259     p.printNewline();
2260     p << "  ";
2261     p << caseValues.getValue<APInt>(i).getLimitedValue();
2262     p << ": ";
2263     p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]);
2264   }
2265   p.printNewline();
2266 }
2267 
verify(SwitchOp op)2268 static LogicalResult verify(SwitchOp op) {
2269   auto caseValues = op.case_values();
2270   auto caseDestinations = op.caseDestinations();
2271 
2272   if (!caseValues && caseDestinations.empty())
2273     return success();
2274 
2275   Type flagType = op.flag().getType();
2276   Type caseValueType = caseValues->getType().getElementType();
2277   if (caseValueType != flagType)
2278     return op.emitOpError()
2279            << "'flag' type (" << flagType << ") should match case value type ("
2280            << caseValueType << ")";
2281 
2282   if (caseValues &&
2283       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
2284     return op.emitOpError() << "number of case values (" << caseValues->size()
2285                             << ") should match number of "
2286                                "case destinations ("
2287                             << caseDestinations.size() << ")";
2288   return success();
2289 }
2290 
2291 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)2292 SwitchOp::getMutableSuccessorOperands(unsigned index) {
2293   assert(index < getNumSuccessors() && "invalid successor index");
2294   return index == 0 ? defaultOperandsMutable()
2295                     : getCaseOperandsMutable(index - 1);
2296 }
2297 
getSuccessorForOperands(ArrayRef<Attribute> operands)2298 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
2299   Optional<DenseIntElementsAttr> caseValues = case_values();
2300 
2301   if (!caseValues)
2302     return defaultDestination();
2303 
2304   SuccessorRange caseDests = caseDestinations();
2305   if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
2306     for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
2307       if (value == caseValues->getValue<IntegerAttr>(i))
2308         return caseDests[i];
2309     return defaultDestination();
2310   }
2311   return nullptr;
2312 }
2313 
2314 /// switch %flag : i32, [
2315 ///   default:  ^bb1
2316 /// ]
2317 ///  -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)2318 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
2319                                                    PatternRewriter &rewriter) {
2320   if (!op.caseDestinations().empty())
2321     return failure();
2322 
2323   rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2324                                         op.defaultOperands());
2325   return success();
2326 }
2327 
2328 /// switch %flag : i32, [
2329 ///   default: ^bb1,
2330 ///   42: ^bb1,
2331 ///   43: ^bb2
2332 /// ]
2333 /// ->
2334 /// switch %flag : i32, [
2335 ///   default: ^bb1,
2336 ///   43: ^bb2
2337 /// ]
2338 static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)2339 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
2340   SmallVector<Block *> newCaseDestinations;
2341   SmallVector<ValueRange> newCaseOperands;
2342   SmallVector<APInt> newCaseValues;
2343   bool requiresChange = false;
2344   auto caseValues = op.case_values();
2345   auto caseDests = op.caseDestinations();
2346 
2347   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2348     if (caseDests[i] == op.defaultDestination() &&
2349         op.getCaseOperands(i) == op.defaultOperands()) {
2350       requiresChange = true;
2351       continue;
2352     }
2353     newCaseDestinations.push_back(caseDests[i]);
2354     newCaseOperands.push_back(op.getCaseOperands(i));
2355     newCaseValues.push_back(caseValues->getValue<APInt>(i));
2356   }
2357 
2358   if (!requiresChange)
2359     return failure();
2360 
2361   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2362                                         op.defaultOperands(), newCaseValues,
2363                                         newCaseDestinations, newCaseOperands);
2364   return success();
2365 }
2366 
2367 /// Helper for folding a switch with a constant value.
2368 /// switch %c_42 : i32, [
2369 ///   default: ^bb1 ,
2370 ///   42: ^bb2,
2371 ///   43: ^bb3
2372 /// ]
2373 /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,APInt caseValue)2374 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
2375                        APInt caseValue) {
2376   auto caseValues = op.case_values();
2377   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2378     if (caseValues->getValue<APInt>(i) == caseValue) {
2379       rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
2380                                             op.getCaseOperands(i));
2381       return;
2382     }
2383   }
2384   rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2385                                         op.defaultOperands());
2386 }
2387 
2388 /// switch %c_42 : i32, [
2389 ///   default: ^bb1,
2390 ///   42: ^bb2,
2391 ///   43: ^bb3
2392 /// ]
2393 /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)2394 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
2395                                               PatternRewriter &rewriter) {
2396   APInt caseValue;
2397   if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
2398     return failure();
2399 
2400   foldSwitch(op, rewriter, caseValue);
2401   return success();
2402 }
2403 
2404 /// switch %c_42 : i32, [
2405 ///   default: ^bb1,
2406 ///   42: ^bb2,
2407 /// ]
2408 /// ^bb2:
2409 ///   br ^bb3
2410 /// ->
2411 /// switch %c_42 : i32, [
2412 ///   default: ^bb1,
2413 ///   42: ^bb3,
2414 /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)2415 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
2416                                                PatternRewriter &rewriter) {
2417   SmallVector<Block *> newCaseDests;
2418   SmallVector<ValueRange> newCaseOperands;
2419   SmallVector<SmallVector<Value>> argStorage;
2420   auto caseValues = op.case_values();
2421   auto caseDests = op.caseDestinations();
2422   bool requiresChange = false;
2423   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2424     Block *caseDest = caseDests[i];
2425     ValueRange caseOperands = op.getCaseOperands(i);
2426     argStorage.emplace_back();
2427     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
2428       requiresChange = true;
2429 
2430     newCaseDests.push_back(caseDest);
2431     newCaseOperands.push_back(caseOperands);
2432   }
2433 
2434   Block *defaultDest = op.defaultDestination();
2435   ValueRange defaultOperands = op.defaultOperands();
2436   argStorage.emplace_back();
2437 
2438   if (succeeded(
2439           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
2440     requiresChange = true;
2441 
2442   if (!requiresChange)
2443     return failure();
2444 
2445   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
2446                                         defaultOperands, caseValues.getValue(),
2447                                         newCaseDests, newCaseOperands);
2448   return success();
2449 }
2450 
2451 /// switch %flag : i32, [
2452 ///   default: ^bb1,
2453 ///   42: ^bb2,
2454 /// ]
2455 /// ^bb2:
2456 ///   switch %flag : i32, [
2457 ///     default: ^bb3,
2458 ///     42: ^bb4
2459 ///   ]
2460 /// ->
2461 /// switch %flag : i32, [
2462 ///   default: ^bb1,
2463 ///   42: ^bb2,
2464 /// ]
2465 /// ^bb2:
2466 ///   br ^bb4
2467 ///
2468 ///  and
2469 ///
2470 /// switch %flag : i32, [
2471 ///   default: ^bb1,
2472 ///   42: ^bb2,
2473 /// ]
2474 /// ^bb2:
2475 ///   switch %flag : i32, [
2476 ///     default: ^bb3,
2477 ///     43: ^bb4
2478 ///   ]
2479 /// ->
2480 /// switch %flag : i32, [
2481 ///   default: ^bb1,
2482 ///   42: ^bb2,
2483 /// ]
2484 /// ^bb2:
2485 ///   br ^bb3
2486 static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2487 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
2488                                         PatternRewriter &rewriter) {
2489   // Check that we have a single distinct predecessor.
2490   Block *currentBlock = op->getBlock();
2491   Block *predecessor = currentBlock->getSinglePredecessor();
2492   if (!predecessor)
2493     return failure();
2494 
2495   // Check that the predecessor terminates with a switch branch to this block
2496   // and that it branches on the same condition and that this branch isn't the
2497   // default destination.
2498   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2499   if (!predSwitch || op.flag() != predSwitch.flag() ||
2500       predSwitch.defaultDestination() == currentBlock)
2501     return failure();
2502 
2503   // Fold this switch to an unconditional branch.
2504   APInt caseValue;
2505   bool isDefault = true;
2506   SuccessorRange predDests = predSwitch.caseDestinations();
2507   Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
2508   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
2509     if (currentBlock == predDests[i]) {
2510       caseValue = predCaseValues->getValue<APInt>(i);
2511       isDefault = false;
2512       break;
2513     }
2514   }
2515   if (isDefault)
2516     rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2517                                           op.defaultOperands());
2518   else
2519     foldSwitch(op, rewriter, caseValue);
2520   return success();
2521 }
2522 
2523 /// switch %flag : i32, [
2524 ///   default: ^bb1,
2525 ///   42: ^bb2
2526 /// ]
2527 /// ^bb1:
2528 ///   switch %flag : i32, [
2529 ///     default: ^bb3,
2530 ///     42: ^bb4,
2531 ///     43: ^bb5
2532 ///   ]
2533 /// ->
2534 /// switch %flag : i32, [
2535 ///   default: ^bb1,
2536 ///   42: ^bb2,
2537 /// ]
2538 /// ^bb1:
2539 ///   switch %flag : i32, [
2540 ///     default: ^bb3,
2541 ///     43: ^bb5
2542 ///   ]
2543 static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2544 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
2545                                                PatternRewriter &rewriter) {
2546   // Check that we have a single distinct predecessor.
2547   Block *currentBlock = op->getBlock();
2548   Block *predecessor = currentBlock->getSinglePredecessor();
2549   if (!predecessor)
2550     return failure();
2551 
2552   // Check that the predecessor terminates with a switch branch to this block
2553   // and that it branches on the same condition and that this branch is the
2554   // default destination.
2555   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2556   if (!predSwitch || op.flag() != predSwitch.flag() ||
2557       predSwitch.defaultDestination() != currentBlock)
2558     return failure();
2559 
2560   // Delete case values that are not possible here.
2561   DenseSet<APInt> caseValuesToRemove;
2562   auto predDests = predSwitch.caseDestinations();
2563   auto predCaseValues = predSwitch.case_values();
2564   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
2565     if (currentBlock != predDests[i])
2566       caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
2567 
2568   SmallVector<Block *> newCaseDestinations;
2569   SmallVector<ValueRange> newCaseOperands;
2570   SmallVector<APInt> newCaseValues;
2571   bool requiresChange = false;
2572 
2573   auto caseValues = op.case_values();
2574   auto caseDests = op.caseDestinations();
2575   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2576     if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
2577       requiresChange = true;
2578       continue;
2579     }
2580     newCaseDestinations.push_back(caseDests[i]);
2581     newCaseOperands.push_back(op.getCaseOperands(i));
2582     newCaseValues.push_back(caseValues->getValue<APInt>(i));
2583   }
2584 
2585   if (!requiresChange)
2586     return failure();
2587 
2588   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2589                                         op.defaultOperands(), newCaseValues,
2590                                         newCaseDestinations, newCaseOperands);
2591   return success();
2592 }
2593 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2594 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
2595                                            MLIRContext *context) {
2596   results.add(&simplifySwitchWithOnlyDefault)
2597       .add(&dropSwitchCasesThatMatchDefault)
2598       .add(&simplifyConstSwitchValue)
2599       .add(&simplifyPassThroughSwitch)
2600       .add(&simplifySwitchFromSwitchOnSameCondition)
2601       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
2602 }
2603 
2604 //===----------------------------------------------------------------------===//
2605 // TruncateIOp
2606 //===----------------------------------------------------------------------===//
2607 
verify(TruncateIOp op)2608 static LogicalResult verify(TruncateIOp op) {
2609   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2610   auto dstType = getElementTypeOrSelf(op.getType());
2611 
2612   if (srcType.isa<IndexType>())
2613     return op.emitError() << srcType << " is not a valid operand type";
2614   if (dstType.isa<IndexType>())
2615     return op.emitError() << dstType << " is not a valid result type";
2616 
2617   if (srcType.cast<IntegerType>().getWidth() <=
2618       dstType.cast<IntegerType>().getWidth())
2619     return op.emitError("operand type ")
2620            << srcType << " must be wider than result type " << dstType;
2621 
2622   return success();
2623 }
2624 
fold(ArrayRef<Attribute> operands)2625 OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
2626   // trunci(zexti(a)) -> a
2627   // trunci(sexti(a)) -> a
2628   if (matchPattern(getOperand(), m_Op<ZeroExtendIOp>()) ||
2629       matchPattern(getOperand(), m_Op<SignExtendIOp>()))
2630     return getOperand().getDefiningOp()->getOperand(0);
2631 
2632   assert(operands.size() == 1 && "unary operation takes one operand");
2633 
2634   if (!operands[0])
2635     return {};
2636 
2637   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
2638 
2639     return IntegerAttr::get(
2640         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
2641   }
2642 
2643   return {};
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // UnsignedDivIOp
2648 //===----------------------------------------------------------------------===//
2649 
fold(ArrayRef<Attribute> operands)2650 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
2651   assert(operands.size() == 2 && "binary operation takes two operands");
2652 
2653   // Don't fold if it would require a division by zero.
2654   bool div0 = false;
2655   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2656     if (div0 || !b) {
2657       div0 = true;
2658       return a;
2659     }
2660     return a.udiv(b);
2661   });
2662 
2663   // Fold out division by one. Assumes all tensors of all ones are splats.
2664   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2665     if (rhs.getValue() == 1)
2666       return lhs();
2667   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2668     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2669       return lhs();
2670   }
2671 
2672   return div0 ? Attribute() : result;
2673 }
2674 
2675 //===----------------------------------------------------------------------===//
2676 // UnsignedRemIOp
2677 //===----------------------------------------------------------------------===//
2678 
fold(ArrayRef<Attribute> operands)2679 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
2680   assert(operands.size() == 2 && "remi_unsigned takes two operands");
2681 
2682   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2683   if (!rhs)
2684     return {};
2685   auto rhsValue = rhs.getValue();
2686 
2687   // x % 1 = 0
2688   if (rhsValue.isOneValue())
2689     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2690 
2691   // Don't fold if it requires division by zero.
2692   if (rhsValue.isNullValue())
2693     return {};
2694 
2695   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2696   if (!lhs)
2697     return {};
2698   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
2699 }
2700 
2701 //===----------------------------------------------------------------------===//
2702 // XOrOp
2703 //===----------------------------------------------------------------------===//
2704 
fold(ArrayRef<Attribute> operands)2705 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
2706   /// xor(x, 0) -> x
2707   if (matchPattern(rhs(), m_Zero()))
2708     return lhs();
2709   /// xor(x,x) -> 0
2710   if (lhs() == rhs())
2711     return Builder(getContext()).getZeroAttr(getType());
2712 
2713   return constFoldBinaryOp<IntegerAttr>(operands,
2714                                         [](APInt a, APInt b) { return a ^ b; });
2715 }
2716 
2717 namespace {
2718 /// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
2719 /// cmp ne A, B. Note that a logical not is implemented as xor 1, val.
2720 struct NotICmp : public OpRewritePattern<XOrOp> {
2721   using OpRewritePattern<XOrOp>::OpRewritePattern;
2722 
matchAndRewrite__anon1787a6f51511::NotICmp2723   LogicalResult matchAndRewrite(XOrOp op,
2724                                 PatternRewriter &rewriter) const override {
2725     // Commutative ops (such as xor) have the constant appear second, which
2726     // we assume here.
2727 
2728     APInt constValue;
2729     if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
2730       return failure();
2731 
2732     if (constValue != 1)
2733       return failure();
2734 
2735     auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
2736     if (!prev)
2737       return failure();
2738 
2739     switch (prev.predicate()) {
2740     case CmpIPredicate::eq:
2741       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
2742                                           prev.rhs());
2743       return success();
2744     case CmpIPredicate::ne:
2745       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
2746                                           prev.rhs());
2747       return success();
2748 
2749     case CmpIPredicate::slt:
2750       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
2751                                           prev.rhs());
2752       return success();
2753     case CmpIPredicate::sle:
2754       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
2755                                           prev.rhs());
2756       return success();
2757     case CmpIPredicate::sgt:
2758       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
2759                                           prev.rhs());
2760       return success();
2761     case CmpIPredicate::sge:
2762       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
2763                                           prev.rhs());
2764       return success();
2765 
2766     case CmpIPredicate::ult:
2767       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
2768                                           prev.rhs());
2769       return success();
2770     case CmpIPredicate::ule:
2771       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
2772                                           prev.rhs());
2773       return success();
2774     case CmpIPredicate::ugt:
2775       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
2776                                           prev.rhs());
2777       return success();
2778     case CmpIPredicate::uge:
2779       rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
2780                                           prev.rhs());
2781       return success();
2782     }
2783     return failure();
2784   }
2785 };
2786 } // namespace
2787 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2788 void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2789                                         MLIRContext *context) {
2790   results.insert<NotICmp>(context);
2791 }
2792 
2793 //===----------------------------------------------------------------------===//
2794 // ZeroExtendIOp
2795 //===----------------------------------------------------------------------===//
2796 
verify(ZeroExtendIOp op)2797 static LogicalResult verify(ZeroExtendIOp op) {
2798   auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2799   auto dstType = getElementTypeOrSelf(op.getType());
2800 
2801   if (srcType.isa<IndexType>())
2802     return op.emitError() << srcType << " is not a valid operand type";
2803   if (dstType.isa<IndexType>())
2804     return op.emitError() << dstType << " is not a valid result type";
2805 
2806   if (srcType.cast<IntegerType>().getWidth() >=
2807       dstType.cast<IntegerType>().getWidth())
2808     return op.emitError("result type ")
2809            << dstType << " must be wider than operand type " << srcType;
2810 
2811   return success();
2812 }
2813 
2814 //===----------------------------------------------------------------------===//
2815 // TableGen'd op method definitions
2816 //===----------------------------------------------------------------------===//
2817 
2818 #define GET_OP_CLASSES
2819 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
2820