1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32
33 #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
34
35 // Pull in all enum type definitions and utility function declarations.
36 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
37
38 using namespace mlir;
39
40 //===----------------------------------------------------------------------===//
41 // StandardOpsDialect Interfaces
42 //===----------------------------------------------------------------------===//
43 namespace {
44 /// This class defines the interface for handling inlining with standard
45 /// operations.
46 struct StdInlinerInterface : public DialectInlinerInterface {
47 using DialectInlinerInterface::DialectInlinerInterface;
48
49 //===--------------------------------------------------------------------===//
50 // Analysis Hooks
51 //===--------------------------------------------------------------------===//
52
53 /// All call operations within standard ops can be inlined.
isLegalToInline__anon1787a6f50111::StdInlinerInterface54 bool isLegalToInline(Operation *call, Operation *callable,
55 bool wouldBeCloned) const final {
56 return true;
57 }
58
59 /// All operations within standard ops can be inlined.
isLegalToInline__anon1787a6f50111::StdInlinerInterface60 bool isLegalToInline(Operation *, Region *, bool,
61 BlockAndValueMapping &) const final {
62 return true;
63 }
64
65 //===--------------------------------------------------------------------===//
66 // Transformation Hooks
67 //===--------------------------------------------------------------------===//
68
69 /// Handle the given inlined terminator by replacing it with a new operation
70 /// as necessary.
handleTerminator__anon1787a6f50111::StdInlinerInterface71 void handleTerminator(Operation *op, Block *newDest) const final {
72 // Only "std.return" needs to be handled here.
73 auto returnOp = dyn_cast<ReturnOp>(op);
74 if (!returnOp)
75 return;
76
77 // Replace the return with a branch to the dest.
78 OpBuilder builder(op);
79 builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
80 op->erase();
81 }
82
83 /// Handle the given inlined terminator by replacing it with a new operation
84 /// as necessary.
handleTerminator__anon1787a6f50111::StdInlinerInterface85 void handleTerminator(Operation *op,
86 ArrayRef<Value> valuesToRepl) const final {
87 // Only "std.return" needs to be handled here.
88 auto returnOp = cast<ReturnOp>(op);
89
90 // Replace the values directly with the return operands.
91 assert(returnOp.getNumOperands() == valuesToRepl.size());
92 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
93 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
94 }
95 };
96 } // end anonymous namespace
97
98 //===----------------------------------------------------------------------===//
99 // StandardOpsDialect
100 //===----------------------------------------------------------------------===//
101
102 /// A custom unary operation printer that omits the "std." prefix from the
103 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)104 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
105 assert(op->getNumOperands() == 1 && "unary op should have one operand");
106 assert(op->getNumResults() == 1 && "unary op should have one result");
107
108 p << ' ' << op->getOperand(0);
109 p.printOptionalAttrDict(op->getAttrs());
110 p << " : " << op->getOperand(0).getType();
111 }
112
113 /// A custom binary operation printer that omits the "std." prefix from the
114 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)115 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
116 assert(op->getNumOperands() == 2 && "binary op should have two operands");
117 assert(op->getNumResults() == 1 && "binary op should have one result");
118
119 // If not all the operand and result types are the same, just use the
120 // generic assembly form to avoid omitting information in printing.
121 auto resultType = op->getResult(0).getType();
122 if (op->getOperand(0).getType() != resultType ||
123 op->getOperand(1).getType() != resultType) {
124 p.printGenericOp(op);
125 return;
126 }
127
128 p << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
129 p.printOptionalAttrDict(op->getAttrs());
130
131 // Now we can output only one type for all operands and the result.
132 p << " : " << op->getResult(0).getType();
133 }
134
135 /// A custom ternary operation printer that omits the "std." prefix from the
136 /// operation names.
printStandardTernaryOp(Operation * op,OpAsmPrinter & p)137 static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
138 assert(op->getNumOperands() == 3 && "ternary op should have three operands");
139 assert(op->getNumResults() == 1 && "ternary op should have one result");
140
141 // If not all the operand and result types are the same, just use the
142 // generic assembly form to avoid omitting information in printing.
143 auto resultType = op->getResult(0).getType();
144 if (op->getOperand(0).getType() != resultType ||
145 op->getOperand(1).getType() != resultType ||
146 op->getOperand(2).getType() != resultType) {
147 p.printGenericOp(op);
148 return;
149 }
150
151 p << ' ' << op->getOperand(0) << ", " << op->getOperand(1) << ", "
152 << op->getOperand(2);
153 p.printOptionalAttrDict(op->getAttrs());
154
155 // Now we can output only one type for all operands and the result.
156 p << " : " << op->getResult(0).getType();
157 }
158
159 /// A custom cast operation printer that omits the "std." prefix from the
160 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)161 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
162 p << ' ' << op->getOperand(0) << " : " << op->getOperand(0).getType()
163 << " to " << op->getResult(0).getType();
164 }
165
initialize()166 void StandardOpsDialect::initialize() {
167 addOperations<
168 #define GET_OP_LIST
169 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
170 >();
171 addInterfaces<StdInlinerInterface>();
172 }
173
174 /// Materialize a single constant operation from a given attribute value with
175 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)176 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
177 Attribute value, Type type,
178 Location loc) {
179 return builder.create<ConstantOp>(loc, type, value);
180 }
181
182 //===----------------------------------------------------------------------===//
183 // Common cast compatibility check for vector types.
184 //===----------------------------------------------------------------------===//
185
186 /// This method checks for cast compatibility of vector types.
187 /// If 'a' and 'b' are vector types, and they are cast compatible,
188 /// it calls the 'areElementsCastCompatible' function to check for
189 /// element cast compatibility.
190 /// Returns 'true' if the vector types are cast compatible, and 'false'
191 /// otherwise.
areVectorCastSimpleCompatible(Type a,Type b,function_ref<bool (TypeRange,TypeRange)> areElementsCastCompatible)192 static bool areVectorCastSimpleCompatible(
193 Type a, Type b,
194 function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
195 if (auto va = a.dyn_cast<VectorType>())
196 if (auto vb = b.dyn_cast<VectorType>())
197 return va.getShape().equals(vb.getShape()) &&
198 areElementsCastCompatible(va.getElementType(),
199 vb.getElementType());
200 return false;
201 }
202
203 //===----------------------------------------------------------------------===//
204 // AddFOp
205 //===----------------------------------------------------------------------===//
206
fold(ArrayRef<Attribute> operands)207 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
208 return constFoldBinaryOp<FloatAttr>(
209 operands, [](APFloat a, APFloat b) { return a + b; });
210 }
211
212 //===----------------------------------------------------------------------===//
213 // AddIOp
214 //===----------------------------------------------------------------------===//
215
fold(ArrayRef<Attribute> operands)216 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
217 /// addi(x, 0) -> x
218 if (matchPattern(rhs(), m_Zero()))
219 return lhs();
220
221 return constFoldBinaryOp<IntegerAttr>(operands,
222 [](APInt a, APInt b) { return a + b; });
223 }
224
225 /// Canonicalize a sum of a constant and (constant - something) to simply be
226 /// a sum of constants minus something. This transformation does similar
227 /// transformations for additions of a constant with a subtract/add of
228 /// a constant. This may result in some operations being reordered (but should
229 /// remain equivalent).
230 struct AddConstantReorder : public OpRewritePattern<AddIOp> {
231 using OpRewritePattern<AddIOp>::OpRewritePattern;
232
matchAndRewriteAddConstantReorder233 LogicalResult matchAndRewrite(AddIOp addop,
234 PatternRewriter &rewriter) const override {
235 for (int i = 0; i < 2; i++) {
236 APInt origConst;
237 APInt midConst;
238 if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
239 if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
240 for (int j = 0; j < 2; j++) {
241 if (matchPattern(midAddOp.getOperand(j),
242 m_ConstantInt(&midConst))) {
243 auto nextConstant = rewriter.create<ConstantOp>(
244 addop.getLoc(), rewriter.getIntegerAttr(
245 addop.getType(), origConst + midConst));
246 rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
247 midAddOp.getOperand(1 - j));
248 return success();
249 }
250 }
251 }
252 if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
253 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
254 auto nextConstant = rewriter.create<ConstantOp>(
255 addop.getLoc(),
256 rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
257 rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
258 midSubOp.getOperand(1));
259 return success();
260 }
261 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
262 auto nextConstant = rewriter.create<ConstantOp>(
263 addop.getLoc(),
264 rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
265 rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
266 midSubOp.getOperand(0));
267 return success();
268 }
269 }
270 }
271 }
272 return failure();
273 }
274 };
275
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)276 void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
277 MLIRContext *context) {
278 results.insert<AddConstantReorder>(context);
279 }
280
281 //===----------------------------------------------------------------------===//
282 // AndOp
283 //===----------------------------------------------------------------------===//
284
fold(ArrayRef<Attribute> operands)285 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
286 /// and(x, 0) -> 0
287 if (matchPattern(rhs(), m_Zero()))
288 return rhs();
289 /// and(x, allOnes) -> x
290 APInt intValue;
291 if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
292 return lhs();
293 /// and(x,x) -> x
294 if (lhs() == rhs())
295 return rhs();
296
297 return constFoldBinaryOp<IntegerAttr>(operands,
298 [](APInt a, APInt b) { return a & b; });
299 }
300
301 //===----------------------------------------------------------------------===//
302 // AssertOp
303 //===----------------------------------------------------------------------===//
304
canonicalize(AssertOp op,PatternRewriter & rewriter)305 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
306 // Erase assertion if argument is constant true.
307 if (matchPattern(op.arg(), m_One())) {
308 rewriter.eraseOp(op);
309 return success();
310 }
311 return failure();
312 }
313
314 //===----------------------------------------------------------------------===//
315 // AtomicRMWOp
316 //===----------------------------------------------------------------------===//
317
verify(AtomicRMWOp op)318 static LogicalResult verify(AtomicRMWOp op) {
319 if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
320 return op.emitOpError(
321 "expects the number of subscripts to be equal to memref rank");
322 switch (op.kind()) {
323 case AtomicRMWKind::addf:
324 case AtomicRMWKind::maxf:
325 case AtomicRMWKind::minf:
326 case AtomicRMWKind::mulf:
327 if (!op.value().getType().isa<FloatType>())
328 return op.emitOpError()
329 << "with kind '" << stringifyAtomicRMWKind(op.kind())
330 << "' expects a floating-point type";
331 break;
332 case AtomicRMWKind::addi:
333 case AtomicRMWKind::maxs:
334 case AtomicRMWKind::maxu:
335 case AtomicRMWKind::mins:
336 case AtomicRMWKind::minu:
337 case AtomicRMWKind::muli:
338 if (!op.value().getType().isa<IntegerType>())
339 return op.emitOpError()
340 << "with kind '" << stringifyAtomicRMWKind(op.kind())
341 << "' expects an integer type";
342 break;
343 default:
344 break;
345 }
346 return success();
347 }
348
349 /// Returns the identity value attribute associated with an AtomicRMWKind op.
getIdentityValueAttr(AtomicRMWKind kind,Type resultType,OpBuilder & builder,Location loc)350 Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
351 OpBuilder &builder, Location loc) {
352 switch (kind) {
353 case AtomicRMWKind::maxf:
354 return builder.getFloatAttr(
355 resultType,
356 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
357 /*Negative=*/true));
358 case AtomicRMWKind::addf:
359 case AtomicRMWKind::addi:
360 case AtomicRMWKind::maxu:
361 return builder.getZeroAttr(resultType);
362 case AtomicRMWKind::maxs:
363 return builder.getIntegerAttr(
364 resultType,
365 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
366 case AtomicRMWKind::minf:
367 return builder.getFloatAttr(
368 resultType,
369 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
370 /*Negative=*/false));
371 case AtomicRMWKind::mins:
372 return builder.getIntegerAttr(
373 resultType,
374 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
375 case AtomicRMWKind::minu:
376 return builder.getIntegerAttr(
377 resultType,
378 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
379 case AtomicRMWKind::muli:
380 return builder.getIntegerAttr(resultType, 1);
381 case AtomicRMWKind::mulf:
382 return builder.getFloatAttr(resultType, 1);
383 // TODO: Add remaining reduction operations.
384 default:
385 (void)emitOptionalError(loc, "Reduction operation type not supported");
386 break;
387 }
388 return nullptr;
389 }
390
391 /// Returns the identity value associated with an AtomicRMWKind op.
getIdentityValue(AtomicRMWKind op,Type resultType,OpBuilder & builder,Location loc)392 Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
393 OpBuilder &builder, Location loc) {
394 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
395 return builder.create<ConstantOp>(loc, attr);
396 }
397
398 /// Return the value obtained by applying the reduction operation kind
399 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
getReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value lhs,Value rhs)400 Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
401 Value lhs, Value rhs) {
402 switch (op) {
403 case AtomicRMWKind::addf:
404 return builder.create<AddFOp>(loc, lhs, rhs);
405 case AtomicRMWKind::addi:
406 return builder.create<AddIOp>(loc, lhs, rhs);
407 case AtomicRMWKind::mulf:
408 return builder.create<MulFOp>(loc, lhs, rhs);
409 case AtomicRMWKind::muli:
410 return builder.create<MulIOp>(loc, lhs, rhs);
411 case AtomicRMWKind::maxf:
412 return builder.create<SelectOp>(
413 loc, builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs), lhs,
414 rhs);
415 case AtomicRMWKind::minf:
416 return builder.create<SelectOp>(
417 loc, builder.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs), lhs,
418 rhs);
419 case AtomicRMWKind::maxs:
420 return builder.create<SelectOp>(
421 loc, builder.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs), lhs,
422 rhs);
423 case AtomicRMWKind::mins:
424 return builder.create<SelectOp>(
425 loc, builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs), lhs,
426 rhs);
427 case AtomicRMWKind::maxu:
428 return builder.create<SelectOp>(
429 loc, builder.create<CmpIOp>(loc, CmpIPredicate::ugt, lhs, rhs), lhs,
430 rhs);
431 case AtomicRMWKind::minu:
432 return builder.create<SelectOp>(
433 loc, builder.create<CmpIOp>(loc, CmpIPredicate::ult, lhs, rhs), lhs,
434 rhs);
435 // TODO: Add remaining reduction operations.
436 default:
437 (void)emitOptionalError(loc, "Reduction operation type not supported");
438 break;
439 }
440 return nullptr;
441 }
442
443 //===----------------------------------------------------------------------===//
444 // GenericAtomicRMWOp
445 //===----------------------------------------------------------------------===//
446
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)447 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
448 Value memref, ValueRange ivs) {
449 result.addOperands(memref);
450 result.addOperands(ivs);
451
452 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
453 Type elementType = memrefType.getElementType();
454 result.addTypes(elementType);
455
456 Region *bodyRegion = result.addRegion();
457 bodyRegion->push_back(new Block());
458 bodyRegion->addArgument(elementType);
459 }
460 }
461
verify(GenericAtomicRMWOp op)462 static LogicalResult verify(GenericAtomicRMWOp op) {
463 auto &body = op.body();
464 if (body.getNumArguments() != 1)
465 return op.emitOpError("expected single number of entry block arguments");
466
467 if (op.getResult().getType() != body.getArgument(0).getType())
468 return op.emitOpError(
469 "expected block argument of the same type result type");
470
471 bool hasSideEffects =
472 body.walk([&](Operation *nestedOp) {
473 if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
474 return WalkResult::advance();
475 nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
476 "only operations with no side effects");
477 return WalkResult::interrupt();
478 })
479 .wasInterrupted();
480 return hasSideEffects ? failure() : success();
481 }
482
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)483 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
484 OperationState &result) {
485 OpAsmParser::OperandType memref;
486 Type memrefType;
487 SmallVector<OpAsmParser::OperandType, 4> ivs;
488
489 Type indexType = parser.getBuilder().getIndexType();
490 if (parser.parseOperand(memref) ||
491 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
492 parser.parseColonType(memrefType) ||
493 parser.resolveOperand(memref, memrefType, result.operands) ||
494 parser.resolveOperands(ivs, indexType, result.operands))
495 return failure();
496
497 Region *body = result.addRegion();
498 if (parser.parseRegion(*body, llvm::None, llvm::None) ||
499 parser.parseOptionalAttrDict(result.attributes))
500 return failure();
501 result.types.push_back(memrefType.cast<MemRefType>().getElementType());
502 return success();
503 }
504
print(OpAsmPrinter & p,GenericAtomicRMWOp op)505 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
506 p << ' ' << op.memref() << "[" << op.indices()
507 << "] : " << op.memref().getType();
508 p.printRegion(op.body());
509 p.printOptionalAttrDict(op->getAttrs());
510 }
511
512 //===----------------------------------------------------------------------===//
513 // AtomicYieldOp
514 //===----------------------------------------------------------------------===//
515
verify(AtomicYieldOp op)516 static LogicalResult verify(AtomicYieldOp op) {
517 Type parentType = op->getParentOp()->getResultTypes().front();
518 Type resultType = op.result().getType();
519 if (parentType != resultType)
520 return op.emitOpError() << "types mismatch between yield op: " << resultType
521 << " and its parent: " << parentType;
522 return success();
523 }
524
525 //===----------------------------------------------------------------------===//
526 // BitcastOp
527 //===----------------------------------------------------------------------===//
528
areCastCompatible(TypeRange inputs,TypeRange outputs)529 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
530 assert(inputs.size() == 1 && outputs.size() == 1 &&
531 "bitcast op expects one operand and result");
532 Type a = inputs.front(), b = outputs.front();
533 if (a.isSignlessIntOrFloat() && b.isSignlessIntOrFloat())
534 return a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
535 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
536 }
537
fold(ArrayRef<Attribute> operands)538 OpFoldResult BitcastOp::fold(ArrayRef<Attribute> operands) {
539 assert(operands.size() == 1 && "bitcastop expects 1 operand");
540
541 // Bitcast of bitcast
542 auto *sourceOp = getOperand().getDefiningOp();
543 if (auto sourceBitcast = dyn_cast_or_null<BitcastOp>(sourceOp)) {
544 setOperand(sourceBitcast.getOperand());
545 return getResult();
546 }
547
548 auto operand = operands[0];
549 if (!operand)
550 return {};
551
552 Type resType = getResult().getType();
553
554 if (auto denseAttr = operand.dyn_cast<DenseElementsAttr>())
555 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
556
557 APInt bits;
558 if (auto floatAttr = operand.dyn_cast<FloatAttr>())
559 bits = floatAttr.getValue().bitcastToAPInt();
560 else if (auto intAttr = operand.dyn_cast<IntegerAttr>())
561 bits = intAttr.getValue();
562 else
563 return {};
564
565 if (resType.isa<IntegerType>())
566 return IntegerAttr::get(resType, bits);
567 if (auto resFloatType = resType.dyn_cast<FloatType>())
568 return FloatAttr::get(resType,
569 APFloat(resFloatType.getFloatSemantics(), bits));
570 return {};
571 }
572
573 //===----------------------------------------------------------------------===//
574 // BranchOp
575 //===----------------------------------------------------------------------===//
576
577 /// Given a successor, try to collapse it to a new destination if it only
578 /// contains a passthrough unconditional branch. If the successor is
579 /// collapsable, `successor` and `successorOperands` are updated to reference
580 /// the new destination and values. `argStorage` is used as storage if operands
581 /// to the collapsed successor need to be remapped. It must outlive uses of
582 /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)583 static LogicalResult collapseBranch(Block *&successor,
584 ValueRange &successorOperands,
585 SmallVectorImpl<Value> &argStorage) {
586 // Check that the successor only contains a unconditional branch.
587 if (std::next(successor->begin()) != successor->end())
588 return failure();
589 // Check that the terminator is an unconditional branch.
590 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
591 if (!successorBranch)
592 return failure();
593 // Check that the arguments are only used within the terminator.
594 for (BlockArgument arg : successor->getArguments()) {
595 for (Operation *user : arg.getUsers())
596 if (user != successorBranch)
597 return failure();
598 }
599 // Don't try to collapse branches to infinite loops.
600 Block *successorDest = successorBranch.getDest();
601 if (successorDest == successor)
602 return failure();
603
604 // Update the operands to the successor. If the branch parent has no
605 // arguments, we can use the branch operands directly.
606 OperandRange operands = successorBranch.getOperands();
607 if (successor->args_empty()) {
608 successor = successorDest;
609 successorOperands = operands;
610 return success();
611 }
612
613 // Otherwise, we need to remap any argument operands.
614 for (Value operand : operands) {
615 BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
616 if (argOperand && argOperand.getOwner() == successor)
617 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
618 else
619 argStorage.push_back(operand);
620 }
621 successor = successorDest;
622 successorOperands = argStorage;
623 return success();
624 }
625
626 /// Simplify a branch to a block that has a single predecessor. This effectively
627 /// merges the two blocks.
628 static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)629 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
630 // Check that the successor block has a single predecessor.
631 Block *succ = op.getDest();
632 Block *opParent = op->getBlock();
633 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
634 return failure();
635
636 // Merge the successor into the current block and erase the branch.
637 rewriter.mergeBlocks(succ, opParent, op.getOperands());
638 rewriter.eraseOp(op);
639 return success();
640 }
641
642 /// br ^bb1
643 /// ^bb1
644 /// br ^bbN(...)
645 ///
646 /// -> br ^bbN(...)
647 ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)648 static LogicalResult simplifyPassThroughBr(BranchOp op,
649 PatternRewriter &rewriter) {
650 Block *dest = op.getDest();
651 ValueRange destOperands = op.getOperands();
652 SmallVector<Value, 4> destOperandStorage;
653
654 // Try to collapse the successor if it points somewhere other than this
655 // block.
656 if (dest == op->getBlock() ||
657 failed(collapseBranch(dest, destOperands, destOperandStorage)))
658 return failure();
659
660 // Create a new branch with the collapsed successor.
661 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
662 return success();
663 }
664
canonicalize(BranchOp op,PatternRewriter & rewriter)665 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
666 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
667 succeeded(simplifyPassThroughBr(op, rewriter)));
668 }
669
getDest()670 Block *BranchOp::getDest() { return getSuccessor(); }
671
setDest(Block * block)672 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
673
eraseOperand(unsigned index)674 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
675
676 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)677 BranchOp::getMutableSuccessorOperands(unsigned index) {
678 assert(index == 0 && "invalid successor index");
679 return destOperandsMutable();
680 }
681
getSuccessorForOperands(ArrayRef<Attribute>)682 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
683
684 //===----------------------------------------------------------------------===//
685 // CallOp
686 //===----------------------------------------------------------------------===//
687
verifySymbolUses(SymbolTableCollection & symbolTable)688 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
689 // Check that the callee attribute was specified.
690 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
691 if (!fnAttr)
692 return emitOpError("requires a 'callee' symbol reference attribute");
693 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
694 if (!fn)
695 return emitOpError() << "'" << fnAttr.getValue()
696 << "' does not reference a valid function";
697
698 // Verify that the operand and result types match the callee.
699 auto fnType = fn.getType();
700 if (fnType.getNumInputs() != getNumOperands())
701 return emitOpError("incorrect number of operands for callee");
702
703 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
704 if (getOperand(i).getType() != fnType.getInput(i))
705 return emitOpError("operand type mismatch: expected operand type ")
706 << fnType.getInput(i) << ", but provided "
707 << getOperand(i).getType() << " for operand number " << i;
708
709 if (fnType.getNumResults() != getNumResults())
710 return emitOpError("incorrect number of results for callee");
711
712 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
713 if (getResult(i).getType() != fnType.getResult(i)) {
714 auto diag = emitOpError("result type mismatch at index ") << i;
715 diag.attachNote() << " op result types: " << getResultTypes();
716 diag.attachNote() << "function result types: " << fnType.getResults();
717 return diag;
718 }
719
720 return success();
721 }
722
getCalleeType()723 FunctionType CallOp::getCalleeType() {
724 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
725 }
726
727 //===----------------------------------------------------------------------===//
728 // CallIndirectOp
729 //===----------------------------------------------------------------------===//
730
731 /// Fold indirect calls that have a constant function as the callee operand.
canonicalize(CallIndirectOp indirectCall,PatternRewriter & rewriter)732 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
733 PatternRewriter &rewriter) {
734 // Check that the callee is a constant callee.
735 SymbolRefAttr calledFn;
736 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
737 return failure();
738
739 // Replace with a direct call.
740 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
741 indirectCall.getResultTypes(),
742 indirectCall.getArgOperands());
743 return success();
744 }
745
746 //===----------------------------------------------------------------------===//
747 // General helpers for comparison ops
748 //===----------------------------------------------------------------------===//
749
750 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)751 static Type getI1SameShape(Type type) {
752 auto i1Type = IntegerType::get(type.getContext(), 1);
753 if (auto tensorType = type.dyn_cast<RankedTensorType>())
754 return RankedTensorType::get(tensorType.getShape(), i1Type);
755 if (type.isa<UnrankedTensorType>())
756 return UnrankedTensorType::get(i1Type);
757 if (auto vectorType = type.dyn_cast<VectorType>())
758 return VectorType::get(vectorType.getShape(), i1Type);
759 return i1Type;
760 }
761
762 //===----------------------------------------------------------------------===//
763 // CmpIOp
764 //===----------------------------------------------------------------------===//
765
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)766 static void buildCmpIOp(OpBuilder &build, OperationState &result,
767 CmpIPredicate predicate, Value lhs, Value rhs) {
768 result.addOperands({lhs, rhs});
769 result.types.push_back(getI1SameShape(lhs.getType()));
770 result.addAttribute(CmpIOp::getPredicateAttrName(),
771 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
772 }
773
774 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
775 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)776 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
777 const APInt &rhs) {
778 switch (predicate) {
779 case CmpIPredicate::eq:
780 return lhs.eq(rhs);
781 case CmpIPredicate::ne:
782 return lhs.ne(rhs);
783 case CmpIPredicate::slt:
784 return lhs.slt(rhs);
785 case CmpIPredicate::sle:
786 return lhs.sle(rhs);
787 case CmpIPredicate::sgt:
788 return lhs.sgt(rhs);
789 case CmpIPredicate::sge:
790 return lhs.sge(rhs);
791 case CmpIPredicate::ult:
792 return lhs.ult(rhs);
793 case CmpIPredicate::ule:
794 return lhs.ule(rhs);
795 case CmpIPredicate::ugt:
796 return lhs.ugt(rhs);
797 case CmpIPredicate::uge:
798 return lhs.uge(rhs);
799 }
800 llvm_unreachable("unknown comparison predicate");
801 }
802
803 // Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(CmpIPredicate predicate)804 static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
805 switch (predicate) {
806 case CmpIPredicate::eq:
807 case CmpIPredicate::sle:
808 case CmpIPredicate::sge:
809 case CmpIPredicate::ule:
810 case CmpIPredicate::uge:
811 return true;
812 case CmpIPredicate::ne:
813 case CmpIPredicate::slt:
814 case CmpIPredicate::sgt:
815 case CmpIPredicate::ult:
816 case CmpIPredicate::ugt:
817 return false;
818 }
819 llvm_unreachable("unknown comparison predicate");
820 }
821
822 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)823 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
824 assert(operands.size() == 2 && "cmpi takes two arguments");
825
826 if (lhs() == rhs()) {
827 auto val = applyCmpPredicateToEqualOperands(getPredicate());
828 return BoolAttr::get(getContext(), val);
829 }
830
831 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
832 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
833 if (!lhs || !rhs)
834 return {};
835
836 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
837 return BoolAttr::get(getContext(), val);
838 }
839
840 //===----------------------------------------------------------------------===//
841 // CmpFOp
842 //===----------------------------------------------------------------------===//
843
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)844 static void buildCmpFOp(OpBuilder &build, OperationState &result,
845 CmpFPredicate predicate, Value lhs, Value rhs) {
846 result.addOperands({lhs, rhs});
847 result.types.push_back(getI1SameShape(lhs.getType()));
848 result.addAttribute(CmpFOp::getPredicateAttrName(),
849 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
850 }
851
852 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
853 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)854 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
855 const APFloat &rhs) {
856 auto cmpResult = lhs.compare(rhs);
857 switch (predicate) {
858 case CmpFPredicate::AlwaysFalse:
859 return false;
860 case CmpFPredicate::OEQ:
861 return cmpResult == APFloat::cmpEqual;
862 case CmpFPredicate::OGT:
863 return cmpResult == APFloat::cmpGreaterThan;
864 case CmpFPredicate::OGE:
865 return cmpResult == APFloat::cmpGreaterThan ||
866 cmpResult == APFloat::cmpEqual;
867 case CmpFPredicate::OLT:
868 return cmpResult == APFloat::cmpLessThan;
869 case CmpFPredicate::OLE:
870 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
871 case CmpFPredicate::ONE:
872 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
873 case CmpFPredicate::ORD:
874 return cmpResult != APFloat::cmpUnordered;
875 case CmpFPredicate::UEQ:
876 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
877 case CmpFPredicate::UGT:
878 return cmpResult == APFloat::cmpUnordered ||
879 cmpResult == APFloat::cmpGreaterThan;
880 case CmpFPredicate::UGE:
881 return cmpResult == APFloat::cmpUnordered ||
882 cmpResult == APFloat::cmpGreaterThan ||
883 cmpResult == APFloat::cmpEqual;
884 case CmpFPredicate::ULT:
885 return cmpResult == APFloat::cmpUnordered ||
886 cmpResult == APFloat::cmpLessThan;
887 case CmpFPredicate::ULE:
888 return cmpResult == APFloat::cmpUnordered ||
889 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
890 case CmpFPredicate::UNE:
891 return cmpResult != APFloat::cmpEqual;
892 case CmpFPredicate::UNO:
893 return cmpResult == APFloat::cmpUnordered;
894 case CmpFPredicate::AlwaysTrue:
895 return true;
896 }
897 llvm_unreachable("unknown comparison predicate");
898 }
899
900 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)901 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
902 assert(operands.size() == 2 && "cmpf takes two arguments");
903
904 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
905 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
906
907 // TODO: We could actually do some intelligent things if we know only one
908 // of the operands, but it's inf or nan.
909 if (!lhs || !rhs)
910 return {};
911
912 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
913 return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
914 }
915
916 //===----------------------------------------------------------------------===//
917 // CondBranchOp
918 //===----------------------------------------------------------------------===//
919
920 namespace {
921 /// cond_br true, ^bb1, ^bb2
922 /// -> br ^bb1
923 /// cond_br false, ^bb1, ^bb2
924 /// -> br ^bb2
925 ///
926 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
927 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
928
matchAndRewrite__anon1787a6f50611::SimplifyConstCondBranchPred929 LogicalResult matchAndRewrite(CondBranchOp condbr,
930 PatternRewriter &rewriter) const override {
931 if (matchPattern(condbr.getCondition(), m_NonZero())) {
932 // True branch taken.
933 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
934 condbr.getTrueOperands());
935 return success();
936 } else if (matchPattern(condbr.getCondition(), m_Zero())) {
937 // False branch taken.
938 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
939 condbr.getFalseOperands());
940 return success();
941 }
942 return failure();
943 }
944 };
945
946 /// cond_br %cond, ^bb1, ^bb2
947 /// ^bb1
948 /// br ^bbN(...)
949 /// ^bb2
950 /// br ^bbK(...)
951 ///
952 /// -> cond_br %cond, ^bbN(...), ^bbK(...)
953 ///
954 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
955 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
956
matchAndRewrite__anon1787a6f50611::SimplifyPassThroughCondBranch957 LogicalResult matchAndRewrite(CondBranchOp condbr,
958 PatternRewriter &rewriter) const override {
959 Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
960 ValueRange trueDestOperands = condbr.getTrueOperands();
961 ValueRange falseDestOperands = condbr.getFalseOperands();
962 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
963
964 // Try to collapse one of the current successors.
965 LogicalResult collapsedTrue =
966 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
967 LogicalResult collapsedFalse =
968 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
969 if (failed(collapsedTrue) && failed(collapsedFalse))
970 return failure();
971
972 // Create a new branch with the collapsed successors.
973 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
974 trueDest, trueDestOperands,
975 falseDest, falseDestOperands);
976 return success();
977 }
978 };
979
980 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
981 /// -> br ^bb1(A, ..., N)
982 ///
983 /// cond_br %cond, ^bb1(A), ^bb1(B)
984 /// -> %select = select %cond, A, B
985 /// br ^bb1(%select)
986 ///
987 struct SimplifyCondBranchIdenticalSuccessors
988 : public OpRewritePattern<CondBranchOp> {
989 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
990
matchAndRewrite__anon1787a6f50611::SimplifyCondBranchIdenticalSuccessors991 LogicalResult matchAndRewrite(CondBranchOp condbr,
992 PatternRewriter &rewriter) const override {
993 // Check that the true and false destinations are the same and have the same
994 // operands.
995 Block *trueDest = condbr.trueDest();
996 if (trueDest != condbr.falseDest())
997 return failure();
998
999 // If all of the operands match, no selects need to be generated.
1000 OperandRange trueOperands = condbr.getTrueOperands();
1001 OperandRange falseOperands = condbr.getFalseOperands();
1002 if (trueOperands == falseOperands) {
1003 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
1004 return success();
1005 }
1006
1007 // Otherwise, if the current block is the only predecessor insert selects
1008 // for any mismatched branch operands.
1009 if (trueDest->getUniquePredecessor() != condbr->getBlock())
1010 return failure();
1011
1012 // Generate a select for any operands that differ between the two.
1013 SmallVector<Value, 8> mergedOperands;
1014 mergedOperands.reserve(trueOperands.size());
1015 Value condition = condbr.getCondition();
1016 for (auto it : llvm::zip(trueOperands, falseOperands)) {
1017 if (std::get<0>(it) == std::get<1>(it))
1018 mergedOperands.push_back(std::get<0>(it));
1019 else
1020 mergedOperands.push_back(rewriter.create<SelectOp>(
1021 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1022 }
1023
1024 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1025 return success();
1026 }
1027 };
1028
1029 /// ...
1030 /// cond_br %cond, ^bb1(...), ^bb2(...)
1031 /// ...
1032 /// ^bb1: // has single predecessor
1033 /// ...
1034 /// cond_br %cond, ^bb3(...), ^bb4(...)
1035 ///
1036 /// ->
1037 ///
1038 /// ...
1039 /// cond_br %cond, ^bb1(...), ^bb2(...)
1040 /// ...
1041 /// ^bb1: // has single predecessor
1042 /// ...
1043 /// br ^bb3(...)
1044 ///
1045 struct SimplifyCondBranchFromCondBranchOnSameCondition
1046 : public OpRewritePattern<CondBranchOp> {
1047 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1048
matchAndRewrite__anon1787a6f50611::SimplifyCondBranchFromCondBranchOnSameCondition1049 LogicalResult matchAndRewrite(CondBranchOp condbr,
1050 PatternRewriter &rewriter) const override {
1051 // Check that we have a single distinct predecessor.
1052 Block *currentBlock = condbr->getBlock();
1053 Block *predecessor = currentBlock->getSinglePredecessor();
1054 if (!predecessor)
1055 return failure();
1056
1057 // Check that the predecessor terminates with a conditional branch to this
1058 // block and that it branches on the same condition.
1059 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
1060 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
1061 return failure();
1062
1063 // Fold this branch to an unconditional branch.
1064 if (currentBlock == predBranch.trueDest())
1065 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
1066 condbr.trueDestOperands());
1067 else
1068 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
1069 condbr.falseDestOperands());
1070 return success();
1071 }
1072 };
1073
1074 /// cond_br %arg0, ^trueB, ^falseB
1075 ///
1076 /// ^trueB:
1077 /// "test.consumer1"(%arg0) : (i1) -> ()
1078 /// ...
1079 ///
1080 /// ^falseB:
1081 /// "test.consumer2"(%arg0) : (i1) -> ()
1082 /// ...
1083 ///
1084 /// ->
1085 ///
1086 /// cond_br %arg0, ^trueB, ^falseB
1087 /// ^trueB:
1088 /// "test.consumer1"(%true) : (i1) -> ()
1089 /// ...
1090 ///
1091 /// ^falseB:
1092 /// "test.consumer2"(%false) : (i1) -> ()
1093 /// ...
1094 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
1095 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1096
matchAndRewrite__anon1787a6f50611::CondBranchTruthPropagation1097 LogicalResult matchAndRewrite(CondBranchOp condbr,
1098 PatternRewriter &rewriter) const override {
1099 // Check that we have a single distinct predecessor.
1100 bool replaced = false;
1101 Type ty = rewriter.getI1Type();
1102
1103 // These variables serve to prevent creating duplicate constants
1104 // and hold constant true or false values.
1105 Value constantTrue = nullptr;
1106 Value constantFalse = nullptr;
1107
1108 // TODO These checks can be expanded to encompas any use with only
1109 // either the true of false edge as a predecessor. For now, we fall
1110 // back to checking the single predecessor is given by the true/fasle
1111 // destination, thereby ensuring that only that edge can reach the
1112 // op.
1113 if (condbr.getTrueDest()->getSinglePredecessor()) {
1114 for (OpOperand &use :
1115 llvm::make_early_inc_range(condbr.condition().getUses())) {
1116 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
1117 replaced = true;
1118
1119 if (!constantTrue)
1120 constantTrue = rewriter.create<mlir::ConstantOp>(
1121 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
1122
1123 rewriter.updateRootInPlace(use.getOwner(),
1124 [&] { use.set(constantTrue); });
1125 }
1126 }
1127 }
1128 if (condbr.getFalseDest()->getSinglePredecessor()) {
1129 for (OpOperand &use :
1130 llvm::make_early_inc_range(condbr.condition().getUses())) {
1131 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
1132 replaced = true;
1133
1134 if (!constantFalse)
1135 constantFalse = rewriter.create<mlir::ConstantOp>(
1136 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
1137
1138 rewriter.updateRootInPlace(use.getOwner(),
1139 [&] { use.set(constantFalse); });
1140 }
1141 }
1142 }
1143 return success(replaced);
1144 }
1145 };
1146 } // end anonymous namespace
1147
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1148 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
1149 MLIRContext *context) {
1150 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1151 SimplifyCondBranchIdenticalSuccessors,
1152 SimplifyCondBranchFromCondBranchOnSameCondition,
1153 CondBranchTruthPropagation>(context);
1154 }
1155
1156 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1157 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1158 assert(index < getNumSuccessors() && "invalid successor index");
1159 return index == trueIndex ? trueDestOperandsMutable()
1160 : falseDestOperandsMutable();
1161 }
1162
getSuccessorForOperands(ArrayRef<Attribute> operands)1163 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1164 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1165 return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1166 return nullptr;
1167 }
1168
1169 //===----------------------------------------------------------------------===//
1170 // Constant*Op
1171 //===----------------------------------------------------------------------===//
1172
print(OpAsmPrinter & p,ConstantOp & op)1173 static void print(OpAsmPrinter &p, ConstantOp &op) {
1174 p << " ";
1175 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
1176
1177 if (op->getAttrs().size() > 1)
1178 p << ' ';
1179 p << op.getValue();
1180
1181 // If the value is a symbol reference or Array, print a trailing type.
1182 if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
1183 p << " : " << op.getType();
1184 }
1185
parseConstantOp(OpAsmParser & parser,OperationState & result)1186 static ParseResult parseConstantOp(OpAsmParser &parser,
1187 OperationState &result) {
1188 Attribute valueAttr;
1189 if (parser.parseOptionalAttrDict(result.attributes) ||
1190 parser.parseAttribute(valueAttr, "value", result.attributes))
1191 return failure();
1192
1193 // If the attribute is a symbol reference or array, then we expect a trailing
1194 // type.
1195 Type type;
1196 if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
1197 type = valueAttr.getType();
1198 else if (parser.parseColonType(type))
1199 return failure();
1200
1201 // Add the attribute type to the list.
1202 return parser.addTypeToList(type, result.types);
1203 }
1204
1205 /// The constant op requires an attribute, and furthermore requires that it
1206 /// matches the return type.
verify(ConstantOp & op)1207 static LogicalResult verify(ConstantOp &op) {
1208 auto value = op.getValue();
1209 if (!value)
1210 return op.emitOpError("requires a 'value' attribute");
1211
1212 Type type = op.getType();
1213 if (!value.getType().isa<NoneType>() && type != value.getType())
1214 return op.emitOpError() << "requires attribute's type (" << value.getType()
1215 << ") to match op's return type (" << type << ")";
1216
1217 if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1218 if (type.isa<IndexType>() || value.isa<BoolAttr>())
1219 return success();
1220 IntegerType intType = type.cast<IntegerType>();
1221 if (!intType.isSignless())
1222 return op.emitOpError("requires integer result types to be signless");
1223
1224 // If the type has a known bitwidth we verify that the value can be
1225 // represented with the given bitwidth.
1226 unsigned bitwidth = intType.getWidth();
1227 APInt intVal = intAttr.getValue();
1228 if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1229 return op.emitOpError("requires 'value' to be an integer within the "
1230 "range of the integer result type");
1231 return success();
1232 }
1233
1234 if (auto complexTy = type.dyn_cast<ComplexType>()) {
1235 auto arrayAttr = value.dyn_cast<ArrayAttr>();
1236 if (!complexTy || arrayAttr.size() != 2)
1237 return op.emitOpError(
1238 "requires 'value' to be a complex constant, represented as array of "
1239 "two values");
1240 auto complexEltTy = complexTy.getElementType();
1241 if (complexEltTy != arrayAttr[0].getType() ||
1242 complexEltTy != arrayAttr[1].getType()) {
1243 return op.emitOpError()
1244 << "requires attribute's element types (" << arrayAttr[0].getType()
1245 << ", " << arrayAttr[1].getType()
1246 << ") to match the element type of the op's return type ("
1247 << complexEltTy << ")";
1248 }
1249 return success();
1250 }
1251
1252 if (type.isa<FloatType>()) {
1253 if (!value.isa<FloatAttr>())
1254 return op.emitOpError("requires 'value' to be a floating point constant");
1255 return success();
1256 }
1257
1258 if (type.isa<ShapedType>()) {
1259 if (!value.isa<ElementsAttr>())
1260 return op.emitOpError("requires 'value' to be a shaped constant");
1261 return success();
1262 }
1263
1264 if (type.isa<FunctionType>()) {
1265 auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1266 if (!fnAttr)
1267 return op.emitOpError("requires 'value' to be a function reference");
1268
1269 // Try to find the referenced function.
1270 auto fn =
1271 op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1272 if (!fn)
1273 return op.emitOpError()
1274 << "reference to undefined function '" << fnAttr.getValue() << "'";
1275
1276 // Check that the referenced function has the correct type.
1277 if (fn.getType() != type)
1278 return op.emitOpError("reference to function with mismatched type");
1279
1280 return success();
1281 }
1282
1283 if (type.isa<NoneType>() && value.isa<UnitAttr>())
1284 return success();
1285
1286 return op.emitOpError("unsupported 'value' attribute: ") << value;
1287 }
1288
fold(ArrayRef<Attribute> operands)1289 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1290 assert(operands.empty() && "constant has no operands");
1291 return getValue();
1292 }
1293
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1294 void ConstantOp::getAsmResultNames(
1295 function_ref<void(Value, StringRef)> setNameFn) {
1296 Type type = getType();
1297 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1298 IntegerType intTy = type.dyn_cast<IntegerType>();
1299
1300 // Sugar i1 constants with 'true' and 'false'.
1301 if (intTy && intTy.getWidth() == 1)
1302 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1303
1304 // Otherwise, build a complex name with the value and type.
1305 SmallString<32> specialNameBuffer;
1306 llvm::raw_svector_ostream specialName(specialNameBuffer);
1307 specialName << 'c' << intCst.getInt();
1308 if (intTy)
1309 specialName << '_' << type;
1310 setNameFn(getResult(), specialName.str());
1311
1312 } else if (type.isa<FunctionType>()) {
1313 setNameFn(getResult(), "f");
1314 } else {
1315 setNameFn(getResult(), "cst");
1316 }
1317 }
1318
1319 /// Returns true if a constant operation can be built with the given value and
1320 /// result type.
isBuildableWith(Attribute value,Type type)1321 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1322 // SymbolRefAttr can only be used with a function type.
1323 if (value.isa<SymbolRefAttr>())
1324 return type.isa<FunctionType>();
1325 // The attribute must have the same type as 'type'.
1326 if (!value.getType().isa<NoneType>() && value.getType() != type)
1327 return false;
1328 // If the type is an integer type, it must be signless.
1329 if (IntegerType integerTy = type.dyn_cast<IntegerType>())
1330 if (!integerTy.isSignless())
1331 return false;
1332 // Finally, check that the attribute kind is handled.
1333 if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
1334 auto complexTy = type.dyn_cast<ComplexType>();
1335 if (!complexTy)
1336 return false;
1337 auto complexEltTy = complexTy.getElementType();
1338 return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
1339 arrAttr[1].getType() == complexEltTy;
1340 }
1341 return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1342 }
1343
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1344 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1345 const APFloat &value, FloatType type) {
1346 ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1347 }
1348
classof(Operation * op)1349 bool ConstantFloatOp::classof(Operation *op) {
1350 return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1351 }
1352
1353 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1354 bool ConstantIntOp::classof(Operation *op) {
1355 return ConstantOp::classof(op) &&
1356 op->getResult(0).getType().isSignlessInteger();
1357 }
1358
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1359 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1360 int64_t value, unsigned width) {
1361 Type type = builder.getIntegerType(width);
1362 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1363 }
1364
1365 /// Build a constant int op producing an integer with the specified type,
1366 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1367 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1368 int64_t value, Type type) {
1369 assert(type.isSignlessInteger() &&
1370 "ConstantIntOp can only have signless integer type");
1371 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1372 }
1373
1374 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1375 bool ConstantIndexOp::classof(Operation *op) {
1376 return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1377 }
1378
build(OpBuilder & builder,OperationState & result,int64_t value)1379 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1380 int64_t value) {
1381 Type type = builder.getIndexType();
1382 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1383 }
1384
1385 // ---------------------------------------------------------------------------
1386 // DivFOp
1387 // ---------------------------------------------------------------------------
1388
fold(ArrayRef<Attribute> operands)1389 OpFoldResult DivFOp::fold(ArrayRef<Attribute> operands) {
1390 return constFoldBinaryOp<FloatAttr>(
1391 operands, [](APFloat a, APFloat b) { return a / b; });
1392 }
1393
1394 //===----------------------------------------------------------------------===//
1395 // FPExtOp
1396 //===----------------------------------------------------------------------===//
1397
areCastCompatible(TypeRange inputs,TypeRange outputs)1398 bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399 if (inputs.size() != 1 || outputs.size() != 1)
1400 return false;
1401 Type a = inputs.front(), b = outputs.front();
1402 if (auto fa = a.dyn_cast<FloatType>())
1403 if (auto fb = b.dyn_cast<FloatType>())
1404 return fa.getWidth() < fb.getWidth();
1405 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1406 }
1407
1408 //===----------------------------------------------------------------------===//
1409 // FPToSIOp
1410 //===----------------------------------------------------------------------===//
1411
areCastCompatible(TypeRange inputs,TypeRange outputs)1412 bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1413 if (inputs.size() != 1 || outputs.size() != 1)
1414 return false;
1415 Type a = inputs.front(), b = outputs.front();
1416 if (a.isa<FloatType>() && b.isSignlessInteger())
1417 return true;
1418 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1419 }
1420
1421 //===----------------------------------------------------------------------===//
1422 // FPToUIOp
1423 //===----------------------------------------------------------------------===//
1424
areCastCompatible(TypeRange inputs,TypeRange outputs)1425 bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1426 if (inputs.size() != 1 || outputs.size() != 1)
1427 return false;
1428 Type a = inputs.front(), b = outputs.front();
1429 if (a.isa<FloatType>() && b.isSignlessInteger())
1430 return true;
1431 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1432 }
1433
1434 //===----------------------------------------------------------------------===//
1435 // FPTruncOp
1436 //===----------------------------------------------------------------------===//
1437
areCastCompatible(TypeRange inputs,TypeRange outputs)1438 bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1439 if (inputs.size() != 1 || outputs.size() != 1)
1440 return false;
1441 Type a = inputs.front(), b = outputs.front();
1442 if (auto fa = a.dyn_cast<FloatType>())
1443 if (auto fb = b.dyn_cast<FloatType>())
1444 return fa.getWidth() > fb.getWidth();
1445 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1446 }
1447
1448 /// Perform safe const propagation for fptrunc, i.e. only propagate
1449 /// if FP value can be represented without precision loss or rounding.
fold(ArrayRef<Attribute> operands)1450 OpFoldResult FPTruncOp::fold(ArrayRef<Attribute> operands) {
1451 assert(operands.size() == 1 && "unary operation takes one operand");
1452
1453 auto constOperand = operands.front();
1454 if (!constOperand || !constOperand.isa<FloatAttr>())
1455 return {};
1456
1457 // Convert to target type via 'double'.
1458 double sourceValue =
1459 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
1460 auto targetAttr = FloatAttr::get(getType(), sourceValue);
1461
1462 // Propagate if constant's value does not change after truncation.
1463 if (sourceValue == targetAttr.getValue().convertToDouble())
1464 return targetAttr;
1465
1466 return {};
1467 }
1468
1469 //===----------------------------------------------------------------------===//
1470 // IndexCastOp
1471 //===----------------------------------------------------------------------===//
1472
1473 // Index cast is applicable from index to integer and backwards.
areCastCompatible(TypeRange inputs,TypeRange outputs)1474 bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1475 if (inputs.size() != 1 || outputs.size() != 1)
1476 return false;
1477 Type a = inputs.front(), b = outputs.front();
1478 if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
1479 auto aShaped = a.cast<ShapedType>();
1480 auto bShaped = b.cast<ShapedType>();
1481
1482 return (aShaped.getShape() == bShaped.getShape()) &&
1483 areCastCompatible(aShaped.getElementType(),
1484 bShaped.getElementType());
1485 }
1486
1487 return (a.isIndex() && b.isSignlessInteger()) ||
1488 (a.isSignlessInteger() && b.isIndex());
1489 }
1490
fold(ArrayRef<Attribute> cstOperands)1491 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1492 // Fold IndexCast(IndexCast(x)) -> x
1493 auto cast = getOperand().getDefiningOp<IndexCastOp>();
1494 if (cast && cast.getOperand().getType() == getType())
1495 return cast.getOperand();
1496
1497 // Fold IndexCast(constant) -> constant
1498 // A little hack because we go through int. Otherwise, the size
1499 // of the constant might need to change.
1500 if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1501 return IntegerAttr::get(getType(), value.getInt());
1502
1503 return {};
1504 }
1505
1506 namespace {
1507 /// index_cast(sign_extend x) => index_cast(x)
1508 struct IndexCastOfSExt : public OpRewritePattern<IndexCastOp> {
1509 using OpRewritePattern<IndexCastOp>::OpRewritePattern;
1510
matchAndRewrite__anon1787a6f50a11::IndexCastOfSExt1511 LogicalResult matchAndRewrite(IndexCastOp op,
1512 PatternRewriter &rewriter) const override {
1513
1514 if (auto extop = op.getOperand().getDefiningOp<SignExtendIOp>()) {
1515 op.setOperand(extop.getOperand());
1516 return success();
1517 }
1518 return failure();
1519 }
1520 };
1521
1522 } // namespace
1523
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1524 void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1525 MLIRContext *context) {
1526 results.insert<IndexCastOfSExt>(context);
1527 }
1528
1529 //===----------------------------------------------------------------------===//
1530 // MulFOp
1531 //===----------------------------------------------------------------------===//
1532
fold(ArrayRef<Attribute> operands)1533 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1534 return constFoldBinaryOp<FloatAttr>(
1535 operands, [](APFloat a, APFloat b) { return a * b; });
1536 }
1537
1538 //===----------------------------------------------------------------------===//
1539 // MulIOp
1540 //===----------------------------------------------------------------------===//
1541
fold(ArrayRef<Attribute> operands)1542 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1543 /// muli(x, 0) -> 0
1544 if (matchPattern(rhs(), m_Zero()))
1545 return rhs();
1546 /// muli(x, 1) -> x
1547 if (matchPattern(rhs(), m_One()))
1548 return getOperand(0);
1549
1550 // TODO: Handle the overflow case.
1551 return constFoldBinaryOp<IntegerAttr>(operands,
1552 [](APInt a, APInt b) { return a * b; });
1553 }
1554
1555 //===----------------------------------------------------------------------===//
1556 // OrOp
1557 //===----------------------------------------------------------------------===//
1558
fold(ArrayRef<Attribute> operands)1559 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1560 /// or(x, 0) -> x
1561 if (matchPattern(rhs(), m_Zero()))
1562 return lhs();
1563 /// or(x, x) -> x
1564 if (lhs() == rhs())
1565 return rhs();
1566 /// or(x, <all ones>) -> <all ones>
1567 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
1568 if (rhsAttr.getValue().isAllOnes())
1569 return rhsAttr;
1570
1571 return constFoldBinaryOp<IntegerAttr>(operands,
1572 [](APInt a, APInt b) { return a | b; });
1573 }
1574
1575 //===----------------------------------------------------------------------===//
1576 // RankOp
1577 //===----------------------------------------------------------------------===//
1578
fold(ArrayRef<Attribute> operands)1579 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1580 // Constant fold rank when the rank of the operand is known.
1581 auto type = getOperand().getType();
1582 if (auto shapedType = type.dyn_cast<ShapedType>())
1583 if (shapedType.hasRank())
1584 return IntegerAttr::get(IndexType::get(getContext()),
1585 shapedType.getRank());
1586 return IntegerAttr();
1587 }
1588
1589 //===----------------------------------------------------------------------===//
1590 // ReturnOp
1591 //===----------------------------------------------------------------------===//
1592
verify(ReturnOp op)1593 static LogicalResult verify(ReturnOp op) {
1594 auto function = cast<FuncOp>(op->getParentOp());
1595
1596 // The operand number and types must match the function signature.
1597 const auto &results = function.getType().getResults();
1598 if (op.getNumOperands() != results.size())
1599 return op.emitOpError("has ")
1600 << op.getNumOperands() << " operands, but enclosing function (@"
1601 << function.getName() << ") returns " << results.size();
1602
1603 for (unsigned i = 0, e = results.size(); i != e; ++i)
1604 if (op.getOperand(i).getType() != results[i])
1605 return op.emitError()
1606 << "type of return operand " << i << " ("
1607 << op.getOperand(i).getType()
1608 << ") doesn't match function result type (" << results[i] << ")"
1609 << " in function @" << function.getName();
1610
1611 return success();
1612 }
1613
1614 //===----------------------------------------------------------------------===//
1615 // SelectOp
1616 //===----------------------------------------------------------------------===//
1617
1618 // Transforms a select to a not, where relevant.
1619 //
1620 // select %arg, %false, %true
1621 //
1622 // becomes
1623 //
1624 // xor %arg, %true
1625 struct SelectToNot : public OpRewritePattern<SelectOp> {
1626 using OpRewritePattern<SelectOp>::OpRewritePattern;
1627
matchAndRewriteSelectToNot1628 LogicalResult matchAndRewrite(SelectOp op,
1629 PatternRewriter &rewriter) const override {
1630 if (!matchPattern(op.getTrueValue(), m_Zero()))
1631 return failure();
1632
1633 if (!matchPattern(op.getFalseValue(), m_One()))
1634 return failure();
1635
1636 if (!op.getType().isInteger(1))
1637 return failure();
1638
1639 rewriter.replaceOpWithNewOp<XOrOp>(op, op.condition(), op.getFalseValue());
1640 return success();
1641 }
1642 };
1643
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1644 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1645 MLIRContext *context) {
1646 results.insert<SelectToNot>(context);
1647 }
1648
fold(ArrayRef<Attribute> operands)1649 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
1650 auto trueVal = getTrueValue();
1651 auto falseVal = getFalseValue();
1652 if (trueVal == falseVal)
1653 return trueVal;
1654
1655 auto condition = getCondition();
1656
1657 // select true, %0, %1 => %0
1658 if (matchPattern(condition, m_One()))
1659 return trueVal;
1660
1661 // select false, %0, %1 => %1
1662 if (matchPattern(condition, m_Zero()))
1663 return falseVal;
1664
1665 if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
1666 auto pred = cmp.predicate();
1667 if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
1668 auto cmpLhs = cmp.lhs();
1669 auto cmpRhs = cmp.rhs();
1670
1671 // %0 = cmpi eq, %arg0, %arg1
1672 // %1 = select %0, %arg0, %arg1 => %arg1
1673
1674 // %0 = cmpi ne, %arg0, %arg1
1675 // %1 = select %0, %arg0, %arg1 => %arg0
1676
1677 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1678 (cmpRhs == trueVal && cmpLhs == falseVal))
1679 return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
1680 }
1681 }
1682 return nullptr;
1683 }
1684
print(OpAsmPrinter & p,SelectOp op)1685 static void print(OpAsmPrinter &p, SelectOp op) {
1686 p << " " << op.getOperands();
1687 p.printOptionalAttrDict(op->getAttrs());
1688 p << " : ";
1689 if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
1690 p << condType << ", ";
1691 p << op.getType();
1692 }
1693
parseSelectOp(OpAsmParser & parser,OperationState & result)1694 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
1695 Type conditionType, resultType;
1696 SmallVector<OpAsmParser::OperandType, 3> operands;
1697 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1698 parser.parseOptionalAttrDict(result.attributes) ||
1699 parser.parseColonType(resultType))
1700 return failure();
1701
1702 // Check for the explicit condition type if this is a masked tensor or vector.
1703 if (succeeded(parser.parseOptionalComma())) {
1704 conditionType = resultType;
1705 if (parser.parseType(resultType))
1706 return failure();
1707 } else {
1708 conditionType = parser.getBuilder().getI1Type();
1709 }
1710
1711 result.addTypes(resultType);
1712 return parser.resolveOperands(operands,
1713 {conditionType, resultType, resultType},
1714 parser.getNameLoc(), result.operands);
1715 }
1716
verify(SelectOp op)1717 static LogicalResult verify(SelectOp op) {
1718 Type conditionType = op.getCondition().getType();
1719 if (conditionType.isSignlessInteger(1))
1720 return success();
1721
1722 // If the result type is a vector or tensor, the type can be a mask with the
1723 // same elements.
1724 Type resultType = op.getType();
1725 if (!resultType.isa<TensorType, VectorType>())
1726 return op.emitOpError()
1727 << "expected condition to be a signless i1, but got "
1728 << conditionType;
1729 Type shapedConditionType = getI1SameShape(resultType);
1730 if (conditionType != shapedConditionType)
1731 return op.emitOpError()
1732 << "expected condition type to have the same shape "
1733 "as the result type, expected "
1734 << shapedConditionType << ", but got " << conditionType;
1735 return success();
1736 }
1737
1738 //===----------------------------------------------------------------------===//
1739 // SignExtendIOp
1740 //===----------------------------------------------------------------------===//
1741
verify(SignExtendIOp op)1742 static LogicalResult verify(SignExtendIOp op) {
1743 // Get the scalar type (which is either directly the type of the operand
1744 // or the vector's/tensor's element type.
1745 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
1746 auto dstType = getElementTypeOrSelf(op.getType());
1747
1748 // For now, index is forbidden for the source and the destination type.
1749 if (srcType.isa<IndexType>())
1750 return op.emitError() << srcType << " is not a valid operand type";
1751 if (dstType.isa<IndexType>())
1752 return op.emitError() << dstType << " is not a valid result type";
1753
1754 if (srcType.cast<IntegerType>().getWidth() >=
1755 dstType.cast<IntegerType>().getWidth())
1756 return op.emitError("result type ")
1757 << dstType << " must be wider than operand type " << srcType;
1758
1759 return success();
1760 }
1761
fold(ArrayRef<Attribute> operands)1762 OpFoldResult SignExtendIOp::fold(ArrayRef<Attribute> operands) {
1763 assert(operands.size() == 1 && "unary operation takes one operand");
1764
1765 if (!operands[0])
1766 return {};
1767
1768 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
1769 return IntegerAttr::get(
1770 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
1771 }
1772
1773 return {};
1774 }
1775
1776 //===----------------------------------------------------------------------===//
1777 // SignedDivIOp
1778 //===----------------------------------------------------------------------===//
1779
fold(ArrayRef<Attribute> operands)1780 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
1781 assert(operands.size() == 2 && "binary operation takes two operands");
1782
1783 // Don't fold if it would overflow or if it requires a division by zero.
1784 bool overflowOrDiv0 = false;
1785 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1786 if (overflowOrDiv0 || !b) {
1787 overflowOrDiv0 = true;
1788 return a;
1789 }
1790 return a.sdiv_ov(b, overflowOrDiv0);
1791 });
1792
1793 // Fold out division by one. Assumes all tensors of all ones are splats.
1794 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1795 if (rhs.getValue() == 1)
1796 return lhs();
1797 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1798 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1799 return lhs();
1800 }
1801
1802 return overflowOrDiv0 ? Attribute() : result;
1803 }
1804
1805 //===----------------------------------------------------------------------===//
1806 // SignedFloorDivIOp
1807 //===----------------------------------------------------------------------===//
1808
signedCeilNonnegInputs(APInt a,APInt b,bool & overflow)1809 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
1810 // Returns (a-1)/b + 1
1811 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
1812 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
1813 return val.sadd_ov(one, overflow);
1814 }
1815
fold(ArrayRef<Attribute> operands)1816 OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
1817 assert(operands.size() == 2 && "binary operation takes two operands");
1818
1819 // Don't fold if it would overflow or if it requires a division by zero.
1820 bool overflowOrDiv0 = false;
1821 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1822 if (overflowOrDiv0 || !b) {
1823 overflowOrDiv0 = true;
1824 return a;
1825 }
1826 unsigned bits = a.getBitWidth();
1827 APInt zero = APInt::getZero(bits);
1828 if (a.sge(zero) && b.sgt(zero)) {
1829 // Both positive (or a is zero), return a / b.
1830 return a.sdiv_ov(b, overflowOrDiv0);
1831 } else if (a.sle(zero) && b.slt(zero)) {
1832 // Both negative (or a is zero), return -a / -b.
1833 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1834 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1835 return posA.sdiv_ov(posB, overflowOrDiv0);
1836 } else if (a.slt(zero) && b.sgt(zero)) {
1837 // A is negative, b is positive, return - ceil(-a, b).
1838 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1839 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
1840 return zero.ssub_ov(ceil, overflowOrDiv0);
1841 } else {
1842 // A is positive, b is negative, return - ceil(a, -b).
1843 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1844 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
1845 return zero.ssub_ov(ceil, overflowOrDiv0);
1846 }
1847 });
1848
1849 // Fold out floor division by one. Assumes all tensors of all ones are
1850 // splats.
1851 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1852 if (rhs.getValue() == 1)
1853 return lhs();
1854 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1855 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1856 return lhs();
1857 }
1858
1859 return overflowOrDiv0 ? Attribute() : result;
1860 }
1861
1862 //===----------------------------------------------------------------------===//
1863 // SignedCeilDivIOp
1864 //===----------------------------------------------------------------------===//
1865
fold(ArrayRef<Attribute> operands)1866 OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
1867 assert(operands.size() == 2 && "binary operation takes two operands");
1868
1869 // Don't fold if it would overflow or if it requires a division by zero.
1870 bool overflowOrDiv0 = false;
1871 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1872 if (overflowOrDiv0 || !b) {
1873 overflowOrDiv0 = true;
1874 return a;
1875 }
1876 unsigned bits = a.getBitWidth();
1877 APInt zero = APInt::getZero(bits);
1878 if (a.sgt(zero) && b.sgt(zero)) {
1879 // Both positive, return ceil(a, b).
1880 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
1881 } else if (a.slt(zero) && b.slt(zero)) {
1882 // Both negative, return ceil(-a, -b).
1883 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1884 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1885 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
1886 } else if (a.slt(zero) && b.sgt(zero)) {
1887 // A is negative, b is positive, return - ( -a / b).
1888 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1889 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
1890 return zero.ssub_ov(div, overflowOrDiv0);
1891 } else {
1892 // A is positive (or zero), b is negative, return - (a / -b).
1893 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1894 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
1895 return zero.ssub_ov(div, overflowOrDiv0);
1896 }
1897 });
1898
1899 // Fold out floor division by one. Assumes all tensors of all ones are
1900 // splats.
1901 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1902 if (rhs.getValue() == 1)
1903 return lhs();
1904 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1905 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1906 return lhs();
1907 }
1908
1909 return overflowOrDiv0 ? Attribute() : result;
1910 }
1911
1912 //===----------------------------------------------------------------------===//
1913 // SignedRemIOp
1914 //===----------------------------------------------------------------------===//
1915
fold(ArrayRef<Attribute> operands)1916 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
1917 assert(operands.size() == 2 && "remi_signed takes two operands");
1918
1919 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1920 if (!rhs)
1921 return {};
1922 auto rhsValue = rhs.getValue();
1923
1924 // x % 1 = 0
1925 if (rhsValue.isOneValue())
1926 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
1927
1928 // Don't fold if it requires division by zero.
1929 if (rhsValue.isNullValue())
1930 return {};
1931
1932 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1933 if (!lhs)
1934 return {};
1935 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
1936 }
1937
1938 //===----------------------------------------------------------------------===//
1939 // SIToFPOp
1940 //===----------------------------------------------------------------------===//
1941
1942 // sitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)1943 bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1944 if (inputs.size() != 1 || outputs.size() != 1)
1945 return false;
1946 Type a = inputs.front(), b = outputs.front();
1947 if (a.isSignlessInteger() && b.isa<FloatType>())
1948 return true;
1949 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1950 }
1951
1952 //===----------------------------------------------------------------------===//
1953 // SplatOp
1954 //===----------------------------------------------------------------------===//
1955
verify(SplatOp op)1956 static LogicalResult verify(SplatOp op) {
1957 // TODO: we could replace this by a trait.
1958 if (op.getOperand().getType() !=
1959 op.getType().cast<ShapedType>().getElementType())
1960 return op.emitError("operand should be of elemental type of result type");
1961
1962 return success();
1963 }
1964
1965 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)1966 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
1967 assert(operands.size() == 1 && "splat takes one operand");
1968
1969 auto constOperand = operands.front();
1970 if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
1971 return {};
1972
1973 auto shapedType = getType().cast<ShapedType>();
1974 assert(shapedType.getElementType() == constOperand.getType() &&
1975 "incorrect input attribute type for folding");
1976
1977 // SplatElementsAttr::get treats single value for second arg as being a splat.
1978 return SplatElementsAttr::get(shapedType, {constOperand});
1979 }
1980
1981 //===----------------------------------------------------------------------===//
1982 // SubFOp
1983 //===----------------------------------------------------------------------===//
1984
fold(ArrayRef<Attribute> operands)1985 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
1986 return constFoldBinaryOp<FloatAttr>(
1987 operands, [](APFloat a, APFloat b) { return a - b; });
1988 }
1989
1990 //===----------------------------------------------------------------------===//
1991 // SubIOp
1992 //===----------------------------------------------------------------------===//
1993
fold(ArrayRef<Attribute> operands)1994 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
1995 // subi(x,x) -> 0
1996 if (getOperand(0) == getOperand(1))
1997 return Builder(getContext()).getZeroAttr(getType());
1998 // subi(x,0) -> x
1999 if (matchPattern(rhs(), m_Zero()))
2000 return lhs();
2001
2002 return constFoldBinaryOp<IntegerAttr>(operands,
2003 [](APInt a, APInt b) { return a - b; });
2004 }
2005
2006 /// Canonicalize a sub of a constant and (constant +/- something) to simply be
2007 /// a single operation that merges the two constants.
2008 struct SubConstantReorder : public OpRewritePattern<SubIOp> {
2009 using OpRewritePattern<SubIOp>::OpRewritePattern;
2010
matchAndRewriteSubConstantReorder2011 LogicalResult matchAndRewrite(SubIOp subOp,
2012 PatternRewriter &rewriter) const override {
2013 APInt origConst;
2014 APInt midConst;
2015
2016 if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
2017 if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
2018 // origConst - (midConst + something) == (origConst - midConst) -
2019 // something
2020 for (int j = 0; j < 2; j++) {
2021 if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
2022 auto nextConstant = rewriter.create<ConstantOp>(
2023 subOp.getLoc(),
2024 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2025 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2026 midAddOp.getOperand(1 - j));
2027 return success();
2028 }
2029 }
2030 }
2031
2032 if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
2033 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2034 // (midConst - something) - origConst == (midConst - origConst) -
2035 // something
2036 auto nextConstant = rewriter.create<ConstantOp>(
2037 subOp.getLoc(),
2038 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
2039 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2040 midSubOp.getOperand(1));
2041 return success();
2042 }
2043
2044 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2045 // (something - midConst) - origConst == something - (origConst +
2046 // midConst)
2047 auto nextConstant = rewriter.create<ConstantOp>(
2048 subOp.getLoc(),
2049 rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
2050 rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
2051 nextConstant);
2052 return success();
2053 }
2054 }
2055
2056 if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
2057 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2058 // origConst - (midConst - something) == (origConst - midConst) +
2059 // something
2060 auto nextConstant = rewriter.create<ConstantOp>(
2061 subOp.getLoc(),
2062 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2063 rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2064 midSubOp.getOperand(1));
2065 return success();
2066 }
2067
2068 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2069 // origConst - (something - midConst) == (origConst + midConst) -
2070 // something
2071 auto nextConstant = rewriter.create<ConstantOp>(
2072 subOp.getLoc(),
2073 rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
2074 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2075 midSubOp.getOperand(0));
2076 return success();
2077 }
2078 }
2079 }
2080
2081 if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
2082 if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
2083 // (midConst + something) - origConst == (midConst - origConst) +
2084 // something
2085 for (int j = 0; j < 2; j++) {
2086 if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
2087 auto nextConstant = rewriter.create<ConstantOp>(
2088 subOp.getLoc(),
2089 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
2090 rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2091 midAddOp.getOperand(1 - j));
2092 return success();
2093 }
2094 }
2095 }
2096
2097 if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
2098 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2099 // (midConst - something) - origConst == (midConst - origConst) -
2100 // something
2101 auto nextConstant = rewriter.create<ConstantOp>(
2102 subOp.getLoc(),
2103 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
2104 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2105 midSubOp.getOperand(1));
2106 return success();
2107 }
2108
2109 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2110 // (something - midConst) - origConst == something - (midConst +
2111 // origConst)
2112 auto nextConstant = rewriter.create<ConstantOp>(
2113 subOp.getLoc(),
2114 rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
2115 rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
2116 nextConstant);
2117 return success();
2118 }
2119 }
2120
2121 if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
2122 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2123 // origConst - (midConst - something) == (origConst - midConst) +
2124 // something
2125 auto nextConstant = rewriter.create<ConstantOp>(
2126 subOp.getLoc(),
2127 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2128 rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2129 midSubOp.getOperand(1));
2130 return success();
2131 }
2132 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2133 // origConst - (something - midConst) == (origConst - midConst) -
2134 // something
2135 auto nextConstant = rewriter.create<ConstantOp>(
2136 subOp.getLoc(),
2137 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2138 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2139 midSubOp.getOperand(0));
2140 return success();
2141 }
2142 }
2143 }
2144 return failure();
2145 }
2146 };
2147
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2148 void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2149 MLIRContext *context) {
2150 results.insert<SubConstantReorder>(context);
2151 }
2152
2153 //===----------------------------------------------------------------------===//
2154 // UIToFPOp
2155 //===----------------------------------------------------------------------===//
2156
2157 // uitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)2158 bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2159 if (inputs.size() != 1 || outputs.size() != 1)
2160 return false;
2161 Type a = inputs.front(), b = outputs.front();
2162 if (a.isSignlessInteger() && b.isa<FloatType>())
2163 return true;
2164 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2165 }
2166
2167 //===----------------------------------------------------------------------===//
2168 // SwitchOp
2169 //===----------------------------------------------------------------------===//
2170
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2171 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2172 Block *defaultDestination, ValueRange defaultOperands,
2173 DenseIntElementsAttr caseValues,
2174 BlockRange caseDestinations,
2175 ArrayRef<ValueRange> caseOperands) {
2176 build(builder, result, value, defaultOperands, caseOperands, caseValues,
2177 defaultDestination, caseDestinations);
2178 }
2179
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2180 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2181 Block *defaultDestination, ValueRange defaultOperands,
2182 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
2183 ArrayRef<ValueRange> caseOperands) {
2184 DenseIntElementsAttr caseValuesAttr;
2185 if (!caseValues.empty()) {
2186 ShapedType caseValueType = VectorType::get(
2187 static_cast<int64_t>(caseValues.size()), value.getType());
2188 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
2189 }
2190 build(builder, result, value, defaultDestination, defaultOperands,
2191 caseValuesAttr, caseDestinations, caseOperands);
2192 }
2193
2194 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
2195 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::OperandType> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)2196 static ParseResult parseSwitchOpCases(
2197 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
2198 SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
2199 SmallVectorImpl<Type> &defaultOperandTypes,
2200 DenseIntElementsAttr &caseValues,
2201 SmallVectorImpl<Block *> &caseDestinations,
2202 SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
2203 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
2204 if (parser.parseKeyword("default") || parser.parseColon() ||
2205 parser.parseSuccessor(defaultDestination))
2206 return failure();
2207 if (succeeded(parser.parseOptionalLParen())) {
2208 if (parser.parseRegionArgumentList(defaultOperands) ||
2209 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
2210 return failure();
2211 }
2212
2213 SmallVector<APInt> values;
2214 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
2215 while (succeeded(parser.parseOptionalComma())) {
2216 int64_t value = 0;
2217 if (failed(parser.parseInteger(value)))
2218 return failure();
2219 values.push_back(APInt(bitWidth, value));
2220
2221 Block *destination;
2222 SmallVector<OpAsmParser::OperandType> operands;
2223 SmallVector<Type> operandTypes;
2224 if (failed(parser.parseColon()) ||
2225 failed(parser.parseSuccessor(destination)))
2226 return failure();
2227 if (succeeded(parser.parseOptionalLParen())) {
2228 if (failed(parser.parseRegionArgumentList(operands)) ||
2229 failed(parser.parseColonTypeList(operandTypes)) ||
2230 failed(parser.parseRParen()))
2231 return failure();
2232 }
2233 caseDestinations.push_back(destination);
2234 caseOperands.emplace_back(operands);
2235 caseOperandTypes.emplace_back(operandTypes);
2236 }
2237
2238 if (!values.empty()) {
2239 ShapedType caseValueType =
2240 VectorType::get(static_cast<int64_t>(values.size()), flagType);
2241 caseValues = DenseIntElementsAttr::get(caseValueType, values);
2242 }
2243 return success();
2244 }
2245
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,TypeRangeRange caseOperandTypes)2246 static void printSwitchOpCases(
2247 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
2248 OperandRange defaultOperands, TypeRange defaultOperandTypes,
2249 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
2250 OperandRangeRange caseOperands, TypeRangeRange caseOperandTypes) {
2251 p << " default: ";
2252 p.printSuccessorAndUseList(defaultDestination, defaultOperands);
2253
2254 if (!caseValues)
2255 return;
2256
2257 for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
2258 p << ',';
2259 p.printNewline();
2260 p << " ";
2261 p << caseValues.getValue<APInt>(i).getLimitedValue();
2262 p << ": ";
2263 p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]);
2264 }
2265 p.printNewline();
2266 }
2267
verify(SwitchOp op)2268 static LogicalResult verify(SwitchOp op) {
2269 auto caseValues = op.case_values();
2270 auto caseDestinations = op.caseDestinations();
2271
2272 if (!caseValues && caseDestinations.empty())
2273 return success();
2274
2275 Type flagType = op.flag().getType();
2276 Type caseValueType = caseValues->getType().getElementType();
2277 if (caseValueType != flagType)
2278 return op.emitOpError()
2279 << "'flag' type (" << flagType << ") should match case value type ("
2280 << caseValueType << ")";
2281
2282 if (caseValues &&
2283 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
2284 return op.emitOpError() << "number of case values (" << caseValues->size()
2285 << ") should match number of "
2286 "case destinations ("
2287 << caseDestinations.size() << ")";
2288 return success();
2289 }
2290
2291 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)2292 SwitchOp::getMutableSuccessorOperands(unsigned index) {
2293 assert(index < getNumSuccessors() && "invalid successor index");
2294 return index == 0 ? defaultOperandsMutable()
2295 : getCaseOperandsMutable(index - 1);
2296 }
2297
getSuccessorForOperands(ArrayRef<Attribute> operands)2298 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
2299 Optional<DenseIntElementsAttr> caseValues = case_values();
2300
2301 if (!caseValues)
2302 return defaultDestination();
2303
2304 SuccessorRange caseDests = caseDestinations();
2305 if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
2306 for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
2307 if (value == caseValues->getValue<IntegerAttr>(i))
2308 return caseDests[i];
2309 return defaultDestination();
2310 }
2311 return nullptr;
2312 }
2313
2314 /// switch %flag : i32, [
2315 /// default: ^bb1
2316 /// ]
2317 /// -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)2318 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
2319 PatternRewriter &rewriter) {
2320 if (!op.caseDestinations().empty())
2321 return failure();
2322
2323 rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2324 op.defaultOperands());
2325 return success();
2326 }
2327
2328 /// switch %flag : i32, [
2329 /// default: ^bb1,
2330 /// 42: ^bb1,
2331 /// 43: ^bb2
2332 /// ]
2333 /// ->
2334 /// switch %flag : i32, [
2335 /// default: ^bb1,
2336 /// 43: ^bb2
2337 /// ]
2338 static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)2339 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
2340 SmallVector<Block *> newCaseDestinations;
2341 SmallVector<ValueRange> newCaseOperands;
2342 SmallVector<APInt> newCaseValues;
2343 bool requiresChange = false;
2344 auto caseValues = op.case_values();
2345 auto caseDests = op.caseDestinations();
2346
2347 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2348 if (caseDests[i] == op.defaultDestination() &&
2349 op.getCaseOperands(i) == op.defaultOperands()) {
2350 requiresChange = true;
2351 continue;
2352 }
2353 newCaseDestinations.push_back(caseDests[i]);
2354 newCaseOperands.push_back(op.getCaseOperands(i));
2355 newCaseValues.push_back(caseValues->getValue<APInt>(i));
2356 }
2357
2358 if (!requiresChange)
2359 return failure();
2360
2361 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2362 op.defaultOperands(), newCaseValues,
2363 newCaseDestinations, newCaseOperands);
2364 return success();
2365 }
2366
2367 /// Helper for folding a switch with a constant value.
2368 /// switch %c_42 : i32, [
2369 /// default: ^bb1 ,
2370 /// 42: ^bb2,
2371 /// 43: ^bb3
2372 /// ]
2373 /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,APInt caseValue)2374 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
2375 APInt caseValue) {
2376 auto caseValues = op.case_values();
2377 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2378 if (caseValues->getValue<APInt>(i) == caseValue) {
2379 rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
2380 op.getCaseOperands(i));
2381 return;
2382 }
2383 }
2384 rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2385 op.defaultOperands());
2386 }
2387
2388 /// switch %c_42 : i32, [
2389 /// default: ^bb1,
2390 /// 42: ^bb2,
2391 /// 43: ^bb3
2392 /// ]
2393 /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)2394 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
2395 PatternRewriter &rewriter) {
2396 APInt caseValue;
2397 if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
2398 return failure();
2399
2400 foldSwitch(op, rewriter, caseValue);
2401 return success();
2402 }
2403
2404 /// switch %c_42 : i32, [
2405 /// default: ^bb1,
2406 /// 42: ^bb2,
2407 /// ]
2408 /// ^bb2:
2409 /// br ^bb3
2410 /// ->
2411 /// switch %c_42 : i32, [
2412 /// default: ^bb1,
2413 /// 42: ^bb3,
2414 /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)2415 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
2416 PatternRewriter &rewriter) {
2417 SmallVector<Block *> newCaseDests;
2418 SmallVector<ValueRange> newCaseOperands;
2419 SmallVector<SmallVector<Value>> argStorage;
2420 auto caseValues = op.case_values();
2421 auto caseDests = op.caseDestinations();
2422 bool requiresChange = false;
2423 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2424 Block *caseDest = caseDests[i];
2425 ValueRange caseOperands = op.getCaseOperands(i);
2426 argStorage.emplace_back();
2427 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
2428 requiresChange = true;
2429
2430 newCaseDests.push_back(caseDest);
2431 newCaseOperands.push_back(caseOperands);
2432 }
2433
2434 Block *defaultDest = op.defaultDestination();
2435 ValueRange defaultOperands = op.defaultOperands();
2436 argStorage.emplace_back();
2437
2438 if (succeeded(
2439 collapseBranch(defaultDest, defaultOperands, argStorage.back())))
2440 requiresChange = true;
2441
2442 if (!requiresChange)
2443 return failure();
2444
2445 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
2446 defaultOperands, caseValues.getValue(),
2447 newCaseDests, newCaseOperands);
2448 return success();
2449 }
2450
2451 /// switch %flag : i32, [
2452 /// default: ^bb1,
2453 /// 42: ^bb2,
2454 /// ]
2455 /// ^bb2:
2456 /// switch %flag : i32, [
2457 /// default: ^bb3,
2458 /// 42: ^bb4
2459 /// ]
2460 /// ->
2461 /// switch %flag : i32, [
2462 /// default: ^bb1,
2463 /// 42: ^bb2,
2464 /// ]
2465 /// ^bb2:
2466 /// br ^bb4
2467 ///
2468 /// and
2469 ///
2470 /// switch %flag : i32, [
2471 /// default: ^bb1,
2472 /// 42: ^bb2,
2473 /// ]
2474 /// ^bb2:
2475 /// switch %flag : i32, [
2476 /// default: ^bb3,
2477 /// 43: ^bb4
2478 /// ]
2479 /// ->
2480 /// switch %flag : i32, [
2481 /// default: ^bb1,
2482 /// 42: ^bb2,
2483 /// ]
2484 /// ^bb2:
2485 /// br ^bb3
2486 static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2487 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
2488 PatternRewriter &rewriter) {
2489 // Check that we have a single distinct predecessor.
2490 Block *currentBlock = op->getBlock();
2491 Block *predecessor = currentBlock->getSinglePredecessor();
2492 if (!predecessor)
2493 return failure();
2494
2495 // Check that the predecessor terminates with a switch branch to this block
2496 // and that it branches on the same condition and that this branch isn't the
2497 // default destination.
2498 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2499 if (!predSwitch || op.flag() != predSwitch.flag() ||
2500 predSwitch.defaultDestination() == currentBlock)
2501 return failure();
2502
2503 // Fold this switch to an unconditional branch.
2504 APInt caseValue;
2505 bool isDefault = true;
2506 SuccessorRange predDests = predSwitch.caseDestinations();
2507 Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
2508 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
2509 if (currentBlock == predDests[i]) {
2510 caseValue = predCaseValues->getValue<APInt>(i);
2511 isDefault = false;
2512 break;
2513 }
2514 }
2515 if (isDefault)
2516 rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2517 op.defaultOperands());
2518 else
2519 foldSwitch(op, rewriter, caseValue);
2520 return success();
2521 }
2522
2523 /// switch %flag : i32, [
2524 /// default: ^bb1,
2525 /// 42: ^bb2
2526 /// ]
2527 /// ^bb1:
2528 /// switch %flag : i32, [
2529 /// default: ^bb3,
2530 /// 42: ^bb4,
2531 /// 43: ^bb5
2532 /// ]
2533 /// ->
2534 /// switch %flag : i32, [
2535 /// default: ^bb1,
2536 /// 42: ^bb2,
2537 /// ]
2538 /// ^bb1:
2539 /// switch %flag : i32, [
2540 /// default: ^bb3,
2541 /// 43: ^bb5
2542 /// ]
2543 static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2544 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
2545 PatternRewriter &rewriter) {
2546 // Check that we have a single distinct predecessor.
2547 Block *currentBlock = op->getBlock();
2548 Block *predecessor = currentBlock->getSinglePredecessor();
2549 if (!predecessor)
2550 return failure();
2551
2552 // Check that the predecessor terminates with a switch branch to this block
2553 // and that it branches on the same condition and that this branch is the
2554 // default destination.
2555 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2556 if (!predSwitch || op.flag() != predSwitch.flag() ||
2557 predSwitch.defaultDestination() != currentBlock)
2558 return failure();
2559
2560 // Delete case values that are not possible here.
2561 DenseSet<APInt> caseValuesToRemove;
2562 auto predDests = predSwitch.caseDestinations();
2563 auto predCaseValues = predSwitch.case_values();
2564 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
2565 if (currentBlock != predDests[i])
2566 caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
2567
2568 SmallVector<Block *> newCaseDestinations;
2569 SmallVector<ValueRange> newCaseOperands;
2570 SmallVector<APInt> newCaseValues;
2571 bool requiresChange = false;
2572
2573 auto caseValues = op.case_values();
2574 auto caseDests = op.caseDestinations();
2575 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2576 if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
2577 requiresChange = true;
2578 continue;
2579 }
2580 newCaseDestinations.push_back(caseDests[i]);
2581 newCaseOperands.push_back(op.getCaseOperands(i));
2582 newCaseValues.push_back(caseValues->getValue<APInt>(i));
2583 }
2584
2585 if (!requiresChange)
2586 return failure();
2587
2588 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2589 op.defaultOperands(), newCaseValues,
2590 newCaseDestinations, newCaseOperands);
2591 return success();
2592 }
2593
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2594 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
2595 MLIRContext *context) {
2596 results.add(&simplifySwitchWithOnlyDefault)
2597 .add(&dropSwitchCasesThatMatchDefault)
2598 .add(&simplifyConstSwitchValue)
2599 .add(&simplifyPassThroughSwitch)
2600 .add(&simplifySwitchFromSwitchOnSameCondition)
2601 .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
2602 }
2603
2604 //===----------------------------------------------------------------------===//
2605 // TruncateIOp
2606 //===----------------------------------------------------------------------===//
2607
verify(TruncateIOp op)2608 static LogicalResult verify(TruncateIOp op) {
2609 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2610 auto dstType = getElementTypeOrSelf(op.getType());
2611
2612 if (srcType.isa<IndexType>())
2613 return op.emitError() << srcType << " is not a valid operand type";
2614 if (dstType.isa<IndexType>())
2615 return op.emitError() << dstType << " is not a valid result type";
2616
2617 if (srcType.cast<IntegerType>().getWidth() <=
2618 dstType.cast<IntegerType>().getWidth())
2619 return op.emitError("operand type ")
2620 << srcType << " must be wider than result type " << dstType;
2621
2622 return success();
2623 }
2624
fold(ArrayRef<Attribute> operands)2625 OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
2626 // trunci(zexti(a)) -> a
2627 // trunci(sexti(a)) -> a
2628 if (matchPattern(getOperand(), m_Op<ZeroExtendIOp>()) ||
2629 matchPattern(getOperand(), m_Op<SignExtendIOp>()))
2630 return getOperand().getDefiningOp()->getOperand(0);
2631
2632 assert(operands.size() == 1 && "unary operation takes one operand");
2633
2634 if (!operands[0])
2635 return {};
2636
2637 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
2638
2639 return IntegerAttr::get(
2640 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
2641 }
2642
2643 return {};
2644 }
2645
2646 //===----------------------------------------------------------------------===//
2647 // UnsignedDivIOp
2648 //===----------------------------------------------------------------------===//
2649
fold(ArrayRef<Attribute> operands)2650 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
2651 assert(operands.size() == 2 && "binary operation takes two operands");
2652
2653 // Don't fold if it would require a division by zero.
2654 bool div0 = false;
2655 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2656 if (div0 || !b) {
2657 div0 = true;
2658 return a;
2659 }
2660 return a.udiv(b);
2661 });
2662
2663 // Fold out division by one. Assumes all tensors of all ones are splats.
2664 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2665 if (rhs.getValue() == 1)
2666 return lhs();
2667 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2668 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2669 return lhs();
2670 }
2671
2672 return div0 ? Attribute() : result;
2673 }
2674
2675 //===----------------------------------------------------------------------===//
2676 // UnsignedRemIOp
2677 //===----------------------------------------------------------------------===//
2678
fold(ArrayRef<Attribute> operands)2679 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
2680 assert(operands.size() == 2 && "remi_unsigned takes two operands");
2681
2682 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2683 if (!rhs)
2684 return {};
2685 auto rhsValue = rhs.getValue();
2686
2687 // x % 1 = 0
2688 if (rhsValue.isOneValue())
2689 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2690
2691 // Don't fold if it requires division by zero.
2692 if (rhsValue.isNullValue())
2693 return {};
2694
2695 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2696 if (!lhs)
2697 return {};
2698 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
2699 }
2700
2701 //===----------------------------------------------------------------------===//
2702 // XOrOp
2703 //===----------------------------------------------------------------------===//
2704
fold(ArrayRef<Attribute> operands)2705 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
2706 /// xor(x, 0) -> x
2707 if (matchPattern(rhs(), m_Zero()))
2708 return lhs();
2709 /// xor(x,x) -> 0
2710 if (lhs() == rhs())
2711 return Builder(getContext()).getZeroAttr(getType());
2712
2713 return constFoldBinaryOp<IntegerAttr>(operands,
2714 [](APInt a, APInt b) { return a ^ b; });
2715 }
2716
2717 namespace {
2718 /// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
2719 /// cmp ne A, B. Note that a logical not is implemented as xor 1, val.
2720 struct NotICmp : public OpRewritePattern<XOrOp> {
2721 using OpRewritePattern<XOrOp>::OpRewritePattern;
2722
matchAndRewrite__anon1787a6f51511::NotICmp2723 LogicalResult matchAndRewrite(XOrOp op,
2724 PatternRewriter &rewriter) const override {
2725 // Commutative ops (such as xor) have the constant appear second, which
2726 // we assume here.
2727
2728 APInt constValue;
2729 if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
2730 return failure();
2731
2732 if (constValue != 1)
2733 return failure();
2734
2735 auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
2736 if (!prev)
2737 return failure();
2738
2739 switch (prev.predicate()) {
2740 case CmpIPredicate::eq:
2741 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
2742 prev.rhs());
2743 return success();
2744 case CmpIPredicate::ne:
2745 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
2746 prev.rhs());
2747 return success();
2748
2749 case CmpIPredicate::slt:
2750 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
2751 prev.rhs());
2752 return success();
2753 case CmpIPredicate::sle:
2754 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
2755 prev.rhs());
2756 return success();
2757 case CmpIPredicate::sgt:
2758 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
2759 prev.rhs());
2760 return success();
2761 case CmpIPredicate::sge:
2762 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
2763 prev.rhs());
2764 return success();
2765
2766 case CmpIPredicate::ult:
2767 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
2768 prev.rhs());
2769 return success();
2770 case CmpIPredicate::ule:
2771 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
2772 prev.rhs());
2773 return success();
2774 case CmpIPredicate::ugt:
2775 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
2776 prev.rhs());
2777 return success();
2778 case CmpIPredicate::uge:
2779 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
2780 prev.rhs());
2781 return success();
2782 }
2783 return failure();
2784 }
2785 };
2786 } // namespace
2787
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2788 void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2789 MLIRContext *context) {
2790 results.insert<NotICmp>(context);
2791 }
2792
2793 //===----------------------------------------------------------------------===//
2794 // ZeroExtendIOp
2795 //===----------------------------------------------------------------------===//
2796
verify(ZeroExtendIOp op)2797 static LogicalResult verify(ZeroExtendIOp op) {
2798 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2799 auto dstType = getElementTypeOrSelf(op.getType());
2800
2801 if (srcType.isa<IndexType>())
2802 return op.emitError() << srcType << " is not a valid operand type";
2803 if (dstType.isa<IndexType>())
2804 return op.emitError() << dstType << " is not a valid result type";
2805
2806 if (srcType.cast<IntegerType>().getWidth() >=
2807 dstType.cast<IntegerType>().getWidth())
2808 return op.emitError("result type ")
2809 << dstType << " must be wider than operand type " << srcType;
2810
2811 return success();
2812 }
2813
2814 //===----------------------------------------------------------------------===//
2815 // TableGen'd op method definitions
2816 //===----------------------------------------------------------------------===//
2817
2818 #define GET_OP_CLASSES
2819 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
2820