1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10
11 #include "mlir/Dialect/CommonFolders.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/Function.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/StandardTypes.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28
29 // Pull in all enum type definitions and utility function declarations.
30 #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc"
31
32 using namespace mlir;
33
34 //===----------------------------------------------------------------------===//
35 // StandardOpsDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 namespace {
38 /// This class defines the interface for handling inlining with standard
39 /// operations.
40 struct StdInlinerInterface : public DialectInlinerInterface {
41 using DialectInlinerInterface::DialectInlinerInterface;
42
43 //===--------------------------------------------------------------------===//
44 // Analysis Hooks
45 //===--------------------------------------------------------------------===//
46
47 /// All operations within standard ops can be inlined.
isLegalToInline__anone4b94cca0111::StdInlinerInterface48 bool isLegalToInline(Operation *, Region *,
49 BlockAndValueMapping &) const final {
50 return true;
51 }
52
53 //===--------------------------------------------------------------------===//
54 // Transformation Hooks
55 //===--------------------------------------------------------------------===//
56
57 /// Handle the given inlined terminator by replacing it with a new operation
58 /// as necessary.
handleTerminator__anone4b94cca0111::StdInlinerInterface59 void handleTerminator(Operation *op, Block *newDest) const final {
60 // Only "std.return" needs to be handled here.
61 auto returnOp = dyn_cast<ReturnOp>(op);
62 if (!returnOp)
63 return;
64
65 // Replace the return with a branch to the dest.
66 OpBuilder builder(op);
67 builder.create<BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
68 op->erase();
69 }
70
71 /// Handle the given inlined terminator by replacing it with a new operation
72 /// as necessary.
handleTerminator__anone4b94cca0111::StdInlinerInterface73 void handleTerminator(Operation *op,
74 ArrayRef<Value> valuesToRepl) const final {
75 // Only "std.return" needs to be handled here.
76 auto returnOp = cast<ReturnOp>(op);
77
78 // Replace the values directly with the return operands.
79 assert(returnOp.getNumOperands() == valuesToRepl.size());
80 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
81 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
82 }
83 };
84 } // end anonymous namespace
85
86 //===----------------------------------------------------------------------===//
87 // StandardOpsDialect
88 //===----------------------------------------------------------------------===//
89
90 /// A custom unary operation printer that omits the "std." prefix from the
91 /// operation names.
printStandardUnaryOp(Operation * op,OpAsmPrinter & p)92 static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
93 assert(op->getNumOperands() == 1 && "unary op should have one operand");
94 assert(op->getNumResults() == 1 && "unary op should have one result");
95
96 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
97 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
98 << op->getOperand(0);
99 p.printOptionalAttrDict(op->getAttrs());
100 p << " : " << op->getOperand(0).getType();
101 }
102
103 /// A custom binary operation printer that omits the "std." prefix from the
104 /// operation names.
printStandardBinaryOp(Operation * op,OpAsmPrinter & p)105 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
106 assert(op->getNumOperands() == 2 && "binary op should have two operands");
107 assert(op->getNumResults() == 1 && "binary op should have one result");
108
109 // If not all the operand and result types are the same, just use the
110 // generic assembly form to avoid omitting information in printing.
111 auto resultType = op->getResult(0).getType();
112 if (op->getOperand(0).getType() != resultType ||
113 op->getOperand(1).getType() != resultType) {
114 p.printGenericOp(op);
115 return;
116 }
117
118 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
119 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
120 << op->getOperand(0) << ", " << op->getOperand(1);
121 p.printOptionalAttrDict(op->getAttrs());
122
123 // Now we can output only one type for all operands and the result.
124 p << " : " << op->getResult(0).getType();
125 }
126
127 /// A custom cast operation printer that omits the "std." prefix from the
128 /// operation names.
printStandardCastOp(Operation * op,OpAsmPrinter & p)129 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
130 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
131 p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
132 << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
133 << op->getResult(0).getType();
134 }
135
136 /// A custom cast operation verifier.
137 template <typename T>
verifyCastOp(T op)138 static LogicalResult verifyCastOp(T op) {
139 auto opType = op.getOperand().getType();
140 auto resType = op.getType();
141 if (!T::areCastCompatible(opType, resType))
142 return op.emitError("operand type ") << opType << " and result type "
143 << resType << " are cast incompatible";
144
145 return success();
146 }
147
StandardOpsDialect(MLIRContext * context)148 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
149 : Dialect(getDialectNamespace(), context) {
150 addOperations<DmaStartOp, DmaWaitOp,
151 #define GET_OP_LIST
152 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
153 >();
154 addInterfaces<StdInlinerInterface>();
155 }
156
157 /// Materialize a single constant operation from a given attribute value with
158 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)159 Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
160 Attribute value, Type type,
161 Location loc) {
162 return builder.create<ConstantOp>(loc, type, value);
163 }
164
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & p)165 void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
166 Operation::operand_iterator end,
167 unsigned numDims, OpAsmPrinter &p) {
168 Operation::operand_range operands(begin, end);
169 p << '(' << operands.take_front(numDims) << ')';
170 if (operands.size() != numDims)
171 p << '[' << operands.drop_front(numDims) << ']';
172 }
173
174 // Parses dimension and symbol list, and sets 'numDims' to the number of
175 // dimension operands parsed.
176 // Returns 'false' on success and 'true' on error.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)177 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
178 SmallVectorImpl<Value> &operands,
179 unsigned &numDims) {
180 SmallVector<OpAsmParser::OperandType, 8> opInfos;
181 if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
182 return failure();
183 // Store number of dimensions for validation by caller.
184 numDims = opInfos.size();
185
186 // Parse the optional symbol operands.
187 auto indexTy = parser.getBuilder().getIndexType();
188 if (parser.parseOperandList(opInfos,
189 OpAsmParser::Delimiter::OptionalSquare) ||
190 parser.resolveOperands(opInfos, indexTy, operands))
191 return failure();
192 return success();
193 }
194
195 /// Matches a ConstantIndexOp.
196 /// TODO: This should probably just be a general matcher that uses m_Constant
197 /// and checks the operation for an index type.
m_ConstantIndex()198 static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
199 return detail::op_matcher<ConstantIndexOp>();
200 }
201
202 //===----------------------------------------------------------------------===//
203 // Common canonicalization pattern support logic
204 //===----------------------------------------------------------------------===//
205
206 /// This is a common class used for patterns of the form
207 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast
208 /// into the root operation directly.
foldMemRefCast(Operation * op)209 static LogicalResult foldMemRefCast(Operation *op) {
210 bool folded = false;
211 for (OpOperand &operand : op->getOpOperands()) {
212 auto cast = operand.get().getDefiningOp<MemRefCastOp>();
213 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
214 operand.set(cast.getOperand());
215 folded = true;
216 }
217 }
218 return success(folded);
219 }
220
221 //===----------------------------------------------------------------------===//
222 // AddFOp
223 //===----------------------------------------------------------------------===//
224
fold(ArrayRef<Attribute> operands)225 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
226 return constFoldBinaryOp<FloatAttr>(
227 operands, [](APFloat a, APFloat b) { return a + b; });
228 }
229
230 //===----------------------------------------------------------------------===//
231 // AddIOp
232 //===----------------------------------------------------------------------===//
233
fold(ArrayRef<Attribute> operands)234 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
235 /// addi(x, 0) -> x
236 if (matchPattern(rhs(), m_Zero()))
237 return lhs();
238
239 return constFoldBinaryOp<IntegerAttr>(operands,
240 [](APInt a, APInt b) { return a + b; });
241 }
242
243 //===----------------------------------------------------------------------===//
244 // AllocOp / AllocaOp
245 //===----------------------------------------------------------------------===//
246
247 template <typename AllocLikeOp>
printAllocLikeOp(OpAsmPrinter & p,AllocLikeOp op,StringRef name)248 static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) {
249 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
250 "applies to only alloc or alloca");
251 p << name;
252
253 // Print dynamic dimension operands.
254 MemRefType type = op.getType();
255 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
256 type.getNumDynamicDims(), p);
257 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
258 p << " : " << type;
259 }
260
print(OpAsmPrinter & p,AllocOp op)261 static void print(OpAsmPrinter &p, AllocOp op) {
262 printAllocLikeOp(p, op, "alloc");
263 }
264
print(OpAsmPrinter & p,AllocaOp op)265 static void print(OpAsmPrinter &p, AllocaOp op) {
266 printAllocLikeOp(p, op, "alloca");
267 }
268
parseAllocLikeOp(OpAsmParser & parser,OperationState & result)269 static ParseResult parseAllocLikeOp(OpAsmParser &parser,
270 OperationState &result) {
271 MemRefType type;
272
273 // Parse the dimension operands and optional symbol operands, followed by a
274 // memref type.
275 unsigned numDimOperands;
276 if (parseDimAndSymbolList(parser, result.operands, numDimOperands) ||
277 parser.parseOptionalAttrDict(result.attributes) ||
278 parser.parseColonType(type))
279 return failure();
280
281 // Check numDynamicDims against number of question marks in memref type.
282 // Note: this check remains here (instead of in verify()), because the
283 // partition between dim operands and symbol operands is lost after parsing.
284 // Verification still checks that the total number of operands matches
285 // the number of symbols in the affine map, plus the number of dynamic
286 // dimensions in the memref.
287 if (numDimOperands != type.getNumDynamicDims())
288 return parser.emitError(parser.getNameLoc())
289 << "dimension operand count does not equal memref dynamic dimension "
290 "count";
291 result.types.push_back(type);
292 return success();
293 }
294
295 template <typename AllocLikeOp>
verify(AllocLikeOp op)296 static LogicalResult verify(AllocLikeOp op) {
297 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
298 "applies to only alloc or alloca");
299 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
300 if (!memRefType)
301 return op.emitOpError("result must be a memref");
302
303 unsigned numSymbols = 0;
304 if (!memRefType.getAffineMaps().empty()) {
305 // Store number of symbols used in affine map (used in subsequent check).
306 AffineMap affineMap = memRefType.getAffineMaps()[0];
307 numSymbols = affineMap.getNumSymbols();
308 }
309
310 // Check that the total number of operands matches the number of symbols in
311 // the affine map, plus the number of dynamic dimensions specified in the
312 // memref type.
313 unsigned numDynamicDims = memRefType.getNumDynamicDims();
314 if (op.getNumOperands() != numDynamicDims + numSymbols)
315 return op.emitOpError(
316 "operand count does not equal dimension plus symbol operand count");
317
318 // Verify that all operands are of type Index.
319 for (auto operandType : op.getOperandTypes())
320 if (!operandType.isIndex())
321 return op.emitOpError("requires operands to be of type Index");
322
323 if (std::is_same<AllocLikeOp, AllocOp>::value)
324 return success();
325
326 // An alloca op needs to have an ancestor with an allocation scope trait.
327 if (!op.template getParentWithTrait<OpTrait::AutomaticAllocationScope>())
328 return op.emitOpError(
329 "requires an ancestor op with AutomaticAllocationScope trait");
330
331 return success();
332 }
333
334 namespace {
335 /// Fold constant dimensions into an alloc like operation.
336 template <typename AllocLikeOp>
337 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
338 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
339
matchAndRewrite__anone4b94cca0411::SimplifyAllocConst340 LogicalResult matchAndRewrite(AllocLikeOp alloc,
341 PatternRewriter &rewriter) const override {
342 // Check to see if any dimensions operands are constants. If so, we can
343 // substitute and drop them.
344 if (llvm::none_of(alloc.getOperands(), [](Value operand) {
345 return matchPattern(operand, m_ConstantIndex());
346 }))
347 return failure();
348
349 auto memrefType = alloc.getType();
350
351 // Ok, we have one or more constant operands. Collect the non-constant ones
352 // and keep track of the resultant memref type to build.
353 SmallVector<int64_t, 4> newShapeConstants;
354 newShapeConstants.reserve(memrefType.getRank());
355 SmallVector<Value, 4> newOperands;
356
357 unsigned dynamicDimPos = 0;
358 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
359 int64_t dimSize = memrefType.getDimSize(dim);
360 // If this is already static dimension, keep it.
361 if (dimSize != -1) {
362 newShapeConstants.push_back(dimSize);
363 continue;
364 }
365 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
366 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
367 // Dynamic shape dimension will be folded.
368 newShapeConstants.push_back(constantIndexOp.getValue());
369 } else {
370 // Dynamic shape dimension not folded; copy operand from old memref.
371 newShapeConstants.push_back(-1);
372 newOperands.push_back(alloc.getOperand(dynamicDimPos));
373 }
374 dynamicDimPos++;
375 }
376
377 // Create new memref type (which will have fewer dynamic dimensions).
378 MemRefType newMemRefType =
379 MemRefType::Builder(memrefType).setShape(newShapeConstants);
380 assert(static_cast<int64_t>(newOperands.size()) ==
381 newMemRefType.getNumDynamicDims());
382
383 // Create and insert the alloc op for the new memref.
384 auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
385 newOperands, IntegerAttr());
386 // Insert a cast so we have the same type as the old alloc.
387 auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
388 alloc.getType());
389
390 rewriter.replaceOp(alloc, {resultCast});
391 return success();
392 }
393 };
394
395 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
396 /// but can still be deleted if it has zero uses.
397 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
398 using OpRewritePattern<AllocOp>::OpRewritePattern;
399
matchAndRewrite__anone4b94cca0411::SimplifyDeadAlloc400 LogicalResult matchAndRewrite(AllocOp alloc,
401 PatternRewriter &rewriter) const override {
402 if (alloc.use_empty()) {
403 rewriter.eraseOp(alloc);
404 return success();
405 }
406 return failure();
407 }
408 };
409 } // end anonymous namespace.
410
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)411 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
412 MLIRContext *context) {
413 results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
414 }
415
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)416 void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
417 MLIRContext *context) {
418 results.insert<SimplifyAllocConst<AllocaOp>>(context);
419 }
420
421 //===----------------------------------------------------------------------===//
422 // AndOp
423 //===----------------------------------------------------------------------===//
424
fold(ArrayRef<Attribute> operands)425 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
426 /// and(x, 0) -> 0
427 if (matchPattern(rhs(), m_Zero()))
428 return rhs();
429 /// and(x, allOnes) -> x
430 APInt intValue;
431 if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
432 intValue.isAllOnesValue())
433 return lhs();
434 /// and(x,x) -> x
435 if (lhs() == rhs())
436 return rhs();
437
438 return constFoldBinaryOp<IntegerAttr>(operands,
439 [](APInt a, APInt b) { return a & b; });
440 }
441
442 //===----------------------------------------------------------------------===//
443 // AssertOp
444 //===----------------------------------------------------------------------===//
445
446 namespace {
447 struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
448 using OpRewritePattern<AssertOp>::OpRewritePattern;
449
matchAndRewrite__anone4b94cca0711::EraseRedundantAssertions450 LogicalResult matchAndRewrite(AssertOp op,
451 PatternRewriter &rewriter) const override {
452 // Erase assertion if argument is constant true.
453 if (matchPattern(op.arg(), m_One())) {
454 rewriter.eraseOp(op);
455 return success();
456 }
457 return failure();
458 }
459 };
460 } // namespace
461
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)462 void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
463 MLIRContext *context) {
464 patterns.insert<EraseRedundantAssertions>(context);
465 }
466
467 //===----------------------------------------------------------------------===//
468 // AssumeAlignmentOp
469 //===----------------------------------------------------------------------===//
470
verify(AssumeAlignmentOp op)471 static LogicalResult verify(AssumeAlignmentOp op) {
472 unsigned alignment = op.alignment().getZExtValue();
473 if (!llvm::isPowerOf2_32(alignment))
474 return op.emitOpError("alignment must be power of 2");
475 return success();
476 }
477
478 //===----------------------------------------------------------------------===//
479 // AtomicRMWOp
480 //===----------------------------------------------------------------------===//
481
verify(AtomicRMWOp op)482 static LogicalResult verify(AtomicRMWOp op) {
483 if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
484 return op.emitOpError(
485 "expects the number of subscripts to be equal to memref rank");
486 switch (op.kind()) {
487 case AtomicRMWKind::addf:
488 case AtomicRMWKind::maxf:
489 case AtomicRMWKind::minf:
490 case AtomicRMWKind::mulf:
491 if (!op.value().getType().isa<FloatType>())
492 return op.emitOpError()
493 << "with kind '" << stringifyAtomicRMWKind(op.kind())
494 << "' expects a floating-point type";
495 break;
496 case AtomicRMWKind::addi:
497 case AtomicRMWKind::maxs:
498 case AtomicRMWKind::maxu:
499 case AtomicRMWKind::mins:
500 case AtomicRMWKind::minu:
501 case AtomicRMWKind::muli:
502 if (!op.value().getType().isa<IntegerType>())
503 return op.emitOpError()
504 << "with kind '" << stringifyAtomicRMWKind(op.kind())
505 << "' expects an integer type";
506 break;
507 default:
508 break;
509 }
510 return success();
511 }
512
513 //===----------------------------------------------------------------------===//
514 // GenericAtomicRMWOp
515 //===----------------------------------------------------------------------===//
516
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)517 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
518 Value memref, ValueRange ivs) {
519 result.addOperands(memref);
520 result.addOperands(ivs);
521
522 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
523 Type elementType = memrefType.getElementType();
524 result.addTypes(elementType);
525
526 Region *bodyRegion = result.addRegion();
527 bodyRegion->push_back(new Block());
528 bodyRegion->addArgument(elementType);
529 }
530 }
531
verify(GenericAtomicRMWOp op)532 static LogicalResult verify(GenericAtomicRMWOp op) {
533 auto &body = op.body();
534 if (body.getNumArguments() != 1)
535 return op.emitOpError("expected single number of entry block arguments");
536
537 if (op.getResult().getType() != body.getArgument(0).getType())
538 return op.emitOpError(
539 "expected block argument of the same type result type");
540
541 bool hasSideEffects =
542 body.walk([&](Operation *nestedOp) {
543 if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
544 return WalkResult::advance();
545 nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
546 "only operations with no side effects");
547 return WalkResult::interrupt();
548 })
549 .wasInterrupted();
550 return hasSideEffects ? failure() : success();
551 }
552
parseGenericAtomicRMWOp(OpAsmParser & parser,OperationState & result)553 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
554 OperationState &result) {
555 OpAsmParser::OperandType memref;
556 Type memrefType;
557 SmallVector<OpAsmParser::OperandType, 4> ivs;
558
559 Type indexType = parser.getBuilder().getIndexType();
560 if (parser.parseOperand(memref) ||
561 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
562 parser.parseColonType(memrefType) ||
563 parser.resolveOperand(memref, memrefType, result.operands) ||
564 parser.resolveOperands(ivs, indexType, result.operands))
565 return failure();
566
567 Region *body = result.addRegion();
568 if (parser.parseRegion(*body, llvm::None, llvm::None))
569 return failure();
570 result.types.push_back(memrefType.cast<MemRefType>().getElementType());
571 return success();
572 }
573
print(OpAsmPrinter & p,GenericAtomicRMWOp op)574 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
575 p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
576 << "] : " << op.memref().getType();
577 p.printRegion(op.body());
578 p.printOptionalAttrDict(op.getAttrs());
579 }
580
581 //===----------------------------------------------------------------------===//
582 // AtomicYieldOp
583 //===----------------------------------------------------------------------===//
584
verify(AtomicYieldOp op)585 static LogicalResult verify(AtomicYieldOp op) {
586 Type parentType = op.getParentOp()->getResultTypes().front();
587 Type resultType = op.result().getType();
588 if (parentType != resultType)
589 return op.emitOpError() << "types mismatch between yield op: " << resultType
590 << " and its parent: " << parentType;
591 return success();
592 }
593
594 //===----------------------------------------------------------------------===//
595 // BranchOp
596 //===----------------------------------------------------------------------===//
597
598 /// Given a successor, try to collapse it to a new destination if it only
599 /// contains a passthrough unconditional branch. If the successor is
600 /// collapsable, `successor` and `successorOperands` are updated to reference
601 /// the new destination and values. `argStorage` is an optional storage to use
602 /// if operands to the collapsed successor need to be remapped.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)603 static LogicalResult collapseBranch(Block *&successor,
604 ValueRange &successorOperands,
605 SmallVectorImpl<Value> &argStorage) {
606 // Check that the successor only contains a unconditional branch.
607 if (std::next(successor->begin()) != successor->end())
608 return failure();
609 // Check that the terminator is an unconditional branch.
610 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
611 if (!successorBranch)
612 return failure();
613 // Check that the arguments are only used within the terminator.
614 for (BlockArgument arg : successor->getArguments()) {
615 for (Operation *user : arg.getUsers())
616 if (user != successorBranch)
617 return failure();
618 }
619 // Don't try to collapse branches to infinite loops.
620 Block *successorDest = successorBranch.getDest();
621 if (successorDest == successor)
622 return failure();
623
624 // Update the operands to the successor. If the branch parent has no
625 // arguments, we can use the branch operands directly.
626 OperandRange operands = successorBranch.getOperands();
627 if (successor->args_empty()) {
628 successor = successorDest;
629 successorOperands = operands;
630 return success();
631 }
632
633 // Otherwise, we need to remap any argument operands.
634 for (Value operand : operands) {
635 BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
636 if (argOperand && argOperand.getOwner() == successor)
637 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
638 else
639 argStorage.push_back(operand);
640 }
641 successor = successorDest;
642 successorOperands = argStorage;
643 return success();
644 }
645
646 namespace {
647 /// Simplify a branch to a block that has a single predecessor. This effectively
648 /// merges the two blocks.
649 struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
650 using OpRewritePattern<BranchOp>::OpRewritePattern;
651
matchAndRewrite__anone4b94cca0911::SimplifyBrToBlockWithSinglePred652 LogicalResult matchAndRewrite(BranchOp op,
653 PatternRewriter &rewriter) const override {
654 // Check that the successor block has a single predecessor.
655 Block *succ = op.getDest();
656 Block *opParent = op.getOperation()->getBlock();
657 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
658 return failure();
659
660 // Merge the successor into the current block and erase the branch.
661 rewriter.mergeBlocks(succ, opParent, op.getOperands());
662 rewriter.eraseOp(op);
663 return success();
664 }
665 };
666
667 /// br ^bb1
668 /// ^bb1
669 /// br ^bbN(...)
670 ///
671 /// -> br ^bbN(...)
672 ///
673 struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> {
674 using OpRewritePattern<BranchOp>::OpRewritePattern;
675
matchAndRewrite__anone4b94cca0911::SimplifyPassThroughBr676 LogicalResult matchAndRewrite(BranchOp op,
677 PatternRewriter &rewriter) const override {
678 Block *dest = op.getDest();
679 ValueRange destOperands = op.getOperands();
680 SmallVector<Value, 4> destOperandStorage;
681
682 // Try to collapse the successor if it points somewhere other than this
683 // block.
684 if (dest == op.getOperation()->getBlock() ||
685 failed(collapseBranch(dest, destOperands, destOperandStorage)))
686 return failure();
687
688 // Create a new branch with the collapsed successor.
689 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
690 return success();
691 }
692 };
693 } // end anonymous namespace.
694
getDest()695 Block *BranchOp::getDest() { return getSuccessor(); }
696
setDest(Block * block)697 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
698
eraseOperand(unsigned index)699 void BranchOp::eraseOperand(unsigned index) {
700 getOperation()->eraseOperand(index);
701 }
702
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)703 void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
704 MLIRContext *context) {
705 results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
706 context);
707 }
708
709 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)710 BranchOp::getMutableSuccessorOperands(unsigned index) {
711 assert(index == 0 && "invalid successor index");
712 return destOperandsMutable();
713 }
714
getSuccessorForOperands(ArrayRef<Attribute>)715 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
716
717 //===----------------------------------------------------------------------===//
718 // CallOp
719 //===----------------------------------------------------------------------===//
720
verify(CallOp op)721 static LogicalResult verify(CallOp op) {
722 // Check that the callee attribute was specified.
723 auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
724 if (!fnAttr)
725 return op.emitOpError("requires a 'callee' symbol reference attribute");
726 auto fn =
727 op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
728 if (!fn)
729 return op.emitOpError() << "'" << fnAttr.getValue()
730 << "' does not reference a valid function";
731
732 // Verify that the operand and result types match the callee.
733 auto fnType = fn.getType();
734 if (fnType.getNumInputs() != op.getNumOperands())
735 return op.emitOpError("incorrect number of operands for callee");
736
737 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
738 if (op.getOperand(i).getType() != fnType.getInput(i))
739 return op.emitOpError("operand type mismatch");
740
741 if (fnType.getNumResults() != op.getNumResults())
742 return op.emitOpError("incorrect number of results for callee");
743
744 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
745 if (op.getResult(i).getType() != fnType.getResult(i))
746 return op.emitOpError("result type mismatch");
747
748 return success();
749 }
750
getCalleeType()751 FunctionType CallOp::getCalleeType() {
752 SmallVector<Type, 8> argTypes(getOperandTypes());
753 return FunctionType::get(argTypes, getResultTypes(), getContext());
754 }
755
756 //===----------------------------------------------------------------------===//
757 // CallIndirectOp
758 //===----------------------------------------------------------------------===//
759 namespace {
760 /// Fold indirect calls that have a constant function as the callee operand.
761 struct SimplifyIndirectCallWithKnownCallee
762 : public OpRewritePattern<CallIndirectOp> {
763 using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
764
matchAndRewrite__anone4b94cca0a11::SimplifyIndirectCallWithKnownCallee765 LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
766 PatternRewriter &rewriter) const override {
767 // Check that the callee is a constant callee.
768 SymbolRefAttr calledFn;
769 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
770 return failure();
771
772 // Replace with a direct call.
773 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
774 indirectCall.getResultTypes(),
775 indirectCall.getArgOperands());
776 return success();
777 }
778 };
779 } // end anonymous namespace.
780
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)781 void CallIndirectOp::getCanonicalizationPatterns(
782 OwningRewritePatternList &results, MLIRContext *context) {
783 results.insert<SimplifyIndirectCallWithKnownCallee>(context);
784 }
785
786 //===----------------------------------------------------------------------===//
787 // General helpers for comparison ops
788 //===----------------------------------------------------------------------===//
789
790 // Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)791 static Type getI1SameShape(Type type) {
792 auto i1Type = IntegerType::get(1, type.getContext());
793 if (auto tensorType = type.dyn_cast<RankedTensorType>())
794 return RankedTensorType::get(tensorType.getShape(), i1Type);
795 if (type.isa<UnrankedTensorType>())
796 return UnrankedTensorType::get(i1Type);
797 if (auto vectorType = type.dyn_cast<VectorType>())
798 return VectorType::get(vectorType.getShape(), i1Type);
799 return i1Type;
800 }
801
802 //===----------------------------------------------------------------------===//
803 // CmpIOp
804 //===----------------------------------------------------------------------===//
805
buildCmpIOp(OpBuilder & build,OperationState & result,CmpIPredicate predicate,Value lhs,Value rhs)806 static void buildCmpIOp(OpBuilder &build, OperationState &result,
807 CmpIPredicate predicate, Value lhs, Value rhs) {
808 result.addOperands({lhs, rhs});
809 result.types.push_back(getI1SameShape(lhs.getType()));
810 result.addAttribute(CmpIOp::getPredicateAttrName(),
811 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
812 }
813
814 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
815 // comparison predicates.
applyCmpPredicate(CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)816 bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
817 const APInt &rhs) {
818 switch (predicate) {
819 case CmpIPredicate::eq:
820 return lhs.eq(rhs);
821 case CmpIPredicate::ne:
822 return lhs.ne(rhs);
823 case CmpIPredicate::slt:
824 return lhs.slt(rhs);
825 case CmpIPredicate::sle:
826 return lhs.sle(rhs);
827 case CmpIPredicate::sgt:
828 return lhs.sgt(rhs);
829 case CmpIPredicate::sge:
830 return lhs.sge(rhs);
831 case CmpIPredicate::ult:
832 return lhs.ult(rhs);
833 case CmpIPredicate::ule:
834 return lhs.ule(rhs);
835 case CmpIPredicate::ugt:
836 return lhs.ugt(rhs);
837 case CmpIPredicate::uge:
838 return lhs.uge(rhs);
839 }
840 llvm_unreachable("unknown comparison predicate");
841 }
842
843 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)844 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
845 assert(operands.size() == 2 && "cmpi takes two arguments");
846
847 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
848 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
849 if (!lhs || !rhs)
850 return {};
851
852 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
853 return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
854 }
855
856 //===----------------------------------------------------------------------===//
857 // CmpFOp
858 //===----------------------------------------------------------------------===//
859
buildCmpFOp(OpBuilder & build,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)860 static void buildCmpFOp(OpBuilder &build, OperationState &result,
861 CmpFPredicate predicate, Value lhs, Value rhs) {
862 result.addOperands({lhs, rhs});
863 result.types.push_back(getI1SameShape(lhs.getType()));
864 result.addAttribute(CmpFOp::getPredicateAttrName(),
865 build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
866 }
867
868 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
869 /// comparison predicates.
applyCmpPredicate(CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)870 bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
871 const APFloat &rhs) {
872 auto cmpResult = lhs.compare(rhs);
873 switch (predicate) {
874 case CmpFPredicate::AlwaysFalse:
875 return false;
876 case CmpFPredicate::OEQ:
877 return cmpResult == APFloat::cmpEqual;
878 case CmpFPredicate::OGT:
879 return cmpResult == APFloat::cmpGreaterThan;
880 case CmpFPredicate::OGE:
881 return cmpResult == APFloat::cmpGreaterThan ||
882 cmpResult == APFloat::cmpEqual;
883 case CmpFPredicate::OLT:
884 return cmpResult == APFloat::cmpLessThan;
885 case CmpFPredicate::OLE:
886 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
887 case CmpFPredicate::ONE:
888 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
889 case CmpFPredicate::ORD:
890 return cmpResult != APFloat::cmpUnordered;
891 case CmpFPredicate::UEQ:
892 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
893 case CmpFPredicate::UGT:
894 return cmpResult == APFloat::cmpUnordered ||
895 cmpResult == APFloat::cmpGreaterThan;
896 case CmpFPredicate::UGE:
897 return cmpResult == APFloat::cmpUnordered ||
898 cmpResult == APFloat::cmpGreaterThan ||
899 cmpResult == APFloat::cmpEqual;
900 case CmpFPredicate::ULT:
901 return cmpResult == APFloat::cmpUnordered ||
902 cmpResult == APFloat::cmpLessThan;
903 case CmpFPredicate::ULE:
904 return cmpResult == APFloat::cmpUnordered ||
905 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
906 case CmpFPredicate::UNE:
907 return cmpResult != APFloat::cmpEqual;
908 case CmpFPredicate::UNO:
909 return cmpResult == APFloat::cmpUnordered;
910 case CmpFPredicate::AlwaysTrue:
911 return true;
912 }
913 llvm_unreachable("unknown comparison predicate");
914 }
915
916 // Constant folding hook for comparisons.
fold(ArrayRef<Attribute> operands)917 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
918 assert(operands.size() == 2 && "cmpf takes two arguments");
919
920 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
921 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
922
923 // TODO: We could actually do some intelligent things if we know only one
924 // of the operands, but it's inf or nan.
925 if (!lhs || !rhs)
926 return {};
927
928 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
929 return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
930 }
931
932 //===----------------------------------------------------------------------===//
933 // CondBranchOp
934 //===----------------------------------------------------------------------===//
935
936 namespace {
937 /// cond_br true, ^bb1, ^bb2
938 /// -> br ^bb1
939 /// cond_br false, ^bb1, ^bb2
940 /// -> br ^bb2
941 ///
942 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
943 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
944
matchAndRewrite__anone4b94cca0b11::SimplifyConstCondBranchPred945 LogicalResult matchAndRewrite(CondBranchOp condbr,
946 PatternRewriter &rewriter) const override {
947 if (matchPattern(condbr.getCondition(), m_NonZero())) {
948 // True branch taken.
949 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
950 condbr.getTrueOperands());
951 return success();
952 } else if (matchPattern(condbr.getCondition(), m_Zero())) {
953 // False branch taken.
954 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
955 condbr.getFalseOperands());
956 return success();
957 }
958 return failure();
959 }
960 };
961
962 /// cond_br %cond, ^bb1, ^bb2
963 /// ^bb1
964 /// br ^bbN(...)
965 /// ^bb2
966 /// br ^bbK(...)
967 ///
968 /// -> cond_br %cond, ^bbN(...), ^bbK(...)
969 ///
970 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
971 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
972
matchAndRewrite__anone4b94cca0b11::SimplifyPassThroughCondBranch973 LogicalResult matchAndRewrite(CondBranchOp condbr,
974 PatternRewriter &rewriter) const override {
975 Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest();
976 ValueRange trueDestOperands = condbr.getTrueOperands();
977 ValueRange falseDestOperands = condbr.getFalseOperands();
978 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
979
980 // Try to collapse one of the current successors.
981 LogicalResult collapsedTrue =
982 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
983 LogicalResult collapsedFalse =
984 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
985 if (failed(collapsedTrue) && failed(collapsedFalse))
986 return failure();
987
988 // Create a new branch with the collapsed successors.
989 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
990 trueDest, trueDestOperands,
991 falseDest, falseDestOperands);
992 return success();
993 }
994 };
995
996 /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
997 /// -> br ^bb1(A, ..., N)
998 ///
999 /// cond_br %cond, ^bb1(A), ^bb1(B)
1000 /// -> %select = select %cond, A, B
1001 /// br ^bb1(%select)
1002 ///
1003 struct SimplifyCondBranchIdenticalSuccessors
1004 : public OpRewritePattern<CondBranchOp> {
1005 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
1006
matchAndRewrite__anone4b94cca0b11::SimplifyCondBranchIdenticalSuccessors1007 LogicalResult matchAndRewrite(CondBranchOp condbr,
1008 PatternRewriter &rewriter) const override {
1009 // Check that the true and false destinations are the same and have the same
1010 // operands.
1011 Block *trueDest = condbr.trueDest();
1012 if (trueDest != condbr.falseDest())
1013 return failure();
1014
1015 // If all of the operands match, no selects need to be generated.
1016 OperandRange trueOperands = condbr.getTrueOperands();
1017 OperandRange falseOperands = condbr.getFalseOperands();
1018 if (trueOperands == falseOperands) {
1019 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
1020 return success();
1021 }
1022
1023 // Otherwise, if the current block is the only predecessor insert selects
1024 // for any mismatched branch operands.
1025 if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
1026 return failure();
1027
1028 // Generate a select for any operands that differ between the two.
1029 SmallVector<Value, 8> mergedOperands;
1030 mergedOperands.reserve(trueOperands.size());
1031 Value condition = condbr.getCondition();
1032 for (auto it : llvm::zip(trueOperands, falseOperands)) {
1033 if (std::get<0>(it) == std::get<1>(it))
1034 mergedOperands.push_back(std::get<0>(it));
1035 else
1036 mergedOperands.push_back(rewriter.create<SelectOp>(
1037 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
1038 }
1039
1040 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
1041 return success();
1042 }
1043 };
1044 } // end anonymous namespace
1045
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1046 void CondBranchOp::getCanonicalizationPatterns(
1047 OwningRewritePatternList &results, MLIRContext *context) {
1048 results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
1049 SimplifyCondBranchIdenticalSuccessors>(context);
1050 }
1051
1052 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1053 CondBranchOp::getMutableSuccessorOperands(unsigned index) {
1054 assert(index < getNumSuccessors() && "invalid successor index");
1055 return index == trueIndex ? trueDestOperandsMutable()
1056 : falseDestOperandsMutable();
1057 }
1058
getSuccessorForOperands(ArrayRef<Attribute> operands)1059 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1060 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
1061 return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
1062 return nullptr;
1063 }
1064
1065 //===----------------------------------------------------------------------===//
1066 // Constant*Op
1067 //===----------------------------------------------------------------------===//
1068
print(OpAsmPrinter & p,ConstantOp & op)1069 static void print(OpAsmPrinter &p, ConstantOp &op) {
1070 p << "constant ";
1071 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
1072
1073 if (op.getAttrs().size() > 1)
1074 p << ' ';
1075 p << op.getValue();
1076
1077 // If the value is a symbol reference, print a trailing type.
1078 if (op.getValue().isa<SymbolRefAttr>())
1079 p << " : " << op.getType();
1080 }
1081
parseConstantOp(OpAsmParser & parser,OperationState & result)1082 static ParseResult parseConstantOp(OpAsmParser &parser,
1083 OperationState &result) {
1084 Attribute valueAttr;
1085 if (parser.parseOptionalAttrDict(result.attributes) ||
1086 parser.parseAttribute(valueAttr, "value", result.attributes))
1087 return failure();
1088
1089 // If the attribute is a symbol reference, then we expect a trailing type.
1090 Type type;
1091 if (!valueAttr.isa<SymbolRefAttr>())
1092 type = valueAttr.getType();
1093 else if (parser.parseColonType(type))
1094 return failure();
1095
1096 // Add the attribute type to the list.
1097 return parser.addTypeToList(type, result.types);
1098 }
1099
1100 /// The constant op requires an attribute, and furthermore requires that it
1101 /// matches the return type.
verify(ConstantOp & op)1102 static LogicalResult verify(ConstantOp &op) {
1103 auto value = op.getValue();
1104 if (!value)
1105 return op.emitOpError("requires a 'value' attribute");
1106
1107 auto type = op.getType();
1108 if (!value.getType().isa<NoneType>() && type != value.getType())
1109 return op.emitOpError() << "requires attribute's type (" << value.getType()
1110 << ") to match op's return type (" << type << ")";
1111
1112 if (type.isa<IndexType>() || value.isa<BoolAttr>())
1113 return success();
1114
1115 if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
1116 // If the type has a known bitwidth we verify that the value can be
1117 // represented with the given bitwidth.
1118 auto bitwidth = type.cast<IntegerType>().getWidth();
1119 auto intVal = intAttr.getValue();
1120 if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
1121 return op.emitOpError("requires 'value' to be an integer within the "
1122 "range of the integer result type");
1123 return success();
1124 }
1125
1126 if (type.isa<FloatType>()) {
1127 if (!value.isa<FloatAttr>())
1128 return op.emitOpError("requires 'value' to be a floating point constant");
1129 return success();
1130 }
1131
1132 if (type.isa<ShapedType>()) {
1133 if (!value.isa<ElementsAttr>())
1134 return op.emitOpError("requires 'value' to be a shaped constant");
1135 return success();
1136 }
1137
1138 if (type.isa<FunctionType>()) {
1139 auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
1140 if (!fnAttr)
1141 return op.emitOpError("requires 'value' to be a function reference");
1142
1143 // Try to find the referenced function.
1144 auto fn =
1145 op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
1146 if (!fn)
1147 return op.emitOpError()
1148 << "reference to undefined function '" << fnAttr.getValue() << "'";
1149
1150 // Check that the referenced function has the correct type.
1151 if (fn.getType() != type)
1152 return op.emitOpError("reference to function with mismatched type");
1153
1154 return success();
1155 }
1156
1157 if (type.isa<NoneType>() && value.isa<UnitAttr>())
1158 return success();
1159
1160 return op.emitOpError("unsupported 'value' attribute: ") << value;
1161 }
1162
fold(ArrayRef<Attribute> operands)1163 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
1164 assert(operands.empty() && "constant has no operands");
1165 return getValue();
1166 }
1167
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1168 void ConstantOp::getAsmResultNames(
1169 function_ref<void(Value, StringRef)> setNameFn) {
1170 Type type = getType();
1171 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
1172 IntegerType intTy = type.dyn_cast<IntegerType>();
1173
1174 // Sugar i1 constants with 'true' and 'false'.
1175 if (intTy && intTy.getWidth() == 1)
1176 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1177
1178 // Otherwise, build a complex name with the value and type.
1179 SmallString<32> specialNameBuffer;
1180 llvm::raw_svector_ostream specialName(specialNameBuffer);
1181 specialName << 'c' << intCst.getInt();
1182 if (intTy)
1183 specialName << '_' << type;
1184 setNameFn(getResult(), specialName.str());
1185
1186 } else if (type.isa<FunctionType>()) {
1187 setNameFn(getResult(), "f");
1188 } else {
1189 setNameFn(getResult(), "cst");
1190 }
1191 }
1192
1193 /// Returns true if a constant operation can be built with the given value and
1194 /// result type.
isBuildableWith(Attribute value,Type type)1195 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
1196 // SymbolRefAttr can only be used with a function type.
1197 if (value.isa<SymbolRefAttr>())
1198 return type.isa<FunctionType>();
1199 // Otherwise, the attribute must have the same type as 'type'.
1200 if (value.getType() != type)
1201 return false;
1202 // Finally, check that the attribute kind is handled.
1203 return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
1204 }
1205
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)1206 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
1207 const APFloat &value, FloatType type) {
1208 ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value));
1209 }
1210
classof(Operation * op)1211 bool ConstantFloatOp::classof(Operation *op) {
1212 return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
1213 }
1214
1215 /// ConstantIntOp only matches values whose result type is an IntegerType.
classof(Operation * op)1216 bool ConstantIntOp::classof(Operation *op) {
1217 return ConstantOp::classof(op) &&
1218 op->getResult(0).getType().isSignlessInteger();
1219 }
1220
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)1221 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1222 int64_t value, unsigned width) {
1223 Type type = builder.getIntegerType(width);
1224 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1225 }
1226
1227 /// Build a constant int op producing an integer with the specified type,
1228 /// which must be an integer type.
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)1229 void ConstantIntOp::build(OpBuilder &builder, OperationState &result,
1230 int64_t value, Type type) {
1231 assert(type.isSignlessInteger() &&
1232 "ConstantIntOp can only have signless integer type");
1233 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1234 }
1235
1236 /// ConstantIndexOp only matches values whose result type is Index.
classof(Operation * op)1237 bool ConstantIndexOp::classof(Operation *op) {
1238 return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
1239 }
1240
build(OpBuilder & builder,OperationState & result,int64_t value)1241 void ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
1242 int64_t value) {
1243 Type type = builder.getIndexType();
1244 ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value));
1245 }
1246
1247 //===----------------------------------------------------------------------===//
1248 // DeallocOp
1249 //===----------------------------------------------------------------------===//
1250 namespace {
1251 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
1252 /// by other Dealloc operations.
1253 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
1254 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1255
matchAndRewrite__anone4b94cca0c11::SimplifyDeadDealloc1256 LogicalResult matchAndRewrite(DeallocOp dealloc,
1257 PatternRewriter &rewriter) const override {
1258 // Check that the memref operand's defining operation is an AllocOp.
1259 Value memref = dealloc.memref();
1260 if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
1261 return failure();
1262
1263 // Check that all of the uses of the AllocOp are other DeallocOps.
1264 for (auto *user : memref.getUsers())
1265 if (!isa<DeallocOp>(user))
1266 return failure();
1267
1268 // Erase the dealloc operation.
1269 rewriter.eraseOp(dealloc);
1270 return success();
1271 }
1272 };
1273 } // end anonymous namespace.
1274
verify(DeallocOp op)1275 static LogicalResult verify(DeallocOp op) {
1276 if (!op.memref().getType().isa<MemRefType>())
1277 return op.emitOpError("operand must be a memref");
1278 return success();
1279 }
1280
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1281 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1282 MLIRContext *context) {
1283 results.insert<SimplifyDeadDealloc>(context);
1284 }
1285
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1286 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
1287 SmallVectorImpl<OpFoldResult> &results) {
1288 /// dealloc(memrefcast) -> dealloc
1289 return foldMemRefCast(*this);
1290 }
1291
1292 //===----------------------------------------------------------------------===//
1293 // DimOp
1294 //===----------------------------------------------------------------------===//
1295
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,int64_t index)1296 void DimOp::build(OpBuilder &builder, OperationState &result,
1297 Value memrefOrTensor, int64_t index) {
1298 auto loc = result.location;
1299 Value indexValue = builder.create<ConstantIndexOp>(loc, index);
1300 build(builder, result, memrefOrTensor, indexValue);
1301 }
1302
build(OpBuilder & builder,OperationState & result,Value memrefOrTensor,Value index)1303 void DimOp::build(OpBuilder &builder, OperationState &result,
1304 Value memrefOrTensor, Value index) {
1305 auto indexTy = builder.getIndexType();
1306 build(builder, result, indexTy, memrefOrTensor, index);
1307 }
1308
getConstantIndex()1309 Optional<int64_t> DimOp::getConstantIndex() {
1310 if (auto constantOp = index().getDefiningOp<ConstantOp>())
1311 return constantOp.getValue().cast<IntegerAttr>().getInt();
1312 return {};
1313 }
1314
verify(DimOp op)1315 static LogicalResult verify(DimOp op) {
1316
1317 // Assume unknown index to be in range.
1318 Optional<int64_t> index = op.getConstantIndex();
1319 if (!index.hasValue())
1320 return success();
1321
1322 // Check that constant index is not knowingly out of range.
1323 auto type = op.memrefOrTensor().getType();
1324 if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
1325 if (index.getValue() >= tensorType.getRank())
1326 return op.emitOpError("index is out of range");
1327 } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
1328 if (index.getValue() >= memrefType.getRank())
1329 return op.emitOpError("index is out of range");
1330 } else if (type.isa<UnrankedTensorType>()) {
1331 // Assume index to be in range.
1332 } else {
1333 llvm_unreachable("expected operand with tensor or memref type");
1334 }
1335
1336 return success();
1337 }
1338
fold(ArrayRef<Attribute> operands)1339 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1340 auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
1341
1342 // All forms of folding require a known index.
1343 if (!index)
1344 return {};
1345
1346 // Fold if the shape extent along the given index is known.
1347 auto argTy = memrefOrTensor().getType();
1348 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
1349 if (!shapedTy.isDynamicDim(index.getInt())) {
1350 Builder builder(getContext());
1351 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
1352 }
1353 }
1354
1355 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1356 auto memrefType = argTy.dyn_cast<MemRefType>();
1357 if (!memrefType)
1358 return {};
1359
1360 // The size at the given index is now known to be a dynamic size of a memref.
1361 auto memref = memrefOrTensor().getDefiningOp();
1362 unsigned unsignedIndex = index.getValue().getZExtValue();
1363 if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
1364 return *(alloc.getDynamicSizes().begin() +
1365 memrefType.getDynamicDimIndex(unsignedIndex));
1366
1367 if (auto view = dyn_cast_or_null<ViewOp>(memref))
1368 return *(view.getDynamicSizes().begin() +
1369 memrefType.getDynamicDimIndex(unsignedIndex));
1370
1371 if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
1372 assert(subview.isDynamicSize(unsignedIndex) &&
1373 "Expected dynamic subview size");
1374 return subview.getDynamicSize(unsignedIndex);
1375 }
1376
1377 // dim(memrefcast) -> dim
1378 if (succeeded(foldMemRefCast(*this)))
1379 return getResult();
1380
1381 return {};
1382 }
1383
1384 // ---------------------------------------------------------------------------
1385 // DmaStartOp
1386 // ---------------------------------------------------------------------------
1387
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)1388 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1389 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1390 ValueRange destIndices, Value numElements,
1391 Value tagMemRef, ValueRange tagIndices, Value stride,
1392 Value elementsPerStride) {
1393 result.addOperands(srcMemRef);
1394 result.addOperands(srcIndices);
1395 result.addOperands(destMemRef);
1396 result.addOperands(destIndices);
1397 result.addOperands({numElements, tagMemRef});
1398 result.addOperands(tagIndices);
1399 if (stride)
1400 result.addOperands({stride, elementsPerStride});
1401 }
1402
print(OpAsmPrinter & p)1403 void DmaStartOp::print(OpAsmPrinter &p) {
1404 p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1405 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1406 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1407 if (isStrided())
1408 p << ", " << getStride() << ", " << getNumElementsPerStride();
1409
1410 p.printOptionalAttrDict(getAttrs());
1411 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1412 << ", " << getTagMemRef().getType();
1413 }
1414
1415 // Parse DmaStartOp.
1416 // Ex:
1417 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1418 // %tag[%index], %stride, %num_elt_per_stride :
1419 // : memref<3076 x f32, 0>,
1420 // memref<1024 x f32, 2>,
1421 // memref<1 x i32>
1422 //
parse(OpAsmParser & parser,OperationState & result)1423 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1424 OpAsmParser::OperandType srcMemRefInfo;
1425 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
1426 OpAsmParser::OperandType dstMemRefInfo;
1427 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
1428 OpAsmParser::OperandType numElementsInfo;
1429 OpAsmParser::OperandType tagMemrefInfo;
1430 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
1431 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1432
1433 SmallVector<Type, 3> types;
1434 auto indexType = parser.getBuilder().getIndexType();
1435
1436 // Parse and resolve the following list of operands:
1437 // *) source memref followed by its indices (in square brackets).
1438 // *) destination memref followed by its indices (in square brackets).
1439 // *) dma size in KiB.
1440 if (parser.parseOperand(srcMemRefInfo) ||
1441 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1442 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1443 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1444 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1445 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1446 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1447 return failure();
1448
1449 // Parse optional stride and elements per stride.
1450 if (parser.parseTrailingOperandList(strideInfo))
1451 return failure();
1452
1453 bool isStrided = strideInfo.size() == 2;
1454 if (!strideInfo.empty() && !isStrided) {
1455 return parser.emitError(parser.getNameLoc(),
1456 "expected two stride related operands");
1457 }
1458
1459 if (parser.parseColonTypeList(types))
1460 return failure();
1461 if (types.size() != 3)
1462 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1463
1464 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1465 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1466 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1467 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1468 // size should be an index.
1469 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1470 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1471 // tag indices should be index.
1472 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1473 return failure();
1474
1475 if (isStrided) {
1476 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1477 return failure();
1478 }
1479
1480 return success();
1481 }
1482
verify()1483 LogicalResult DmaStartOp::verify() {
1484 unsigned numOperands = getNumOperands();
1485
1486 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1487 // the number of elements.
1488 if (numOperands < 4)
1489 return emitOpError("expected at least 4 operands");
1490
1491 // Check types of operands. The order of these calls is important: the later
1492 // calls rely on some type properties to compute the operand position.
1493 // 1. Source memref.
1494 if (!getSrcMemRef().getType().isa<MemRefType>())
1495 return emitOpError("expected source to be of memref type");
1496 if (numOperands < getSrcMemRefRank() + 4)
1497 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1498 << " operands";
1499 if (!getSrcIndices().empty() &&
1500 !llvm::all_of(getSrcIndices().getTypes(),
1501 [](Type t) { return t.isIndex(); }))
1502 return emitOpError("expected source indices to be of index type");
1503
1504 // 2. Destination memref.
1505 if (!getDstMemRef().getType().isa<MemRefType>())
1506 return emitOpError("expected destination to be of memref type");
1507 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1508 if (numOperands < numExpectedOperands)
1509 return emitOpError() << "expected at least " << numExpectedOperands
1510 << " operands";
1511 if (!getDstIndices().empty() &&
1512 !llvm::all_of(getDstIndices().getTypes(),
1513 [](Type t) { return t.isIndex(); }))
1514 return emitOpError("expected destination indices to be of index type");
1515
1516 // 3. Number of elements.
1517 if (!getNumElements().getType().isIndex())
1518 return emitOpError("expected num elements to be of index type");
1519
1520 // 4. Tag memref.
1521 if (!getTagMemRef().getType().isa<MemRefType>())
1522 return emitOpError("expected tag to be of memref type");
1523 numExpectedOperands += getTagMemRefRank();
1524 if (numOperands < numExpectedOperands)
1525 return emitOpError() << "expected at least " << numExpectedOperands
1526 << " operands";
1527 if (!getTagIndices().empty() &&
1528 !llvm::all_of(getTagIndices().getTypes(),
1529 [](Type t) { return t.isIndex(); }))
1530 return emitOpError("expected tag indices to be of index type");
1531
1532 // DMAs from different memory spaces supported.
1533 if (getSrcMemorySpace() == getDstMemorySpace())
1534 return emitOpError("DMA should be between different memory spaces");
1535
1536 // Optional stride-related operands must be either both present or both
1537 // absent.
1538 if (numOperands != numExpectedOperands &&
1539 numOperands != numExpectedOperands + 2)
1540 return emitOpError("incorrect number of operands");
1541
1542 // 5. Strides.
1543 if (isStrided()) {
1544 if (!getStride().getType().isIndex() ||
1545 !getNumElementsPerStride().getType().isIndex())
1546 return emitOpError(
1547 "expected stride and num elements per stride to be of type index");
1548 }
1549
1550 return success();
1551 }
1552
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1553 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1554 SmallVectorImpl<OpFoldResult> &results) {
1555 /// dma_start(memrefcast) -> dma_start
1556 return foldMemRefCast(*this);
1557 }
1558
1559 // ---------------------------------------------------------------------------
1560 // DmaWaitOp
1561 // ---------------------------------------------------------------------------
1562
build(OpBuilder & builder,OperationState & result,Value tagMemRef,ValueRange tagIndices,Value numElements)1563 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1564 Value tagMemRef, ValueRange tagIndices,
1565 Value numElements) {
1566 result.addOperands(tagMemRef);
1567 result.addOperands(tagIndices);
1568 result.addOperands(numElements);
1569 }
1570
print(OpAsmPrinter & p)1571 void DmaWaitOp::print(OpAsmPrinter &p) {
1572 p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
1573 << getNumElements();
1574 p.printOptionalAttrDict(getAttrs());
1575 p << " : " << getTagMemRef().getType();
1576 }
1577
1578 // Parse DmaWaitOp.
1579 // Eg:
1580 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1581 //
parse(OpAsmParser & parser,OperationState & result)1582 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1583 OpAsmParser::OperandType tagMemrefInfo;
1584 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1585 Type type;
1586 auto indexType = parser.getBuilder().getIndexType();
1587 OpAsmParser::OperandType numElementsInfo;
1588
1589 // Parse tag memref, its indices, and dma size.
1590 if (parser.parseOperand(tagMemrefInfo) ||
1591 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1592 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1593 parser.parseColonType(type) ||
1594 parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1595 parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1596 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1597 return failure();
1598
1599 return success();
1600 }
1601
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1602 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1603 SmallVectorImpl<OpFoldResult> &results) {
1604 /// dma_wait(memrefcast) -> dma_wait
1605 return foldMemRefCast(*this);
1606 }
1607
verify()1608 LogicalResult DmaWaitOp::verify() {
1609 // Mandatory non-variadic operands are tag and the number of elements.
1610 if (getNumOperands() < 2)
1611 return emitOpError() << "expected at least 2 operands";
1612
1613 // Check types of operands. The order of these calls is important: the later
1614 // calls rely on some type properties to compute the operand position.
1615 if (!getTagMemRef().getType().isa<MemRefType>())
1616 return emitOpError() << "expected tag to be of memref type";
1617
1618 if (getNumOperands() != 2 + getTagMemRefRank())
1619 return emitOpError() << "expected " << 2 + getTagMemRefRank()
1620 << " operands";
1621
1622 if (!getTagIndices().empty() &&
1623 !llvm::all_of(getTagIndices().getTypes(),
1624 [](Type t) { return t.isIndex(); }))
1625 return emitOpError() << "expected tag indices to be of index type";
1626
1627 if (!getNumElements().getType().isIndex())
1628 return emitOpError()
1629 << "expected the number of elements to be of index type";
1630
1631 return success();
1632 }
1633
1634 //===----------------------------------------------------------------------===//
1635 // ExtractElementOp
1636 //===----------------------------------------------------------------------===//
1637
verify(ExtractElementOp op)1638 static LogicalResult verify(ExtractElementOp op) {
1639 // Verify the # indices match if we have a ranked type.
1640 auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
1641 if (aggregateType.hasRank() &&
1642 aggregateType.getRank() != op.getNumOperands() - 1)
1643 return op.emitOpError("incorrect number of indices for extract_element");
1644
1645 return success();
1646 }
1647
fold(ArrayRef<Attribute> operands)1648 OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1649 assert(!operands.empty() && "extract_element takes at least one operand");
1650
1651 // The aggregate operand must be a known constant.
1652 Attribute aggregate = operands.front();
1653 if (!aggregate)
1654 return {};
1655
1656 // If this is a splat elements attribute, simply return the value. All of the
1657 // elements of a splat attribute are the same.
1658 if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
1659 return splatAggregate.getSplatValue();
1660
1661 // Otherwise, collect the constant indices into the aggregate.
1662 SmallVector<uint64_t, 8> indices;
1663 for (Attribute indice : llvm::drop_begin(operands, 1)) {
1664 if (!indice || !indice.isa<IntegerAttr>())
1665 return {};
1666 indices.push_back(indice.cast<IntegerAttr>().getInt());
1667 }
1668
1669 // If this is an elements attribute, query the value at the given indices.
1670 auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
1671 if (elementsAttr && elementsAttr.isValidIndex(indices))
1672 return elementsAttr.getValue(indices);
1673 return {};
1674 }
1675
1676 //===----------------------------------------------------------------------===//
1677 // TensorFromElementsOp
1678 //===----------------------------------------------------------------------===//
1679
parseTensorFromElementsOp(OpAsmParser & parser,OperationState & result)1680 static ParseResult parseTensorFromElementsOp(OpAsmParser &parser,
1681 OperationState &result) {
1682 SmallVector<OpAsmParser::OperandType, 4> elementsOperands;
1683 Type resultType;
1684 if (parser.parseLParen() || parser.parseOperandList(elementsOperands) ||
1685 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
1686 parser.parseColon() || parser.parseType(resultType))
1687 return failure();
1688
1689 if (parser.resolveOperands(elementsOperands,
1690 resultType.cast<ShapedType>().getElementType(),
1691 result.operands))
1692 return failure();
1693
1694 result.addTypes(resultType);
1695 return success();
1696 }
1697
print(OpAsmPrinter & p,TensorFromElementsOp op)1698 static void print(OpAsmPrinter &p, TensorFromElementsOp op) {
1699 p << "tensor_from_elements(" << op.elements() << ')';
1700 p.printOptionalAttrDict(op.getAttrs());
1701 p << " : " << op.result().getType();
1702 }
1703
verify(TensorFromElementsOp op)1704 static LogicalResult verify(TensorFromElementsOp op) {
1705 auto resultTensorType = op.result().getType().dyn_cast<RankedTensorType>();
1706 if (!resultTensorType)
1707 return op.emitOpError("expected result type to be a ranked tensor");
1708
1709 int64_t elementsCount = static_cast<int64_t>(op.elements().size());
1710 if (resultTensorType.getRank() != 1 ||
1711 resultTensorType.getShape().front() != elementsCount)
1712 return op.emitOpError()
1713 << "expected result type to be a 1D tensor with " << elementsCount
1714 << (elementsCount == 1 ? " element" : " elements");
1715 return success();
1716 }
1717
1718 namespace {
1719
1720 // Canonicalizes the pattern of the form
1721 //
1722 // %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
1723 // %extracted_element = extract_element %tensor[%c0] : tensor<1xi32>
1724 //
1725 // to just %element.
1726 struct ExtractElementFromTensorFromElements
1727 : public OpRewritePattern<ExtractElementOp> {
1728 using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
1729
matchAndRewrite__anone4b94cca1111::ExtractElementFromTensorFromElements1730 LogicalResult matchAndRewrite(ExtractElementOp extract,
1731 PatternRewriter &rewriter) const final {
1732 if (extract.indices().size() != 1)
1733 return failure();
1734
1735 auto tensor_from_elements = dyn_cast_or_null<TensorFromElementsOp>(
1736 extract.aggregate().getDefiningOp());
1737 if (tensor_from_elements == nullptr)
1738 return failure();
1739
1740 APInt index;
1741 if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
1742 return failure();
1743 rewriter.replaceOp(extract,
1744 tensor_from_elements.getOperand(index.getZExtValue()));
1745 return success();
1746 }
1747 };
1748
1749 } // namespace
1750
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1751 void TensorFromElementsOp::getCanonicalizationPatterns(
1752 OwningRewritePatternList &results, MLIRContext *context) {
1753 results.insert<ExtractElementFromTensorFromElements>(context);
1754 }
1755
1756 //===----------------------------------------------------------------------===//
1757 // FPExtOp
1758 //===----------------------------------------------------------------------===//
1759
areCastCompatible(Type a,Type b)1760 bool FPExtOp::areCastCompatible(Type a, Type b) {
1761 if (auto fa = a.dyn_cast<FloatType>())
1762 if (auto fb = b.dyn_cast<FloatType>())
1763 return fa.getWidth() < fb.getWidth();
1764 if (auto va = a.dyn_cast<VectorType>())
1765 if (auto vb = b.dyn_cast<VectorType>())
1766 return va.getShape().equals(vb.getShape()) &&
1767 areCastCompatible(va.getElementType(), vb.getElementType());
1768 return false;
1769 }
1770
1771 //===----------------------------------------------------------------------===//
1772 // FPToSIOp
1773 //===----------------------------------------------------------------------===//
1774
areCastCompatible(Type a,Type b)1775 bool FPToSIOp::areCastCompatible(Type a, Type b) {
1776 return a.isa<FloatType>() && b.isSignlessInteger();
1777 }
1778
1779 //===----------------------------------------------------------------------===//
1780 // FPTruncOp
1781 //===----------------------------------------------------------------------===//
1782
areCastCompatible(Type a,Type b)1783 bool FPTruncOp::areCastCompatible(Type a, Type b) {
1784 if (auto fa = a.dyn_cast<FloatType>())
1785 if (auto fb = b.dyn_cast<FloatType>())
1786 return fa.getWidth() > fb.getWidth();
1787 if (auto va = a.dyn_cast<VectorType>())
1788 if (auto vb = b.dyn_cast<VectorType>())
1789 return va.getShape().equals(vb.getShape()) &&
1790 areCastCompatible(va.getElementType(), vb.getElementType());
1791 return false;
1792 }
1793
1794 //===----------------------------------------------------------------------===//
1795 // IndexCastOp
1796 //===----------------------------------------------------------------------===//
1797
1798 // Index cast is applicable from index to integer and backwards.
areCastCompatible(Type a,Type b)1799 bool IndexCastOp::areCastCompatible(Type a, Type b) {
1800 if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
1801 auto aShaped = a.cast<ShapedType>();
1802 auto bShaped = b.cast<ShapedType>();
1803
1804 return (aShaped.getShape() == bShaped.getShape()) &&
1805 areCastCompatible(aShaped.getElementType(),
1806 bShaped.getElementType());
1807 }
1808
1809 return (a.isIndex() && b.isSignlessInteger()) ||
1810 (a.isSignlessInteger() && b.isIndex());
1811 }
1812
fold(ArrayRef<Attribute> cstOperands)1813 OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
1814 // Fold IndexCast(IndexCast(x)) -> x
1815 auto cast = getOperand().getDefiningOp<IndexCastOp>();
1816 if (cast && cast.getOperand().getType() == getType())
1817 return cast.getOperand();
1818
1819 // Fold IndexCast(constant) -> constant
1820 // A little hack because we go through int. Otherwise, the size
1821 // of the constant might need to change.
1822 if (auto value = cstOperands[0].dyn_cast_or_null<IntegerAttr>())
1823 return IntegerAttr::get(getType(), value.getInt());
1824
1825 return {};
1826 }
1827
1828 //===----------------------------------------------------------------------===//
1829 // LoadOp
1830 //===----------------------------------------------------------------------===//
1831
verify(LoadOp op)1832 static LogicalResult verify(LoadOp op) {
1833 if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1834 return op.emitOpError("incorrect number of indices for load");
1835 return success();
1836 }
1837
fold(ArrayRef<Attribute> cstOperands)1838 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1839 /// load(memrefcast) -> load
1840 if (succeeded(foldMemRefCast(*this)))
1841 return getResult();
1842 return OpFoldResult();
1843 }
1844
1845 //===----------------------------------------------------------------------===//
1846 // MemRefCastOp
1847 //===----------------------------------------------------------------------===//
1848
getViewSource()1849 Value MemRefCastOp::getViewSource() { return source(); }
1850
areCastCompatible(Type a,Type b)1851 bool MemRefCastOp::areCastCompatible(Type a, Type b) {
1852 auto aT = a.dyn_cast<MemRefType>();
1853 auto bT = b.dyn_cast<MemRefType>();
1854
1855 auto uaT = a.dyn_cast<UnrankedMemRefType>();
1856 auto ubT = b.dyn_cast<UnrankedMemRefType>();
1857
1858 if (aT && bT) {
1859 if (aT.getElementType() != bT.getElementType())
1860 return false;
1861 if (aT.getAffineMaps() != bT.getAffineMaps()) {
1862 int64_t aOffset, bOffset;
1863 SmallVector<int64_t, 4> aStrides, bStrides;
1864 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
1865 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
1866 aStrides.size() != bStrides.size())
1867 return false;
1868
1869 // Strides along a dimension/offset are compatible if the value in the
1870 // source memref is static and the value in the target memref is the
1871 // same. They are also compatible if either one is dynamic (see
1872 // description of MemRefCastOp for details).
1873 auto checkCompatible = [](int64_t a, int64_t b) {
1874 return (a == MemRefType::getDynamicStrideOrOffset() ||
1875 b == MemRefType::getDynamicStrideOrOffset() || a == b);
1876 };
1877 if (!checkCompatible(aOffset, bOffset))
1878 return false;
1879 for (auto aStride : enumerate(aStrides))
1880 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
1881 return false;
1882 }
1883 if (aT.getMemorySpace() != bT.getMemorySpace())
1884 return false;
1885
1886 // They must have the same rank, and any specified dimensions must match.
1887 if (aT.getRank() != bT.getRank())
1888 return false;
1889
1890 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
1891 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
1892 if (aDim != -1 && bDim != -1 && aDim != bDim)
1893 return false;
1894 }
1895 return true;
1896 } else {
1897 if (!aT && !uaT)
1898 return false;
1899 if (!bT && !ubT)
1900 return false;
1901 // Unranked to unranked casting is unsupported
1902 if (uaT && ubT)
1903 return false;
1904
1905 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
1906 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
1907 if (aEltType != bEltType)
1908 return false;
1909
1910 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
1911 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
1912 if (aMemSpace != bMemSpace)
1913 return false;
1914
1915 return true;
1916 }
1917
1918 return false;
1919 }
1920
fold(ArrayRef<Attribute> operands)1921 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
1922 return impl::foldCastOp(*this);
1923 }
1924
1925 //===----------------------------------------------------------------------===//
1926 // MulFOp
1927 //===----------------------------------------------------------------------===//
1928
fold(ArrayRef<Attribute> operands)1929 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
1930 return constFoldBinaryOp<FloatAttr>(
1931 operands, [](APFloat a, APFloat b) { return a * b; });
1932 }
1933
1934 //===----------------------------------------------------------------------===//
1935 // MulIOp
1936 //===----------------------------------------------------------------------===//
1937
fold(ArrayRef<Attribute> operands)1938 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
1939 /// muli(x, 0) -> 0
1940 if (matchPattern(rhs(), m_Zero()))
1941 return rhs();
1942 /// muli(x, 1) -> x
1943 if (matchPattern(rhs(), m_One()))
1944 return getOperand(0);
1945
1946 // TODO: Handle the overflow case.
1947 return constFoldBinaryOp<IntegerAttr>(operands,
1948 [](APInt a, APInt b) { return a * b; });
1949 }
1950
1951 //===----------------------------------------------------------------------===//
1952 // OrOp
1953 //===----------------------------------------------------------------------===//
1954
fold(ArrayRef<Attribute> operands)1955 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1956 /// or(x, 0) -> x
1957 if (matchPattern(rhs(), m_Zero()))
1958 return lhs();
1959 /// or(x,x) -> x
1960 if (lhs() == rhs())
1961 return rhs();
1962
1963 return constFoldBinaryOp<IntegerAttr>(operands,
1964 [](APInt a, APInt b) { return a | b; });
1965 }
1966
1967 //===----------------------------------------------------------------------===//
1968 // PrefetchOp
1969 //===----------------------------------------------------------------------===//
1970
print(OpAsmPrinter & p,PrefetchOp op)1971 static void print(OpAsmPrinter &p, PrefetchOp op) {
1972 p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1973 p.printOperands(op.indices());
1974 p << ']' << ", " << (op.isWrite() ? "write" : "read");
1975 p << ", locality<" << op.localityHint();
1976 p << ">, " << (op.isDataCache() ? "data" : "instr");
1977 p.printOptionalAttrDict(
1978 op.getAttrs(),
1979 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1980 p << " : " << op.getMemRefType();
1981 }
1982
parsePrefetchOp(OpAsmParser & parser,OperationState & result)1983 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1984 OperationState &result) {
1985 OpAsmParser::OperandType memrefInfo;
1986 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1987 IntegerAttr localityHint;
1988 MemRefType type;
1989 StringRef readOrWrite, cacheType;
1990
1991 auto indexTy = parser.getBuilder().getIndexType();
1992 auto i32Type = parser.getBuilder().getIntegerType(32);
1993 if (parser.parseOperand(memrefInfo) ||
1994 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1995 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1996 parser.parseComma() || parser.parseKeyword("locality") ||
1997 parser.parseLess() ||
1998 parser.parseAttribute(localityHint, i32Type, "localityHint",
1999 result.attributes) ||
2000 parser.parseGreater() || parser.parseComma() ||
2001 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
2002 parser.resolveOperand(memrefInfo, type, result.operands) ||
2003 parser.resolveOperands(indexInfo, indexTy, result.operands))
2004 return failure();
2005
2006 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2007 return parser.emitError(parser.getNameLoc(),
2008 "rw specifier has to be 'read' or 'write'");
2009 result.addAttribute(
2010 PrefetchOp::getIsWriteAttrName(),
2011 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2012
2013 if (!cacheType.equals("data") && !cacheType.equals("instr"))
2014 return parser.emitError(parser.getNameLoc(),
2015 "cache type has to be 'data' or 'instr'");
2016
2017 result.addAttribute(
2018 PrefetchOp::getIsDataCacheAttrName(),
2019 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2020
2021 return success();
2022 }
2023
verify(PrefetchOp op)2024 static LogicalResult verify(PrefetchOp op) {
2025 if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
2026 return op.emitOpError("too few indices");
2027
2028 return success();
2029 }
2030
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2031 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2032 SmallVectorImpl<OpFoldResult> &results) {
2033 // prefetch(memrefcast) -> prefetch
2034 return foldMemRefCast(*this);
2035 }
2036
2037 //===----------------------------------------------------------------------===//
2038 // RankOp
2039 //===----------------------------------------------------------------------===//
2040
fold(ArrayRef<Attribute> operands)2041 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2042 // Constant fold rank when the rank of the tensor is known.
2043 auto type = getOperand().getType();
2044 if (auto tensorType = type.dyn_cast<RankedTensorType>())
2045 return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
2046 return IntegerAttr();
2047 }
2048
2049 //===----------------------------------------------------------------------===//
2050 // ReturnOp
2051 //===----------------------------------------------------------------------===//
2052
verify(ReturnOp op)2053 static LogicalResult verify(ReturnOp op) {
2054 auto function = cast<FuncOp>(op.getParentOp());
2055
2056 // The operand number and types must match the function signature.
2057 const auto &results = function.getType().getResults();
2058 if (op.getNumOperands() != results.size())
2059 return op.emitOpError("has ")
2060 << op.getNumOperands() << " operands, but enclosing function (@"
2061 << function.getName() << ") returns " << results.size();
2062
2063 for (unsigned i = 0, e = results.size(); i != e; ++i)
2064 if (op.getOperand(i).getType() != results[i])
2065 return op.emitError()
2066 << "type of return operand " << i << " ("
2067 << op.getOperand(i).getType()
2068 << ") doesn't match function result type (" << results[i] << ")"
2069 << " in function @" << function.getName();
2070
2071 return success();
2072 }
2073
2074 //===----------------------------------------------------------------------===//
2075 // SelectOp
2076 //===----------------------------------------------------------------------===//
2077
fold(ArrayRef<Attribute> operands)2078 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2079 auto condition = getCondition();
2080
2081 // select true, %0, %1 => %0
2082 if (matchPattern(condition, m_One()))
2083 return getTrueValue();
2084
2085 // select false, %0, %1 => %1
2086 if (matchPattern(condition, m_Zero()))
2087 return getFalseValue();
2088 return nullptr;
2089 }
2090
print(OpAsmPrinter & p,SelectOp op)2091 static void print(OpAsmPrinter &p, SelectOp op) {
2092 p << "select " << op.getOperands();
2093 p.printOptionalAttrDict(op.getAttrs());
2094 p << " : ";
2095 if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
2096 p << condType << ", ";
2097 p << op.getType();
2098 }
2099
parseSelectOp(OpAsmParser & parser,OperationState & result)2100 static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
2101 Type conditionType, resultType;
2102 SmallVector<OpAsmParser::OperandType, 3> operands;
2103 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2104 parser.parseOptionalAttrDict(result.attributes) ||
2105 parser.parseColonType(resultType))
2106 return failure();
2107
2108 // Check for the explicit condition type if this is a masked tensor or vector.
2109 if (succeeded(parser.parseOptionalComma())) {
2110 conditionType = resultType;
2111 if (parser.parseType(resultType))
2112 return failure();
2113 } else {
2114 conditionType = parser.getBuilder().getI1Type();
2115 }
2116
2117 result.addTypes(resultType);
2118 return parser.resolveOperands(operands,
2119 {conditionType, resultType, resultType},
2120 parser.getNameLoc(), result.operands);
2121 }
2122
verify(SelectOp op)2123 static LogicalResult verify(SelectOp op) {
2124 Type conditionType = op.getCondition().getType();
2125 if (conditionType.isSignlessInteger(1))
2126 return success();
2127
2128 // If the result type is a vector or tensor, the type can be a mask with the
2129 // same elements.
2130 Type resultType = op.getType();
2131 if (!resultType.isa<TensorType, VectorType>())
2132 return op.emitOpError()
2133 << "expected condition to be a signless i1, but got "
2134 << conditionType;
2135 Type shapedConditionType = getI1SameShape(resultType);
2136 if (conditionType != shapedConditionType)
2137 return op.emitOpError()
2138 << "expected condition type to have the same shape "
2139 "as the result type, expected "
2140 << shapedConditionType << ", but got " << conditionType;
2141 return success();
2142 }
2143
2144 //===----------------------------------------------------------------------===//
2145 // SignExtendIOp
2146 //===----------------------------------------------------------------------===//
2147
verify(SignExtendIOp op)2148 static LogicalResult verify(SignExtendIOp op) {
2149 // Get the scalar type (which is either directly the type of the operand
2150 // or the vector's/tensor's element type.
2151 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2152 auto dstType = getElementTypeOrSelf(op.getType());
2153
2154 // For now, index is forbidden for the source and the destination type.
2155 if (srcType.isa<IndexType>())
2156 return op.emitError() << srcType << " is not a valid operand type";
2157 if (dstType.isa<IndexType>())
2158 return op.emitError() << dstType << " is not a valid result type";
2159
2160 if (srcType.cast<IntegerType>().getWidth() >=
2161 dstType.cast<IntegerType>().getWidth())
2162 return op.emitError("result type ")
2163 << dstType << " must be wider than operand type " << srcType;
2164
2165 return success();
2166 }
2167
2168 //===----------------------------------------------------------------------===//
2169 // SignedDivIOp
2170 //===----------------------------------------------------------------------===//
2171
fold(ArrayRef<Attribute> operands)2172 OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
2173 assert(operands.size() == 2 && "binary operation takes two operands");
2174
2175 // Don't fold if it would overflow or if it requires a division by zero.
2176 bool overflowOrDiv0 = false;
2177 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
2178 if (overflowOrDiv0 || !b) {
2179 overflowOrDiv0 = true;
2180 return a;
2181 }
2182 return a.sdiv_ov(b, overflowOrDiv0);
2183 });
2184
2185 // Fold out division by one. Assumes all tensors of all ones are splats.
2186 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
2187 if (rhs.getValue() == 1)
2188 return lhs();
2189 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
2190 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
2191 return lhs();
2192 }
2193
2194 return overflowOrDiv0 ? Attribute() : result;
2195 }
2196
2197 //===----------------------------------------------------------------------===//
2198 // SignedRemIOp
2199 //===----------------------------------------------------------------------===//
2200
fold(ArrayRef<Attribute> operands)2201 OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
2202 assert(operands.size() == 2 && "remi_signed takes two operands");
2203
2204 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
2205 if (!rhs)
2206 return {};
2207 auto rhsValue = rhs.getValue();
2208
2209 // x % 1 = 0
2210 if (rhsValue.isOneValue())
2211 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
2212
2213 // Don't fold if it requires division by zero.
2214 if (rhsValue.isNullValue())
2215 return {};
2216
2217 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
2218 if (!lhs)
2219 return {};
2220 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
2221 }
2222
2223 //===----------------------------------------------------------------------===//
2224 // SIToFPOp
2225 //===----------------------------------------------------------------------===//
2226
2227 // sitofp is applicable from integer types to float types.
areCastCompatible(Type a,Type b)2228 bool SIToFPOp::areCastCompatible(Type a, Type b) {
2229 return a.isSignlessInteger() && b.isa<FloatType>();
2230 }
2231
2232 //===----------------------------------------------------------------------===//
2233 // SplatOp
2234 //===----------------------------------------------------------------------===//
2235
verify(SplatOp op)2236 static LogicalResult verify(SplatOp op) {
2237 // TODO: we could replace this by a trait.
2238 if (op.getOperand().getType() !=
2239 op.getType().cast<ShapedType>().getElementType())
2240 return op.emitError("operand should be of elemental type of result type");
2241
2242 return success();
2243 }
2244
2245 // Constant folding hook for SplatOp.
fold(ArrayRef<Attribute> operands)2246 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2247 assert(operands.size() == 1 && "splat takes one operand");
2248
2249 auto constOperand = operands.front();
2250 if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
2251 return {};
2252
2253 auto shapedType = getType().cast<ShapedType>();
2254 assert(shapedType.getElementType() == constOperand.getType() &&
2255 "incorrect input attribute type for folding");
2256
2257 // SplatElementsAttr::get treats single value for second arg as being a splat.
2258 return SplatElementsAttr::get(shapedType, {constOperand});
2259 }
2260
2261 //===----------------------------------------------------------------------===//
2262 // StoreOp
2263 //===----------------------------------------------------------------------===//
2264
verify(StoreOp op)2265 static LogicalResult verify(StoreOp op) {
2266 if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
2267 return op.emitOpError("store index operand count not equal to memref rank");
2268
2269 return success();
2270 }
2271
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2272 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2273 SmallVectorImpl<OpFoldResult> &results) {
2274 /// store(memrefcast) -> store
2275 return foldMemRefCast(*this);
2276 }
2277
2278 //===----------------------------------------------------------------------===//
2279 // SubFOp
2280 //===----------------------------------------------------------------------===//
2281
fold(ArrayRef<Attribute> operands)2282 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
2283 return constFoldBinaryOp<FloatAttr>(
2284 operands, [](APFloat a, APFloat b) { return a - b; });
2285 }
2286
2287 //===----------------------------------------------------------------------===//
2288 // SubIOp
2289 //===----------------------------------------------------------------------===//
2290
fold(ArrayRef<Attribute> operands)2291 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
2292 // subi(x,x) -> 0
2293 if (getOperand(0) == getOperand(1))
2294 return Builder(getContext()).getZeroAttr(getType());
2295 // subi(x,0) -> x
2296 if (matchPattern(rhs(), m_Zero()))
2297 return lhs();
2298
2299 return constFoldBinaryOp<IntegerAttr>(operands,
2300 [](APInt a, APInt b) { return a - b; });
2301 }
2302
2303 //===----------------------------------------------------------------------===//
2304 // SubViewOp
2305 //===----------------------------------------------------------------------===//
2306
2307 /// Print a list with either (1) the static integer value in `arrayAttr` if
2308 /// `isDynamic` evaluates to false or (2) the next value otherwise.
2309 /// This allows idiomatic printing of mixed value and integer attributes in a
2310 /// list. E.g. `[%arg0, 7, 42, %arg42]`.
printSubViewListOfOperandsOrIntegers(OpAsmPrinter & p,ValueRange values,ArrayAttr arrayAttr,llvm::function_ref<bool (int64_t)> isDynamic)2311 static void printSubViewListOfOperandsOrIntegers(
2312 OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
2313 llvm::function_ref<bool(int64_t)> isDynamic) {
2314 p << "[";
2315 unsigned idx = 0;
2316 llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
2317 int64_t val = a.cast<IntegerAttr>().getInt();
2318 if (isDynamic(val))
2319 p << values[idx++];
2320 else
2321 p << val;
2322 });
2323 p << "] ";
2324 }
2325
2326 /// Parse a mixed list with either (1) static integer values or (2) SSA values.
2327 /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
2328 /// encode the position of SSA values. Add the parsed SSA values to `ssa`
2329 /// in-order.
2330 //
2331 /// E.g. after parsing "[%arg0, 7, 42, %arg42]":
2332 /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
2333 /// 2. `ssa` is filled with "[%arg0, %arg1]".
2334 static ParseResult
parseListOfOperandsOrIntegers(OpAsmParser & parser,OperationState & result,StringRef attrName,int64_t dynVal,SmallVectorImpl<OpAsmParser::OperandType> & ssa)2335 parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
2336 StringRef attrName, int64_t dynVal,
2337 SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
2338 if (failed(parser.parseLSquare()))
2339 return failure();
2340 // 0-D.
2341 if (succeeded(parser.parseOptionalRSquare()))
2342 return success();
2343
2344 SmallVector<int64_t, 4> attrVals;
2345 while (true) {
2346 OpAsmParser::OperandType operand;
2347 auto res = parser.parseOptionalOperand(operand);
2348 if (res.hasValue() && succeeded(res.getValue())) {
2349 ssa.push_back(operand);
2350 attrVals.push_back(dynVal);
2351 } else {
2352 Attribute attr;
2353 NamedAttrList placeholder;
2354 if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
2355 !attr.isa<IntegerAttr>())
2356 return parser.emitError(parser.getNameLoc())
2357 << "expected SSA value or integer";
2358 attrVals.push_back(attr.cast<IntegerAttr>().getInt());
2359 }
2360
2361 if (succeeded(parser.parseOptionalComma()))
2362 continue;
2363 if (failed(parser.parseRSquare()))
2364 return failure();
2365 else
2366 break;
2367 }
2368
2369 auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
2370 result.addAttribute(attrName, arrayAttr);
2371 return success();
2372 }
2373
2374 namespace {
2375 /// Helpers to write more idiomatic operations.
2376 namespace saturated_arith {
2377 struct Wrapper {
Wrapper__anone4b94cca1a11::saturated_arith::Wrapper2378 explicit Wrapper(int64_t v) : v(v) {}
operator int64_t__anone4b94cca1a11::saturated_arith::Wrapper2379 operator int64_t() { return v; }
2380 int64_t v;
2381 };
operator +(Wrapper a,int64_t b)2382 Wrapper operator+(Wrapper a, int64_t b) {
2383 if (ShapedType::isDynamicStrideOrOffset(a) ||
2384 ShapedType::isDynamicStrideOrOffset(b))
2385 return Wrapper(ShapedType::kDynamicStrideOrOffset);
2386 return Wrapper(a.v + b);
2387 }
operator *(Wrapper a,int64_t b)2388 Wrapper operator*(Wrapper a, int64_t b) {
2389 if (ShapedType::isDynamicStrideOrOffset(a) ||
2390 ShapedType::isDynamicStrideOrOffset(b))
2391 return Wrapper(ShapedType::kDynamicStrideOrOffset);
2392 return Wrapper(a.v * b);
2393 }
2394 } // end namespace saturated_arith
2395 } // end namespace
2396
2397 /// A subview result type can be fully inferred from the source type and the
2398 /// static representation of offsets, sizes and strides. Special sentinels
2399 /// encode the dynamic case.
inferSubViewResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)2400 Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
2401 ArrayRef<int64_t> staticOffsets,
2402 ArrayRef<int64_t> staticSizes,
2403 ArrayRef<int64_t> staticStrides) {
2404 unsigned rank = sourceMemRefType.getRank();
2405 (void)rank;
2406 assert(staticOffsets.size() == rank &&
2407 "unexpected staticOffsets size mismatch");
2408 assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
2409 assert(staticStrides.size() == rank &&
2410 "unexpected staticStrides size mismatch");
2411
2412 // Extract source offset and strides.
2413 int64_t sourceOffset;
2414 SmallVector<int64_t, 4> sourceStrides;
2415 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2416 assert(succeeded(res) && "SubViewOp expected strided memref type");
2417 (void)res;
2418
2419 // Compute target offset whose value is:
2420 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2421 int64_t targetOffset = sourceOffset;
2422 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2423 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2424 using namespace saturated_arith;
2425 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
2426 }
2427
2428 // Compute target stride whose value is:
2429 // `sourceStrides_i * staticStrides_i`.
2430 SmallVector<int64_t, 4> targetStrides;
2431 targetStrides.reserve(staticOffsets.size());
2432 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2433 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2434 using namespace saturated_arith;
2435 targetStrides.push_back(Wrapper(sourceStride) * staticStride);
2436 }
2437
2438 // The type is now known.
2439 return MemRefType::get(
2440 staticSizes, sourceMemRefType.getElementType(),
2441 makeStridedLinearLayoutMap(targetStrides, targetOffset,
2442 sourceMemRefType.getContext()),
2443 sourceMemRefType.getMemorySpace());
2444 }
2445
2446 /// Print SubViewOp in the form:
2447 /// ```
2448 /// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2449 /// `:` strided-memref-type `to` strided-memref-type
2450 /// ```
print(OpAsmPrinter & p,SubViewOp op)2451 static void print(OpAsmPrinter &p, SubViewOp op) {
2452 int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
2453 p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
2454 p << op.getOperand(0);
2455 printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
2456 ShapedType::isDynamicStrideOrOffset);
2457 printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
2458 ShapedType::isDynamic);
2459 printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
2460 ShapedType::isDynamicStrideOrOffset);
2461 p.printOptionalAttrDict(op.getAttrs(),
2462 /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()});
2463 p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2464 }
2465
2466 /// Parse SubViewOp of the form:
2467 /// ```
2468 /// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
2469 /// `:` strided-memref-type `to` strided-memref-type
2470 /// ```
parseSubViewOp(OpAsmParser & parser,OperationState & result)2471 static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
2472 OpAsmParser::OperandType srcInfo;
2473 SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
2474 auto indexType = parser.getBuilder().getIndexType();
2475 Type srcType, dstType;
2476 if (parser.parseOperand(srcInfo))
2477 return failure();
2478 if (parseListOfOperandsOrIntegers(
2479 parser, result, SubViewOp::getStaticOffsetsAttrName(),
2480 ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
2481 parseListOfOperandsOrIntegers(parser, result,
2482 SubViewOp::getStaticSizesAttrName(),
2483 ShapedType::kDynamicSize, sizesInfo) ||
2484 parseListOfOperandsOrIntegers(
2485 parser, result, SubViewOp::getStaticStridesAttrName(),
2486 ShapedType::kDynamicStrideOrOffset, stridesInfo))
2487 return failure();
2488
2489 auto b = parser.getBuilder();
2490 SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
2491 static_cast<int>(sizesInfo.size()),
2492 static_cast<int>(stridesInfo.size())};
2493 result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
2494 b.getI32VectorAttr(segmentSizes));
2495
2496 return failure(
2497 parser.parseOptionalAttrDict(result.attributes) ||
2498 parser.parseColonType(srcType) ||
2499 parser.resolveOperand(srcInfo, srcType, result.operands) ||
2500 parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
2501 parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2502 parser.resolveOperands(stridesInfo, indexType, result.operands) ||
2503 parser.parseKeywordType("to", dstType) ||
2504 parser.addTypeToList(dstType, result.types));
2505 }
2506
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2507 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2508 ArrayRef<int64_t> staticOffsets,
2509 ArrayRef<int64_t> staticSizes,
2510 ArrayRef<int64_t> staticStrides, ValueRange offsets,
2511 ValueRange sizes, ValueRange strides,
2512 ArrayRef<NamedAttribute> attrs) {
2513 auto sourceMemRefType = source.getType().cast<MemRefType>();
2514 auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
2515 staticSizes, staticStrides);
2516 build(b, result, resultType, source, offsets, sizes, strides,
2517 b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
2518 b.getI64ArrayAttr(staticStrides));
2519 result.addAttributes(attrs);
2520 }
2521
2522 /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
2523 /// and `staticStrides` are automatically filled with source-memref-rank
2524 /// sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2525 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2526 ValueRange offsets, ValueRange sizes,
2527 ValueRange strides,
2528 ArrayRef<NamedAttribute> attrs) {
2529 auto sourceMemRefType = source.getType().cast<MemRefType>();
2530 unsigned rank = sourceMemRefType.getRank();
2531 SmallVector<int64_t, 4> staticOffsetsVector;
2532 staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
2533 SmallVector<int64_t, 4> staticSizesVector;
2534 staticSizesVector.assign(rank, ShapedType::kDynamicSize);
2535 SmallVector<int64_t, 4> staticStridesVector;
2536 staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
2537 build(b, result, source, staticOffsetsVector, staticSizesVector,
2538 staticStridesVector, offsets, sizes, strides, attrs);
2539 }
2540
2541 /// Verify that a particular offset/size/stride static attribute is well-formed.
2542 static LogicalResult
verifySubViewOpPart(SubViewOp op,StringRef name,StringRef attrName,ArrayAttr attr,llvm::function_ref<bool (int64_t)> isDynamic,ValueRange values)2543 verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
2544 ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
2545 ValueRange values) {
2546 /// Check static and dynamic offsets/sizes/strides breakdown.
2547 if (attr.size() != op.getRank())
2548 return op.emitError("expected ")
2549 << op.getRank() << " " << name << " values";
2550 unsigned expectedNumDynamicEntries =
2551 llvm::count_if(attr.getValue(), [&](Attribute attr) {
2552 return isDynamic(attr.cast<IntegerAttr>().getInt());
2553 });
2554 if (values.size() != expectedNumDynamicEntries)
2555 return op.emitError("expected ")
2556 << expectedNumDynamicEntries << " dynamic " << name << " values";
2557 return success();
2558 }
2559
2560 /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)2561 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
2562 return llvm::to_vector<4>(
2563 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
2564 return a.cast<IntegerAttr>().getInt();
2565 }));
2566 }
2567
2568 /// Verifier for SubViewOp.
verify(SubViewOp op)2569 static LogicalResult verify(SubViewOp op) {
2570 auto baseType = op.getBaseMemRefType().cast<MemRefType>();
2571 auto subViewType = op.getType();
2572
2573 // The base memref and the view memref should be in the same memory space.
2574 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2575 return op.emitError("different memory spaces specified for base memref "
2576 "type ")
2577 << baseType << " and subview memref type " << subViewType;
2578
2579 // Verify that the base memref type has a strided layout map.
2580 if (!isStrided(baseType))
2581 return op.emitError("base type ") << baseType << " is not strided";
2582
2583 // Verify static attributes offsets/sizes/strides.
2584 if (failed(verifySubViewOpPart(
2585 op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
2586 ShapedType::isDynamicStrideOrOffset, op.offsets())))
2587 return failure();
2588
2589 if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
2590 op.static_sizes(), ShapedType::isDynamic,
2591 op.sizes())))
2592 return failure();
2593 if (failed(verifySubViewOpPart(
2594 op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
2595 ShapedType::isDynamicStrideOrOffset, op.strides())))
2596 return failure();
2597
2598 // Verify result type against inferred type.
2599 auto expectedType = SubViewOp::inferSubViewResultType(
2600 op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
2601 extractFromI64ArrayAttr(op.static_sizes()),
2602 extractFromI64ArrayAttr(op.static_strides()));
2603 if (op.getType() != expectedType)
2604 return op.emitError("expected result type to be ") << expectedType;
2605
2606 return success();
2607 }
2608
operator <<(raw_ostream & os,SubViewOp::Range & range)2609 raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
2610 return os << "range " << range.offset << ":" << range.size << ":"
2611 << range.stride;
2612 }
2613
getNumDynamicEntriesUpToIdx(ArrayAttr attr,llvm::function_ref<bool (int64_t)> isDynamic,unsigned idx)2614 static unsigned getNumDynamicEntriesUpToIdx(
2615 ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
2616 return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
2617 [&](Attribute attr) {
2618 return isDynamic(attr.cast<IntegerAttr>().getInt());
2619 });
2620 }
2621
isDynamicOffset(unsigned idx)2622 bool SubViewOp::isDynamicOffset(unsigned idx) {
2623 return ShapedType::isDynamicStrideOrOffset(
2624 extractFromI64ArrayAttr(static_offsets())[idx]);
2625 }
isDynamicSize(unsigned idx)2626 bool SubViewOp::isDynamicSize(unsigned idx) {
2627 return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]);
2628 }
isDynamicStride(unsigned idx)2629 bool SubViewOp::isDynamicStride(unsigned idx) {
2630 return ShapedType::isDynamicStrideOrOffset(
2631 extractFromI64ArrayAttr(static_strides())[idx]);
2632 }
2633
getIndexOfDynamicOffset(unsigned idx)2634 unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) {
2635 assert(isDynamicOffset(idx) && "expected static offset");
2636 auto numDynamic =
2637 getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
2638 ShapedType::isDynamicStrideOrOffset, idx);
2639 return 1 + numDynamic;
2640 }
getIndexOfDynamicSize(unsigned idx)2641 unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) {
2642 assert(isDynamicSize(idx) && "expected static size");
2643 auto numDynamic = getNumDynamicEntriesUpToIdx(
2644 static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
2645 return 1 + offsets().size() + numDynamic;
2646 }
getIndexOfDynamicStride(unsigned idx)2647 unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
2648 assert(isDynamicStride(idx) && "expected static stride");
2649 auto numDynamic =
2650 getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
2651 ShapedType::isDynamicStrideOrOffset, idx);
2652 return 1 + offsets().size() + sizes().size() + numDynamic;
2653 }
2654
2655 /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
2656 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2657 /// with `b` at location `loc`.
getOrCreateRanges(OpBuilder & b,Location loc)2658 SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
2659 Location loc) {
2660 SmallVector<Range, 8> res;
2661 unsigned rank = getType().getRank();
2662 res.reserve(rank);
2663 for (unsigned idx = 0; idx < rank; ++idx) {
2664 auto offset = isDynamicOffset(idx)
2665 ? getDynamicOffset(idx)
2666 : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
2667 auto size = isDynamicSize(idx)
2668 ? getDynamicSize(idx)
2669 : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
2670 auto stride = isDynamicStride(idx)
2671 ? getDynamicStride(idx)
2672 : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
2673 res.emplace_back(Range{offset, size, stride});
2674 }
2675 return res;
2676 }
2677
getOrCreateOffsets(OpBuilder & b,Location loc)2678 SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
2679 Location loc) {
2680 unsigned dynamicIdx = 1;
2681 return llvm::to_vector<4>(llvm::map_range(
2682 static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2683 int64_t staticOffset = a.cast<IntegerAttr>().getInt();
2684 if (ShapedType::isDynamicStrideOrOffset(staticOffset))
2685 return getOperand(dynamicIdx++);
2686 else
2687 return b.create<ConstantIndexOp>(loc, staticOffset);
2688 }));
2689 }
2690
getOrCreateSizes(OpBuilder & b,Location loc)2691 SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
2692 unsigned dynamicIdx = 1 + offsets().size();
2693 return llvm::to_vector<4>(llvm::map_range(
2694 static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2695 int64_t staticSize = a.cast<IntegerAttr>().getInt();
2696 if (ShapedType::isDynamic(staticSize))
2697 return getOperand(dynamicIdx++);
2698 else
2699 return b.create<ConstantIndexOp>(loc, staticSize);
2700 }));
2701 }
2702
getOrCreateStrides(OpBuilder & b,Location loc)2703 SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
2704 Location loc) {
2705 unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
2706 return llvm::to_vector<4>(llvm::map_range(
2707 static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
2708 int64_t staticStride = a.cast<IntegerAttr>().getInt();
2709 if (ShapedType::isDynamicStrideOrOffset(staticStride))
2710 return getOperand(dynamicIdx++);
2711 else
2712 return b.create<ConstantIndexOp>(loc, staticStride);
2713 }));
2714 }
2715
2716 LogicalResult
getStaticStrides(SmallVectorImpl<int64_t> & staticStrides)2717 SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
2718 if (!strides().empty())
2719 return failure();
2720 staticStrides = extractFromI64ArrayAttr(static_strides());
2721 return success();
2722 }
2723
getViewSource()2724 Value SubViewOp::getViewSource() { return source(); }
2725
2726 namespace {
2727
2728 /// Take a list of `values` with potential new constant to extract and a list
2729 /// of `constantValues` with`values.size()` sentinel that evaluate to true by
2730 /// applying `isDynamic`.
2731 /// Detects the `values` produced by a ConstantIndexOp and places the new
2732 /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<Value> & values,SmallVectorImpl<int64_t> & constantValues,llvm::function_ref<bool (int64_t)> isDynamic)2733 void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
2734 SmallVectorImpl<int64_t> &constantValues,
2735 llvm::function_ref<bool(int64_t)> isDynamic) {
2736 bool hasNewStaticValue = llvm::any_of(
2737 values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
2738 if (hasNewStaticValue) {
2739 for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
2740 cstIdx != e; ++cstIdx) {
2741 // Was already static, skip.
2742 if (!isDynamic(constantValues[cstIdx]))
2743 continue;
2744 // Newly static, move from Value to constant.
2745 if (matchPattern(values[valIdx], m_ConstantIndex())) {
2746 constantValues[cstIdx] =
2747 cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
2748 // Erase for impl. simplicity. Reverse iterator if we really must.
2749 values.erase(std::next(values.begin(), valIdx));
2750 continue;
2751 }
2752 // Remains dynamic move to next value.
2753 ++valIdx;
2754 }
2755 }
2756 }
2757
2758 /// Pattern to rewrite a subview op with constant arguments.
2759 class SubViewOpConstantArgumentFolder final
2760 : public OpRewritePattern<SubViewOp> {
2761 public:
2762 using OpRewritePattern<SubViewOp>::OpRewritePattern;
2763
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2764 LogicalResult matchAndRewrite(SubViewOp subViewOp,
2765 PatternRewriter &rewriter) const override {
2766 // No constant operand, just return;
2767 if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
2768 return matchPattern(operand, m_ConstantIndex());
2769 }))
2770 return failure();
2771
2772 // At least one of offsets/sizes/strides is a new constant.
2773 // Form the new list of operands and constant attributes from the existing.
2774 SmallVector<Value, 8> newOffsets(subViewOp.offsets());
2775 SmallVector<int64_t, 8> newStaticOffsets =
2776 extractFromI64ArrayAttr(subViewOp.static_offsets());
2777 assert(newStaticOffsets.size() == subViewOp.getRank());
2778 canonicalizeSubViewPart(newOffsets, newStaticOffsets,
2779 ShapedType::isDynamicStrideOrOffset);
2780
2781 SmallVector<Value, 8> newSizes(subViewOp.sizes());
2782 SmallVector<int64_t, 8> newStaticSizes =
2783 extractFromI64ArrayAttr(subViewOp.static_sizes());
2784 assert(newStaticOffsets.size() == subViewOp.getRank());
2785 canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
2786
2787 SmallVector<Value, 8> newStrides(subViewOp.strides());
2788 SmallVector<int64_t, 8> newStaticStrides =
2789 extractFromI64ArrayAttr(subViewOp.static_strides());
2790 assert(newStaticOffsets.size() == subViewOp.getRank());
2791 canonicalizeSubViewPart(newStrides, newStaticStrides,
2792 ShapedType::isDynamicStrideOrOffset);
2793
2794 // Create the new op in canonical form.
2795 auto newSubViewOp = rewriter.create<SubViewOp>(
2796 subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
2797 newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
2798
2799 // Insert a memref_cast for compatibility of the uses of the op.
2800 rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
2801 subViewOp.getType());
2802
2803 return success();
2804 }
2805 };
2806
2807 } // end anonymous namespace
2808
2809 /// Determines whether MemRefCastOp casts to a more dynamic version of the
2810 /// source memref. This is useful to to fold a memref_cast into a consuming op
2811 /// and implement canonicalization patterns for ops in different dialects that
2812 /// may consume the results of memref_cast operations. Such foldable memref_cast
2813 /// operations are typically inserted as `view` and `subview` ops are
2814 /// canonicalized, to preserve the type compatibility of their uses.
2815 ///
2816 /// Returns true when all conditions are met:
2817 /// 1. source and result are ranked memrefs with strided semantics and same
2818 /// element type and rank.
2819 /// 2. each of the source's size, offset or stride has more static information
2820 /// than the corresponding result's size, offset or stride.
2821 ///
2822 /// Example 1:
2823 /// ```mlir
2824 /// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
2825 /// %2 = consumer %1 ... : memref<?x?xf32> ...
2826 /// ```
2827 ///
2828 /// may fold into:
2829 ///
2830 /// ```mlir
2831 /// %2 = consumer %0 ... : memref<8x16xf32> ...
2832 /// ```
2833 ///
2834 /// Example 2:
2835 /// ```
2836 /// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
2837 /// to memref<?x?xf32>
2838 /// consumer %1 : memref<?x?xf32> ...
2839 /// ```
2840 ///
2841 /// may fold into:
2842 ///
2843 /// ```
2844 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
2845 /// ```
canFoldIntoConsumerOp(MemRefCastOp castOp)2846 bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
2847 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
2848 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
2849
2850 // Requires ranked MemRefType.
2851 if (!sourceType || !resultType)
2852 return false;
2853
2854 // Requires same elemental type.
2855 if (sourceType.getElementType() != resultType.getElementType())
2856 return false;
2857
2858 // Requires same rank.
2859 if (sourceType.getRank() != resultType.getRank())
2860 return false;
2861
2862 // Only fold casts between strided memref forms.
2863 int64_t sourceOffset, resultOffset;
2864 SmallVector<int64_t, 4> sourceStrides, resultStrides;
2865 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
2866 failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
2867 return false;
2868
2869 // If cast is towards more static sizes along any dimension, don't fold.
2870 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
2871 auto ss = std::get<0>(it), st = std::get<1>(it);
2872 if (ss != st)
2873 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
2874 return false;
2875 }
2876
2877 // If cast is towards more static offset along any dimension, don't fold.
2878 if (sourceOffset != resultOffset)
2879 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
2880 !MemRefType::isDynamicStrideOrOffset(resultOffset))
2881 return false;
2882
2883 // If cast is towards more static strides along any dimension, don't fold.
2884 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
2885 auto ss = std::get<0>(it), st = std::get<1>(it);
2886 if (ss != st)
2887 if (MemRefType::isDynamicStrideOrOffset(ss) &&
2888 !MemRefType::isDynamicStrideOrOffset(st))
2889 return false;
2890 }
2891
2892 return true;
2893 }
2894
2895 namespace {
2896 /// Pattern to rewrite a subview op with MemRefCast arguments.
2897 /// This essentially pushes memref_cast past its consuming subview when
2898 /// `canFoldIntoConsumerOp` is true.
2899 ///
2900 /// Example:
2901 /// ```
2902 /// %0 = memref_cast %V : memref<16x16xf32> to memref<?x?xf32>
2903 /// %1 = subview %0[0, 0][3, 4][1, 1] :
2904 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2905 /// ```
2906 /// is rewritten into:
2907 /// ```
2908 /// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2909 /// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2910 /// memref<3x4xf32, offset:?, strides:[?, 1]>
2911 /// ```
2912 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2913 public:
2914 using OpRewritePattern<SubViewOp>::OpRewritePattern;
2915
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2916 LogicalResult matchAndRewrite(SubViewOp subViewOp,
2917 PatternRewriter &rewriter) const override {
2918 // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2919 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2920 return matchPattern(operand, m_ConstantIndex());
2921 }))
2922 return failure();
2923
2924 auto castOp = subViewOp.source().getDefiningOp<MemRefCastOp>();
2925 if (!castOp)
2926 return failure();
2927
2928 if (!canFoldIntoConsumerOp(castOp))
2929 return failure();
2930
2931 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2932 /// the cast source operand type and the SubViewOp static information. This
2933 /// is the resulting type if the MemRefCastOp were folded.
2934 Type resultType = SubViewOp::inferSubViewResultType(
2935 castOp.source().getType().cast<MemRefType>(),
2936 extractFromI64ArrayAttr(subViewOp.static_offsets()),
2937 extractFromI64ArrayAttr(subViewOp.static_sizes()),
2938 extractFromI64ArrayAttr(subViewOp.static_strides()));
2939 Value newSubView = rewriter.create<SubViewOp>(
2940 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2941 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2942 subViewOp.static_sizes(), subViewOp.static_strides());
2943 rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, subViewOp.getType(),
2944 newSubView);
2945 return success();
2946 }
2947 };
2948 } // namespace
2949
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2950 void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2951 MLIRContext *context) {
2952 results.insert<SubViewOpConstantArgumentFolder, SubViewOpMemRefCastFolder>(
2953 context);
2954 }
2955
2956 //===----------------------------------------------------------------------===//
2957 // TensorCastOp
2958 //===----------------------------------------------------------------------===//
2959
areCastCompatible(Type a,Type b)2960 bool TensorCastOp::areCastCompatible(Type a, Type b) {
2961 auto aT = a.dyn_cast<TensorType>();
2962 auto bT = b.dyn_cast<TensorType>();
2963 if (!aT || !bT)
2964 return false;
2965
2966 if (aT.getElementType() != bT.getElementType())
2967 return false;
2968
2969 return succeeded(verifyCompatibleShape(aT, bT));
2970 }
2971
fold(ArrayRef<Attribute> operands)2972 OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
2973 return impl::foldCastOp(*this);
2974 }
2975
2976 //===----------------------------------------------------------------------===//
2977 // Helpers for Tensor[Load|Store]Op
2978 //===----------------------------------------------------------------------===//
2979
getTensorTypeFromMemRefType(Type type)2980 static Type getTensorTypeFromMemRefType(Type type) {
2981 if (auto memref = type.dyn_cast<MemRefType>())
2982 return RankedTensorType::get(memref.getShape(), memref.getElementType());
2983 return NoneType::get(type.getContext());
2984 }
2985
2986 //===----------------------------------------------------------------------===//
2987 // TruncateIOp
2988 //===----------------------------------------------------------------------===//
2989
verify(TruncateIOp op)2990 static LogicalResult verify(TruncateIOp op) {
2991 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
2992 auto dstType = getElementTypeOrSelf(op.getType());
2993
2994 if (srcType.isa<IndexType>())
2995 return op.emitError() << srcType << " is not a valid operand type";
2996 if (dstType.isa<IndexType>())
2997 return op.emitError() << dstType << " is not a valid result type";
2998
2999 if (srcType.cast<IntegerType>().getWidth() <=
3000 dstType.cast<IntegerType>().getWidth())
3001 return op.emitError("operand type ")
3002 << srcType << " must be wider than result type " << dstType;
3003
3004 return success();
3005 }
3006
3007 //===----------------------------------------------------------------------===//
3008 // UnsignedDivIOp
3009 //===----------------------------------------------------------------------===//
3010
fold(ArrayRef<Attribute> operands)3011 OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
3012 assert(operands.size() == 2 && "binary operation takes two operands");
3013
3014 // Don't fold if it would require a division by zero.
3015 bool div0 = false;
3016 auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
3017 if (div0 || !b) {
3018 div0 = true;
3019 return a;
3020 }
3021 return a.udiv(b);
3022 });
3023
3024 // Fold out division by one. Assumes all tensors of all ones are splats.
3025 if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
3026 if (rhs.getValue() == 1)
3027 return lhs();
3028 } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
3029 if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
3030 return lhs();
3031 }
3032
3033 return div0 ? Attribute() : result;
3034 }
3035
3036 //===----------------------------------------------------------------------===//
3037 // UnsignedRemIOp
3038 //===----------------------------------------------------------------------===//
3039
fold(ArrayRef<Attribute> operands)3040 OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
3041 assert(operands.size() == 2 && "remi_unsigned takes two operands");
3042
3043 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
3044 if (!rhs)
3045 return {};
3046 auto rhsValue = rhs.getValue();
3047
3048 // x % 1 = 0
3049 if (rhsValue.isOneValue())
3050 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
3051
3052 // Don't fold if it requires division by zero.
3053 if (rhsValue.isNullValue())
3054 return {};
3055
3056 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
3057 if (!lhs)
3058 return {};
3059 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
3060 }
3061
3062 //===----------------------------------------------------------------------===//
3063 // ViewOp
3064 //===----------------------------------------------------------------------===//
3065
parseViewOp(OpAsmParser & parser,OperationState & result)3066 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
3067 OpAsmParser::OperandType srcInfo;
3068 SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
3069 SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
3070 auto indexType = parser.getBuilder().getIndexType();
3071 Type srcType, dstType;
3072 llvm::SMLoc offsetLoc;
3073 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
3074 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
3075 return failure();
3076
3077 if (offsetInfo.size() != 1)
3078 return parser.emitError(offsetLoc) << "expects 1 offset operand";
3079
3080 return failure(
3081 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
3082 parser.parseOptionalAttrDict(result.attributes) ||
3083 parser.parseColonType(srcType) ||
3084 parser.resolveOperand(srcInfo, srcType, result.operands) ||
3085 parser.resolveOperands(offsetInfo, indexType, result.operands) ||
3086 parser.resolveOperands(sizesInfo, indexType, result.operands) ||
3087 parser.parseKeywordType("to", dstType) ||
3088 parser.addTypeToList(dstType, result.types));
3089 }
3090
print(OpAsmPrinter & p,ViewOp op)3091 static void print(OpAsmPrinter &p, ViewOp op) {
3092 p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
3093 p.printOperand(op.byte_shift());
3094 p << "][" << op.sizes() << ']';
3095 p.printOptionalAttrDict(op.getAttrs());
3096 p << " : " << op.getOperand(0).getType() << " to " << op.getType();
3097 }
3098
verify(ViewOp op)3099 static LogicalResult verify(ViewOp op) {
3100 auto baseType = op.getOperand(0).getType().cast<MemRefType>();
3101 auto viewType = op.getType();
3102
3103 // The base memref should have identity layout map (or none).
3104 if (baseType.getAffineMaps().size() > 1 ||
3105 (baseType.getAffineMaps().size() == 1 &&
3106 !baseType.getAffineMaps()[0].isIdentity()))
3107 return op.emitError("unsupported map for base memref type ") << baseType;
3108
3109 // The result memref should have identity layout map (or none).
3110 if (viewType.getAffineMaps().size() > 1 ||
3111 (viewType.getAffineMaps().size() == 1 &&
3112 !viewType.getAffineMaps()[0].isIdentity()))
3113 return op.emitError("unsupported map for result memref type ") << viewType;
3114
3115 // The base memref and the view memref should be in the same memory space.
3116 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3117 return op.emitError("different memory spaces specified for base memref "
3118 "type ")
3119 << baseType << " and view memref type " << viewType;
3120
3121 // Verify that we have the correct number of sizes for the result type.
3122 unsigned numDynamicDims = viewType.getNumDynamicDims();
3123 if (op.sizes().size() != numDynamicDims)
3124 return op.emitError("incorrect number of size operands for type ")
3125 << viewType;
3126
3127 return success();
3128 }
3129
getViewSource()3130 Value ViewOp::getViewSource() { return source(); }
3131
3132 namespace {
3133
3134 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3135 using OpRewritePattern<ViewOp>::OpRewritePattern;
3136
matchAndRewrite__anone4b94cca2711::ViewOpShapeFolder3137 LogicalResult matchAndRewrite(ViewOp viewOp,
3138 PatternRewriter &rewriter) const override {
3139 // Return if none of the operands are constants.
3140 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3141 return matchPattern(operand, m_ConstantIndex());
3142 }))
3143 return failure();
3144
3145 // Get result memref type.
3146 auto memrefType = viewOp.getType();
3147
3148 // Get offset from old memref view type 'memRefType'.
3149 int64_t oldOffset;
3150 SmallVector<int64_t, 4> oldStrides;
3151 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
3152 return failure();
3153 assert(oldOffset == 0 && "Expected 0 offset");
3154
3155 SmallVector<Value, 4> newOperands;
3156
3157 // Offset cannot be folded into result type.
3158
3159 // Fold any dynamic dim operands which are produced by a constant.
3160 SmallVector<int64_t, 4> newShapeConstants;
3161 newShapeConstants.reserve(memrefType.getRank());
3162
3163 unsigned dynamicDimPos = 0;
3164 unsigned rank = memrefType.getRank();
3165 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3166 int64_t dimSize = memrefType.getDimSize(dim);
3167 // If this is already static dimension, keep it.
3168 if (!ShapedType::isDynamic(dimSize)) {
3169 newShapeConstants.push_back(dimSize);
3170 continue;
3171 }
3172 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
3173 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
3174 // Dynamic shape dimension will be folded.
3175 newShapeConstants.push_back(constantIndexOp.getValue());
3176 } else {
3177 // Dynamic shape dimension not folded; copy operand from old memref.
3178 newShapeConstants.push_back(dimSize);
3179 newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
3180 }
3181 dynamicDimPos++;
3182 }
3183
3184 // Create new memref type with constant folded dims.
3185 MemRefType newMemRefType =
3186 MemRefType::Builder(memrefType).setShape(newShapeConstants);
3187 // Nothing new, don't fold.
3188 if (newMemRefType == memrefType)
3189 return failure();
3190
3191 // Create new ViewOp.
3192 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
3193 viewOp.getOperand(0),
3194 viewOp.byte_shift(), newOperands);
3195 // Insert a cast so we have the same type as the old memref type.
3196 rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
3197 viewOp.getType());
3198 return success();
3199 }
3200 };
3201
3202 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3203 using OpRewritePattern<ViewOp>::OpRewritePattern;
3204
matchAndRewrite__anone4b94cca2711::ViewOpMemrefCastFolder3205 LogicalResult matchAndRewrite(ViewOp viewOp,
3206 PatternRewriter &rewriter) const override {
3207 Value memrefOperand = viewOp.getOperand(0);
3208 MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
3209 if (!memrefCastOp)
3210 return failure();
3211 Value allocOperand = memrefCastOp.getOperand();
3212 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3213 if (!allocOp)
3214 return failure();
3215 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3216 viewOp.byte_shift(), viewOp.sizes());
3217 return success();
3218 }
3219 };
3220
3221 } // end anonymous namespace
3222
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3223 void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3224 MLIRContext *context) {
3225 results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3226 }
3227
3228 //===----------------------------------------------------------------------===//
3229 // XOrOp
3230 //===----------------------------------------------------------------------===//
3231
fold(ArrayRef<Attribute> operands)3232 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
3233 /// xor(x, 0) -> x
3234 if (matchPattern(rhs(), m_Zero()))
3235 return lhs();
3236 /// xor(x,x) -> 0
3237 if (lhs() == rhs())
3238 return Builder(getContext()).getZeroAttr(getType());
3239
3240 return constFoldBinaryOp<IntegerAttr>(operands,
3241 [](APInt a, APInt b) { return a ^ b; });
3242 }
3243
3244 //===----------------------------------------------------------------------===//
3245 // ZeroExtendIOp
3246 //===----------------------------------------------------------------------===//
3247
verify(ZeroExtendIOp op)3248 static LogicalResult verify(ZeroExtendIOp op) {
3249 auto srcType = getElementTypeOrSelf(op.getOperand().getType());
3250 auto dstType = getElementTypeOrSelf(op.getType());
3251
3252 if (srcType.isa<IndexType>())
3253 return op.emitError() << srcType << " is not a valid operand type";
3254 if (dstType.isa<IndexType>())
3255 return op.emitError() << dstType << " is not a valid result type";
3256
3257 if (srcType.cast<IntegerType>().getWidth() >=
3258 dstType.cast<IntegerType>().getWidth())
3259 return op.emitError("result type ")
3260 << dstType << " must be wider than operand type " << srcType;
3261
3262 return success();
3263 }
3264
3265 //===----------------------------------------------------------------------===//
3266 // TableGen'd op method definitions
3267 //===----------------------------------------------------------------------===//
3268
3269 #define GET_OP_CLASSES
3270 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"
3271