1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30
31 #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
32
33 // Pull in all enum type definitions and utility function declarations.
34 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
35
36 using namespace mlir;
37
38 //===----------------------------------------------------------------------===//
39 // StandardOpsDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with standard
43 /// operations.
44 struct StdInlinerInterface : public DialectInlinerInterface {
45 using DialectInlinerInterface::DialectInlinerInterface;
46
47 //===--------------------------------------------------------------------===//
48 // Analysis Hooks
49 //===--------------------------------------------------------------------===//
50
51 /// All call operations within standard ops can be inlined.
isLegalToInline__anon476530360111::StdInlinerInterface52 bool isLegalToInline(Operation *call, Operation *callable,
53 bool wouldBeCloned) const final {
54 return true;
55 }
56
57 /// All operations within standard ops can be inlined.
isLegalToInline__anon476530360111::StdInlinerInterface58 bool isLegalToInline(Operation *, Region *, bool,
59 BlockAndValueMapping &) const final {
60 return true;
61 }
62
63 //===--------------------------------------------------------------------===//
64 // Transformation Hooks
65 //===--------------------------------------------------------------------===//
66
67 /// Handle the given inlined terminator by replacing it with a new operation
68 /// as necessary.
handleTerminator__anon476530360111::StdInlinerInterface69 void handleTerminator(Operation *op, Block *newDest) const final {
70 // Only "std.return" needs to be handled here.
71 auto returnOp = dyn_cast<ReturnOp>(op);
72 if (!returnOp)
73 return;
74
75 // Replace the return with a branch to the dest.
76 OpBuilder builder(op);
77 builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
78 op->erase();
79 }
80
81 /// Handle the given inlined terminator by replacing it with a new operation
82 /// as necessary.
handleTerminator__anon476530360111::StdInlinerInterface83 void handleTerminator(Operation *op,
84 ArrayRef<Value> valuesToRepl) const final {
85 // Only "std.return" needs to be handled here.
86 auto returnOp = cast<ReturnOp>(op);
87
88 // Replace the values directly with the return operands.
89 assert(returnOp.getNumOperands() == valuesToRepl.size());
90 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
91 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
92 }
93 };
94 } // end anonymous namespace
95
96 //===----------------------------------------------------------------------===//
97 // StandardOpsDialect
98 //===----------------------------------------------------------------------===//
99
100 /// A custom unary operation printer that omits the "std." prefix from the
101 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)102 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
103 assert(op->getNumOperands() == 1 && "unary op should have one operand");
104 assert(op->getNumResults() == 1 && "unary op should have one result");
105
106 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
107 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
108 << op->getOperand(0);
109 p.printOptionalAttrDict(op->getAttrs());
110 p << " : " << op->getOperand(0).getType();
111 }
112
113 /// A custom binary operation printer that omits the "std." prefix from the
114 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)115 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
116 assert(op->getNumOperands() == 2 && "binary op should have two operands");
117 assert(op->getNumResults() == 1 && "binary op should have one result");
118
119 // If not all the operand and result types are the same, just use the
120 // generic assembly form to avoid omitting information in printing.
121 auto resultType = op->getResult(0).getType();
122 if (op->getOperand(0).getType() != resultType ||
123 op->getOperand(1).getType() != resultType) {
124 p.printGenericOp(op);
125 return;
126 }
127
128 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
129 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
130 << op->getOperand(0) << ", " << op->getOperand(1);
131 p.printOptionalAttrDict(op->getAttrs());
132
133 // Now we can output only one type for all operands and the result.
134 p << " : " << op->getResult(0).getType();
135 }
136
137 /// A custom ternary operation printer that omits the "std." prefix from the
138 /// operation names.
printStandardTernaryOp(Operation * op,OpAsmPrinter & p)139 static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
140 assert(op->getNumOperands() == 3 && "ternary op should have three operands");
141 assert(op->getNumResults() == 1 && "ternary op should have one result");
142
143 // If not all the operand and result types are the same, just use the
144 // generic assembly form to avoid omitting information in printing.
145 auto resultType = op->getResult(0).getType();
146 if (op->getOperand(0).getType() != resultType ||
147 op->getOperand(1).getType() != resultType ||
148 op->getOperand(2).getType() != resultType) {
149 p.printGenericOp(op);
150 return;
151 }
152
153 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
154 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
155 << op->getOperand(0) << ", " << op->getOperand(1) << ", "
156 << op->getOperand(2);
157 p.printOptionalAttrDict(op->getAttrs());
158
159 // Now we can output only one type for all operands and the result.
160 p << " : " << op->getResult(0).getType();
161 }
162
163 /// A custom cast operation printer that omits the "std." prefix from the
164 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)165 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
166 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
167 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
168 << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
169 << op->getResult(0).getType();
170 }
171
initialize()172 void StandardOpsDialect::initialize() {
173 addOperations<
174 #define GET_OP_LIST
175 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
176 >();
177 addInterfaces<StdInlinerInterface>();
178 }
179
180 /// Materialize a single constant operation from a given attribute value with
181 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)182 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
183 Attribute value, Type type,
184 Location loc) {
185 return builder.create<ConstantOp>(loc, type, value);
186 }
187
188 //===----------------------------------------------------------------------===//
189 // Common cast compatibility check for vector types.
190 //===----------------------------------------------------------------------===//
191
192 /// This method checks for cast compatibility of vector types.
193 /// If 'a' and 'b' are vector types, and they are cast compatible,
194 /// it calls the 'areElementsCastCompatible' function to check for
195 /// element cast compatibility.
196 /// Returns 'true' if the vector types are cast compatible, and 'false'
197 /// otherwise.
areVectorCastSimpleCompatible(Type a,Type b,function_ref<bool (TypeRange,TypeRange)> areElementsCastCompatible)198 static bool areVectorCastSimpleCompatible(
199 Type a, Type b,
200 function_ref<bool(TypeRange, TypeRange)> areElementsCastCompatible) {
201 if (auto va = a.dyn_cast<VectorType>())
202 if (auto vb = b.dyn_cast<VectorType>())
203 return va.getShape().equals(vb.getShape()) &&
204 areElementsCastCompatible(va.getElementType(),
205 vb.getElementType());
206 return false;
207 }
208
209 //===----------------------------------------------------------------------===//
210 // AddFOp
211 //===----------------------------------------------------------------------===//
212
fold(ArrayRef<Attribute> operands)213 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
214 return constFoldBinaryOp<FloatAttr>(
215 operands, [](APFloat a, APFloat b) { return a + b; });
216 }
217
218 //===----------------------------------------------------------------------===//
219 // AddIOp
220 //===----------------------------------------------------------------------===//
221
fold(ArrayRef<Attribute> operands)222 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
223 /// addi(x, 0) -> x
224 if (matchPattern(rhs(), m_Zero()))
225 return lhs();
226
227 return constFoldBinaryOp<IntegerAttr>(operands,
228 [](APInt a, APInt b) { return a + b; });
229 }
230
231 /// Canonicalize a sum of a constant and (constant - something) to simply be
232 /// a sum of constants minus something. This transformation does similar
233 /// transformations for additions of a constant with a subtract/add of
234 /// a constant. This may result in some operations being reordered (but should
235 /// remain equivalent).
236 struct AddConstantReorder : public OpRewritePattern<AddIOp> {
237 using OpRewritePattern<AddIOp>::OpRewritePattern;
238
matchAndRewriteAddConstantReorder239 LogicalResult matchAndRewrite(AddIOp addop,
240 PatternRewriter &rewriter) const override {
241 for (int i = 0; i < 2; i++) {
242 APInt origConst;
243 APInt midConst;
244 if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
245 if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
246 for (int j = 0; j < 2; j++) {
247 if (matchPattern(midAddOp.getOperand(j),
248 m_ConstantInt(&midConst))) {
249 auto nextConstant = rewriter.create<ConstantOp>(
250 addop.getLoc(), rewriter.getIntegerAttr(
251 addop.getType(), origConst + midConst));
252 rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
253 midAddOp.getOperand(1 - j));
254 return success();
255 }
256 }
257 }
258 if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
259 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
260 auto nextConstant = rewriter.create<ConstantOp>(
261 addop.getLoc(),
262 rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
263 rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
264 midSubOp.getOperand(1));
265 return success();
266 }
267 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
268 auto nextConstant = rewriter.create<ConstantOp>(
269 addop.getLoc(),
270 rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
271 rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
272 midSubOp.getOperand(0));
273 return success();
274 }
275 }
276 }
277 }
278 return failure();
279 }
280 };
281
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)282 void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
283 MLIRContext *context) {
284 results.insert<AddConstantReorder>(context);
285 }
286
287 //===----------------------------------------------------------------------===//
288 // AndOp
289 //===----------------------------------------------------------------------===//
290
fold(ArrayRef<Attribute> operands)291 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
292 /// and(x, 0) -> 0
293 if (matchPattern(rhs(), m_Zero()))
294 return rhs();
295 /// and(x, allOnes) -> x
296 APInt intValue;
297 if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
298 intValue.isAllOnesValue())
299 return lhs();
300 /// and(x,x) -> x
301 if (lhs() == rhs())
302 return rhs();
303
304 return constFoldBinaryOp<IntegerAttr>(operands,
305 [](APInt a, APInt b) { return a & b; });
306 }
307
308 //===----------------------------------------------------------------------===//
309 // AssertOp
310 //===----------------------------------------------------------------------===//
311
canonicalize(AssertOp op,PatternRewriter & rewriter)312 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
313 // Erase assertion if argument is constant true.
314 if (matchPattern(op.arg(), m_One())) {
315 rewriter.eraseOp(op);
316 return success();
317 }
318 return failure();
319 }
320
321 //===----------------------------------------------------------------------===//
322 // AtomicRMWOp
323 //===----------------------------------------------------------------------===//
324
verify(AtomicRMWOp op)325 static LogicalResult verify(AtomicRMWOp op) {
326 if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
327 return op.emitOpError(
328 "expects the number of subscripts to be equal to memref rank");
329 switch (op.kind()) {
330 case AtomicRMWKind::addf:
331 case AtomicRMWKind::maxf:
332 case AtomicRMWKind::minf:
333 case AtomicRMWKind::mulf:
334 if (!op.value().getType().isa<FloatType>())
335 return op.emitOpError()
336 << "with kind '" << stringifyAtomicRMWKind(op.kind())
337 << "' expects a floating-point type";
338 break;
339 case AtomicRMWKind::addi:
340 case AtomicRMWKind::maxs:
341 case AtomicRMWKind::maxu:
342 case AtomicRMWKind::mins:
343 case AtomicRMWKind::minu:
344 case AtomicRMWKind::muli:
345 if (!op.value().getType().isa<IntegerType>())
346 return op.emitOpError()
347 << "with kind '" << stringifyAtomicRMWKind(op.kind())
348 << "' expects an integer type";
349 break;
350 default:
351 break;
352 }
353 return success();
354 }
355
356 /// Returns the identity value attribute associated with an AtomicRMWKind op.
getIdentityValueAttr(AtomicRMWKind kind,Type resultType,OpBuilder & builder,Location loc)357 Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
358 OpBuilder &builder, Location loc) {
359 switch (kind) {
360 case AtomicRMWKind::addf:
361 case AtomicRMWKind::addi:
362 return builder.getZeroAttr(resultType);
363 case AtomicRMWKind::muli:
364 return builder.getIntegerAttr(resultType, 1);
365 case AtomicRMWKind::mulf:
366 return builder.getFloatAttr(resultType, 1);
367 // TODO: Add remaining reduction operations.
368 default:
369 (void)emitOptionalError(loc, "Reduction operation type not supported");
370 break;
371 }
372 return nullptr;
373 }
374
375 /// Returns the identity value associated with an AtomicRMWKind op.
getIdentityValue(AtomicRMWKind op,Type resultType,OpBuilder & builder,Location loc)376 Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
377 OpBuilder &builder, Location loc) {
378 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
379 return builder.create<ConstantOp>(loc, attr);
380 }
381
382 /// Return the value obtained by applying the reduction operation kind
383 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
getReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value lhs,Value rhs)384 Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
385 Value lhs, Value rhs) {
386 switch (op) {
387 case AtomicRMWKind::addf:
388 return builder.create<AddFOp>(loc, lhs, rhs);
389 case AtomicRMWKind::addi:
390 return builder.create<AddIOp>(loc, lhs, rhs);
391 case AtomicRMWKind::mulf:
392 return builder.create<MulFOp>(loc, lhs, rhs);
393 case AtomicRMWKind::muli:
394 return builder.create<MulIOp>(loc, lhs, rhs);
395 // TODO: Add remaining reduction operations.
396 default:
397 (void)emitOptionalError(loc, "Reduction operation type not supported");
398 break;
399 }
400 return nullptr;
401 }
402
403 //===----------------------------------------------------------------------===//
404 // GenericAtomicRMWOp
405 //===----------------------------------------------------------------------===//
406
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)407 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
408 Value memref, ValueRange ivs) {
409 result.addOperands(memref);
410 result.addOperands(ivs);
411
412 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
413 Type elementType = memrefType.getElementType();
414 result.addTypes(elementType);
415
416 Region *bodyRegion = result.addRegion();
417 bodyRegion->push_back(new Block());
418 bodyRegion->addArgument(elementType);
419 }
420 }
421
verify(GenericAtomicRMWOp op)422 static LogicalResult verify(GenericAtomicRMWOp op) {
423 auto &body = op.body();
424 if (body.getNumArguments() != 1)
425 return op.emitOpError("expected single number of entry block arguments");
426
427 if (op.getResult().getType() != body.getArgument(0).getType())
428 return op.emitOpError(
429 "expected block argument of the same type result type");
430
431 bool hasSideEffects =
432 body.walk([&](Operation *nestedOp) {
433 if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
434 return WalkResult::advance();
435 nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
436 "only operations with no side effects");
437 return WalkResult::interrupt();
438 })
439 .wasInterrupted();
440 return hasSideEffects ? failure() : success();
441 }
442
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)443 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
444 OperationState &result) {
445 OpAsmParser::OperandType memref;
446 Type memrefType;
447 SmallVector<OpAsmParser::OperandType, 4> ivs;
448
449 Type indexType = parser.getBuilder().getIndexType();
450 if (parser.parseOperand(memref) ||
451 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
452 parser.parseColonType(memrefType) ||
453 parser.resolveOperand(memref, memrefType, result.operands) ||
454 parser.resolveOperands(ivs, indexType, result.operands))
455 return failure();
456
457 Region *body = result.addRegion();
458 if (parser.parseRegion(*body, llvm::None, llvm::None) ||
459 parser.parseOptionalAttrDict(result.attributes))
460 return failure();
461 result.types.push_back(memrefType.cast<MemRefType>().getElementType());
462 return success();
463 }
464
print(OpAsmPrinter & p,GenericAtomicRMWOp op)465 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
466 p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
467 << "] : " << op.memref().getType();
468 p.printRegion(op.body());
469 p.printOptionalAttrDict(op->getAttrs());
470 }
471
472 //===----------------------------------------------------------------------===//
473 // AtomicYieldOp
474 //===----------------------------------------------------------------------===//
475
verify(AtomicYieldOp op)476 static LogicalResult verify(AtomicYieldOp op) {
477 Type parentType = op->getParentOp()->getResultTypes().front();
478 Type resultType = op.result().getType();
479 if (parentType != resultType)
480 return op.emitOpError() << "types mismatch between yield op: " << resultType
481 << " and its parent: " << parentType;
482 return success();
483 }
484
485 //===----------------------------------------------------------------------===//
486 // BranchOp
487 //===----------------------------------------------------------------------===//
488
489 /// Given a successor, try to collapse it to a new destination if it only
490 /// contains a passthrough unconditional branch. If the successor is
491 /// collapsable, `successor` and `successorOperands` are updated to reference
492 /// the new destination and values. `argStorage` is used as storage if operands
493 /// to the collapsed successor need to be remapped. It must outlive uses of
494 /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)495 static LogicalResult collapseBranch(Block *&successor,
496 ValueRange &successorOperands,
497 SmallVectorImpl<Value> &argStorage) {
498 // Check that the successor only contains a unconditional branch.
499 if (std::next(successor->begin()) != successor->end())
500 return failure();
501 // Check that the terminator is an unconditional branch.
502 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
503 if (!successorBranch)
504 return failure();
505 // Check that the arguments are only used within the terminator.
506 for (BlockArgument arg : successor->getArguments()) {
507 for (Operation *user : arg.getUsers())
508 if (user != successorBranch)
509 return failure();
510 }
511 // Don't try to collapse branches to infinite loops.
512 Block *successorDest = successorBranch.getDest();
513 if (successorDest == successor)
514 return failure();
515
516 // Update the operands to the successor. If the branch parent has no
517 // arguments, we can use the branch operands directly.
518 OperandRange operands = successorBranch.getOperands();
519 if (successor->args_empty()) {
520 successor = successorDest;
521 successorOperands = operands;
522 return success();
523 }
524
525 // Otherwise, we need to remap any argument operands.
526 for (Value operand : operands) {
527 BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
528 if (argOperand && argOperand.getOwner() == successor)
529 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
530 else
531 argStorage.push_back(operand);
532 }
533 successor = successorDest;
534 successorOperands = argStorage;
535 return success();
536 }
537
538 /// Simplify a branch to a block that has a single predecessor. This effectively
539 /// merges the two blocks.
540 static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)541 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
542 // Check that the successor block has a single predecessor.
543 Block *succ = op.getDest();
544 Block *opParent = op->getBlock();
545 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
546 return failure();
547
548 // Merge the successor into the current block and erase the branch.
549 rewriter.mergeBlocks(succ, opParent, op.getOperands());
550 rewriter.eraseOp(op);
551 return success();
552 }
553
554 /// br ^bb1
555 /// ^bb1
556 /// br ^bbN(...)
557 ///
558 /// -> br ^bbN(...)
559 ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)560 static LogicalResult simplifyPassThroughBr(BranchOp op,
561 PatternRewriter &rewriter) {
562 Block *dest = op.getDest();
563 ValueRange destOperands = op.getOperands();
564 SmallVector<Value, 4> destOperandStorage;
565
566 // Try to collapse the successor if it points somewhere other than this
567 // block.
568 if (dest == op->getBlock() ||
569 failed(collapseBranch(dest, destOperands, destOperandStorage)))
570 return failure();
571
572 // Create a new branch with the collapsed successor.
573 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
574 return success();
575 }
576
canonicalize(BranchOp op,PatternRewriter & rewriter)577 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
578 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
579 succeeded(simplifyPassThroughBr(op, rewriter)));
580 }
581
getDest()582 Block *BranchOp::getDest() { return getSuccessor(); }
583
setDest(Block * block)584 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
585
eraseOperand(unsigned index)586 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
587
588 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)589 BranchOp::getMutableSuccessorOperands(unsigned index) {
590 assert(index == 0 && "invalid successor index");
591 return destOperandsMutable();
592 }
593
getSuccessorForOperands(ArrayRef<Attribute>)594 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
595
596 //===----------------------------------------------------------------------===//
597 // CallOp
598 //===----------------------------------------------------------------------===//
599
verifySymbolUses(SymbolTableCollection & symbolTable)600 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
601 // Check that the callee attribute was specified.
602 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
603 if (!fnAttr)
604 return emitOpError("requires a 'callee' symbol reference attribute");
605 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
606 if (!fn)
607 return emitOpError() << "'" << fnAttr.getValue()
608 << "' does not reference a valid function";
609
610 // Verify that the operand and result types match the callee.
611 auto fnType = fn.getType();
612 if (fnType.getNumInputs() != getNumOperands())
613 return emitOpError("incorrect number of operands for callee");
614
615 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
616 if (getOperand(i).getType() != fnType.getInput(i))
617 return emitOpError("operand type mismatch: expected operand type ")
618 << fnType.getInput(i) << ", but provided "
619 << getOperand(i).getType() << " for operand number " << i;
620
621 if (fnType.getNumResults() != getNumResults())
622 return emitOpError("incorrect number of results for callee");
623
624 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
625 if (getResult(i).getType() != fnType.getResult(i))
626 return emitOpError("result type mismatch");
627
628 return success();
629 }
630
getCalleeType()631 FunctionType CallOp::getCalleeType() {
632 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
633 }
634
635 //===----------------------------------------------------------------------===//
636 // CallIndirectOp
637 //===----------------------------------------------------------------------===//
638
639 /// Fold indirect calls that have a constant function as the callee operand.
canonicalize(CallIndirectOp indirectCall,PatternRewriter & rewriter)640 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
641 PatternRewriter &rewriter) {
642 // Check that the callee is a constant callee.
643 SymbolRefAttr calledFn;
644 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
645 return failure();
646
647 // Replace with a direct call.
648 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
649 indirectCall.getResultTypes(),
650 indirectCall.getArgOperands());
651 return success();
652 }
653
654 //===----------------------------------------------------------------------===//
655 // General helpers for comparison ops
656 //===----------------------------------------------------------------------===//
657
658 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)659 static Type getI1SameShape(Type type) {
660 auto i1Type = IntegerType::get(type.getContext(), 1);
661 if (auto tensorType = type.dyn_cast<RankedTensorType>())
662 return RankedTensorType::get(tensorType.getShape(), i1Type);
663 if (type.isa<UnrankedTensorType>())
664 return UnrankedTensorType::get(i1Type);
665 if (auto vectorType = type.dyn_cast<VectorType>())
666 return VectorType::get(vectorType.getShape(), i1Type);
667 return i1Type;
668 }
669
670 //===----------------------------------------------------------------------===//
671 // CmpIOp
672 //===----------------------------------------------------------------------===//
673
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)674 static void buildCmpIOp(OpBuilder &build, OperationState &result,
675 CmpIPredicate predicate, Value lhs, Value rhs) {
676 result.addOperands({lhs, rhs});
677 result.types.push_back(getI1SameShape(lhs.getType()));
678 result.addAttribute(CmpIOp::getPredicateAttrName(),
679 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
680 }
681
682 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
683 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)684 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
685 const APInt &rhs) {
686 switch (predicate) {
687 case CmpIPredicate::eq:
688 return lhs.eq(rhs);
689 case CmpIPredicate::ne:
690 return lhs.ne(rhs);
691 case CmpIPredicate::slt:
692 return lhs.slt(rhs);
693 case CmpIPredicate::sle:
694 return lhs.sle(rhs);
695 case CmpIPredicate::sgt:
696 return lhs.sgt(rhs);
697 case CmpIPredicate::sge:
698 return lhs.sge(rhs);
699 case CmpIPredicate::ult:
700 return lhs.ult(rhs);
701 case CmpIPredicate::ule:
702 return lhs.ule(rhs);
703 case CmpIPredicate::ugt:
704 return lhs.ugt(rhs);
705 case CmpIPredicate::uge:
706 return lhs.uge(rhs);
707 }
708 llvm_unreachable("unknown comparison predicate");
709 }
710
711 // Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(CmpIPredicate predicate)712 static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
713 switch (predicate) {
714 case CmpIPredicate::eq:
715 case CmpIPredicate::sle:
716 case CmpIPredicate::sge:
717 case CmpIPredicate::ule:
718 case CmpIPredicate::uge:
719 return true;
720 case CmpIPredicate::ne:
721 case CmpIPredicate::slt:
722 case CmpIPredicate::sgt:
723 case CmpIPredicate::ult:
724 case CmpIPredicate::ugt:
725 return false;
726 }
727 llvm_unreachable("unknown comparison predicate");
728 }
729
730 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)731 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
732 assert(operands.size() == 2 && "cmpi takes two arguments");
733
734 if (lhs() == rhs()) {
735 auto val = applyCmpPredicateToEqualOperands(getPredicate());
736 return BoolAttr::get(getContext(), val);
737 }
738
739 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
740 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
741 if (!lhs || !rhs)
742 return {};
743
744 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
745 return BoolAttr::get(getContext(), val);
746 }
747
748 //===----------------------------------------------------------------------===//
749 // CmpFOp
750 //===----------------------------------------------------------------------===//
751
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)752 static void buildCmpFOp(OpBuilder &build, OperationState &result,
753 CmpFPredicate predicate, Value lhs, Value rhs) {
754 result.addOperands({lhs, rhs});
755 result.types.push_back(getI1SameShape(lhs.getType()));
756 result.addAttribute(CmpFOp::getPredicateAttrName(),
757 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
758 }
759
760 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
761 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)762 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
763 const APFloat &rhs) {
764 auto cmpResult = lhs.compare(rhs);
765 switch (predicate) {
766 case CmpFPredicate::AlwaysFalse:
767 return false;
768 case CmpFPredicate::OEQ:
769 return cmpResult == APFloat::cmpEqual;
770 case CmpFPredicate::OGT:
771 return cmpResult == APFloat::cmpGreaterThan;
772 case CmpFPredicate::OGE:
773 return cmpResult == APFloat::cmpGreaterThan ||
774 cmpResult == APFloat::cmpEqual;
775 case CmpFPredicate::OLT:
776 return cmpResult == APFloat::cmpLessThan;
777 case CmpFPredicate::OLE:
778 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
779 case CmpFPredicate::ONE:
780 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
781 case CmpFPredicate::ORD:
782 return cmpResult != APFloat::cmpUnordered;
783 case CmpFPredicate::UEQ:
784 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
785 case CmpFPredicate::UGT:
786 return cmpResult == APFloat::cmpUnordered ||
787 cmpResult == APFloat::cmpGreaterThan;
788 case CmpFPredicate::UGE:
789 return cmpResult == APFloat::cmpUnordered ||
790 cmpResult == APFloat::cmpGreaterThan ||
791 cmpResult == APFloat::cmpEqual;
792 case CmpFPredicate::ULT:
793 return cmpResult == APFloat::cmpUnordered ||
794 cmpResult == APFloat::cmpLessThan;
795 case CmpFPredicate::ULE:
796 return cmpResult == APFloat::cmpUnordered ||
797 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
798 case CmpFPredicate::UNE:
799 return cmpResult != APFloat::cmpEqual;
800 case CmpFPredicate::UNO:
801 return cmpResult == APFloat::cmpUnordered;
802 case CmpFPredicate::AlwaysTrue:
803 return true;
804 }
805 llvm_unreachable("unknown comparison predicate");
806 }
807
808 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)809 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
810 assert(operands.size() == 2 && "cmpf takes two arguments");
811
812 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
813 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
814
815 // TODO: We could actually do some intelligent things if we know only one
816 // of the operands, but it's inf or nan.
817 if (!lhs || !rhs)
818 return {};
819
820 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
821 return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
822 }
823
824 //===----------------------------------------------------------------------===//
825 // CondBranchOp
826 //===----------------------------------------------------------------------===//
827
828 namespace {
829 /// cond_br true, ^bb1, ^bb2
830 /// -> br ^bb1
831 /// cond_br false, ^bb1, ^bb2
832 /// -> br ^bb2
833 ///
834 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
835 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
836
matchAndRewrite__anon476530360611::SimplifyConstCondBranchPred837 LogicalResult matchAndRewrite(CondBranchOp condbr,
838 PatternRewriter &rewriter) const override {
839 if (matchPattern(condbr.getCondition(), m_NonZero())) {
840 // True branch taken.
841 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
842 condbr.getTrueOperands());
843 return success();
844 } else if (matchPattern(condbr.getCondition(), m_Zero())) {
845 // False branch taken.
846 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
847 condbr.getFalseOperands());
848 return success();
849 }
850 return failure();
851 }
852 };
853
854 /// cond_br %cond, ^bb1, ^bb2
855 /// ^bb1
856 /// br ^bbN(...)
857 /// ^bb2
858 /// br ^bbK(...)
859 ///
860 /// -> cond_br %cond, ^bbN(...), ^bbK(...)
861 ///
862 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
863 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
864
matchAndRewrite__anon476530360611::SimplifyPassThroughCondBranch865 LogicalResult matchAndRewrite(CondBranchOp condbr,
866 PatternRewriter &rewriter) const override {
867 Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
868 ValueRange trueDestOperands = condbr.getTrueOperands();
869 ValueRange falseDestOperands = condbr.getFalseOperands();
870 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
871
872 // Try to collapse one of the current successors.
873 LogicalResult collapsedTrue =
874 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
875 LogicalResult collapsedFalse =
876 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
877 if (failed(collapsedTrue) && failed(collapsedFalse))
878 return failure();
879
880 // Create a new branch with the collapsed successors.
881 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
882 trueDest, trueDestOperands,
883 falseDest, falseDestOperands);
884 return success();
885 }
886 };
887
888 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
889 /// -> br ^bb1(A, ..., N)
890 ///
891 /// cond_br %cond, ^bb1(A), ^bb1(B)
892 /// -> %select = select %cond, A, B
893 /// br ^bb1(%select)
894 ///
895 struct SimplifyCondBranchIdenticalSuccessors
896 : public OpRewritePattern<CondBranchOp> {
897 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
898
matchAndRewrite__anon476530360611::SimplifyCondBranchIdenticalSuccessors899 LogicalResult matchAndRewrite(CondBranchOp condbr,
900 PatternRewriter &rewriter) const override {
901 // Check that the true and false destinations are the same and have the same
902 // operands.
903 Block *trueDest = condbr.trueDest();
904 if (trueDest != condbr.falseDest())
905 return failure();
906
907 // If all of the operands match, no selects need to be generated.
908 OperandRange trueOperands = condbr.getTrueOperands();
909 OperandRange falseOperands = condbr.getFalseOperands();
910 if (trueOperands == falseOperands) {
911 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
912 return success();
913 }
914
915 // Otherwise, if the current block is the only predecessor insert selects
916 // for any mismatched branch operands.
917 if (trueDest->getUniquePredecessor() != condbr->getBlock())
918 return failure();
919
920 // Generate a select for any operands that differ between the two.
921 SmallVector<Value, 8> mergedOperands;
922 mergedOperands.reserve(trueOperands.size());
923 Value condition = condbr.getCondition();
924 for (auto it : llvm::zip(trueOperands, falseOperands)) {
925 if (std::get<0>(it) == std::get<1>(it))
926 mergedOperands.push_back(std::get<0>(it));
927 else
928 mergedOperands.push_back(rewriter.create<SelectOp>(
929 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
930 }
931
932 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
933 return success();
934 }
935 };
936
937 /// ...
938 /// cond_br %cond, ^bb1(...), ^bb2(...)
939 /// ...
940 /// ^bb1: // has single predecessor
941 /// ...
942 /// cond_br %cond, ^bb3(...), ^bb4(...)
943 ///
944 /// ->
945 ///
946 /// ...
947 /// cond_br %cond, ^bb1(...), ^bb2(...)
948 /// ...
949 /// ^bb1: // has single predecessor
950 /// ...
951 /// br ^bb3(...)
952 ///
953 struct SimplifyCondBranchFromCondBranchOnSameCondition
954 : public OpRewritePattern<CondBranchOp> {
955 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
956
matchAndRewrite__anon476530360611::SimplifyCondBranchFromCondBranchOnSameCondition957 LogicalResult matchAndRewrite(CondBranchOp condbr,
958 PatternRewriter &rewriter) const override {
959 // Check that we have a single distinct predecessor.
960 Block *currentBlock = condbr->getBlock();
961 Block *predecessor = currentBlock->getSinglePredecessor();
962 if (!predecessor)
963 return failure();
964
965 // Check that the predecessor terminates with a conditional branch to this
966 // block and that it branches on the same condition.
967 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
968 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
969 return failure();
970
971 // Fold this branch to an unconditional branch.
972 if (currentBlock == predBranch.trueDest())
973 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.trueDest(),
974 condbr.trueDestOperands());
975 else
976 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.falseDest(),
977 condbr.falseDestOperands());
978 return success();
979 }
980 };
981
982 /// cond_br %arg0, ^trueB, ^falseB
983 ///
984 /// ^trueB:
985 /// "test.consumer1"(%arg0) : (i1) -> ()
986 /// ...
987 ///
988 /// ^falseB:
989 /// "test.consumer2"(%arg0) : (i1) -> ()
990 /// ...
991 ///
992 /// ->
993 ///
994 /// cond_br %arg0, ^trueB, ^falseB
995 /// ^trueB:
996 /// "test.consumer1"(%true) : (i1) -> ()
997 /// ...
998 ///
999 /// ^falseB:
1000 /// "test.consumer2"(%false) : (i1) -> ()
1001 /// ...
1002 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
1003 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1004
matchAndRewrite__anon476530360611::CondBranchTruthPropagation1005 LogicalResult matchAndRewrite(CondBranchOp condbr,
1006 PatternRewriter &rewriter) const override {
1007 // Check that we have a single distinct predecessor.
1008 bool replaced = false;
1009 Type ty = rewriter.getI1Type();
1010
1011 // These variables serve to prevent creating duplicate constants
1012 // and hold constant true or false values.
1013 Value constantTrue = nullptr;
1014 Value constantFalse = nullptr;
1015
1016 // TODO These checks can be expanded to encompas any use with only
1017 // either the true of false edge as a predecessor. For now, we fall
1018 // back to checking the single predecessor is given by the true/fasle
1019 // destination, thereby ensuring that only that edge can reach the
1020 // op.
1021 if (condbr.getTrueDest()->getSinglePredecessor()) {
1022 for (OpOperand &use :
1023 llvm::make_early_inc_range(condbr.condition().getUses())) {
1024 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
1025 replaced = true;
1026
1027 if (!constantTrue)
1028 constantTrue = rewriter.create<mlir::ConstantOp>(
1029 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
1030
1031 rewriter.updateRootInPlace(use.getOwner(),
1032 [&] { use.set(constantTrue); });
1033 }
1034 }
1035 }
1036 if (condbr.getFalseDest()->getSinglePredecessor()) {
1037 for (OpOperand &use :
1038 llvm::make_early_inc_range(condbr.condition().getUses())) {
1039 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
1040 replaced = true;
1041
1042 if (!constantFalse)
1043 constantFalse = rewriter.create<mlir::ConstantOp>(
1044 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
1045
1046 rewriter.updateRootInPlace(use.getOwner(),
1047 [&] { use.set(constantFalse); });
1048 }
1049 }
1050 }
1051 return success(replaced);
1052 }
1053 };
1054 } // end anonymous namespace
1055
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1056 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
1057 MLIRContext *context) {
1058 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1059 SimplifyCondBranchIdenticalSuccessors,
1060 SimplifyCondBranchFromCondBranchOnSameCondition,
1061 CondBranchTruthPropagation>(context);
1062 }
1063
1064 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1065 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1066 assert(index < getNumSuccessors() && "invalid successor index");
1067 return index == trueIndex ? trueDestOperandsMutable()
1068 : falseDestOperandsMutable();
1069 }
1070
getSuccessorForOperands(ArrayRef<Attribute> operands)1071 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1072 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1073 return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1074 return nullptr;
1075 }
1076
1077 //===----------------------------------------------------------------------===//
1078 // Constant*Op
1079 //===----------------------------------------------------------------------===//
1080
print(OpAsmPrinter & p,ConstantOp & op)1081 static void print(OpAsmPrinter &p, ConstantOp &op) {
1082 p << "constant ";
1083 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
1084
1085 if (op->getAttrs().size() > 1)
1086 p << ' ';
1087 p << op.getValue();
1088
1089 // If the value is a symbol reference or Array, print a trailing type.
1090 if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
1091 p << " : " << op.getType();
1092 }
1093
parseConstantOp(OpAsmParser & parser,OperationState & result)1094 static ParseResult parseConstantOp(OpAsmParser &parser,
1095 OperationState &result) {
1096 Attribute valueAttr;
1097 if (parser.parseOptionalAttrDict(result.attributes) ||
1098 parser.parseAttribute(valueAttr, "value", result.attributes))
1099 return failure();
1100
1101 // If the attribute is a symbol reference or array, then we expect a trailing
1102 // type.
1103 Type type;
1104 if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
1105 type = valueAttr.getType();
1106 else if (parser.parseColonType(type))
1107 return failure();
1108
1109 // Add the attribute type to the list.
1110 return parser.addTypeToList(type, result.types);
1111 }
1112
1113 /// The constant op requires an attribute, and furthermore requires that it
1114 /// matches the return type.
verify(ConstantOp & op)1115 static LogicalResult verify(ConstantOp &op) {
1116 auto value = op.getValue();
1117 if (!value)
1118 return op.emitOpError("requires a 'value' attribute");
1119
1120 Type type = op.getType();
1121 if (!value.getType().isa<NoneType>() && type != value.getType())
1122 return op.emitOpError() << "requires attribute's type (" << value.getType()
1123 << ") to match op's return type (" << type << ")";
1124
1125 if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1126 if (type.isa<IndexType>() || value.isa<BoolAttr>())
1127 return success();
1128 IntegerType intType = type.cast<IntegerType>();
1129 if (!intType.isSignless())
1130 return op.emitOpError("requires integer result types to be signless");
1131
1132 // If the type has a known bitwidth we verify that the value can be
1133 // represented with the given bitwidth.
1134 unsigned bitwidth = intType.getWidth();
1135 APInt intVal = intAttr.getValue();
1136 if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1137 return op.emitOpError("requires 'value' to be an integer within the "
1138 "range of the integer result type");
1139 return success();
1140 }
1141
1142 if (auto complexTy = type.dyn_cast<ComplexType>()) {
1143 auto arrayAttr = value.dyn_cast<ArrayAttr>();
1144 if (!complexTy || arrayAttr.size() != 2)
1145 return op.emitOpError(
1146 "requires 'value' to be a complex constant, represented as array of "
1147 "two values");
1148 auto complexEltTy = complexTy.getElementType();
1149 if (complexEltTy != arrayAttr[0].getType() ||
1150 complexEltTy != arrayAttr[1].getType()) {
1151 return op.emitOpError()
1152 << "requires attribute's element types (" << arrayAttr[0].getType()
1153 << ", " << arrayAttr[1].getType()
1154 << ") to match the element type of the op's return type ("
1155 << complexEltTy << ")";
1156 }
1157 return success();
1158 }
1159
1160 if (type.isa<FloatType>()) {
1161 if (!value.isa<FloatAttr>())
1162 return op.emitOpError("requires 'value' to be a floating point constant");
1163 return success();
1164 }
1165
1166 if (type.isa<ShapedType>()) {
1167 if (!value.isa<ElementsAttr>())
1168 return op.emitOpError("requires 'value' to be a shaped constant");
1169 return success();
1170 }
1171
1172 if (type.isa<FunctionType>()) {
1173 auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1174 if (!fnAttr)
1175 return op.emitOpError("requires 'value' to be a function reference");
1176
1177 // Try to find the referenced function.
1178 auto fn =
1179 op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1180 if (!fn)
1181 return op.emitOpError()
1182 << "reference to undefined function '" << fnAttr.getValue() << "'";
1183
1184 // Check that the referenced function has the correct type.
1185 if (fn.getType() != type)
1186 return op.emitOpError("reference to function with mismatched type");
1187
1188 return success();
1189 }
1190
1191 if (type.isa<NoneType>() && value.isa<UnitAttr>())
1192 return success();
1193
1194 return op.emitOpError("unsupported 'value' attribute: ") << value;
1195 }
1196
fold(ArrayRef<Attribute> operands)1197 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1198 assert(operands.empty() && "constant has no operands");
1199 return getValue();
1200 }
1201
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1202 void ConstantOp::getAsmResultNames(
1203 function_ref<void(Value, StringRef)> setNameFn) {
1204 Type type = getType();
1205 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1206 IntegerType intTy = type.dyn_cast<IntegerType>();
1207
1208 // Sugar i1 constants with 'true' and 'false'.
1209 if (intTy && intTy.getWidth() == 1)
1210 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1211
1212 // Otherwise, build a complex name with the value and type.
1213 SmallString<32> specialNameBuffer;
1214 llvm::raw_svector_ostream specialName(specialNameBuffer);
1215 specialName << 'c' << intCst.getInt();
1216 if (intTy)
1217 specialName << '_' << type;
1218 setNameFn(getResult(), specialName.str());
1219
1220 } else if (type.isa<FunctionType>()) {
1221 setNameFn(getResult(), "f");
1222 } else {
1223 setNameFn(getResult(), "cst");
1224 }
1225 }
1226
1227 /// Returns true if a constant operation can be built with the given value and
1228 /// result type.
isBuildableWith(Attribute value,Type type)1229 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1230 // SymbolRefAttr can only be used with a function type.
1231 if (value.isa<SymbolRefAttr>())
1232 return type.isa<FunctionType>();
1233 // The attribute must have the same type as 'type'.
1234 if (!value.getType().isa<NoneType>() && value.getType() != type)
1235 return false;
1236 // If the type is an integer type, it must be signless.
1237 if (IntegerType integerTy = type.dyn_cast<IntegerType>())
1238 if (!integerTy.isSignless())
1239 return false;
1240 // Finally, check that the attribute kind is handled.
1241 if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
1242 auto complexTy = type.dyn_cast<ComplexType>();
1243 if (!complexTy)
1244 return false;
1245 auto complexEltTy = complexTy.getElementType();
1246 return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
1247 arrAttr[1].getType() == complexEltTy;
1248 }
1249 return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1250 }
1251
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1252 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1253 const APFloat &value, FloatType type) {
1254 ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1255 }
1256
classof(Operation * op)1257 bool ConstantFloatOp::classof(Operation *op) {
1258 return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1259 }
1260
1261 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1262 bool ConstantIntOp::classof(Operation *op) {
1263 return ConstantOp::classof(op) &&
1264 op->getResult(0).getType().isSignlessInteger();
1265 }
1266
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1267 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1268 int64_t value, unsigned width) {
1269 Type type = builder.getIntegerType(width);
1270 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1271 }
1272
1273 /// Build a constant int op producing an integer with the specified type,
1274 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1275 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1276 int64_t value, Type type) {
1277 assert(type.isSignlessInteger() &&
1278 "ConstantIntOp can only have signless integer type");
1279 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1280 }
1281
1282 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1283 bool ConstantIndexOp::classof(Operation *op) {
1284 return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1285 }
1286
build(OpBuilder & builder,OperationState & result,int64_t value)1287 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1288 int64_t value) {
1289 Type type = builder.getIndexType();
1290 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1291 }
1292
1293 // ---------------------------------------------------------------------------
1294 // DivFOp
1295 // ---------------------------------------------------------------------------
1296
fold(ArrayRef<Attribute> operands)1297 OpFoldResult DivFOp::fold(ArrayRef<Attribute> operands) {
1298 return constFoldBinaryOp<FloatAttr>(
1299 operands, [](APFloat a, APFloat b) { return a / b; });
1300 }
1301
1302 //===----------------------------------------------------------------------===//
1303 // FPExtOp
1304 //===----------------------------------------------------------------------===//
1305
areCastCompatible(TypeRange inputs,TypeRange outputs)1306 bool FPExtOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1307 if (inputs.size() != 1 || outputs.size() != 1)
1308 return false;
1309 Type a = inputs.front(), b = outputs.front();
1310 if (auto fa = a.dyn_cast<FloatType>())
1311 if (auto fb = b.dyn_cast<FloatType>())
1312 return fa.getWidth() < fb.getWidth();
1313 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1314 }
1315
1316 //===----------------------------------------------------------------------===//
1317 // FPToSIOp
1318 //===----------------------------------------------------------------------===//
1319
areCastCompatible(TypeRange inputs,TypeRange outputs)1320 bool FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1321 if (inputs.size() != 1 || outputs.size() != 1)
1322 return false;
1323 Type a = inputs.front(), b = outputs.front();
1324 if (a.isa<FloatType>() && b.isSignlessInteger())
1325 return true;
1326 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1327 }
1328
1329 //===----------------------------------------------------------------------===//
1330 // FPToUIOp
1331 //===----------------------------------------------------------------------===//
1332
areCastCompatible(TypeRange inputs,TypeRange outputs)1333 bool FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1334 if (inputs.size() != 1 || outputs.size() != 1)
1335 return false;
1336 Type a = inputs.front(), b = outputs.front();
1337 if (a.isa<FloatType>() && b.isSignlessInteger())
1338 return true;
1339 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1340 }
1341
1342 //===----------------------------------------------------------------------===//
1343 // FPTruncOp
1344 //===----------------------------------------------------------------------===//
1345
areCastCompatible(TypeRange inputs,TypeRange outputs)1346 bool FPTruncOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1347 if (inputs.size() != 1 || outputs.size() != 1)
1348 return false;
1349 Type a = inputs.front(), b = outputs.front();
1350 if (auto fa = a.dyn_cast<FloatType>())
1351 if (auto fb = b.dyn_cast<FloatType>())
1352 return fa.getWidth() > fb.getWidth();
1353 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1354 }
1355
1356 //===----------------------------------------------------------------------===//
1357 // IndexCastOp
1358 //===----------------------------------------------------------------------===//
1359
1360 // Index cast is applicable from index to integer and backwards.
areCastCompatible(TypeRange inputs,TypeRange outputs)1361 bool IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1362 if (inputs.size() != 1 || outputs.size() != 1)
1363 return false;
1364 Type a = inputs.front(), b = outputs.front();
1365 if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
1366 auto aShaped = a.cast<ShapedType>();
1367 auto bShaped = b.cast<ShapedType>();
1368
1369 return (aShaped.getShape() == bShaped.getShape()) &&
1370 areCastCompatible(aShaped.getElementType(),
1371 bShaped.getElementType());
1372 }
1373
1374 return (a.isIndex() && b.isSignlessInteger()) ||
1375 (a.isSignlessInteger() && b.isIndex());
1376 }
1377
fold(ArrayRef<Attribute> cstOperands)1378 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1379 // Fold IndexCast(IndexCast(x)) -> x
1380 auto cast = getOperand().getDefiningOp<IndexCastOp>();
1381 if (cast && cast.getOperand().getType() == getType())
1382 return cast.getOperand();
1383
1384 // Fold IndexCast(constant) -> constant
1385 // A little hack because we go through int. Otherwise, the size
1386 // of the constant might need to change.
1387 if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1388 return IntegerAttr::get(getType(), value.getInt());
1389
1390 return {};
1391 }
1392
1393 namespace {
1394 /// index_cast(sign_extend x) => index_cast(x)
1395 struct IndexCastOfSExt : public OpRewritePattern<IndexCastOp> {
1396 using OpRewritePattern<IndexCastOp>::OpRewritePattern;
1397
matchAndRewrite__anon476530360a11::IndexCastOfSExt1398 LogicalResult matchAndRewrite(IndexCastOp op,
1399 PatternRewriter &rewriter) const override {
1400
1401 if (auto extop = op.getOperand().getDefiningOp<SignExtendIOp>()) {
1402 op.setOperand(extop.getOperand());
1403 return success();
1404 }
1405 return failure();
1406 }
1407 };
1408
1409 } // namespace
1410
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1411 void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1412 MLIRContext *context) {
1413 results.insert<IndexCastOfSExt>(context);
1414 }
1415
1416 //===----------------------------------------------------------------------===//
1417 // MulFOp
1418 //===----------------------------------------------------------------------===//
1419
fold(ArrayRef<Attribute> operands)1420 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1421 return constFoldBinaryOp<FloatAttr>(
1422 operands, [](APFloat a, APFloat b) { return a * b; });
1423 }
1424
1425 //===----------------------------------------------------------------------===//
1426 // MulIOp
1427 //===----------------------------------------------------------------------===//
1428
fold(ArrayRef<Attribute> operands)1429 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1430 /// muli(x, 0) -> 0
1431 if (matchPattern(rhs(), m_Zero()))
1432 return rhs();
1433 /// muli(x, 1) -> x
1434 if (matchPattern(rhs(), m_One()))
1435 return getOperand(0);
1436
1437 // TODO: Handle the overflow case.
1438 return constFoldBinaryOp<IntegerAttr>(operands,
1439 [](APInt a, APInt b) { return a * b; });
1440 }
1441
1442 //===----------------------------------------------------------------------===//
1443 // OrOp
1444 //===----------------------------------------------------------------------===//
1445
fold(ArrayRef<Attribute> operands)1446 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1447 /// or(x, 0) -> x
1448 if (matchPattern(rhs(), m_Zero()))
1449 return lhs();
1450 /// or(x,x) -> x
1451 if (lhs() == rhs())
1452 return rhs();
1453
1454 return constFoldBinaryOp<IntegerAttr>(operands,
1455 [](APInt a, APInt b) { return a | b; });
1456 }
1457
1458 //===----------------------------------------------------------------------===//
1459 // RankOp
1460 //===----------------------------------------------------------------------===//
1461
fold(ArrayRef<Attribute> operands)1462 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1463 // Constant fold rank when the rank of the operand is known.
1464 auto type = getOperand().getType();
1465 if (auto shapedType = type.dyn_cast<ShapedType>())
1466 if (shapedType.hasRank())
1467 return IntegerAttr::get(IndexType::get(getContext()),
1468 shapedType.getRank());
1469 return IntegerAttr();
1470 }
1471
1472 //===----------------------------------------------------------------------===//
1473 // ReturnOp
1474 //===----------------------------------------------------------------------===//
1475
verify(ReturnOp op)1476 static LogicalResult verify(ReturnOp op) {
1477 auto function = cast<FuncOp>(op->getParentOp());
1478
1479 // The operand number and types must match the function signature.
1480 const auto &results = function.getType().getResults();
1481 if (op.getNumOperands() != results.size())
1482 return op.emitOpError("has ")
1483 << op.getNumOperands() << " operands, but enclosing function (@"
1484 << function.getName() << ") returns " << results.size();
1485
1486 for (unsigned i = 0, e = results.size(); i != e; ++i)
1487 if (op.getOperand(i).getType() != results[i])
1488 return op.emitError()
1489 << "type of return operand " << i << " ("
1490 << op.getOperand(i).getType()
1491 << ") doesn't match function result type (" << results[i] << ")"
1492 << " in function @" << function.getName();
1493
1494 return success();
1495 }
1496
1497 //===----------------------------------------------------------------------===//
1498 // SelectOp
1499 //===----------------------------------------------------------------------===//
1500
1501 // Transforms a select to a not, where relevant.
1502 //
1503 // select %arg, %false, %true
1504 //
1505 // becomes
1506 //
1507 // xor %arg, %true
1508 struct SelectToNot : public OpRewritePattern<SelectOp> {
1509 using OpRewritePattern<SelectOp>::OpRewritePattern;
1510
matchAndRewriteSelectToNot1511 LogicalResult matchAndRewrite(SelectOp op,
1512 PatternRewriter &rewriter) const override {
1513 if (!matchPattern(op.getTrueValue(), m_Zero()))
1514 return failure();
1515
1516 if (!matchPattern(op.getFalseValue(), m_One()))
1517 return failure();
1518
1519 if (!op.getType().isInteger(1))
1520 return failure();
1521
1522 rewriter.replaceOpWithNewOp<XOrOp>(op, op.condition(), op.getFalseValue());
1523 return success();
1524 }
1525 };
1526
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1527 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1528 MLIRContext *context) {
1529 results.insert<SelectToNot>(context);
1530 }
1531
fold(ArrayRef<Attribute> operands)1532 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
1533 auto trueVal = getTrueValue();
1534 auto falseVal = getFalseValue();
1535 if (trueVal == falseVal)
1536 return trueVal;
1537
1538 auto condition = getCondition();
1539
1540 // select true, %0, %1 => %0
1541 if (matchPattern(condition, m_One()))
1542 return trueVal;
1543
1544 // select false, %0, %1 => %1
1545 if (matchPattern(condition, m_Zero()))
1546 return falseVal;
1547
1548 if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
1549 auto pred = cmp.predicate();
1550 if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
1551 auto cmpLhs = cmp.lhs();
1552 auto cmpRhs = cmp.rhs();
1553
1554 // %0 = cmpi eq, %arg0, %arg1
1555 // %1 = select %0, %arg0, %arg1 => %arg1
1556
1557 // %0 = cmpi ne, %arg0, %arg1
1558 // %1 = select %0, %arg0, %arg1 => %arg0
1559
1560 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1561 (cmpRhs == trueVal && cmpLhs == falseVal))
1562 return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
1563 }
1564 }
1565 return nullptr;
1566 }
1567
print(OpAsmPrinter & p,SelectOp op)1568 static void print(OpAsmPrinter &p, SelectOp op) {
1569 p << "select " << op.getOperands();
1570 p.printOptionalAttrDict(op->getAttrs());
1571 p << " : ";
1572 if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
1573 p << condType << ", ";
1574 p << op.getType();
1575 }
1576
parseSelectOp(OpAsmParser & parser,OperationState & result)1577 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
1578 Type conditionType, resultType;
1579 SmallVector<OpAsmParser::OperandType, 3> operands;
1580 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1581 parser.parseOptionalAttrDict(result.attributes) ||
1582 parser.parseColonType(resultType))
1583 return failure();
1584
1585 // Check for the explicit condition type if this is a masked tensor or vector.
1586 if (succeeded(parser.parseOptionalComma())) {
1587 conditionType = resultType;
1588 if (parser.parseType(resultType))
1589 return failure();
1590 } else {
1591 conditionType = parser.getBuilder().getI1Type();
1592 }
1593
1594 result.addTypes(resultType);
1595 return parser.resolveOperands(operands,
1596 {conditionType, resultType, resultType},
1597 parser.getNameLoc(), result.operands);
1598 }
1599
verify(SelectOp op)1600 static LogicalResult verify(SelectOp op) {
1601 Type conditionType = op.getCondition().getType();
1602 if (conditionType.isSignlessInteger(1))
1603 return success();
1604
1605 // If the result type is a vector or tensor, the type can be a mask with the
1606 // same elements.
1607 Type resultType = op.getType();
1608 if (!resultType.isa<TensorType, VectorType>())
1609 return op.emitOpError()
1610 << "expected condition to be a signless i1, but got "
1611 << conditionType;
1612 Type shapedConditionType = getI1SameShape(resultType);
1613 if (conditionType != shapedConditionType)
1614 return op.emitOpError()
1615 << "expected condition type to have the same shape "
1616 "as the result type, expected "
1617 << shapedConditionType << ", but got " << conditionType;
1618 return success();
1619 }
1620
1621 //===----------------------------------------------------------------------===//
1622 // SignExtendIOp
1623 //===----------------------------------------------------------------------===//
1624
verify(SignExtendIOp op)1625 static LogicalResult verify(SignExtendIOp op) {
1626 // Get the scalar type (which is either directly the type of the operand
1627 // or the vector's/tensor's element type.
1628 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
1629 auto dstType = getElementTypeOrSelf(op.getType());
1630
1631 // For now, index is forbidden for the source and the destination type.
1632 if (srcType.isa<IndexType>())
1633 return op.emitError() << srcType << " is not a valid operand type";
1634 if (dstType.isa<IndexType>())
1635 return op.emitError() << dstType << " is not a valid result type";
1636
1637 if (srcType.cast<IntegerType>().getWidth() >=
1638 dstType.cast<IntegerType>().getWidth())
1639 return op.emitError("result type ")
1640 << dstType << " must be wider than operand type " << srcType;
1641
1642 return success();
1643 }
1644
fold(ArrayRef<Attribute> operands)1645 OpFoldResult SignExtendIOp::fold(ArrayRef<Attribute> operands) {
1646 assert(operands.size() == 1 && "unary operation takes one operand");
1647
1648 if (!operands[0])
1649 return {};
1650
1651 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
1652 return IntegerAttr::get(
1653 getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
1654 }
1655
1656 return {};
1657 }
1658
1659 //===----------------------------------------------------------------------===//
1660 // SignedDivIOp
1661 //===----------------------------------------------------------------------===//
1662
fold(ArrayRef<Attribute> operands)1663 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
1664 assert(operands.size() == 2 && "binary operation takes two operands");
1665
1666 // Don't fold if it would overflow or if it requires a division by zero.
1667 bool overflowOrDiv0 = false;
1668 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1669 if (overflowOrDiv0 || !b) {
1670 overflowOrDiv0 = true;
1671 return a;
1672 }
1673 return a.sdiv_ov(b, overflowOrDiv0);
1674 });
1675
1676 // Fold out division by one. Assumes all tensors of all ones are splats.
1677 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1678 if (rhs.getValue() == 1)
1679 return lhs();
1680 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1681 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1682 return lhs();
1683 }
1684
1685 return overflowOrDiv0 ? Attribute() : result;
1686 }
1687
1688 //===----------------------------------------------------------------------===//
1689 // SignedFloorDivIOp
1690 //===----------------------------------------------------------------------===//
1691
signedCeilNonnegInputs(APInt a,APInt b,bool & overflow)1692 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
1693 // Returns (a-1)/b + 1
1694 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
1695 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
1696 return val.sadd_ov(one, overflow);
1697 }
1698
fold(ArrayRef<Attribute> operands)1699 OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
1700 assert(operands.size() == 2 && "binary operation takes two operands");
1701
1702 // Don't fold if it would overflow or if it requires a division by zero.
1703 bool overflowOrDiv0 = false;
1704 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1705 if (overflowOrDiv0 || !b) {
1706 overflowOrDiv0 = true;
1707 return a;
1708 }
1709 unsigned bits = a.getBitWidth();
1710 APInt zero = APInt::getNullValue(bits);
1711 if (a.sge(zero) && b.sgt(zero)) {
1712 // Both positive (or a is zero), return a / b.
1713 return a.sdiv_ov(b, overflowOrDiv0);
1714 } else if (a.sle(zero) && b.slt(zero)) {
1715 // Both negative (or a is zero), return -a / -b.
1716 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1717 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1718 return posA.sdiv_ov(posB, overflowOrDiv0);
1719 } else if (a.slt(zero) && b.sgt(zero)) {
1720 // A is negative, b is positive, return - ceil(-a, b).
1721 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1722 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
1723 return zero.ssub_ov(ceil, overflowOrDiv0);
1724 } else {
1725 // A is positive, b is negative, return - ceil(a, -b).
1726 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1727 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
1728 return zero.ssub_ov(ceil, overflowOrDiv0);
1729 }
1730 });
1731
1732 // Fold out floor division by one. Assumes all tensors of all ones are
1733 // splats.
1734 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1735 if (rhs.getValue() == 1)
1736 return lhs();
1737 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1738 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1739 return lhs();
1740 }
1741
1742 return overflowOrDiv0 ? Attribute() : result;
1743 }
1744
1745 //===----------------------------------------------------------------------===//
1746 // SignedCeilDivIOp
1747 //===----------------------------------------------------------------------===//
1748
fold(ArrayRef<Attribute> operands)1749 OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
1750 assert(operands.size() == 2 && "binary operation takes two operands");
1751
1752 // Don't fold if it would overflow or if it requires a division by zero.
1753 bool overflowOrDiv0 = false;
1754 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
1755 if (overflowOrDiv0 || !b) {
1756 overflowOrDiv0 = true;
1757 return a;
1758 }
1759 unsigned bits = a.getBitWidth();
1760 APInt zero = APInt::getNullValue(bits);
1761 if (a.sgt(zero) && b.sgt(zero)) {
1762 // Both positive, return ceil(a, b).
1763 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
1764 } else if (a.slt(zero) && b.slt(zero)) {
1765 // Both negative, return ceil(-a, -b).
1766 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1767 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1768 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
1769 } else if (a.slt(zero) && b.sgt(zero)) {
1770 // A is negative, b is positive, return - ( -a / b).
1771 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
1772 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
1773 return zero.ssub_ov(div, overflowOrDiv0);
1774 } else {
1775 // A is positive (or zero), b is negative, return - (a / -b).
1776 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
1777 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
1778 return zero.ssub_ov(div, overflowOrDiv0);
1779 }
1780 });
1781
1782 // Fold out floor division by one. Assumes all tensors of all ones are
1783 // splats.
1784 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
1785 if (rhs.getValue() == 1)
1786 return lhs();
1787 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
1788 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
1789 return lhs();
1790 }
1791
1792 return overflowOrDiv0 ? Attribute() : result;
1793 }
1794
1795 //===----------------------------------------------------------------------===//
1796 // SignedRemIOp
1797 //===----------------------------------------------------------------------===//
1798
fold(ArrayRef<Attribute> operands)1799 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
1800 assert(operands.size() == 2 && "remi_signed takes two operands");
1801
1802 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
1803 if (!rhs)
1804 return {};
1805 auto rhsValue = rhs.getValue();
1806
1807 // x % 1 = 0
1808 if (rhsValue.isOneValue())
1809 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
1810
1811 // Don't fold if it requires division by zero.
1812 if (rhsValue.isNullValue())
1813 return {};
1814
1815 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1816 if (!lhs)
1817 return {};
1818 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
1819 }
1820
1821 //===----------------------------------------------------------------------===//
1822 // SIToFPOp
1823 //===----------------------------------------------------------------------===//
1824
1825 // sitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)1826 bool SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1827 if (inputs.size() != 1 || outputs.size() != 1)
1828 return false;
1829 Type a = inputs.front(), b = outputs.front();
1830 if (a.isSignlessInteger() && b.isa<FloatType>())
1831 return true;
1832 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
1833 }
1834
1835 //===----------------------------------------------------------------------===//
1836 // SplatOp
1837 //===----------------------------------------------------------------------===//
1838
verify(SplatOp op)1839 static LogicalResult verify(SplatOp op) {
1840 // TODO: we could replace this by a trait.
1841 if (op.getOperand().getType() !=
1842 op.getType().cast<ShapedType>().getElementType())
1843 return op.emitError("operand should be of elemental type of result type");
1844
1845 return success();
1846 }
1847
1848 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)1849 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
1850 assert(operands.size() == 1 && "splat takes one operand");
1851
1852 auto constOperand = operands.front();
1853 if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
1854 return {};
1855
1856 auto shapedType = getType().cast<ShapedType>();
1857 assert(shapedType.getElementType() == constOperand.getType() &&
1858 "incorrect input attribute type for folding");
1859
1860 // SplatElementsAttr::get treats single value for second arg as being a splat.
1861 return SplatElementsAttr::get(shapedType, {constOperand});
1862 }
1863
1864 //===----------------------------------------------------------------------===//
1865 // SubFOp
1866 //===----------------------------------------------------------------------===//
1867
fold(ArrayRef<Attribute> operands)1868 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
1869 return constFoldBinaryOp<FloatAttr>(
1870 operands, [](APFloat a, APFloat b) { return a - b; });
1871 }
1872
1873 //===----------------------------------------------------------------------===//
1874 // SubIOp
1875 //===----------------------------------------------------------------------===//
1876
fold(ArrayRef<Attribute> operands)1877 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
1878 // subi(x,x) -> 0
1879 if (getOperand(0) == getOperand(1))
1880 return Builder(getContext()).getZeroAttr(getType());
1881 // subi(x,0) -> x
1882 if (matchPattern(rhs(), m_Zero()))
1883 return lhs();
1884
1885 return constFoldBinaryOp<IntegerAttr>(operands,
1886 [](APInt a, APInt b) { return a - b; });
1887 }
1888
1889 /// Canonicalize a sub of a constant and (constant +/- something) to simply be
1890 /// a single operation that merges the two constants.
1891 struct SubConstantReorder : public OpRewritePattern<SubIOp> {
1892 using OpRewritePattern<SubIOp>::OpRewritePattern;
1893
matchAndRewriteSubConstantReorder1894 LogicalResult matchAndRewrite(SubIOp subOp,
1895 PatternRewriter &rewriter) const override {
1896 APInt origConst;
1897 APInt midConst;
1898
1899 if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
1900 if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
1901 // origConst - (midConst + something) == (origConst - midConst) -
1902 // something
1903 for (int j = 0; j < 2; j++) {
1904 if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
1905 auto nextConstant = rewriter.create<ConstantOp>(
1906 subOp.getLoc(),
1907 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1908 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1909 midAddOp.getOperand(1 - j));
1910 return success();
1911 }
1912 }
1913 }
1914
1915 if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
1916 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1917 // (midConst - something) - origConst == (midConst - origConst) -
1918 // something
1919 auto nextConstant = rewriter.create<ConstantOp>(
1920 subOp.getLoc(),
1921 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1922 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1923 midSubOp.getOperand(1));
1924 return success();
1925 }
1926
1927 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1928 // (something - midConst) - origConst == something - (origConst +
1929 // midConst)
1930 auto nextConstant = rewriter.create<ConstantOp>(
1931 subOp.getLoc(),
1932 rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
1933 rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
1934 nextConstant);
1935 return success();
1936 }
1937 }
1938
1939 if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
1940 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1941 // origConst - (midConst - something) == (origConst - midConst) +
1942 // something
1943 auto nextConstant = rewriter.create<ConstantOp>(
1944 subOp.getLoc(),
1945 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1946 rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1947 midSubOp.getOperand(1));
1948 return success();
1949 }
1950
1951 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1952 // origConst - (something - midConst) == (origConst + midConst) -
1953 // something
1954 auto nextConstant = rewriter.create<ConstantOp>(
1955 subOp.getLoc(),
1956 rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
1957 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1958 midSubOp.getOperand(0));
1959 return success();
1960 }
1961 }
1962 }
1963
1964 if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
1965 if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
1966 // (midConst + something) - origConst == (midConst - origConst) +
1967 // something
1968 for (int j = 0; j < 2; j++) {
1969 if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
1970 auto nextConstant = rewriter.create<ConstantOp>(
1971 subOp.getLoc(),
1972 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1973 rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1974 midAddOp.getOperand(1 - j));
1975 return success();
1976 }
1977 }
1978 }
1979
1980 if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
1981 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1982 // (midConst - something) - origConst == (midConst - origConst) -
1983 // something
1984 auto nextConstant = rewriter.create<ConstantOp>(
1985 subOp.getLoc(),
1986 rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1987 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1988 midSubOp.getOperand(1));
1989 return success();
1990 }
1991
1992 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1993 // (something - midConst) - origConst == something - (midConst +
1994 // origConst)
1995 auto nextConstant = rewriter.create<ConstantOp>(
1996 subOp.getLoc(),
1997 rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
1998 rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
1999 nextConstant);
2000 return success();
2001 }
2002 }
2003
2004 if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
2005 if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
2006 // origConst - (midConst - something) == (origConst - midConst) +
2007 // something
2008 auto nextConstant = rewriter.create<ConstantOp>(
2009 subOp.getLoc(),
2010 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2011 rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
2012 midSubOp.getOperand(1));
2013 return success();
2014 }
2015 if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
2016 // origConst - (something - midConst) == (origConst - midConst) -
2017 // something
2018 auto nextConstant = rewriter.create<ConstantOp>(
2019 subOp.getLoc(),
2020 rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
2021 rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
2022 midSubOp.getOperand(0));
2023 return success();
2024 }
2025 }
2026 }
2027 return failure();
2028 }
2029 };
2030
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2031 void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2032 MLIRContext *context) {
2033 results.insert<SubConstantReorder>(context);
2034 }
2035
2036 //===----------------------------------------------------------------------===//
2037 // UIToFPOp
2038 //===----------------------------------------------------------------------===//
2039
2040 // uitofp is applicable from integer types to float types.
areCastCompatible(TypeRange inputs,TypeRange outputs)2041 bool UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2042 if (inputs.size() != 1 || outputs.size() != 1)
2043 return false;
2044 Type a = inputs.front(), b = outputs.front();
2045 if (a.isSignlessInteger() && b.isa<FloatType>())
2046 return true;
2047 return areVectorCastSimpleCompatible(a, b, areCastCompatible);
2048 }
2049
2050 //===----------------------------------------------------------------------===//
2051 // SwitchOp
2052 //===----------------------------------------------------------------------===//
2053
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2054 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2055 Block *defaultDestination, ValueRange defaultOperands,
2056 DenseIntElementsAttr caseValues,
2057 BlockRange caseDestinations,
2058 ArrayRef<ValueRange> caseOperands) {
2059 SmallVector<Value> flattenedCaseOperands;
2060 SmallVector<int32_t> caseOperandOffsets;
2061 int32_t offset = 0;
2062 for (ValueRange operands : caseOperands) {
2063 flattenedCaseOperands.append(operands.begin(), operands.end());
2064 caseOperandOffsets.push_back(offset);
2065 offset += operands.size();
2066 }
2067 DenseIntElementsAttr caseOperandOffsetsAttr;
2068 if (!caseOperandOffsets.empty())
2069 caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
2070
2071 build(builder, result, value, defaultOperands, flattenedCaseOperands,
2072 caseValues, caseOperandOffsetsAttr, defaultDestination,
2073 caseDestinations);
2074 }
2075
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)2076 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
2077 Block *defaultDestination, ValueRange defaultOperands,
2078 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
2079 ArrayRef<ValueRange> caseOperands) {
2080 DenseIntElementsAttr caseValuesAttr;
2081 if (!caseValues.empty()) {
2082 ShapedType caseValueType = VectorType::get(
2083 static_cast<int64_t>(caseValues.size()), value.getType());
2084 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
2085 }
2086 build(builder, result, value, defaultDestination, defaultOperands,
2087 caseValuesAttr, caseDestinations, caseOperands);
2088 }
2089
2090 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
2091 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
2092 static ParseResult
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::OperandType> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<OpAsmParser::OperandType> & caseOperands,SmallVectorImpl<Type> & caseOperandTypes,DenseIntElementsAttr & caseOperandOffsets)2093 parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
2094 Block *&defaultDestination,
2095 SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
2096 SmallVectorImpl<Type> &defaultOperandTypes,
2097 DenseIntElementsAttr &caseValues,
2098 SmallVectorImpl<Block *> &caseDestinations,
2099 SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
2100 SmallVectorImpl<Type> &caseOperandTypes,
2101 DenseIntElementsAttr &caseOperandOffsets) {
2102 if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
2103 failed(parser.parseSuccessor(defaultDestination)))
2104 return failure();
2105 if (succeeded(parser.parseOptionalLParen())) {
2106 if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
2107 failed(parser.parseColonTypeList(defaultOperandTypes)) ||
2108 failed(parser.parseRParen()))
2109 return failure();
2110 }
2111
2112 SmallVector<APInt> values;
2113 SmallVector<int32_t> offsets;
2114 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
2115 int64_t offset = 0;
2116 while (succeeded(parser.parseOptionalComma())) {
2117 int64_t value = 0;
2118 if (failed(parser.parseInteger(value)))
2119 return failure();
2120 values.push_back(APInt(bitWidth, value));
2121
2122 Block *destination;
2123 SmallVector<OpAsmParser::OperandType> operands;
2124 if (failed(parser.parseColon()) ||
2125 failed(parser.parseSuccessor(destination)))
2126 return failure();
2127 if (succeeded(parser.parseOptionalLParen())) {
2128 if (failed(parser.parseRegionArgumentList(operands)) ||
2129 failed(parser.parseColonTypeList(caseOperandTypes)) ||
2130 failed(parser.parseRParen()))
2131 return failure();
2132 }
2133 caseDestinations.push_back(destination);
2134 caseOperands.append(operands.begin(), operands.end());
2135 offsets.push_back(offset);
2136 offset += operands.size();
2137 }
2138
2139 if (values.empty())
2140 return success();
2141
2142 Builder &builder = parser.getBuilder();
2143 ShapedType caseValueType =
2144 VectorType::get(static_cast<int64_t>(values.size()), flagType);
2145 caseValues = DenseIntElementsAttr::get(caseValueType, values);
2146 caseOperandOffsets = builder.getI32VectorAttr(offsets);
2147
2148 return success();
2149 }
2150
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRange caseOperands,TypeRange caseOperandTypes,ElementsAttr caseOperandOffsets)2151 static void printSwitchOpCases(
2152 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
2153 OperandRange defaultOperands, TypeRange defaultOperandTypes,
2154 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
2155 OperandRange caseOperands, TypeRange caseOperandTypes,
2156 ElementsAttr caseOperandOffsets) {
2157 p << " default: ";
2158 p.printSuccessorAndUseList(defaultDestination, defaultOperands);
2159
2160 if (!caseValues)
2161 return;
2162
2163 for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
2164 p << ',';
2165 p.printNewline();
2166 p << " ";
2167 p << caseValues.getValue<APInt>(i).getLimitedValue();
2168 p << ": ";
2169 p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
2170 }
2171 p.printNewline();
2172 }
2173
verify(SwitchOp op)2174 static LogicalResult verify(SwitchOp op) {
2175 auto caseValues = op.case_values();
2176 auto caseDestinations = op.caseDestinations();
2177
2178 if (!caseValues && caseDestinations.empty())
2179 return success();
2180
2181 Type flagType = op.flag().getType();
2182 Type caseValueType = caseValues->getType().getElementType();
2183 if (caseValueType != flagType)
2184 return op.emitOpError()
2185 << "'flag' type (" << flagType << ") should match case value type ("
2186 << caseValueType << ")";
2187
2188 if (caseValues &&
2189 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
2190 return op.emitOpError() << "number of case values (" << caseValues->size()
2191 << ") should match number of "
2192 "case destinations ("
2193 << caseDestinations.size() << ")";
2194 return success();
2195 }
2196
getCaseOperands(unsigned index)2197 OperandRange SwitchOp::getCaseOperands(unsigned index) {
2198 return getCaseOperandsMutable(index);
2199 }
2200
getCaseOperandsMutable(unsigned index)2201 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
2202 MutableOperandRange caseOperands = caseOperandsMutable();
2203 if (!case_operand_offsets()) {
2204 assert(caseOperands.size() == 0 &&
2205 "non-empty case operands must have offsets");
2206 return caseOperands;
2207 }
2208
2209 ElementsAttr offsets = case_operand_offsets().getValue();
2210 assert(index < offsets.size() && "invalid case operand offset index");
2211
2212 int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
2213 int64_t end = index + 1 == offsets.size()
2214 ? caseOperands.size()
2215 : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
2216 return caseOperandsMutable().slice(begin, end - begin);
2217 }
2218
2219 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)2220 SwitchOp::getMutableSuccessorOperands(unsigned index) {
2221 assert(index < getNumSuccessors() && "invalid successor index");
2222 return index == 0 ? defaultOperandsMutable()
2223 : getCaseOperandsMutable(index - 1);
2224 }
2225
getSuccessorForOperands(ArrayRef<Attribute> operands)2226 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
2227 Optional<DenseIntElementsAttr> caseValues = case_values();
2228
2229 if (!caseValues)
2230 return defaultDestination();
2231
2232 SuccessorRange caseDests = caseDestinations();
2233 if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
2234 for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
2235 if (value == caseValues->getValue<IntegerAttr>(i))
2236 return caseDests[i];
2237 return defaultDestination();
2238 }
2239 return nullptr;
2240 }
2241
2242 /// switch %flag : i32, [
2243 /// default: ^bb1
2244 /// ]
2245 /// -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)2246 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
2247 PatternRewriter &rewriter) {
2248 if (!op.caseDestinations().empty())
2249 return failure();
2250
2251 rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2252 op.defaultOperands());
2253 return success();
2254 }
2255
2256 /// switch %flag : i32, [
2257 /// default: ^bb1,
2258 /// 42: ^bb1,
2259 /// 43: ^bb2
2260 /// ]
2261 /// ->
2262 /// switch %flag : i32, [
2263 /// default: ^bb1,
2264 /// 43: ^bb2
2265 /// ]
2266 static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)2267 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
2268 SmallVector<Block *> newCaseDestinations;
2269 SmallVector<ValueRange> newCaseOperands;
2270 SmallVector<APInt> newCaseValues;
2271 bool requiresChange = false;
2272 auto caseValues = op.case_values();
2273 auto caseDests = op.caseDestinations();
2274
2275 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2276 if (caseDests[i] == op.defaultDestination() &&
2277 op.getCaseOperands(i) == op.defaultOperands()) {
2278 requiresChange = true;
2279 continue;
2280 }
2281 newCaseDestinations.push_back(caseDests[i]);
2282 newCaseOperands.push_back(op.getCaseOperands(i));
2283 newCaseValues.push_back(caseValues->getValue<APInt>(i));
2284 }
2285
2286 if (!requiresChange)
2287 return failure();
2288
2289 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2290 op.defaultOperands(), newCaseValues,
2291 newCaseDestinations, newCaseOperands);
2292 return success();
2293 }
2294
2295 /// Helper for folding a switch with a constant value.
2296 /// switch %c_42 : i32, [
2297 /// default: ^bb1 ,
2298 /// 42: ^bb2,
2299 /// 43: ^bb3
2300 /// ]
2301 /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,APInt caseValue)2302 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
2303 APInt caseValue) {
2304 auto caseValues = op.case_values();
2305 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2306 if (caseValues->getValue<APInt>(i) == caseValue) {
2307 rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
2308 op.getCaseOperands(i));
2309 return;
2310 }
2311 }
2312 rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2313 op.defaultOperands());
2314 }
2315
2316 /// switch %c_42 : i32, [
2317 /// default: ^bb1,
2318 /// 42: ^bb2,
2319 /// 43: ^bb3
2320 /// ]
2321 /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)2322 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
2323 PatternRewriter &rewriter) {
2324 APInt caseValue;
2325 if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
2326 return failure();
2327
2328 foldSwitch(op, rewriter, caseValue);
2329 return success();
2330 }
2331
2332 /// switch %c_42 : i32, [
2333 /// default: ^bb1,
2334 /// 42: ^bb2,
2335 /// ]
2336 /// ^bb2:
2337 /// br ^bb3
2338 /// ->
2339 /// switch %c_42 : i32, [
2340 /// default: ^bb1,
2341 /// 42: ^bb3,
2342 /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)2343 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
2344 PatternRewriter &rewriter) {
2345 SmallVector<Block *> newCaseDests;
2346 SmallVector<ValueRange> newCaseOperands;
2347 SmallVector<SmallVector<Value>> argStorage;
2348 auto caseValues = op.case_values();
2349 auto caseDests = op.caseDestinations();
2350 bool requiresChange = false;
2351 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2352 Block *caseDest = caseDests[i];
2353 ValueRange caseOperands = op.getCaseOperands(i);
2354 argStorage.emplace_back();
2355 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
2356 requiresChange = true;
2357
2358 newCaseDests.push_back(caseDest);
2359 newCaseOperands.push_back(caseOperands);
2360 }
2361
2362 Block *defaultDest = op.defaultDestination();
2363 ValueRange defaultOperands = op.defaultOperands();
2364 argStorage.emplace_back();
2365
2366 if (succeeded(
2367 collapseBranch(defaultDest, defaultOperands, argStorage.back())))
2368 requiresChange = true;
2369
2370 if (!requiresChange)
2371 return failure();
2372
2373 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
2374 defaultOperands, caseValues.getValue(),
2375 newCaseDests, newCaseOperands);
2376 return success();
2377 }
2378
2379 /// switch %flag : i32, [
2380 /// default: ^bb1,
2381 /// 42: ^bb2,
2382 /// ]
2383 /// ^bb2:
2384 /// switch %flag : i32, [
2385 /// default: ^bb3,
2386 /// 42: ^bb4
2387 /// ]
2388 /// ->
2389 /// switch %flag : i32, [
2390 /// default: ^bb1,
2391 /// 42: ^bb2,
2392 /// ]
2393 /// ^bb2:
2394 /// br ^bb4
2395 ///
2396 /// and
2397 ///
2398 /// switch %flag : i32, [
2399 /// default: ^bb1,
2400 /// 42: ^bb2,
2401 /// ]
2402 /// ^bb2:
2403 /// switch %flag : i32, [
2404 /// default: ^bb3,
2405 /// 43: ^bb4
2406 /// ]
2407 /// ->
2408 /// switch %flag : i32, [
2409 /// default: ^bb1,
2410 /// 42: ^bb2,
2411 /// ]
2412 /// ^bb2:
2413 /// br ^bb3
2414 static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2415 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
2416 PatternRewriter &rewriter) {
2417 // Check that we have a single distinct predecessor.
2418 Block *currentBlock = op->getBlock();
2419 Block *predecessor = currentBlock->getSinglePredecessor();
2420 if (!predecessor)
2421 return failure();
2422
2423 // Check that the predecessor terminates with a switch branch to this block
2424 // and that it branches on the same condition and that this branch isn't the
2425 // default destination.
2426 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2427 if (!predSwitch || op.flag() != predSwitch.flag() ||
2428 predSwitch.defaultDestination() == currentBlock)
2429 return failure();
2430
2431 // Fold this switch to an unconditional branch.
2432 APInt caseValue;
2433 bool isDefault = true;
2434 SuccessorRange predDests = predSwitch.caseDestinations();
2435 Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
2436 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
2437 if (currentBlock == predDests[i]) {
2438 caseValue = predCaseValues->getValue<APInt>(i);
2439 isDefault = false;
2440 break;
2441 }
2442 }
2443 if (isDefault)
2444 rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
2445 op.defaultOperands());
2446 else
2447 foldSwitch(op, rewriter, caseValue);
2448 return success();
2449 }
2450
2451 /// switch %flag : i32, [
2452 /// default: ^bb1,
2453 /// 42: ^bb2
2454 /// ]
2455 /// ^bb1:
2456 /// switch %flag : i32, [
2457 /// default: ^bb3,
2458 /// 42: ^bb4,
2459 /// 43: ^bb5
2460 /// ]
2461 /// ->
2462 /// switch %flag : i32, [
2463 /// default: ^bb1,
2464 /// 42: ^bb2,
2465 /// ]
2466 /// ^bb1:
2467 /// switch %flag : i32, [
2468 /// default: ^bb3,
2469 /// 43: ^bb5
2470 /// ]
2471 static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)2472 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
2473 PatternRewriter &rewriter) {
2474 // Check that we have a single distinct predecessor.
2475 Block *currentBlock = op->getBlock();
2476 Block *predecessor = currentBlock->getSinglePredecessor();
2477 if (!predecessor)
2478 return failure();
2479
2480 // Check that the predecessor terminates with a switch branch to this block
2481 // and that it branches on the same condition and that this branch is the
2482 // default destination.
2483 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
2484 if (!predSwitch || op.flag() != predSwitch.flag() ||
2485 predSwitch.defaultDestination() != currentBlock)
2486 return failure();
2487
2488 // Delete case values that are not possible here.
2489 DenseSet<APInt> caseValuesToRemove;
2490 auto predDests = predSwitch.caseDestinations();
2491 auto predCaseValues = predSwitch.case_values();
2492 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
2493 if (currentBlock != predDests[i])
2494 caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
2495
2496 SmallVector<Block *> newCaseDestinations;
2497 SmallVector<ValueRange> newCaseOperands;
2498 SmallVector<APInt> newCaseValues;
2499 bool requiresChange = false;
2500
2501 auto caseValues = op.case_values();
2502 auto caseDests = op.caseDestinations();
2503 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
2504 if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
2505 requiresChange = true;
2506 continue;
2507 }
2508 newCaseDestinations.push_back(caseDests[i]);
2509 newCaseOperands.push_back(op.getCaseOperands(i));
2510 newCaseValues.push_back(caseValues->getValue<APInt>(i));
2511 }
2512
2513 if (!requiresChange)
2514 return failure();
2515
2516 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
2517 op.defaultOperands(), newCaseValues,
2518 newCaseDestinations, newCaseOperands);
2519 return success();
2520 }
2521
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2522 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
2523 MLIRContext *context) {
2524 results.add(&simplifySwitchWithOnlyDefault)
2525 .add(&dropSwitchCasesThatMatchDefault)
2526 .add(&simplifyConstSwitchValue)
2527 .add(&simplifyPassThroughSwitch)
2528 .add(&simplifySwitchFromSwitchOnSameCondition)
2529 .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
2530 }
2531
2532 //===----------------------------------------------------------------------===//
2533 // TruncateIOp
2534 //===----------------------------------------------------------------------===//
2535
verify(TruncateIOp op)2536 static LogicalResult verify(TruncateIOp op) {
2537 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2538 auto dstType = getElementTypeOrSelf(op.getType());
2539
2540 if (srcType.isa<IndexType>())
2541 return op.emitError() << srcType << " is not a valid operand type";
2542 if (dstType.isa<IndexType>())
2543 return op.emitError() << dstType << " is not a valid result type";
2544
2545 if (srcType.cast<IntegerType>().getWidth() <=
2546 dstType.cast<IntegerType>().getWidth())
2547 return op.emitError("operand type ")
2548 << srcType << " must be wider than result type " << dstType;
2549
2550 return success();
2551 }
2552
fold(ArrayRef<Attribute> operands)2553 OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
2554 // trunci(zexti(a)) -> a
2555 // trunci(sexti(a)) -> a
2556 if (matchPattern(getOperand(), m_Op<ZeroExtendIOp>()) ||
2557 matchPattern(getOperand(), m_Op<SignExtendIOp>()))
2558 return getOperand().getDefiningOp()->getOperand(0);
2559
2560 assert(operands.size() == 1 && "unary operation takes one operand");
2561
2562 if (!operands[0])
2563 return {};
2564
2565 if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
2566
2567 return IntegerAttr::get(
2568 getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
2569 }
2570
2571 return {};
2572 }
2573
2574 //===----------------------------------------------------------------------===//
2575 // UnsignedDivIOp
2576 //===----------------------------------------------------------------------===//
2577
fold(ArrayRef<Attribute> operands)2578 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
2579 assert(operands.size() == 2 && "binary operation takes two operands");
2580
2581 // Don't fold if it would require a division by zero.
2582 bool div0 = false;
2583 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2584 if (div0 || !b) {
2585 div0 = true;
2586 return a;
2587 }
2588 return a.udiv(b);
2589 });
2590
2591 // Fold out division by one. Assumes all tensors of all ones are splats.
2592 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2593 if (rhs.getValue() == 1)
2594 return lhs();
2595 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2596 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2597 return lhs();
2598 }
2599
2600 return div0 ? Attribute() : result;
2601 }
2602
2603 //===----------------------------------------------------------------------===//
2604 // UnsignedRemIOp
2605 //===----------------------------------------------------------------------===//
2606
fold(ArrayRef<Attribute> operands)2607 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
2608 assert(operands.size() == 2 && "remi_unsigned takes two operands");
2609
2610 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2611 if (!rhs)
2612 return {};
2613 auto rhsValue = rhs.getValue();
2614
2615 // x % 1 = 0
2616 if (rhsValue.isOneValue())
2617 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2618
2619 // Don't fold if it requires division by zero.
2620 if (rhsValue.isNullValue())
2621 return {};
2622
2623 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2624 if (!lhs)
2625 return {};
2626 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
2627 }
2628
2629 //===----------------------------------------------------------------------===//
2630 // XOrOp
2631 //===----------------------------------------------------------------------===//
2632
fold(ArrayRef<Attribute> operands)2633 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
2634 /// xor(x, 0) -> x
2635 if (matchPattern(rhs(), m_Zero()))
2636 return lhs();
2637 /// xor(x,x) -> 0
2638 if (lhs() == rhs())
2639 return Builder(getContext()).getZeroAttr(getType());
2640
2641 return constFoldBinaryOp<IntegerAttr>(operands,
2642 [](APInt a, APInt b) { return a ^ b; });
2643 }
2644
2645 namespace {
2646 /// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
2647 /// cmp ne A, B. Note that a logical not is implemented as xor 1, val.
2648 struct NotICmp : public OpRewritePattern<XOrOp> {
2649 using OpRewritePattern<XOrOp>::OpRewritePattern;
2650
matchAndRewrite__anon476530361511::NotICmp2651 LogicalResult matchAndRewrite(XOrOp op,
2652 PatternRewriter &rewriter) const override {
2653 // Commutative ops (such as xor) have the constant appear second, which
2654 // we assume here.
2655
2656 APInt constValue;
2657 if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
2658 return failure();
2659
2660 if (constValue != 1)
2661 return failure();
2662
2663 auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
2664 if (!prev)
2665 return failure();
2666
2667 switch (prev.predicate()) {
2668 case CmpIPredicate::eq:
2669 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
2670 prev.rhs());
2671 return success();
2672 case CmpIPredicate::ne:
2673 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
2674 prev.rhs());
2675 return success();
2676
2677 case CmpIPredicate::slt:
2678 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
2679 prev.rhs());
2680 return success();
2681 case CmpIPredicate::sle:
2682 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
2683 prev.rhs());
2684 return success();
2685 case CmpIPredicate::sgt:
2686 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
2687 prev.rhs());
2688 return success();
2689 case CmpIPredicate::sge:
2690 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
2691 prev.rhs());
2692 return success();
2693
2694 case CmpIPredicate::ult:
2695 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
2696 prev.rhs());
2697 return success();
2698 case CmpIPredicate::ule:
2699 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
2700 prev.rhs());
2701 return success();
2702 case CmpIPredicate::ugt:
2703 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
2704 prev.rhs());
2705 return success();
2706 case CmpIPredicate::uge:
2707 rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
2708 prev.rhs());
2709 return success();
2710 }
2711 return failure();
2712 }
2713 };
2714 } // namespace
2715
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2716 void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2717 MLIRContext *context) {
2718 results.insert<NotICmp>(context);
2719 }
2720
2721 //===----------------------------------------------------------------------===//
2722 // ZeroExtendIOp
2723 //===----------------------------------------------------------------------===//
2724
verify(ZeroExtendIOp op)2725 static LogicalResult verify(ZeroExtendIOp op) {
2726 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2727 auto dstType = getElementTypeOrSelf(op.getType());
2728
2729 if (srcType.isa<IndexType>())
2730 return op.emitError() << srcType << " is not a valid operand type";
2731 if (dstType.isa<IndexType>())
2732 return op.emitError() << dstType << " is not a valid result type";
2733
2734 if (srcType.cast<IntegerType>().getWidth() >=
2735 dstType.cast<IntegerType>().getWidth())
2736 return op.emitError("result type ")
2737 << dstType << " must be wider than operand type " << srcType;
2738
2739 return success();
2740 }
2741
2742 //===----------------------------------------------------------------------===//
2743 // TableGen'd op method definitions
2744 //===----------------------------------------------------------------------===//
2745
2746 #define GET_OP_CLASSES
2747 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
2748