1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/VectorUtils.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/DialectImplementation.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Support/MathExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/ADT/bit.h"
33 #include <numeric>
34
35 #include "mlir/Dialect/Vector/VectorOpsDialect.cpp.inc"
36 // Pull in all enum type and utility function definitions.
37 #include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
38
39 using namespace mlir;
40 using namespace mlir::vector;
41
42 /// Helper enum to classify mask value.
43 enum class MaskFormat {
44 AllTrue = 0,
45 AllFalse = 1,
46 Unknown = 2,
47 };
48
49 /// Helper method to classify a 1-D mask value. Currently, the method
50 /// looks "under the hood" of a constant value with dense attributes
51 /// and a constant mask operation (since the client may be called at
52 /// various stages during progressive lowering).
get1DMaskFormat(Value mask)53 static MaskFormat get1DMaskFormat(Value mask) {
54 if (auto c = mask.getDefiningOp<ConstantOp>()) {
55 // Inspect constant dense values. We count up for bits that
56 // are set, count down for bits that are cleared, and bail
57 // when a mix is detected.
58 if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
59 int64_t val = 0;
60 for (bool b : denseElts.getValues<bool>())
61 if (b && val >= 0)
62 val++;
63 else if (!b && val <= 0)
64 val--;
65 else
66 return MaskFormat::Unknown;
67 if (val > 0)
68 return MaskFormat::AllTrue;
69 if (val < 0)
70 return MaskFormat::AllFalse;
71 }
72 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
73 // Inspect constant mask index. If the index exceeds the
74 // dimension size, all bits are set. If the index is zero
75 // or less, no bits are set.
76 ArrayAttr masks = m.mask_dim_sizes();
77 assert(masks.size() == 1);
78 int64_t i = masks[0].cast<IntegerAttr>().getInt();
79 int64_t u = m.getType().getDimSize(0);
80 if (i >= u)
81 return MaskFormat::AllTrue;
82 if (i <= 0)
83 return MaskFormat::AllFalse;
84 }
85 return MaskFormat::Unknown;
86 }
87
88 // Helper for verifying combining kinds in contractions and reductions.
isSupportedCombiningKind(CombiningKind combiningKind,Type elementType)89 static bool isSupportedCombiningKind(CombiningKind combiningKind,
90 Type elementType) {
91 switch (combiningKind) {
92 case CombiningKind::ADD:
93 case CombiningKind::MUL:
94 case CombiningKind::MIN:
95 case CombiningKind::MAX:
96 return elementType.isIntOrIndexOrFloat();
97 case CombiningKind::AND:
98 case CombiningKind::OR:
99 case CombiningKind::XOR:
100 return elementType.isIntOrIndex();
101 }
102 return false;
103 }
104
105 /// Return true if the last dimension of the MemRefType has unit stride. Also
106 /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)107 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
108 int64_t offset;
109 SmallVector<int64_t> strides;
110 auto successStrides = getStridesAndOffset(type, strides, offset);
111 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
112 }
113
114 //===----------------------------------------------------------------------===//
115 // CombiningKindAttr
116 //===----------------------------------------------------------------------===//
117
118 namespace mlir {
119 namespace vector {
120 namespace detail {
121 struct BitmaskEnumStorage : public AttributeStorage {
122 using KeyTy = uint64_t;
123
BitmaskEnumStoragemlir::vector::detail::BitmaskEnumStorage124 BitmaskEnumStorage(KeyTy val) : value(val) {}
125
operator ==mlir::vector::detail::BitmaskEnumStorage126 bool operator==(const KeyTy &key) const { return value == key; }
127
constructmlir::vector::detail::BitmaskEnumStorage128 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
129 const KeyTy &key) {
130 return new (allocator.allocate<BitmaskEnumStorage>())
131 BitmaskEnumStorage(key);
132 }
133
134 KeyTy value = 0;
135 };
136 } // namespace detail
137 } // namespace vector
138 } // namespace mlir
139
get(CombiningKind kind,MLIRContext * context)140 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
141 MLIRContext *context) {
142 return Base::get(context, static_cast<uint64_t>(kind));
143 }
144
getKind() const145 CombiningKind CombiningKindAttr::getKind() const {
146 return static_cast<CombiningKind>(getImpl()->value);
147 }
148
149 static constexpr const CombiningKind combiningKindsList[] = {
150 // clang-format off
151 CombiningKind::ADD,
152 CombiningKind::MUL,
153 CombiningKind::MIN,
154 CombiningKind::MAX,
155 CombiningKind::AND,
156 CombiningKind::OR,
157 CombiningKind::XOR,
158 // clang-format on
159 };
160
print(DialectAsmPrinter & printer) const161 void CombiningKindAttr::print(DialectAsmPrinter &printer) const {
162 printer << "kind<";
163 auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
164 return bitEnumContains(this->getKind(), kind);
165 });
166 llvm::interleaveComma(kinds, printer,
167 [&](auto kind) { printer << stringifyEnum(kind); });
168 printer << ">";
169 }
170
parse(DialectAsmParser & parser)171 Attribute CombiningKindAttr::parse(DialectAsmParser &parser) {
172 if (failed(parser.parseLess()))
173 return {};
174
175 StringRef elemName;
176 if (failed(parser.parseKeyword(&elemName)))
177 return {};
178
179 auto kind = symbolizeCombiningKind(elemName);
180 if (!kind) {
181 parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
182 << elemName;
183 return {};
184 }
185
186 if (failed(parser.parseGreater()))
187 return {};
188
189 return CombiningKindAttr::get(kind.getValue(),
190 parser.getBuilder().getContext());
191 }
192
parseAttribute(DialectAsmParser & parser,Type type) const193 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
194 Type type) const {
195 StringRef attrKind;
196 if (parser.parseKeyword(&attrKind))
197 return {};
198
199 if (attrKind == "kind")
200 return CombiningKindAttr::parse(parser);
201
202 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
203 return {};
204 }
205
printAttribute(Attribute attr,DialectAsmPrinter & os) const206 void VectorDialect::printAttribute(Attribute attr,
207 DialectAsmPrinter &os) const {
208 if (auto ck = attr.dyn_cast<CombiningKindAttr>())
209 ck.print(os);
210 else
211 llvm_unreachable("Unknown attribute type");
212 }
213
214 //===----------------------------------------------------------------------===//
215 // VectorDialect
216 //===----------------------------------------------------------------------===//
217
initialize()218 void VectorDialect::initialize() {
219 addAttributes<CombiningKindAttr>();
220
221 addOperations<
222 #define GET_OP_LIST
223 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
224 >();
225 }
226
227 /// Materialize a single constant operation from a given attribute value with
228 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)229 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
230 Attribute value, Type type,
231 Location loc) {
232 return builder.create<ConstantOp>(loc, type, value);
233 }
234
getVectorSubscriptType(Builder & builder)235 IntegerType vector::getVectorSubscriptType(Builder &builder) {
236 return builder.getIntegerType(64);
237 }
238
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)239 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
240 ArrayRef<int64_t> values) {
241 return builder.getI64ArrayAttr(values);
242 }
243
244 //===----------------------------------------------------------------------===//
245 // MultiDimReductionOp
246 //===----------------------------------------------------------------------===//
247
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<bool> reductionMask,CombiningKind kind)248 void vector::MultiDimReductionOp::build(OpBuilder &builder,
249 OperationState &result, Value source,
250 ArrayRef<bool> reductionMask,
251 CombiningKind kind) {
252 result.addOperands(source);
253 auto sourceVectorType = source.getType().cast<VectorType>();
254 auto targetShape = MultiDimReductionOp::inferDestShape(
255 sourceVectorType.getShape(), reductionMask);
256 auto targetVectorType =
257 VectorType::get(targetShape, sourceVectorType.getElementType());
258 result.addTypes(targetVectorType);
259
260 SmallVector<int64_t> reductionDims;
261 for (auto en : llvm::enumerate(reductionMask))
262 if (en.value())
263 reductionDims.push_back(en.index());
264 result.addAttribute(getReductionDimsAttrName(),
265 builder.getI64ArrayAttr(reductionDims));
266 result.addAttribute(getKindAttrName(),
267 CombiningKindAttr::get(kind, builder.getContext()));
268 }
269
verify(MultiDimReductionOp op)270 static LogicalResult verify(MultiDimReductionOp op) {
271 auto reductionMask = op.getReductionMask();
272 auto targetShape = MultiDimReductionOp::inferDestShape(
273 op.getSourceVectorType().getShape(), reductionMask);
274 auto targetVectorType =
275 VectorType::get(targetShape, op.getSourceVectorType().getElementType());
276 if (targetVectorType != op.getDestVectorType())
277 return op.emitError("invalid output vector type: ")
278 << op.getDestVectorType() << " (expected: " << targetVectorType
279 << ")";
280 return success();
281 }
282
283 //===----------------------------------------------------------------------===//
284 // ReductionOp
285 //===----------------------------------------------------------------------===//
286
verify(ReductionOp op)287 static LogicalResult verify(ReductionOp op) {
288 // Verify for 1-D vector.
289 int64_t rank = op.getVectorType().getRank();
290 if (rank != 1)
291 return op.emitOpError("unsupported reduction rank: ") << rank;
292
293 // Verify supported reduction kind.
294 auto kind = op.kind();
295 Type eltType = op.dest().getType();
296 if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
297 if (!eltType.isIntOrIndexOrFloat())
298 return op.emitOpError("unsupported reduction type");
299 } else if (kind == "and" || kind == "or" || kind == "xor") {
300 if (!eltType.isIntOrIndex())
301 return op.emitOpError("unsupported reduction type");
302 } else {
303 return op.emitOpError("unknown reduction kind: ") << kind;
304 }
305
306 // Verify optional accumulator.
307 if (!op.acc().empty()) {
308 if (kind != "add" && kind != "mul")
309 return op.emitOpError("no accumulator for reduction kind: ") << kind;
310 if (!eltType.isa<FloatType>())
311 return op.emitOpError("no accumulator for type: ") << eltType;
312 }
313
314 return success();
315 }
316
parseReductionOp(OpAsmParser & parser,OperationState & result)317 static ParseResult parseReductionOp(OpAsmParser &parser,
318 OperationState &result) {
319 SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
320 Type redType;
321 Type resType;
322 Attribute attr;
323 if (parser.parseAttribute(attr, "kind", result.attributes) ||
324 parser.parseComma() || parser.parseOperandList(operandsInfo) ||
325 parser.parseColonType(redType) ||
326 parser.parseKeywordType("into", resType) ||
327 (operandsInfo.size() > 0 &&
328 parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
329 (operandsInfo.size() > 1 &&
330 parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
331 parser.addTypeToList(resType, result.types))
332 return failure();
333 if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
334 return parser.emitError(parser.getNameLoc(),
335 "unsupported number of operands");
336 return success();
337 }
338
print(OpAsmPrinter & p,ReductionOp op)339 static void print(OpAsmPrinter &p, ReductionOp op) {
340 p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector();
341 if (!op.acc().empty())
342 p << ", " << op.acc();
343 p << " : " << op.vector().getType() << " into " << op.dest().getType();
344 }
345
getVectorReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value vector)346 Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
347 Location loc, Value vector) {
348 Type scalarType = vector.getType().cast<ShapedType>().getElementType();
349 switch (op) {
350 case AtomicRMWKind::addf:
351 case AtomicRMWKind::addi:
352 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
353 builder.getStringAttr("add"),
354 vector, ValueRange{});
355 case AtomicRMWKind::mulf:
356 case AtomicRMWKind::muli:
357 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
358 builder.getStringAttr("mul"),
359 vector, ValueRange{});
360 // TODO: Add remaining reduction operations.
361 default:
362 (void)emitOptionalError(loc, "Reduction operation type not supported");
363 break;
364 }
365 return nullptr;
366 }
367
368 //===----------------------------------------------------------------------===//
369 // ContractionOp
370 //===----------------------------------------------------------------------===//
371
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)372 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
373 Value lhs, Value rhs, Value acc,
374 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
375 ArrayRef<StringRef> iteratorTypes) {
376 result.addOperands({lhs, rhs, acc});
377 result.addTypes(acc.getType());
378 result.addAttribute(getIndexingMapsAttrName(),
379 builder.getAffineMapArrayAttr(
380 AffineMap::inferFromExprList(indexingExprs)));
381 result.addAttribute(getIteratorTypesAttrName(),
382 builder.getStrArrayAttr(iteratorTypes));
383 }
384
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)385 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
386 Value lhs, Value rhs, Value acc,
387 ArrayAttr indexingMaps,
388 ArrayAttr iteratorTypes) {
389 result.addOperands({lhs, rhs, acc});
390 result.addTypes(acc.getType());
391 result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
392 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
393 result.addAttribute(ContractionOp::getKindAttrName(),
394 CombiningKindAttr::get(ContractionOp::getDefaultKind(),
395 builder.getContext()));
396 }
397
parseContractionOp(OpAsmParser & parser,OperationState & result)398 static ParseResult parseContractionOp(OpAsmParser &parser,
399 OperationState &result) {
400 OpAsmParser::OperandType lhsInfo;
401 OpAsmParser::OperandType rhsInfo;
402 OpAsmParser::OperandType accInfo;
403 SmallVector<OpAsmParser::OperandType, 2> masksInfo;
404 SmallVector<Type, 2> types;
405 Type resultType;
406 auto loc = parser.getCurrentLocation();
407 DictionaryAttr dictAttr;
408 // TODO: Unify linalg op attribute parsing.
409 if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
410 parser.parseOperand(lhsInfo) || parser.parseComma() ||
411 parser.parseOperand(rhsInfo) || parser.parseComma() ||
412 parser.parseOperand(accInfo) ||
413 parser.parseTrailingOperandList(masksInfo) ||
414 parser.parseOptionalAttrDict(result.attributes) ||
415 parser.parseColonTypeList(types) ||
416 parser.parseKeywordType("into", resultType) ||
417 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
418 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
419 parser.resolveOperand(accInfo, resultType, result.operands) ||
420 parser.addTypeToList(resultType, result.types))
421 return failure();
422 result.attributes.assign(dictAttr.getValue().begin(),
423 dictAttr.getValue().end());
424 if (!result.attributes.get(ContractionOp::getKindAttrName())) {
425 result.addAttribute(ContractionOp::getKindAttrName(),
426 CombiningKindAttr::get(ContractionOp::getDefaultKind(),
427 result.getContext()));
428 }
429 if (masksInfo.empty())
430 return success();
431 if (masksInfo.size() != 2)
432 return parser.emitError(parser.getNameLoc(),
433 "expected zero or exactly 2 vector mask operands");
434 auto lhsType = types[0].cast<VectorType>();
435 auto rhsType = types[1].cast<VectorType>();
436 auto maskElementType = parser.getBuilder().getI1Type();
437 std::array<Type, 2> maskTypes = {
438 VectorType::get(lhsType.getShape(), maskElementType),
439 VectorType::get(rhsType.getShape(), maskElementType)};
440 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
441 return failure();
442 return success();
443 }
444
print(OpAsmPrinter & p,ContractionOp op)445 static void print(OpAsmPrinter &p, ContractionOp op) {
446 // TODO: Unify printing code with linalg ops.
447 auto attrNames = op.getTraitAttrNames();
448 llvm::StringSet<> traitAttrsSet;
449 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
450 SmallVector<NamedAttribute, 8> attrs;
451 for (auto attr : op->getAttrs())
452 if (traitAttrsSet.count(attr.first.strref()) > 0)
453 attrs.push_back(attr);
454
455 auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
456 p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
457 p << op.rhs() << ", " << op.acc();
458 if (op.masks().size() == 2)
459 p << ", " << op.masks();
460
461 p.printOptionalAttrDict(op->getAttrs(), attrNames);
462 p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
463 << op.getResultType();
464 }
465
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)466 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
467 const std::vector<std::pair<int64_t, int64_t>> &map) {
468 for (auto &dimPair : map) {
469 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
470 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
471 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
472 return false;
473 }
474 return true;
475 }
476
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)477 static LogicalResult verifyOutputShape(
478 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
479 Type resType,
480 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
481 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
482 DenseSet<int64_t> lhsContractingDimSet;
483 DenseSet<int64_t> rhsContractingDimSet;
484 for (auto &dimPair : contractingDimMap) {
485 lhsContractingDimSet.insert(dimPair.first);
486 rhsContractingDimSet.insert(dimPair.second);
487 }
488 DenseSet<int64_t> rhsBatchDimSet;
489 for (auto &dimPair : batchDimMap)
490 rhsBatchDimSet.insert(dimPair.second);
491
492 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
493 SmallVector<int64_t, 4> expectedResultDims;
494 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
495 if (lhsContractingDimSet.count(i) > 0)
496 continue;
497 expectedResultDims.push_back(lhsType.getDimSize(i));
498 }
499
500 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
501 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
502 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
503 continue;
504 expectedResultDims.push_back(rhsType.getDimSize(i));
505 }
506
507 // Verify 'expectedResultDims'.
508 if (expectedResultDims.size() == 0) {
509 // No batch or free dimension implies a scalar result.
510 if (resType.isa<VectorType>() || accType.isa<VectorType>())
511 return op.emitOpError("invalid accumulator/result vector shape");
512 } else {
513 // At least one batch or free dimension implies a vector result.
514 auto resVectorType = resType.dyn_cast<VectorType>();
515 auto accVectorType = accType.dyn_cast<VectorType>();
516 if (!resVectorType || !accVectorType)
517 return op.emitOpError("invalid accumulator/result vector shape");
518
519 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
520 // types fully define the result vector type. This assumes the affine maps
521 // are well-formed, which must have been verified already.
522 MLIRContext *ctx = op.getContext();
523 AffineMap lhsMap = op.getIndexingMaps()[0];
524 AffineMap rhsMap = op.getIndexingMaps()[1];
525 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
526 for (auto pair :
527 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
528 VectorType v = pair.first;
529 auto map = pair.second;
530 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
531 unsigned pos = map.getDimPosition(idx);
532 if (!extents[pos])
533 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
534 }
535 }
536 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
537 "expected extent along all dimensions.");
538
539 AffineMap resMap = op.getIndexingMaps()[2];
540 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
541 /*symCount=*/0, extents, ctx);
542 // Compose the resMap with the extentsMap, which is a constant map.
543 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
544 assert(llvm::all_of(
545 expectedMap.getResults(),
546 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
547 "expected constant extent along all dimensions.");
548 // Extract the expected shape and build the type.
549 auto expectedShape = llvm::to_vector<4>(
550 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
551 return e.cast<AffineConstantExpr>().getValue();
552 }));
553 auto expected =
554 VectorType::get(expectedShape, resVectorType.getElementType());
555 if (resVectorType != expected || accVectorType != expected)
556 return op.emitOpError(
557 "invalid accumulator/result vector shape, expected: ")
558 << expected;
559 }
560 return success();
561 }
562
verify(ContractionOp op)563 static LogicalResult verify(ContractionOp op) {
564 auto lhsType = op.getLhsType();
565 auto rhsType = op.getRhsType();
566 auto accType = op.getAccType();
567 auto resType = op.getResultType();
568
569 // Verify that an indexing map was specified for each vector operand.
570 if (op.indexing_maps().size() != 3)
571 return op.emitOpError("expected an indexing map for each vector operand");
572
573 // Verify that each index map has 'numIterators' inputs, no symbols, and
574 // that the number of map outputs equals the rank of its associated
575 // vector operand.
576 unsigned numIterators = op.iterator_types().getValue().size();
577 for (auto it : llvm::enumerate(op.indexing_maps())) {
578 auto index = it.index();
579 auto map = it.value().cast<AffineMapAttr>().getValue();
580 if (map.getNumSymbols() != 0)
581 return op.emitOpError("expected indexing map ")
582 << index << " to have no symbols";
583 auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
584 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
585 // Verify that the map has the right number of inputs, outputs, and indices.
586 // This also correctly accounts for (..) -> () for rank-0 results.
587 if (map.getNumDims() != numIterators)
588 return op.emitOpError("expected indexing map ")
589 << index << " to have " << numIterators << " number of inputs";
590 if (map.getNumResults() != rank)
591 return op.emitOpError("expected indexing map ")
592 << index << " to have " << rank << " number of outputs";
593 if (!map.isProjectedPermutation())
594 return op.emitOpError("expected indexing map ")
595 << index << " to be a projected permutation of its inputs";
596 }
597
598 auto contractingDimMap = op.getContractingDimMap();
599 auto batchDimMap = op.getBatchDimMap();
600
601 // Verify at least one contracting dimension pair was specified.
602 if (contractingDimMap.empty())
603 return op.emitOpError("expected at least one contracting dimension pair");
604
605 // Verify contracting dimension map was properly constructed.
606 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
607 return op.emitOpError("invalid contracting dimension map");
608
609 // Verify batch dimension map was properly constructed.
610 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
611 return op.emitOpError("invalid batch dimension map");
612
613 // Verify 'accType' and 'resType' shape.
614 if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
615 contractingDimMap, batchDimMap)))
616 return failure();
617
618 // Verify that either two vector masks are set or none are set.
619 auto lhsMaskType = op.getLHSVectorMaskType();
620 auto rhsMaskType = op.getRHSVectorMaskType();
621 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
622 return op.emitOpError("invalid number of vector masks specified");
623 if (lhsMaskType && rhsMaskType) {
624 // Verify mask rank == argument rank.
625 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
626 rhsMaskType.getShape().size() != rhsType.getShape().size())
627 return op.emitOpError("invalid vector mask rank");
628 }
629
630 // Verify supported combining kind.
631 auto vectorType = resType.dyn_cast<VectorType>();
632 auto elementType = vectorType ? vectorType.getElementType() : resType;
633 if (!isSupportedCombiningKind(op.kind(), elementType))
634 return op.emitOpError("unsupported contraction type");
635
636 return success();
637 }
638
getTraitAttrNames()639 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
640 static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
641 getIteratorTypesAttrName(),
642 ContractionOp::getKindAttrName()};
643 return llvm::makeArrayRef(names);
644 }
645
getResultIndex(AffineMap map,AffineExpr targetExpr)646 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
647 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
648 if (targetExpr == map.getResult(i))
649 return i;
650 return -1;
651 }
652
653 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)654 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
655 StringRef targetIteratorTypeName, MLIRContext *context) {
656 std::vector<std::pair<int64_t, int64_t>> dimMap;
657 for (auto it : llvm::enumerate(iteratorTypes)) {
658 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
659 if (iteratorTypeName != targetIteratorTypeName)
660 continue;
661 // Search lhs/rhs map results for 'targetExpr'.
662 auto targetExpr = getAffineDimExpr(it.index(), context);
663 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
664 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
665 if (lhsDim >= 0 && rhsDim >= 0)
666 dimMap.push_back({lhsDim, rhsDim});
667 }
668 return dimMap;
669 }
670
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)671 void ContractionOp::getIterationBounds(
672 SmallVectorImpl<int64_t> &iterationBounds) {
673 auto lhsShape = getLhsType().getShape();
674 auto resVectorType = getResultType().dyn_cast<VectorType>();
675 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
676 SmallVector<int64_t, 2> iterationShape;
677 for (auto it : llvm::enumerate(iterator_types())) {
678 // Search lhs/rhs map results for 'targetExpr'.
679 auto targetExpr = getAffineDimExpr(it.index(), getContext());
680 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
681 if (iteratorTypeName == getReductionIteratorTypeName()) {
682 // Get reduction dim size from lhs shape (same size in rhsShape).
683 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
684 assert(lhsDimIndex >= 0);
685 iterationBounds.push_back(lhsShape[lhsDimIndex]);
686 continue;
687 }
688 // Get parallel dimension size from result shape.
689 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
690 assert(resDimIndex >= 0);
691 assert(resVectorType != nullptr);
692 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
693 }
694 }
695
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)696 void ContractionOp::getIterationIndexMap(
697 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
698 unsigned numMaps = indexing_maps().getValue().size();
699 iterationIndexMap.resize(numMaps);
700 for (auto it : llvm::enumerate(indexing_maps())) {
701 auto index = it.index();
702 auto map = it.value().cast<AffineMapAttr>().getValue();
703 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
704 auto dim = map.getResult(i).cast<AffineDimExpr>();
705 iterationIndexMap[index][dim.getPosition()] = i;
706 }
707 }
708 }
709
getContractingDimMap()710 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
711 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
712 return getDimMap(indexingMaps, iterator_types(),
713 getReductionIteratorTypeName(), getContext());
714 }
715
getBatchDimMap()716 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
717 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
718 return getDimMap(indexingMaps, iterator_types(),
719 getParallelIteratorTypeName(), getContext());
720 }
721
getIndexingMaps()722 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
723 return llvm::to_vector<4>(
724 llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
725 return mapAttr.cast<AffineMapAttr>().getValue();
726 }));
727 }
728
getShapeForUnroll()729 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
730 SmallVector<int64_t, 4> shape;
731 getIterationBounds(shape);
732 return shape;
733 }
734
735 /// Return a fused vector::ContractionOp which represents a patterns such as:
736 ///
737 /// ```mlir
738 /// %c0 = vector.constant 0: ...
739 /// %c = vector.contract %a, %b, %c0: ...
740 /// %e = add %c, %d: ...
741 /// ```
742 ///
743 /// by:
744 ///
745 /// ```mlir
746 /// %e = vector.contract %a, %b, %d: ...
747 /// ```
748 ///
749 /// Return null if the canonicalization does not apply.
750 // TODO: This should be a folding of Add into Contract in core but while they
751 // live in different dialects, it is not possible without unnatural
752 // dependencies.
753 template <typename AddOpType>
754 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
755 using OpRewritePattern<AddOpType>::OpRewritePattern;
756
matchAndRewriteCanonicalizeContractAdd757 LogicalResult matchAndRewrite(AddOpType addOp,
758 PatternRewriter &rewriter) const override {
759 auto canonicalize = [&](Value maybeContraction,
760 Value otherOperand) -> vector::ContractionOp {
761 vector::ContractionOp contractionOp =
762 dyn_cast_or_null<vector::ContractionOp>(
763 maybeContraction.getDefiningOp());
764 if (!contractionOp)
765 return vector::ContractionOp();
766 if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
767 contractionOp.acc().getDefiningOp())) {
768 if (maybeZero.value() ==
769 rewriter.getZeroAttr(contractionOp.acc().getType())) {
770 BlockAndValueMapping bvm;
771 bvm.map(contractionOp.acc(), otherOperand);
772 auto newContraction =
773 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
774 rewriter.replaceOp(addOp, newContraction.getResult());
775 return newContraction;
776 }
777 }
778 return vector::ContractionOp();
779 };
780
781 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
782 vector::ContractionOp contract = canonicalize(a, b);
783 contract = contract ? contract : canonicalize(b, a);
784 return success();
785 }
786 };
787
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)788 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
789 MLIRContext *context) {
790 results.add<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
791 context);
792 }
793
794 //===----------------------------------------------------------------------===//
795 // ExtractElementOp
796 //===----------------------------------------------------------------------===//
797
build(OpBuilder & builder,OperationState & result,Value source,Value position)798 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
799 Value source, Value position) {
800 result.addOperands({source, position});
801 result.addTypes(source.getType().cast<VectorType>().getElementType());
802 }
803
build(OpBuilder & builder,OperationState & result,Value source,int64_t position)804 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
805 Value source, int64_t position) {
806 Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
807 build(builder, result, source, pos);
808 }
809
verify(vector::ExtractElementOp op)810 static LogicalResult verify(vector::ExtractElementOp op) {
811 VectorType vectorType = op.getVectorType();
812 if (vectorType.getRank() != 1)
813 return op.emitOpError("expected 1-D vector");
814 return success();
815 }
816
817 //===----------------------------------------------------------------------===//
818 // ExtractOp
819 //===----------------------------------------------------------------------===//
820
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)821 static Type inferExtractOpResultType(VectorType vectorType,
822 ArrayAttr position) {
823 if (static_cast<int64_t>(position.size()) == vectorType.getRank())
824 return vectorType.getElementType();
825 return VectorType::get(vectorType.getShape().drop_front(position.size()),
826 vectorType.getElementType());
827 }
828
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)829 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
830 Value source, ArrayRef<int64_t> position) {
831 result.addOperands(source);
832 auto positionAttr = getVectorSubscriptAttr(builder, position);
833 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
834 positionAttr));
835 result.addAttribute(getPositionAttrName(), positionAttr);
836 }
837
838 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)839 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
840 Value source, ValueRange position) {
841 SmallVector<int64_t, 4> positionConstants =
842 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
843 return pos.getDefiningOp<ConstantIndexOp>().getValue();
844 }));
845 build(builder, result, source, positionConstants);
846 }
847
print(OpAsmPrinter & p,vector::ExtractOp op)848 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
849 p << op.getOperationName() << " " << op.vector() << op.position();
850 p.printOptionalAttrDict(op->getAttrs(), {"position"});
851 p << " : " << op.vector().getType();
852 }
853
parseExtractOp(OpAsmParser & parser,OperationState & result)854 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
855 llvm::SMLoc attributeLoc, typeLoc;
856 NamedAttrList attrs;
857 OpAsmParser::OperandType vector;
858 Type type;
859 Attribute attr;
860 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
861 parser.parseAttribute(attr, "position", attrs) ||
862 parser.parseOptionalAttrDict(attrs) ||
863 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
864 return failure();
865
866 auto vectorType = type.dyn_cast<VectorType>();
867 if (!vectorType)
868 return parser.emitError(typeLoc, "expected vector type");
869
870 auto positionAttr = attr.dyn_cast<ArrayAttr>();
871 if (!positionAttr ||
872 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
873 return parser.emitError(
874 attributeLoc,
875 "expected position attribute of rank smaller than vector rank");
876
877 Type resType = inferExtractOpResultType(vectorType, positionAttr);
878 result.attributes = attrs;
879 return failure(parser.resolveOperand(vector, type, result.operands) ||
880 parser.addTypeToList(resType, result.types));
881 }
882
verify(vector::ExtractOp op)883 static LogicalResult verify(vector::ExtractOp op) {
884 auto positionAttr = op.position().getValue();
885 if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
886 return op.emitOpError(
887 "expected position attribute of rank smaller than vector rank");
888 for (auto en : llvm::enumerate(positionAttr)) {
889 auto attr = en.value().dyn_cast<IntegerAttr>();
890 if (!attr || attr.getInt() < 0 ||
891 attr.getInt() >= op.getVectorType().getDimSize(en.index()))
892 return op.emitOpError("expected position attribute #")
893 << (en.index() + 1)
894 << " to be a non-negative integer smaller than the corresponding "
895 "vector dimension";
896 }
897 return success();
898 }
899
900 template <typename IntType>
extractVector(ArrayAttr arrayAttr)901 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
902 return llvm::to_vector<4>(llvm::map_range(
903 arrayAttr.getAsRange<IntegerAttr>(),
904 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
905 }
906
907 /// Fold the result of chains of ExtractOp in place by simply concatenating the
908 /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)909 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
910 if (!extractOp.vector().getDefiningOp<ExtractOp>())
911 return failure();
912
913 SmallVector<int64_t, 4> globalPosition;
914 ExtractOp currentOp = extractOp;
915 auto extractedPos = extractVector<int64_t>(currentOp.position());
916 globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
917 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
918 currentOp = nextOp;
919 auto extractedPos = extractVector<int64_t>(currentOp.position());
920 globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
921 }
922 extractOp.setOperand(currentOp.vector());
923 // OpBuilder is only used as a helper to build an I64ArrayAttr.
924 OpBuilder b(extractOp.getContext());
925 std::reverse(globalPosition.begin(), globalPosition.end());
926 extractOp->setAttr(ExtractOp::getPositionAttrName(),
927 b.getI64ArrayAttr(globalPosition));
928 return success();
929 }
930
931 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
foldExtractOpFromTranspose(ExtractOp extractOp)932 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
933 auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
934 if (!transposeOp)
935 return failure();
936
937 auto permutation = extractVector<unsigned>(transposeOp.transp());
938 auto extractedPos = extractVector<int64_t>(extractOp.position());
939
940 // If transposition permutation is larger than the ExtractOp, all minor
941 // dimensions must be an identity for folding to occur. If not, individual
942 // elements within the extracted value are transposed and this is not just a
943 // simple folding.
944 unsigned minorRank = permutation.size() - extractedPos.size();
945 MLIRContext *ctx = extractOp.getContext();
946 AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
947 AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
948 if (minorMap && !minorMap.isMinorIdentity())
949 return failure();
950
951 // %1 = transpose %0[x, y, z] : vector<axbxcxf32>
952 // %2 = extract %1[u, v] : vector<..xf32>
953 // may turn into:
954 // %2 = extract %0[w, x] : vector<..xf32>
955 // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
956 // -1 denotes the inverse.
957 permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
958 // The major submap has fewer results but the same number of dims. To compose
959 // cleanly, we need to drop dims to form a "square matrix". This is possible
960 // because:
961 // (a) this is a permutation map and
962 // (b) the minor map has already been checked to be identity.
963 // Therefore, the major map cannot contain dims of position greater or equal
964 // than the number of results.
965 assert(llvm::all_of(permutationMap.getResults(),
966 [&](AffineExpr e) {
967 auto dim = e.dyn_cast<AffineDimExpr>();
968 return dim && dim.getPosition() <
969 permutationMap.getNumResults();
970 }) &&
971 "Unexpected map results depend on higher rank positions");
972 // Project on the first domain dimensions to allow composition.
973 permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
974 permutationMap.getResults(), ctx);
975
976 extractOp.setOperand(transposeOp.vector());
977 // Compose the inverse permutation map with the extractedPos.
978 auto newExtractedPos =
979 inversePermutation(permutationMap).compose(extractedPos);
980 // OpBuilder is only used as a helper to build an I64ArrayAttr.
981 OpBuilder b(extractOp.getContext());
982 extractOp->setAttr(ExtractOp::getPositionAttrName(),
983 b.getI64ArrayAttr(newExtractedPos));
984
985 return success();
986 }
987
988 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
989 /// result is always the input to some InsertOp.
foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp)990 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
991 MLIRContext *context = extractOp.getContext();
992 AffineMap permutationMap;
993 auto extractedPos = extractVector<unsigned>(extractOp.position());
994 // Walk back a chain of InsertOp/TransposeOp until we hit a match.
995 // Compose TransposeOp permutations as we walk back.
996 auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
997 auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
998 while (insertOp || transposeOp) {
999 if (transposeOp) {
1000 // If it is transposed, compose the map and iterate.
1001 auto permutation = extractVector<unsigned>(transposeOp.transp());
1002 AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
1003 if (!permutationMap)
1004 permutationMap = newMap;
1005 else if (newMap.getNumInputs() != permutationMap.getNumResults())
1006 return Value();
1007 else
1008 permutationMap = newMap.compose(permutationMap);
1009 // Compute insert/transpose for the next iteration.
1010 Value transposed = transposeOp.vector();
1011 insertOp = transposed.getDefiningOp<vector::InsertOp>();
1012 transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
1013 continue;
1014 }
1015
1016 assert(insertOp);
1017 Value insertionDest = insertOp.dest();
1018 // If it is inserted into, either the position matches and we have a
1019 // successful folding; or we iterate until we run out of
1020 // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
1021 // produces a new vector with 1 modified value/slice in exactly the static
1022 // position we need to match.
1023 auto insertedPos = extractVector<unsigned>(insertOp.position());
1024 // Trivial permutations are solved with position equality checks.
1025 if (!permutationMap || permutationMap.isIdentity()) {
1026 if (extractedPos == insertedPos)
1027 return insertOp.source();
1028 // Fallthrough: if the position does not match, just skip to the next
1029 // producing `vector.insert` / `vector.transpose`.
1030 // Compute insert/transpose for the next iteration.
1031 insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1032 transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1033 continue;
1034 }
1035
1036 // More advanced permutations require application of the permutation.
1037 // However, the rank of `insertedPos` may be different from that of the
1038 // `permutationMap`. To support such case, we need to:
1039 // 1. apply on the `insertedPos.size()` major dimensions
1040 // 2. check the other dimensions of the permutation form a minor identity.
1041 assert(permutationMap.isPermutation() && "expected a permutation");
1042 if (insertedPos.size() == extractedPos.size()) {
1043 bool fold = true;
1044 for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
1045 auto pos = permutationMap.getDimPosition(idx);
1046 if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
1047 fold = false;
1048 break;
1049 }
1050 }
1051 if (fold) {
1052 assert(permutationMap.getNumResults() >= insertedPos.size() &&
1053 "expected map of rank larger than insert indexing");
1054 unsigned minorRank =
1055 permutationMap.getNumResults() - insertedPos.size();
1056 AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
1057 if (!minorMap || minorMap.isMinorIdentity())
1058 return insertOp.source();
1059 }
1060 }
1061
1062 // If we haven't found a match, just continue to the next producing
1063 // `vector.insert` / `vector.transpose`.
1064 // Compute insert/transpose for the next iteration.
1065 insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1066 transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1067 }
1068 return Value();
1069 }
1070
1071 /// Fold extractOp with scalar result coming from BroadcastOp.
foldExtractFromBroadcast(ExtractOp extractOp)1072 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1073 auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
1074 if (!broadcastOp)
1075 return Value();
1076 if (extractOp.getType() == broadcastOp.getSourceType())
1077 return broadcastOp.source();
1078 auto getRank = [](Type type) {
1079 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1080 };
1081 unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
1082 unsigned extractResultRank = getRank(extractOp.getType());
1083 if (extractResultRank < broadcasrSrcRank) {
1084 auto extractPos = extractVector<int64_t>(extractOp.position());
1085 unsigned rankDiff = broadcasrSrcRank - extractResultRank;
1086 extractPos.erase(
1087 extractPos.begin(),
1088 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1089 extractOp.setOperand(broadcastOp.source());
1090 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1091 OpBuilder b(extractOp.getContext());
1092 extractOp->setAttr(ExtractOp::getPositionAttrName(),
1093 b.getI64ArrayAttr(extractPos));
1094 return extractOp.getResult();
1095 }
1096 // TODO: In case the rank of the broadcast source is greater than the rank of
1097 // the extract result this can be combined into a new broadcast op. This needs
1098 // to be added a canonicalization pattern if needed.
1099 return Value();
1100 }
1101
1102 // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)1103 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1104 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
1105 if (!shapeCastOp)
1106 return Value();
1107 // Get the nth dimension size starting from lowest dimension.
1108 auto getDimReverse = [](VectorType type, int64_t n) {
1109 return type.getShape().take_back(n + 1).front();
1110 };
1111 int64_t destinationRank =
1112 extractOp.getType().isa<VectorType>()
1113 ? extractOp.getType().cast<VectorType>().getRank()
1114 : 0;
1115 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1116 return Value();
1117 if (destinationRank > 0) {
1118 auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1119 for (int64_t i = 0; i < destinationRank; i++) {
1120 // The lowest dimension of of the destination must match the lowest
1121 // dimension of the shapecast op source.
1122 // TODO: This case could be support in a canonicalization pattern.
1123 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1124 getDimReverse(destinationType, i))
1125 return Value();
1126 }
1127 }
1128 // Extract the strides associated with the extract op vector source. Then use
1129 // this to calculate a linearized position for the extract.
1130 auto extractedPos = extractVector<int64_t>(extractOp.position());
1131 std::reverse(extractedPos.begin(), extractedPos.end());
1132 SmallVector<int64_t, 4> strides;
1133 int64_t stride = 1;
1134 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1135 strides.push_back(stride);
1136 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1137 }
1138
1139 int64_t position = linearize(extractedPos, strides);
1140 // Then extract the strides associated to the shapeCast op vector source and
1141 // delinearize the position using those strides.
1142 SmallVector<int64_t, 4> newStrides;
1143 int64_t numDimension =
1144 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1145 stride = 1;
1146 for (int64_t i = 0; i < numDimension; i++) {
1147 newStrides.push_back(stride);
1148 stride *=
1149 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1150 }
1151 std::reverse(newStrides.begin(), newStrides.end());
1152 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1153 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1154 OpBuilder b(extractOp.getContext());
1155 extractOp->setAttr(ExtractOp::getPositionAttrName(),
1156 b.getI64ArrayAttr(newPosition));
1157 extractOp.setOperand(shapeCastOp.source());
1158 return extractOp.getResult();
1159 }
1160
fold(ArrayRef<Attribute>)1161 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1162 if (position().empty())
1163 return vector();
1164 if (succeeded(foldExtractOpFromExtractChain(*this)))
1165 return getResult();
1166 if (succeeded(foldExtractOpFromTranspose(*this)))
1167 return getResult();
1168 if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
1169 return val;
1170 if (auto val = foldExtractFromBroadcast(*this))
1171 return val;
1172 if (auto val = foldExtractFromShapeCast(*this))
1173 return val;
1174 return OpFoldResult();
1175 }
1176
1177 namespace {
1178
1179 // If extractOp is only removing unit dimensions it can be transformed to a
1180 // shapecast.
1181 class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
1182 public:
1183 using OpRewritePattern<ExtractOp>::OpRewritePattern;
1184
matchAndRewrite(ExtractOp extractOp,PatternRewriter & rewriter) const1185 LogicalResult matchAndRewrite(ExtractOp extractOp,
1186 PatternRewriter &rewriter) const override {
1187 auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
1188 if (!dstVecType || extractOp.getVectorType().getNumElements() !=
1189 dstVecType.getNumElements())
1190 return failure();
1191 rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
1192 extractOp.vector());
1193 return success();
1194 }
1195 };
1196
1197 } // namespace
1198
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1199 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1200 MLIRContext *context) {
1201 results.add<ExtractToShapeCast>(context);
1202 }
1203
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)1204 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1205 SmallVectorImpl<int64_t> &results) {
1206 for (auto attr : arrayAttr)
1207 results.push_back(attr.cast<IntegerAttr>().getInt());
1208 }
1209
1210 //===----------------------------------------------------------------------===//
1211 // ExtractMapOp
1212 //===----------------------------------------------------------------------===//
1213
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)1214 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1215 Value vector, ValueRange ids,
1216 ArrayRef<int64_t> multiplicity,
1217 AffineMap permutationMap) {
1218 assert(ids.size() == multiplicity.size() &&
1219 ids.size() == permutationMap.getNumResults());
1220 assert(permutationMap.isProjectedPermutation());
1221 VectorType type = vector.getType().cast<VectorType>();
1222 SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1223 type.getShape().end());
1224 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1225 AffineExpr expr = permutationMap.getResult(i);
1226 auto dim = expr.cast<AffineDimExpr>();
1227 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1228 }
1229 VectorType resultType = VectorType::get(newShape, type.getElementType());
1230 ExtractMapOp::build(builder, result, resultType, vector, ids);
1231 }
1232
verify(ExtractMapOp op)1233 static LogicalResult verify(ExtractMapOp op) {
1234 if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1235 return op.emitOpError(
1236 "expected source and destination vectors of same rank");
1237 unsigned numId = 0;
1238 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1239 if (op.getSourceVectorType().getDimSize(i) %
1240 op.getResultType().getDimSize(i) !=
1241 0)
1242 return op.emitOpError("source vector dimensions must be a multiple of "
1243 "destination vector dimensions");
1244 if (op.getSourceVectorType().getDimSize(i) !=
1245 op.getResultType().getDimSize(i))
1246 numId++;
1247 }
1248 if (numId != op.ids().size())
1249 return op.emitOpError("expected number of ids must match the number of "
1250 "dimensions distributed");
1251 return success();
1252 }
1253
fold(ArrayRef<Attribute> operands)1254 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1255 auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1256 if (insert == nullptr || getType() != insert.vector().getType() ||
1257 ids() != insert.ids())
1258 return {};
1259 return insert.vector();
1260 }
1261
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)1262 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1263 assert(multiplicity.empty());
1264 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1265 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1266 multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1267 getResultType().getDimSize(i));
1268 }
1269 }
1270
1271 template <typename MapOp>
calculateImplicitMap(MapOp op)1272 AffineMap calculateImplicitMap(MapOp op) {
1273 SmallVector<AffineExpr, 4> perm;
1274 // Check which dimension have a multiplicity greater than 1 and associated
1275 // them to the IDs in order.
1276 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1277 if (op.getSourceVectorType().getDimSize(i) !=
1278 op.getResultType().getDimSize(i))
1279 perm.push_back(getAffineDimExpr(i, op.getContext()));
1280 }
1281 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1282 op.getContext());
1283 return map;
1284 }
1285
map()1286 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1287
1288 //===----------------------------------------------------------------------===//
1289 // FmaOp
1290 //===----------------------------------------------------------------------===//
1291
getShapeForUnroll()1292 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1293 return llvm::to_vector<4>(getVectorType().getShape());
1294 }
1295
1296 //===----------------------------------------------------------------------===//
1297 // BroadcastOp
1298 //===----------------------------------------------------------------------===//
1299
verify(BroadcastOp op)1300 static LogicalResult verify(BroadcastOp op) {
1301 VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1302 VectorType dstVectorType = op.getVectorType();
1303 // Scalar to vector broadcast is always valid. A vector
1304 // to vector broadcast needs some additional checking.
1305 if (srcVectorType) {
1306 int64_t srcRank = srcVectorType.getRank();
1307 int64_t dstRank = dstVectorType.getRank();
1308 if (srcRank > dstRank)
1309 return op.emitOpError("source rank higher than destination rank");
1310 // Source has an exact match or singleton value for all trailing dimensions
1311 // (all leading dimensions are simply duplicated).
1312 int64_t lead = dstRank - srcRank;
1313 for (int64_t r = 0; r < srcRank; ++r) {
1314 int64_t srcDim = srcVectorType.getDimSize(r);
1315 int64_t dstDim = dstVectorType.getDimSize(lead + r);
1316 if (srcDim != 1 && srcDim != dstDim)
1317 return op.emitOpError("dimension mismatch (")
1318 << srcDim << " vs. " << dstDim << ")";
1319 }
1320 }
1321 return success();
1322 }
1323
fold(ArrayRef<Attribute> operands)1324 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1325 if (!operands[0])
1326 return {};
1327 auto vectorType = getVectorType();
1328 if (operands[0].getType().isIntOrIndexOrFloat())
1329 return DenseElementsAttr::get(vectorType, operands[0]);
1330 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1331 return DenseElementsAttr::get(vectorType, attr.getSplatValue());
1332 return {};
1333 }
1334
1335 namespace {
1336
1337 // BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
1338 // the degenerated case where the broadcast only adds dimensions of size 1 it
1339 // can be replaced by a ShapeCastOp. This canonicalization checks if the total
1340 // number of elements is the same before and after the broadcast to detect if
1341 // the only change in the vector type are new dimensions of size 1.
1342 class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
1343 public:
1344 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1345
matchAndRewrite(BroadcastOp broadcastOp,PatternRewriter & rewriter) const1346 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1347 PatternRewriter &rewriter) const override {
1348 auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
1349 if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
1350 srcVecType.getNumElements())
1351 return failure();
1352 rewriter.replaceOpWithNewOp<ShapeCastOp>(
1353 broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
1354 return success();
1355 }
1356 };
1357
1358 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1359 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1360 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1361
matchAndRewrite__anon80a4b5030e11::BroadcastFolder1362 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1363 PatternRewriter &rewriter) const override {
1364 auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
1365 if (!srcBroadcast)
1366 return failure();
1367 rewriter.replaceOpWithNewOp<BroadcastOp>(
1368 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
1369 return success();
1370 }
1371 };
1372 } // namespace
1373
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1374 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375 MLIRContext *context) {
1376 results.add<BroadcastToShapeCast, BroadcastFolder>(context);
1377 }
1378
1379 //===----------------------------------------------------------------------===//
1380 // ShuffleOp
1381 //===----------------------------------------------------------------------===//
1382
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)1383 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1384 Value v2, ArrayRef<int64_t> mask) {
1385 result.addOperands({v1, v2});
1386 auto maskAttr = getVectorSubscriptAttr(builder, mask);
1387 result.addTypes(v1.getType());
1388 result.addAttribute(getMaskAttrName(), maskAttr);
1389 }
1390
print(OpAsmPrinter & p,ShuffleOp op)1391 static void print(OpAsmPrinter &p, ShuffleOp op) {
1392 p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
1393 << op.mask();
1394 p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()});
1395 p << " : " << op.v1().getType() << ", " << op.v2().getType();
1396 }
1397
verify(ShuffleOp op)1398 static LogicalResult verify(ShuffleOp op) {
1399 VectorType resultType = op.getVectorType();
1400 VectorType v1Type = op.getV1VectorType();
1401 VectorType v2Type = op.getV2VectorType();
1402 // Verify ranks.
1403 int64_t resRank = resultType.getRank();
1404 int64_t v1Rank = v1Type.getRank();
1405 int64_t v2Rank = v2Type.getRank();
1406 if (resRank != v1Rank || v1Rank != v2Rank)
1407 return op.emitOpError("rank mismatch");
1408 // Verify all but leading dimension sizes.
1409 for (int64_t r = 1; r < v1Rank; ++r) {
1410 int64_t resDim = resultType.getDimSize(r);
1411 int64_t v1Dim = v1Type.getDimSize(r);
1412 int64_t v2Dim = v2Type.getDimSize(r);
1413 if (resDim != v1Dim || v1Dim != v2Dim)
1414 return op.emitOpError("dimension mismatch");
1415 }
1416 // Verify mask length.
1417 auto maskAttr = op.mask().getValue();
1418 int64_t maskLength = maskAttr.size();
1419 if (maskLength != resultType.getDimSize(0))
1420 return op.emitOpError("mask length mismatch");
1421 // Verify all indices.
1422 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1423 for (auto en : llvm::enumerate(maskAttr)) {
1424 auto attr = en.value().dyn_cast<IntegerAttr>();
1425 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1426 return op.emitOpError("mask index #")
1427 << (en.index() + 1) << " out of range";
1428 }
1429 return success();
1430 }
1431
parseShuffleOp(OpAsmParser & parser,OperationState & result)1432 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1433 OpAsmParser::OperandType v1, v2;
1434 Attribute attr;
1435 VectorType v1Type, v2Type;
1436 if (parser.parseOperand(v1) || parser.parseComma() ||
1437 parser.parseOperand(v2) ||
1438 parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1439 result.attributes) ||
1440 parser.parseOptionalAttrDict(result.attributes) ||
1441 parser.parseColonType(v1Type) || parser.parseComma() ||
1442 parser.parseType(v2Type) ||
1443 parser.resolveOperand(v1, v1Type, result.operands) ||
1444 parser.resolveOperand(v2, v2Type, result.operands))
1445 return failure();
1446 // Construct resulting type: leading dimension matches mask length,
1447 // all trailing dimensions match the operands.
1448 auto maskAttr = attr.dyn_cast<ArrayAttr>();
1449 if (!maskAttr)
1450 return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1451 int64_t maskLength = maskAttr.size();
1452 if (maskLength <= 0)
1453 return parser.emitError(parser.getNameLoc(), "invalid mask length");
1454 int64_t v1Rank = v1Type.getRank();
1455 SmallVector<int64_t, 4> shape;
1456 shape.reserve(v1Rank);
1457 shape.push_back(maskLength);
1458 for (int64_t r = 1; r < v1Rank; ++r)
1459 shape.push_back(v1Type.getDimSize(r));
1460 VectorType resType = VectorType::get(shape, v1Type.getElementType());
1461 parser.addTypeToList(resType, result.types);
1462 return success();
1463 }
1464
1465 //===----------------------------------------------------------------------===//
1466 // InsertElementOp
1467 //===----------------------------------------------------------------------===//
1468
build(OpBuilder & builder,OperationState & result,Value source,Value dest,Value position)1469 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1470 Value source, Value dest, Value position) {
1471 result.addOperands({source, dest, position});
1472 result.addTypes(dest.getType());
1473 }
1474
build(OpBuilder & builder,OperationState & result,Value source,Value dest,int64_t position)1475 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1476 Value source, Value dest, int64_t position) {
1477 Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
1478 build(builder, result, source, dest, pos);
1479 }
1480
verify(InsertElementOp op)1481 static LogicalResult verify(InsertElementOp op) {
1482 auto dstVectorType = op.getDestVectorType();
1483 if (dstVectorType.getRank() != 1)
1484 return op.emitOpError("expected 1-D vector");
1485 return success();
1486 }
1487
1488 //===----------------------------------------------------------------------===//
1489 // InsertOp
1490 //===----------------------------------------------------------------------===//
1491
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)1492 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1493 Value dest, ArrayRef<int64_t> position) {
1494 result.addOperands({source, dest});
1495 auto positionAttr = getVectorSubscriptAttr(builder, position);
1496 result.addTypes(dest.getType());
1497 result.addAttribute(getPositionAttrName(), positionAttr);
1498 }
1499
1500 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)1501 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1502 Value dest, ValueRange position) {
1503 SmallVector<int64_t, 4> positionConstants =
1504 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1505 return pos.getDefiningOp<ConstantIndexOp>().getValue();
1506 }));
1507 build(builder, result, source, dest, positionConstants);
1508 }
1509
verify(InsertOp op)1510 static LogicalResult verify(InsertOp op) {
1511 auto positionAttr = op.position().getValue();
1512 auto destVectorType = op.getDestVectorType();
1513 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1514 return op.emitOpError(
1515 "expected position attribute of rank smaller than dest vector rank");
1516 auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1517 if (srcVectorType &&
1518 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1519 static_cast<unsigned>(destVectorType.getRank())))
1520 return op.emitOpError("expected position attribute rank + source rank to "
1521 "match dest vector rank");
1522 else if (!srcVectorType && (positionAttr.size() !=
1523 static_cast<unsigned>(destVectorType.getRank())))
1524 return op.emitOpError(
1525 "expected position attribute rank to match the dest vector rank");
1526 for (auto en : llvm::enumerate(positionAttr)) {
1527 auto attr = en.value().dyn_cast<IntegerAttr>();
1528 if (!attr || attr.getInt() < 0 ||
1529 attr.getInt() >= destVectorType.getDimSize(en.index()))
1530 return op.emitOpError("expected position attribute #")
1531 << (en.index() + 1)
1532 << " to be a non-negative integer smaller than the corresponding "
1533 "dest vector dimension";
1534 }
1535 return success();
1536 }
1537
1538 namespace {
1539
1540 // If insertOp is only inserting unit dimensions it can be transformed to a
1541 // shapecast.
1542 class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
1543 public:
1544 using OpRewritePattern<InsertOp>::OpRewritePattern;
1545
matchAndRewrite(InsertOp insertOp,PatternRewriter & rewriter) const1546 LogicalResult matchAndRewrite(InsertOp insertOp,
1547 PatternRewriter &rewriter) const override {
1548 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1549 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1550 srcVecType.getNumElements())
1551 return failure();
1552 rewriter.replaceOpWithNewOp<ShapeCastOp>(
1553 insertOp, insertOp.getDestVectorType(), insertOp.source());
1554 return success();
1555 }
1556 };
1557
1558 } // namespace
1559
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1560 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1561 MLIRContext *context) {
1562 results.add<InsertToShapeCast>(context);
1563 }
1564
1565 // Eliminates insert operations that produce values identical to their source
1566 // value. This happens when the source and destination vectors have identical
1567 // sizes.
fold(ArrayRef<Attribute> operands)1568 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
1569 if (position().empty())
1570 return source();
1571 return {};
1572 }
1573
1574 //===----------------------------------------------------------------------===//
1575 // InsertMapOp
1576 //===----------------------------------------------------------------------===//
1577
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange ids)1578 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1579 Value vector, Value dest, ValueRange ids) {
1580 InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1581 }
1582
verify(InsertMapOp op)1583 static LogicalResult verify(InsertMapOp op) {
1584 if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1585 return op.emitOpError(
1586 "expected source and destination vectors of same rank");
1587 unsigned numId = 0;
1588 for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1589 if (op.getResultType().getDimSize(i) %
1590 op.getSourceVectorType().getDimSize(i) !=
1591 0)
1592 return op.emitOpError(
1593 "destination vector size must be a multiple of source vector size");
1594 if (op.getResultType().getDimSize(i) !=
1595 op.getSourceVectorType().getDimSize(i))
1596 numId++;
1597 }
1598 if (numId != op.ids().size())
1599 return op.emitOpError("expected number of ids must match the number of "
1600 "dimensions distributed");
1601 return success();
1602 }
1603
map()1604 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1605
1606 //===----------------------------------------------------------------------===//
1607 // InsertStridedSliceOp
1608 //===----------------------------------------------------------------------===//
1609
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)1610 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1611 Value source, Value dest,
1612 ArrayRef<int64_t> offsets,
1613 ArrayRef<int64_t> strides) {
1614 result.addOperands({source, dest});
1615 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1616 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1617 result.addTypes(dest.getType());
1618 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1619 result.addAttribute(getStridesAttrName(), stridesAttr);
1620 }
1621
1622 // TODO: Should be moved to Tablegen Confined attributes.
1623 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)1624 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1625 ArrayAttr arrayAttr,
1626 ArrayRef<int64_t> shape,
1627 StringRef attrName) {
1628 if (arrayAttr.size() > shape.size())
1629 return op.emitOpError("expected ")
1630 << attrName << " attribute of rank smaller than vector rank";
1631 return success();
1632 }
1633
1634 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1635 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1636 // Otherwise, the admissible interval is [min, max].
1637 template <typename OpType>
1638 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)1639 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1640 int64_t max, StringRef attrName,
1641 bool halfOpen = true) {
1642 for (auto attr : arrayAttr) {
1643 auto val = attr.cast<IntegerAttr>().getInt();
1644 auto upper = max;
1645 if (!halfOpen)
1646 upper += 1;
1647 if (val < min || val >= upper)
1648 return op.emitOpError("expected ") << attrName << " to be confined to ["
1649 << min << ", " << upper << ")";
1650 }
1651 return success();
1652 }
1653
1654 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1655 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1656 // Otherwise, the admissible interval is [min, max].
1657 template <typename OpType>
1658 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)1659 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1660 ArrayRef<int64_t> shape, StringRef attrName,
1661 bool halfOpen = true, int64_t min = 0) {
1662 assert(arrayAttr.size() <= shape.size());
1663 unsigned index = 0;
1664 for (auto it : llvm::zip(arrayAttr, shape)) {
1665 auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1666 auto max = std::get<1>(it);
1667 if (!halfOpen)
1668 max += 1;
1669 if (val < min || val >= max)
1670 return op.emitOpError("expected ")
1671 << attrName << " dimension " << index << " to be confined to ["
1672 << min << ", " << max << ")";
1673 ++index;
1674 }
1675 return success();
1676 }
1677
1678 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1679 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1680 // Otherwise, the admissible interval is [min, max].
1681 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)1682 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
1683 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1684 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1685 bool halfOpen = true, int64_t min = 1) {
1686 assert(arrayAttr1.size() <= shape.size());
1687 assert(arrayAttr2.size() <= shape.size());
1688 unsigned index = 0;
1689 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1690 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1691 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1692 auto max = std::get<2>(it);
1693 if (!halfOpen)
1694 max += 1;
1695 if (val1 + val2 < 0 || val1 + val2 >= max)
1696 return op.emitOpError("expected sum(")
1697 << attrName1 << ", " << attrName2 << ") dimension " << index
1698 << " to be confined to [" << min << ", " << max << ")";
1699 ++index;
1700 }
1701 return success();
1702 }
1703
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1704 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1705 MLIRContext *context) {
1706 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1707 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
1708 });
1709 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
1710 }
1711
verify(InsertStridedSliceOp op)1712 static LogicalResult verify(InsertStridedSliceOp op) {
1713 auto sourceVectorType = op.getSourceVectorType();
1714 auto destVectorType = op.getDestVectorType();
1715 auto offsets = op.offsets();
1716 auto strides = op.strides();
1717 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1718 return op.emitOpError(
1719 "expected offsets of same size as destination vector rank");
1720 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1721 return op.emitOpError(
1722 "expected strides of same size as source vector rank");
1723 if (sourceVectorType.getRank() > destVectorType.getRank())
1724 return op.emitOpError(
1725 "expected source rank to be smaller than destination rank");
1726
1727 auto sourceShape = sourceVectorType.getShape();
1728 auto destShape = destVectorType.getShape();
1729 SmallVector<int64_t, 4> sourceShapeAsDestShape(
1730 destShape.size() - sourceShape.size(), 0);
1731 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1732 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1733 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1734 if (failed(
1735 isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1736 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1737 /*halfOpen=*/false)) ||
1738 failed(isSumOfIntegerArrayAttrConfinedToShape(
1739 op, offsets,
1740 makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1741 offName, "source vector shape",
1742 /*halfOpen=*/false, /*min=*/1)))
1743 return failure();
1744
1745 return success();
1746 }
1747
1748 //===----------------------------------------------------------------------===//
1749 // OuterProductOp
1750 //===----------------------------------------------------------------------===//
1751
1752 /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)1753 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
1754 Value lhs, Value rhs, Value acc) {
1755 result.addOperands({lhs, rhs, acc});
1756 result.addTypes(acc.getType());
1757 }
1758
print(OpAsmPrinter & p,OuterProductOp op)1759 static void print(OpAsmPrinter &p, OuterProductOp op) {
1760 p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
1761 if (!op.acc().empty()) {
1762 p << ", " << op.acc();
1763 p.printOptionalAttrDict(op->getAttrs());
1764 }
1765 p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1766 }
1767
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1768 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1769 OperationState &result) {
1770 SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1771 Type tLHS, tRHS;
1772 if (parser.parseOperandList(operandsInfo) ||
1773 parser.parseOptionalAttrDict(result.attributes) ||
1774 parser.parseColonType(tLHS) || parser.parseComma() ||
1775 parser.parseType(tRHS))
1776 return failure();
1777 if (operandsInfo.size() < 2)
1778 return parser.emitError(parser.getNameLoc(),
1779 "expected at least 2 operands");
1780 VectorType vLHS = tLHS.dyn_cast<VectorType>();
1781 VectorType vRHS = tRHS.dyn_cast<VectorType>();
1782 if (!vLHS)
1783 return parser.emitError(parser.getNameLoc(),
1784 "expected vector type for operand #1");
1785 VectorType resType =
1786 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1787 vLHS.getElementType())
1788 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
1789
1790 if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
1791 result.attributes.append(
1792 OuterProductOp::getKindAttrName(),
1793 CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
1794 result.getContext()));
1795 }
1796
1797 return failure(
1798 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1799 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1800 (operandsInfo.size() > 2 &&
1801 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1802 parser.addTypeToList(resType, result.types));
1803 }
1804
verify(OuterProductOp op)1805 static LogicalResult verify(OuterProductOp op) {
1806 Type tRHS = op.getOperandTypeRHS();
1807 VectorType vLHS = op.getOperandVectorTypeLHS(),
1808 vRHS = tRHS.dyn_cast<VectorType>(),
1809 vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1810
1811 if (vLHS.getRank() != 1)
1812 return op.emitOpError("expected 1-d vector for operand #1");
1813
1814 if (vRHS) {
1815 // Proper OUTER operation.
1816 if (vRHS.getRank() != 1)
1817 return op.emitOpError("expected 1-d vector for operand #2");
1818 if (vRES.getRank() != 2)
1819 return op.emitOpError("expected 2-d vector result");
1820 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1821 return op.emitOpError("expected #1 operand dim to match result dim #1");
1822 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1823 return op.emitOpError("expected #2 operand dim to match result dim #2");
1824 } else {
1825 // An AXPY operation.
1826 if (vRES.getRank() != 1)
1827 return op.emitOpError("expected 1-d vector result");
1828 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1829 return op.emitOpError("expected #1 operand dim to match result dim #1");
1830 }
1831
1832 if (vACC && vACC != vRES)
1833 return op.emitOpError("expected operand #3 of same type as result type");
1834
1835 // Verify supported combining kind.
1836 if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
1837 return op.emitOpError("unsupported outerproduct type");
1838
1839 return success();
1840 }
1841
1842 //===----------------------------------------------------------------------===//
1843 // ReshapeOp
1844 //===----------------------------------------------------------------------===//
1845
verify(ReshapeOp op)1846 static LogicalResult verify(ReshapeOp op) {
1847 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1848 auto inputVectorType = op.getInputVectorType();
1849 auto outputVectorType = op.getOutputVectorType();
1850 int64_t inputShapeRank = op.getNumInputShapeSizes();
1851 int64_t outputShapeRank = op.getNumOutputShapeSizes();
1852 SmallVector<int64_t, 4> fixedVectorSizes;
1853 op.getFixedVectorSizes(fixedVectorSizes);
1854 int64_t numFixedVectorSizes = fixedVectorSizes.size();
1855
1856 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1857 return op.emitError("invalid input shape for vector type ")
1858 << inputVectorType;
1859
1860 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1861 return op.emitError("invalid output shape for vector type ")
1862 << outputVectorType;
1863
1864 // Verify that the 'fixedVectorSizes' match an input/output vector shape
1865 // suffix.
1866 unsigned inputVectorRank = inputVectorType.getRank();
1867 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1868 unsigned index = inputVectorRank - numFixedVectorSizes - i;
1869 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1870 return op.emitError("fixed vector size must match input vector for dim ")
1871 << i;
1872 }
1873
1874 unsigned outputVectorRank = outputVectorType.getRank();
1875 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1876 unsigned index = outputVectorRank - numFixedVectorSizes - i;
1877 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1878 return op.emitError("fixed vector size must match output vector for dim ")
1879 << i;
1880 }
1881
1882 // If all shape operands are produced by constant ops, verify that product
1883 // of dimensions for input/output shape match.
1884 auto isDefByConstant = [](Value operand) {
1885 return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1886 };
1887 if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1888 llvm::all_of(op.output_shape(), isDefByConstant)) {
1889 int64_t numInputElements = 1;
1890 for (auto operand : op.input_shape())
1891 numInputElements *=
1892 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1893 int64_t numOutputElements = 1;
1894 for (auto operand : op.output_shape())
1895 numOutputElements *=
1896 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1897 if (numInputElements != numOutputElements)
1898 return op.emitError("product of input and output shape sizes must match");
1899 }
1900 return success();
1901 }
1902
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1903 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1904 populateFromInt64AttrArray(fixed_vector_sizes(), results);
1905 }
1906
1907 //===----------------------------------------------------------------------===//
1908 // ExtractStridedSliceOp
1909 //===----------------------------------------------------------------------===//
1910
1911 // Inference works as follows:
1912 // 1. Add 'sizes' from prefix of dims in 'offsets'.
1913 // 2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1914 static Type inferStridedSliceOpResultType(VectorType vectorType,
1915 ArrayAttr offsets, ArrayAttr sizes,
1916 ArrayAttr strides) {
1917 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1918 SmallVector<int64_t, 4> shape;
1919 shape.reserve(vectorType.getRank());
1920 unsigned idx = 0;
1921 for (unsigned e = offsets.size(); idx < e; ++idx)
1922 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
1923 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1924 shape.push_back(vectorType.getShape()[idx]);
1925
1926 return VectorType::get(shape, vectorType.getElementType());
1927 }
1928
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1929 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1930 Value source, ArrayRef<int64_t> offsets,
1931 ArrayRef<int64_t> sizes,
1932 ArrayRef<int64_t> strides) {
1933 result.addOperands(source);
1934 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1935 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
1936 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1937 result.addTypes(
1938 inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1939 offsetsAttr, sizesAttr, stridesAttr));
1940 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1941 result.addAttribute(getSizesAttrName(), sizesAttr);
1942 result.addAttribute(getStridesAttrName(), stridesAttr);
1943 }
1944
verify(ExtractStridedSliceOp op)1945 static LogicalResult verify(ExtractStridedSliceOp op) {
1946 auto type = op.getVectorType();
1947 auto offsets = op.offsets();
1948 auto sizes = op.sizes();
1949 auto strides = op.strides();
1950 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
1951 op.emitOpError(
1952 "expected offsets, sizes and strides attributes of same size");
1953 return failure();
1954 }
1955
1956 auto shape = type.getShape();
1957 auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
1958 auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
1959 auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
1960 if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
1961 failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
1962 failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
1963 stridesName)) ||
1964 failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
1965 failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
1966 /*halfOpen=*/false,
1967 /*min=*/1)) ||
1968 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1969 /*halfOpen=*/false)) ||
1970 failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
1971 offName, sizesName,
1972 /*halfOpen=*/false)))
1973 return failure();
1974
1975 auto resultType = inferStridedSliceOpResultType(
1976 op.getVectorType(), op.offsets(), op.sizes(), op.strides());
1977 if (op.getResult().getType() != resultType) {
1978 op.emitOpError("expected result type to be ") << resultType;
1979 return failure();
1980 }
1981
1982 return success();
1983 }
1984
1985 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
1986 // to use the source of the InsertStrided ops if we can detect that the
1987 // extracted vector is a subset of one of the vector inserted.
1988 static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)1989 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
1990 // Helper to extract integer out of ArrayAttr.
1991 auto getElement = [](ArrayAttr array, int idx) {
1992 return array[idx].cast<IntegerAttr>().getInt();
1993 };
1994 ArrayAttr extractOffsets = op.offsets();
1995 ArrayAttr extractStrides = op.strides();
1996 ArrayAttr extractSizes = op.sizes();
1997 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1998 while (insertOp) {
1999 if (op.getVectorType().getRank() !=
2000 insertOp.getSourceVectorType().getRank())
2001 return failure();
2002 ArrayAttr insertOffsets = insertOp.offsets();
2003 ArrayAttr insertStrides = insertOp.strides();
2004 // If the rank of extract is greater than the rank of insert, we are likely
2005 // extracting a partial chunk of the vector inserted.
2006 if (extractOffsets.size() > insertOffsets.size())
2007 return failure();
2008 bool patialoverlap = false;
2009 bool disjoint = false;
2010 SmallVector<int64_t, 4> offsetDiffs;
2011 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2012 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2013 return failure();
2014 int64_t start = getElement(insertOffsets, dim);
2015 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2016 int64_t offset = getElement(extractOffsets, dim);
2017 int64_t size = getElement(extractSizes, dim);
2018 // Check if the start of the extract offset is in the interval inserted.
2019 if (start <= offset && offset < end) {
2020 // If the extract interval overlaps but is not fully included we may
2021 // have a partial overlap that will prevent any folding.
2022 if (offset + size > end)
2023 patialoverlap = true;
2024 offsetDiffs.push_back(offset - start);
2025 continue;
2026 }
2027 disjoint = true;
2028 break;
2029 }
2030 // The extract element chunk is a subset of the insert element.
2031 if (!disjoint && !patialoverlap) {
2032 op.setOperand(insertOp.source());
2033 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2034 OpBuilder b(op.getContext());
2035 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
2036 b.getI64ArrayAttr(offsetDiffs));
2037 return success();
2038 }
2039 // If the chunk extracted is disjoint from the chunk inserted, keep looking
2040 // in the insert chain.
2041 if (disjoint)
2042 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
2043 else {
2044 // The extracted vector partially overlap the inserted vector, we cannot
2045 // fold.
2046 return failure();
2047 }
2048 }
2049 return failure();
2050 }
2051
fold(ArrayRef<Attribute> operands)2052 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2053 if (getVectorType() == getResult().getType())
2054 return vector();
2055 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2056 return getResult();
2057 return {};
2058 }
2059
getOffsets(SmallVectorImpl<int64_t> & results)2060 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2061 populateFromInt64AttrArray(offsets(), results);
2062 }
2063
2064 namespace {
2065
2066 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2067 // ConstantMaskOp.
2068 class StridedSliceConstantMaskFolder final
2069 : public OpRewritePattern<ExtractStridedSliceOp> {
2070 public:
2071 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2072
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2073 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2074 PatternRewriter &rewriter) const override {
2075 // Return if 'extractStridedSliceOp' operand is not defined by a
2076 // ConstantMaskOp.
2077 auto defOp = extractStridedSliceOp.vector().getDefiningOp();
2078 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2079 if (!constantMaskOp)
2080 return failure();
2081 // Return if 'extractStridedSliceOp' has non-unit strides.
2082 if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
2083 return attr.cast<IntegerAttr>().getInt() != 1;
2084 }))
2085 return failure();
2086 // Gather constant mask dimension sizes.
2087 SmallVector<int64_t, 4> maskDimSizes;
2088 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
2089 // Gather strided slice offsets and sizes.
2090 SmallVector<int64_t, 4> sliceOffsets;
2091 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
2092 SmallVector<int64_t, 4> sliceSizes;
2093 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
2094
2095 // Compute slice of vector mask region.
2096 SmallVector<int64_t, 4> sliceMaskDimSizes;
2097 assert(sliceOffsets.size() == maskDimSizes.size());
2098 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2099 int64_t maskDimSize = std::get<0>(it);
2100 int64_t sliceOffset = std::get<1>(it);
2101 int64_t sliceSize = std::get<2>(it);
2102 int64_t sliceMaskDimSize = std::max(
2103 static_cast<int64_t>(0),
2104 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2105 sliceMaskDimSizes.push_back(sliceMaskDimSize);
2106 }
2107 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2108 // region is a conjunction of mask dim intervals).
2109 if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
2110 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2111
2112 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2113 // region.
2114 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2115 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2116 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2117 return success();
2118 }
2119 };
2120
2121 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2122 class StridedSliceConstantFolder final
2123 : public OpRewritePattern<ExtractStridedSliceOp> {
2124 public:
2125 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2126
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2127 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2128 PatternRewriter &rewriter) const override {
2129 // Return if 'extractStridedSliceOp' operand is not defined by a
2130 // ConstantOp.
2131 auto constantOp =
2132 extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
2133 if (!constantOp)
2134 return failure();
2135 auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2136 if (!dense)
2137 return failure();
2138 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2139 dense.getSplatValue());
2140 rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
2141 return success();
2142 }
2143 };
2144
2145 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
getI64SubArray(ArrayAttr arrayAttr,unsigned dropFront=0,unsigned dropBack=0)2146 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
2147 unsigned dropFront = 0,
2148 unsigned dropBack = 0) {
2149 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
2150 auto range = arrayAttr.getAsRange<IntegerAttr>();
2151 SmallVector<int64_t, 4> res;
2152 res.reserve(arrayAttr.size() - dropFront - dropBack);
2153 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
2154 it != eit; ++it)
2155 res.push_back((*it).getValue().getSExtValue());
2156 return res;
2157 }
2158
2159 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2160 // BroadcastOp(ExtractStrideSliceOp).
2161 class StridedSliceBroadcast final
2162 : public OpRewritePattern<ExtractStridedSliceOp> {
2163 public:
2164 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2165
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const2166 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2167 PatternRewriter &rewriter) const override {
2168 auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
2169 if (!broadcast)
2170 return failure();
2171 auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
2172 unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
2173 auto dstVecType = op.getType().cast<VectorType>();
2174 unsigned dstRank = dstVecType.getRank();
2175 unsigned rankDiff = dstRank - srcRrank;
2176 // Check if the most inner dimensions of the source of the broadcast are the
2177 // same as the destination of the extract. If this is the case we can just
2178 // use a broadcast as the original dimensions are untouched.
2179 bool lowerDimMatch = true;
2180 for (unsigned i = 0; i < srcRrank; i++) {
2181 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2182 lowerDimMatch = false;
2183 break;
2184 }
2185 }
2186 Value source = broadcast.source();
2187 if (!lowerDimMatch) {
2188 // The inner dimensions don't match, it means we need to extract from the
2189 // source of the orignal broadcast and then broadcast the extracted value.
2190 source = rewriter.create<ExtractStridedSliceOp>(
2191 op->getLoc(), source,
2192 getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
2193 getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
2194 getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
2195 }
2196 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2197 return success();
2198 }
2199 };
2200
2201 } // end anonymous namespace
2202
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2203 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2204 RewritePatternSet &results, MLIRContext *context) {
2205 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2206 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2207 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2208 StridedSliceBroadcast>(context);
2209 }
2210
2211 //===----------------------------------------------------------------------===//
2212 // TransferReadOp
2213 //===----------------------------------------------------------------------===//
2214
2215 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)2216 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2217 EmitFun emitOpError) {
2218 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2219 for (auto expr : permutationMap.getResults()) {
2220 auto dim = expr.dyn_cast<AffineDimExpr>();
2221 auto zero = expr.dyn_cast<AffineConstantExpr>();
2222 if (zero) {
2223 if (zero.getValue() != 0) {
2224 return emitOpError(
2225 "requires a projected permutation_map (at most one dim or the zero "
2226 "constant can appear in each result)");
2227 }
2228 continue;
2229 }
2230 if (!dim) {
2231 return emitOpError("requires a projected permutation_map (at most one "
2232 "dim or the zero constant can appear in each result)");
2233 }
2234 if (seen[dim.getPosition()]) {
2235 return emitOpError(
2236 "requires a permutation_map that is a permutation (found one dim "
2237 "used more than once)");
2238 }
2239 seen[dim.getPosition()] = true;
2240 }
2241 return success();
2242 }
2243
verifyTransferOp(Operation * op,ShapedType shapedType,VectorType vectorType,VectorType maskType,AffineMap permutationMap,ArrayAttr inBounds)2244 static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
2245 VectorType vectorType,
2246 VectorType maskType,
2247 AffineMap permutationMap,
2248 ArrayAttr inBounds) {
2249 if (op->hasAttr("masked")) {
2250 return op->emitOpError("masked attribute has been removed. "
2251 "Use in_bounds instead.");
2252 }
2253
2254 if (!shapedType.isa<MemRefType, RankedTensorType>())
2255 return op->emitOpError(
2256 "requires source to be a memref or ranked tensor type");
2257 auto elementType = shapedType.getElementType();
2258 DataLayout dataLayout = DataLayout::closest(op);
2259 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2260 // Memref or tensor has vector element type.
2261 unsigned sourceVecSize =
2262 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2263 vectorElementType.getShape().back();
2264 unsigned resultVecSize =
2265 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2266 vectorType.getShape().back();
2267 if (resultVecSize % sourceVecSize != 0)
2268 return op->emitOpError(
2269 "requires the bitwidth of the minor 1-D vector to be an integral "
2270 "multiple of the bitwidth of the minor 1-D vector of the source");
2271
2272 unsigned sourceVecEltRank = vectorElementType.getRank();
2273 unsigned resultVecRank = vectorType.getRank();
2274 if (sourceVecEltRank > resultVecRank)
2275 return op->emitOpError(
2276 "requires source vector element and vector result ranks to match.");
2277 unsigned rankOffset = resultVecRank - sourceVecEltRank;
2278 // Check that permutation map results match 'rankOffset' of vector type.
2279 if (permutationMap.getNumResults() != rankOffset)
2280 return op->emitOpError("requires a permutation_map with result dims of "
2281 "the same rank as the vector type");
2282
2283 if (maskType)
2284 return op->emitOpError("does not support masks with vector element type");
2285 } else {
2286 // Memref or tensor has scalar element type.
2287 unsigned resultVecSize =
2288 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2289 vectorType.getShape().back();
2290 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2291 return op->emitOpError(
2292 "requires the bitwidth of the minor 1-D vector to be an integral "
2293 "multiple of the bitwidth of the source element type");
2294
2295 // Check that permutation map results match rank of vector type.
2296 if (permutationMap.getNumResults() != vectorType.getRank())
2297 return op->emitOpError("requires a permutation_map with result dims of "
2298 "the same rank as the vector type");
2299
2300 VectorType expectedMaskType =
2301 vector::detail::transferMaskType(vectorType, permutationMap);
2302 if (maskType && expectedMaskType != maskType)
2303 return op->emitOpError("expects mask type consistent with permutation "
2304 "map: ")
2305 << maskType;
2306 }
2307
2308 if (permutationMap.getNumSymbols() != 0)
2309 return op->emitOpError("requires permutation_map without symbols");
2310 if (permutationMap.getNumInputs() != shapedType.getRank())
2311 return op->emitOpError("requires a permutation_map with input dims of the "
2312 "same rank as the source type");
2313
2314 if (inBounds) {
2315 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2316 return op->emitOpError("expects the optional in_bounds attr of same rank "
2317 "as permutation_map results: ")
2318 << AffineMapAttr::get(permutationMap);
2319 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2320 if (permutationMap.getResult(i).isa<AffineConstantExpr>()
2321 && !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2322 return op->emitOpError("requires broadcast dimensions to be in-bounds");
2323 }
2324
2325 return success();
2326 }
2327
2328 /// Builder that sets padding to zero.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> inBounds)2329 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2330 VectorType vectorType, Value source,
2331 ValueRange indices, AffineMap permutationMap,
2332 ArrayRef<bool> inBounds) {
2333 Type elemType = source.getType().cast<ShapedType>().getElementType();
2334 Value padding = builder.create<ConstantOp>(result.location, elemType,
2335 builder.getZeroAttr(elemType));
2336 if (inBounds.empty())
2337 return build(builder, result, vectorType, source, indices, permutationMap,
2338 padding, ArrayAttr());
2339 ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2340 build(builder, result, vectorType, source, indices, permutationMap, padding,
2341 inBoundsArrayAttr);
2342 }
2343
2344 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,Value padding,ArrayRef<bool> inBounds)2345 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2346 VectorType vectorType, Value source,
2347 ValueRange indices, Value padding,
2348 ArrayRef<bool> inBounds) {
2349 auto permMap = getTransferMinorIdentityMap(
2350 source.getType().cast<ShapedType>(), vectorType);
2351 if (inBounds.empty())
2352 return build(builder, result, vectorType, source, indices, permMap, padding,
2353 ArrayAttr());
2354 ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2355 build(builder, result, vectorType, source, indices, permMap, padding,
2356 inBoundsArrayAttr);
2357 }
2358
2359 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
2360 /// (resp. zero).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,ArrayRef<bool> inBounds)2361 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2362 VectorType vectorType, Value source,
2363 ValueRange indices, ArrayRef<bool> inBounds) {
2364 auto permMap = getTransferMinorIdentityMap(
2365 source.getType().cast<ShapedType>(), vectorType);
2366 build(builder, result, vectorType, source, indices, permMap, inBounds);
2367 }
2368
2369 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMap permutationMap,Value padding,ArrayAttr inBounds)2370 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2371 Type vectorType, Value source, ValueRange indices,
2372 AffineMap permutationMap, Value padding,
2373 ArrayAttr inBounds) {
2374 build(builder, result, vectorType, source, indices, permutationMap, padding,
2375 /*mask=*/Value(), inBounds);
2376 }
2377
2378 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMapAttr permutationMap,Value padding,ArrayAttr inBounds)2379 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2380 Type vectorType, Value source, ValueRange indices,
2381 AffineMapAttr permutationMap, Value padding,
2382 ArrayAttr inBounds) {
2383 build(builder, result, vectorType, source, indices, permutationMap, padding,
2384 /*mask=*/Value(), inBounds);
2385 }
2386
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)2387 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2388 SmallVector<StringRef, 3> elidedAttrs;
2389 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2390 if (op.permutation_map().isMinorIdentity())
2391 elidedAttrs.push_back(op.getPermutationMapAttrName());
2392 bool elideInBounds = true;
2393 if (auto inBounds = op.in_bounds()) {
2394 for (auto attr : *inBounds) {
2395 if (attr.template cast<BoolAttr>().getValue()) {
2396 elideInBounds = false;
2397 break;
2398 }
2399 }
2400 }
2401 if (elideInBounds)
2402 elidedAttrs.push_back(op.getInBoundsAttrName());
2403 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2404 }
2405
print(OpAsmPrinter & p,TransferReadOp op)2406 static void print(OpAsmPrinter &p, TransferReadOp op) {
2407 p << op.getOperationName() << " " << op.source() << "[" << op.indices()
2408 << "], " << op.padding();
2409 if (op.mask())
2410 p << ", " << op.mask();
2411 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2412 p << " : " << op.getShapedType() << ", " << op.getVectorType();
2413 }
2414
parseTransferReadOp(OpAsmParser & parser,OperationState & result)2415 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2416 OperationState &result) {
2417 auto &builder = parser.getBuilder();
2418 llvm::SMLoc typesLoc;
2419 OpAsmParser::OperandType sourceInfo;
2420 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2421 OpAsmParser::OperandType paddingInfo;
2422 SmallVector<Type, 2> types;
2423 OpAsmParser::OperandType maskInfo;
2424 // Parsing with support for paddingValue.
2425 if (parser.parseOperand(sourceInfo) ||
2426 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2427 parser.parseComma() || parser.parseOperand(paddingInfo))
2428 return failure();
2429 ParseResult hasMask = parser.parseOptionalComma();
2430 if (hasMask.succeeded()) {
2431 parser.parseOperand(maskInfo);
2432 }
2433 if (parser.parseOptionalAttrDict(result.attributes) ||
2434 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2435 return failure();
2436 if (types.size() != 2)
2437 return parser.emitError(typesLoc, "requires two types");
2438 auto indexType = builder.getIndexType();
2439 auto shapedType = types[0].dyn_cast<ShapedType>();
2440 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2441 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2442 VectorType vectorType = types[1].dyn_cast<VectorType>();
2443 if (!vectorType)
2444 return parser.emitError(typesLoc, "requires vector type");
2445 auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2446 Attribute mapAttr = result.attributes.get(permutationAttrName);
2447 if (!mapAttr) {
2448 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2449 mapAttr = AffineMapAttr::get(permMap);
2450 result.attributes.set(permutationAttrName, mapAttr);
2451 }
2452 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2453 parser.resolveOperands(indexInfo, indexType, result.operands) ||
2454 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2455 result.operands))
2456 return failure();
2457 if (hasMask.succeeded()) {
2458 if (shapedType.getElementType().dyn_cast<VectorType>())
2459 return parser.emitError(
2460 maskInfo.location, "does not support masks with vector element type");
2461 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2462 // Instead of adding the mask type as an op type, compute it based on the
2463 // vector type and the permutation map (to keep the type signature small).
2464 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2465 if (parser.resolveOperand(maskInfo, maskType, result.operands))
2466 return failure();
2467 }
2468 result.addAttribute(
2469 TransferReadOp::getOperandSegmentSizeAttr(),
2470 builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
2471 static_cast<int32_t>(hasMask.succeeded())}));
2472 return parser.addTypeToList(vectorType, result.types);
2473 }
2474
verify(TransferReadOp op)2475 static LogicalResult verify(TransferReadOp op) {
2476 // Consistency of elemental types in source and vector.
2477 ShapedType shapedType = op.getShapedType();
2478 VectorType vectorType = op.getVectorType();
2479 VectorType maskType = op.getMaskType();
2480 auto paddingType = op.padding().getType();
2481 auto permutationMap = op.permutation_map();
2482 auto sourceElementType = shapedType.getElementType();
2483
2484 if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
2485 return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2486
2487 if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
2488 maskType, permutationMap,
2489 op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2490 return failure();
2491
2492 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
2493 // Source has vector element type.
2494 // Check that 'sourceVectorElementType' and 'paddingType' types match.
2495 if (sourceVectorElementType != paddingType)
2496 return op.emitOpError(
2497 "requires source element type and padding type to match.");
2498
2499 } else {
2500 // Check that 'paddingType' is valid to store in a vector type.
2501 if (!VectorType::isValidElementType(paddingType))
2502 return op.emitOpError("requires valid padding vector elemental type");
2503
2504 // Check that padding type and vector element types match.
2505 if (paddingType != sourceElementType)
2506 return op.emitOpError(
2507 "requires formal padding and source of the same elemental type");
2508 }
2509
2510 return verifyPermutationMap(permutationMap,
2511 [&op](Twine t) { return op.emitOpError(t); });
2512 }
2513
2514 /// This is a common class used for patterns of the form
2515 /// ```
2516 /// someop(memrefcast) -> someop
2517 /// ```
2518 /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)2519 static LogicalResult foldMemRefCast(Operation *op) {
2520 bool folded = false;
2521 for (OpOperand &operand : op->getOpOperands()) {
2522 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
2523 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
2524 operand.set(castOp.getOperand());
2525 folded = true;
2526 }
2527 }
2528 return success(folded);
2529 }
2530
foldTensorCast(Operation * op)2531 static LogicalResult foldTensorCast(Operation *op) {
2532 bool folded = false;
2533 for (OpOperand &operand : op->getOpOperands()) {
2534 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2535 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2536 operand.set(castOp.getOperand());
2537 folded = true;
2538 }
2539 }
2540 return success(folded);
2541 }
2542
2543 template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)2544 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2545 // TODO: support more aggressive createOrFold on:
2546 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
2547 if (op.getShapedType().isDynamicDim(indicesIdx))
2548 return false;
2549 Value index = op.indices()[indicesIdx];
2550 auto cstOp = index.getDefiningOp<ConstantIndexOp>();
2551 if (!cstOp)
2552 return false;
2553
2554 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
2555 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2556
2557 return cstOp.getValue() + vectorSize <= sourceSize;
2558 }
2559
2560 template <typename TransferOp>
foldTransferInBoundsAttribute(TransferOp op)2561 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
2562 AffineMap permutationMap = op.permutation_map();
2563 bool changed = false;
2564 SmallVector<bool, 4> newInBounds;
2565 newInBounds.reserve(op.getTransferRank());
2566 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
2567 // Already marked as in-bounds, nothing to see here.
2568 if (op.isDimInBounds(i)) {
2569 newInBounds.push_back(true);
2570 continue;
2571 }
2572 // Currently out-of-bounds, check whether we can statically determine it is
2573 // inBounds.
2574 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
2575 assert(dimExpr && "Broadcast dims must be in-bounds");
2576 auto inBounds = isInBounds(
2577 op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
2578 newInBounds.push_back(inBounds);
2579 // We commit the pattern if it is "more inbounds".
2580 changed |= inBounds;
2581 }
2582 if (!changed)
2583 return failure();
2584 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2585 OpBuilder b(op.getContext());
2586 op->setAttr(TransferOp::getInBoundsAttrName(),
2587 b.getBoolArrayAttr(newInBounds));
2588 return success();
2589 }
2590
2591 /// ```
2592 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2593 /// : vector<1x4xf32>, tensor<4x4xf32>
2594 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
2595 /// : tensor<4x4xf32>, vector<1x4xf32>
2596 /// ```
2597 /// -> Folds into
2598 /// ```
2599 /// %v0
2600 /// ```
foldRAW(TransferReadOp readOp)2601 static Value foldRAW(TransferReadOp readOp) {
2602 if (!readOp.getShapedType().isa<RankedTensorType>())
2603 return {};
2604 auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
2605 while (defWrite) {
2606 if (checkSameValueRAW(defWrite, readOp))
2607 return defWrite.vector();
2608 if (!isDisjointTransferIndices(
2609 cast<VectorTransferOpInterface>(defWrite.getOperation()),
2610 cast<VectorTransferOpInterface>(readOp.getOperation())))
2611 break;
2612 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2613 }
2614 return {};
2615 }
2616
fold(ArrayRef<Attribute>)2617 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2618 if (Value vec = foldRAW(*this))
2619 return vec;
2620 /// transfer_read(memrefcast) -> transfer_read
2621 if (succeeded(foldTransferInBoundsAttribute(*this)))
2622 return getResult();
2623 if (succeeded(foldMemRefCast(*this)))
2624 return getResult();
2625 if (succeeded(foldTensorCast(*this)))
2626 return getResult();
2627 return OpFoldResult();
2628 }
2629
getShapeForUnroll()2630 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2631 return llvm::to_vector<4>(getVectorType().getShape());
2632 }
2633
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2634 void TransferReadOp::getEffects(
2635 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2636 &effects) {
2637 if (getShapedType().isa<MemRefType>())
2638 effects.emplace_back(MemoryEffects::Read::get(), source(),
2639 SideEffects::DefaultResource::get());
2640 }
2641
2642 //===----------------------------------------------------------------------===//
2643 // TransferWriteOp
2644 //===----------------------------------------------------------------------===//
2645
2646 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,ArrayRef<bool> inBounds)2647 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2648 Value vector, Value source, ValueRange indices,
2649 ArrayRef<bool> inBounds) {
2650 auto vectorType = vector.getType().cast<VectorType>();
2651 auto permMap = getTransferMinorIdentityMap(
2652 source.getType().cast<ShapedType>(), vectorType);
2653 if (inBounds.empty())
2654 return build(builder, result, vector, source, indices, permMap,
2655 ArrayAttr());
2656 ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2657 build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
2658 }
2659
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap)2660 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2661 Value vector, Value source, ValueRange indices,
2662 AffineMap permutationMap) {
2663 build(builder, result, vector, source, indices, permutationMap,
2664 /*inBounds=*/ArrayAttr());
2665 }
2666
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMapAttr permutationMap,ArrayAttr inBounds)2667 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2668 Value vector, Value source, ValueRange indices,
2669 AffineMapAttr permutationMap,
2670 /*optional*/ ArrayAttr inBounds) {
2671 Type resultType = source.getType().dyn_cast<RankedTensorType>();
2672 build(builder, result, resultType, vector, source, indices, permutationMap,
2673 /*mask=*/Value(), inBounds);
2674 }
2675
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,ArrayAttr inBounds)2676 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2677 Value vector, Value source, ValueRange indices,
2678 AffineMap permutationMap,
2679 /*optional*/ ArrayAttr inBounds) {
2680 Type resultType = source.getType().dyn_cast<RankedTensorType>();
2681 build(builder, result, resultType, vector, source, indices, permutationMap,
2682 /*mask=*/Value(), inBounds);
2683 }
2684
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,Value mask,ArrayAttr inBounds)2685 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2686 Value vector, Value source, ValueRange indices,
2687 AffineMap permutationMap, /*optional*/ Value mask,
2688 /*optional*/ ArrayAttr inBounds) {
2689 Type resultType = source.getType().dyn_cast<RankedTensorType>();
2690 build(builder, result, resultType, vector, source, indices, permutationMap,
2691 mask, inBounds);
2692 }
2693
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)2694 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
2695 OperationState &result) {
2696 auto &builder = parser.getBuilder();
2697 llvm::SMLoc typesLoc;
2698 OpAsmParser::OperandType vectorInfo, sourceInfo;
2699 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2700 SmallVector<Type, 2> types;
2701 OpAsmParser::OperandType maskInfo;
2702 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
2703 parser.parseOperand(sourceInfo) ||
2704 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
2705 return failure();
2706 ParseResult hasMask = parser.parseOptionalComma();
2707 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
2708 return failure();
2709 if (parser.parseOptionalAttrDict(result.attributes) ||
2710 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2711 return failure();
2712 if (types.size() != 2)
2713 return parser.emitError(typesLoc, "requires two types");
2714 auto indexType = builder.getIndexType();
2715 VectorType vectorType = types[0].dyn_cast<VectorType>();
2716 if (!vectorType)
2717 return parser.emitError(typesLoc, "requires vector type");
2718 ShapedType shapedType = types[1].dyn_cast<ShapedType>();
2719 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2720 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2721 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
2722 auto attr = result.attributes.get(permutationAttrName);
2723 if (!attr) {
2724 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2725 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2726 }
2727 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
2728 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2729 parser.resolveOperands(indexInfo, indexType, result.operands))
2730 return failure();
2731 if (hasMask.succeeded()) {
2732 if (shapedType.getElementType().dyn_cast<VectorType>())
2733 return parser.emitError(
2734 maskInfo.location, "does not support masks with vector element type");
2735 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
2736 if (parser.resolveOperand(maskInfo, maskType, result.operands))
2737 return failure();
2738 }
2739 result.addAttribute(
2740 TransferWriteOp::getOperandSegmentSizeAttr(),
2741 builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
2742 static_cast<int32_t>(hasMask.succeeded())}));
2743 return failure(shapedType.isa<RankedTensorType>() &&
2744 parser.addTypeToList(shapedType, result.types));
2745 }
2746
print(OpAsmPrinter & p,TransferWriteOp op)2747 static void print(OpAsmPrinter &p, TransferWriteOp op) {
2748 p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
2749 << op.indices() << "]";
2750 if (op.mask())
2751 p << ", " << op.mask();
2752 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2753 p << " : " << op.getVectorType() << ", " << op.getShapedType();
2754 }
2755
verify(TransferWriteOp op)2756 static LogicalResult verify(TransferWriteOp op) {
2757 // Consistency of elemental types in shape and vector.
2758 ShapedType shapedType = op.getShapedType();
2759 VectorType vectorType = op.getVectorType();
2760 VectorType maskType = op.getMaskType();
2761 auto permutationMap = op.permutation_map();
2762
2763 if (llvm::size(op.indices()) != shapedType.getRank())
2764 return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2765
2766 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
2767 // as the semantics is unclear. This can be revisited later if necessary.
2768 if (op.hasBroadcastDim())
2769 return op.emitOpError("should not have broadcast dimensions");
2770
2771 if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
2772 maskType, permutationMap,
2773 op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2774 return failure();
2775
2776 return verifyPermutationMap(permutationMap,
2777 [&op](Twine t) { return op.emitOpError(t); });
2778 }
2779
2780 /// Fold:
2781 /// ```
2782 /// %t1 = ...
2783 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
2784 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
2785 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
2786 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
2787 /// ```
2788 ///
2789 /// into:
2790 ///
2791 /// ```
2792 /// %t0
2793 /// ```
2794 ///
2795 /// The producer of t1 may or may not be DCE'd depending on whether it is a
2796 /// block argument or has side effects.
foldReadInitWrite(TransferWriteOp write,ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)2797 static LogicalResult foldReadInitWrite(TransferWriteOp write,
2798 ArrayRef<Attribute>,
2799 SmallVectorImpl<OpFoldResult> &results) {
2800 auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
2801 // If not operating on tensors, bail.
2802 if (!rankedTensorType)
2803 return failure();
2804 // If no read, bail.
2805 auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
2806 if (!read)
2807 return failure();
2808 // For now, only accept minor identity. Future: composition is minor identity.
2809 if (!read.permutation_map().isMinorIdentity() ||
2810 !write.permutation_map().isMinorIdentity())
2811 return failure();
2812 // Bail on mismatching ranks.
2813 if (read.getTransferRank() != write.getTransferRank())
2814 return failure();
2815 // Bail on potential out-of-bounds accesses.
2816 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
2817 return failure();
2818 // Tensor types must be the same.
2819 if (read.source().getType() != rankedTensorType)
2820 return failure();
2821 // Vector types must be the same.
2822 if (read.getVectorType() != write.getVectorType())
2823 return failure();
2824 // Vector and Tensor shapes must match.
2825 if (read.getVectorType().getShape() != rankedTensorType.getShape())
2826 return failure();
2827 // If any index is nonzero.
2828 auto isNotConstantZero = [](Value v) {
2829 auto cstOp = v.getDefiningOp<ConstantIndexOp>();
2830 return !cstOp || cstOp.getValue() != 0;
2831 };
2832 if (llvm::any_of(read.indices(), isNotConstantZero) ||
2833 llvm::any_of(write.indices(), isNotConstantZero))
2834 return failure();
2835 // Success.
2836 results.push_back(read.source());
2837 return success();
2838 }
2839
checkSameValueWAR(vector::TransferReadOp read,vector::TransferWriteOp write)2840 static bool checkSameValueWAR(vector::TransferReadOp read,
2841 vector::TransferWriteOp write) {
2842 return read.source() == write.source() && read.indices() == write.indices() &&
2843 read.permutation_map() == write.permutation_map() &&
2844 read.getVectorType() == write.getVectorType() && !read.mask() &&
2845 !write.mask();
2846 }
2847 /// Fold transfer_write write after read:
2848 /// ```
2849 /// %t0 = ...
2850 /// %v = vector.transfer_read %t0[%c0...] :
2851 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
2852 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
2853 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
2854 /// ```
2855 ///
2856 /// into:
2857 ///
2858 /// ```
2859 /// %t0
2860 /// ```
foldWAR(TransferWriteOp write,SmallVectorImpl<OpFoldResult> & results)2861 static LogicalResult foldWAR(TransferWriteOp write,
2862 SmallVectorImpl<OpFoldResult> &results) {
2863 if (!write.source().getType().isa<RankedTensorType>())
2864 return failure();
2865 auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
2866 if (!read)
2867 return failure();
2868
2869 if (!checkSameValueWAR(read, write))
2870 return failure();
2871 results.push_back(read.source());
2872 return success();
2873 }
2874
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2875 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
2876 SmallVectorImpl<OpFoldResult> &results) {
2877 if (succeeded(foldReadInitWrite(*this, operands, results)))
2878 return success();
2879 if (succeeded(foldWAR(*this, results)))
2880 return success();
2881 if (succeeded(foldTransferInBoundsAttribute(*this)))
2882 return success();
2883 return foldMemRefCast(*this);
2884 }
2885
getShapeForUnroll()2886 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
2887 return llvm::to_vector<4>(getVectorType().getShape());
2888 }
2889
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2890 void TransferWriteOp::getEffects(
2891 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2892 &effects) {
2893 if (getShapedType().isa<MemRefType>())
2894 effects.emplace_back(MemoryEffects::Write::get(), source(),
2895 SideEffects::DefaultResource::get());
2896 }
2897
2898 namespace {
2899 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
2900 /// DCE
2901 /// ```
2902 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2903 /// : vector<1x4xf32>, tensor<4x4xf32>
2904 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
2905 /// : vector<1x4xf32>, tensor<4x4xf32>
2906 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
2907 /// : vector<1x4xf32>, tensor<4x4xf32>
2908 /// ```
2909 ///
2910 /// into:
2911 ///
2912 /// ```
2913 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2914 /// : vector<1x4xf32>, tensor<4x4xf32>
2915 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
2916 /// : vector<1x4xf32>, tensor<4x4xf32>
2917 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
2918 /// : vector<1x4xf32>, tensor<4x4xf32>
2919 /// ```
2920 ///
2921 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
2922 /// any other uses.
2923 class foldWAW final : public OpRewritePattern<TransferWriteOp> {
2924 public:
2925 using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
matchAndRewrite(TransferWriteOp writeOp,PatternRewriter & rewriter) const2926 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
2927 PatternRewriter &rewriter) const override {
2928 if (!writeOp.getShapedType().isa<RankedTensorType>())
2929 return failure();
2930 vector::TransferWriteOp writeToModify = writeOp;
2931
2932 auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
2933 while (defWrite) {
2934 if (checkSameValueWAW(writeOp, defWrite)) {
2935 writeToModify.sourceMutable().assign(defWrite.source());
2936 return success();
2937 }
2938 if (!isDisjointTransferIndices(
2939 cast<VectorTransferOpInterface>(defWrite.getOperation()),
2940 cast<VectorTransferOpInterface>(writeOp.getOperation())))
2941 break;
2942 // If the previous write op doesn't have any other use we an safely look
2943 // at the previous store to see if it can be removed.
2944 if (!defWrite->hasOneUse())
2945 break;
2946 writeToModify = defWrite;
2947 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2948 }
2949 return failure();
2950 }
2951 };
2952 } // namespace
2953
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2954 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
2955 MLIRContext *context) {
2956 results.add<foldWAW>(context);
2957 }
2958
2959 //===----------------------------------------------------------------------===//
2960 // LoadOp
2961 //===----------------------------------------------------------------------===//
2962
verifyLoadStoreMemRefLayout(Operation * op,MemRefType memRefTy)2963 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
2964 MemRefType memRefTy) {
2965 if (!isLastMemrefDimUnitStride(memRefTy))
2966 return op->emitOpError("most minor memref dim must have unit stride");
2967 return success();
2968 }
2969
verify(vector::LoadOp op)2970 static LogicalResult verify(vector::LoadOp op) {
2971 VectorType resVecTy = op.getVectorType();
2972 MemRefType memRefTy = op.getMemRefType();
2973
2974 if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
2975 return failure();
2976
2977 // Checks for vector memrefs.
2978 Type memElemTy = memRefTy.getElementType();
2979 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
2980 if (memVecTy != resVecTy)
2981 return op.emitOpError("base memref and result vector types should match");
2982 memElemTy = memVecTy.getElementType();
2983 }
2984
2985 if (resVecTy.getElementType() != memElemTy)
2986 return op.emitOpError("base and result element types should match");
2987 if (llvm::size(op.indices()) != memRefTy.getRank())
2988 return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
2989 return success();
2990 }
2991
fold(ArrayRef<Attribute>)2992 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
2993 if (succeeded(foldMemRefCast(*this)))
2994 return getResult();
2995 return OpFoldResult();
2996 }
2997
2998 //===----------------------------------------------------------------------===//
2999 // StoreOp
3000 //===----------------------------------------------------------------------===//
3001
verify(vector::StoreOp op)3002 static LogicalResult verify(vector::StoreOp op) {
3003 VectorType valueVecTy = op.getVectorType();
3004 MemRefType memRefTy = op.getMemRefType();
3005
3006 if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3007 return failure();
3008
3009 // Checks for vector memrefs.
3010 Type memElemTy = memRefTy.getElementType();
3011 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3012 if (memVecTy != valueVecTy)
3013 return op.emitOpError(
3014 "base memref and valueToStore vector types should match");
3015 memElemTy = memVecTy.getElementType();
3016 }
3017
3018 if (valueVecTy.getElementType() != memElemTy)
3019 return op.emitOpError("base and valueToStore element type should match");
3020 if (llvm::size(op.indices()) != memRefTy.getRank())
3021 return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3022 return success();
3023 }
3024
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3025 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3026 SmallVectorImpl<OpFoldResult> &results) {
3027 return foldMemRefCast(*this);
3028 }
3029
3030 //===----------------------------------------------------------------------===//
3031 // MaskedLoadOp
3032 //===----------------------------------------------------------------------===//
3033
verify(MaskedLoadOp op)3034 static LogicalResult verify(MaskedLoadOp op) {
3035 VectorType maskVType = op.getMaskVectorType();
3036 VectorType passVType = op.getPassThruVectorType();
3037 VectorType resVType = op.getVectorType();
3038 MemRefType memType = op.getMemRefType();
3039
3040 if (resVType.getElementType() != memType.getElementType())
3041 return op.emitOpError("base and result element type should match");
3042 if (llvm::size(op.indices()) != memType.getRank())
3043 return op.emitOpError("requires ") << memType.getRank() << " indices";
3044 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3045 return op.emitOpError("expected result dim to match mask dim");
3046 if (resVType != passVType)
3047 return op.emitOpError("expected pass_thru of same type as result type");
3048 return success();
3049 }
3050
3051 namespace {
3052 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3053 public:
3054 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const3055 LogicalResult matchAndRewrite(MaskedLoadOp load,
3056 PatternRewriter &rewriter) const override {
3057 switch (get1DMaskFormat(load.mask())) {
3058 case MaskFormat::AllTrue:
3059 rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
3060 load.base(), load.indices());
3061 return success();
3062 case MaskFormat::AllFalse:
3063 rewriter.replaceOp(load, load.pass_thru());
3064 return success();
3065 case MaskFormat::Unknown:
3066 return failure();
3067 }
3068 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3069 }
3070 };
3071 } // namespace
3072
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3073 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3074 MLIRContext *context) {
3075 results.add<MaskedLoadFolder>(context);
3076 }
3077
fold(ArrayRef<Attribute>)3078 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3079 if (succeeded(foldMemRefCast(*this)))
3080 return getResult();
3081 return OpFoldResult();
3082 }
3083
3084 //===----------------------------------------------------------------------===//
3085 // MaskedStoreOp
3086 //===----------------------------------------------------------------------===//
3087
verify(MaskedStoreOp op)3088 static LogicalResult verify(MaskedStoreOp op) {
3089 VectorType maskVType = op.getMaskVectorType();
3090 VectorType valueVType = op.getVectorType();
3091 MemRefType memType = op.getMemRefType();
3092
3093 if (valueVType.getElementType() != memType.getElementType())
3094 return op.emitOpError("base and valueToStore element type should match");
3095 if (llvm::size(op.indices()) != memType.getRank())
3096 return op.emitOpError("requires ") << memType.getRank() << " indices";
3097 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3098 return op.emitOpError("expected valueToStore dim to match mask dim");
3099 return success();
3100 }
3101
3102 namespace {
3103 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
3104 public:
3105 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const3106 LogicalResult matchAndRewrite(MaskedStoreOp store,
3107 PatternRewriter &rewriter) const override {
3108 switch (get1DMaskFormat(store.mask())) {
3109 case MaskFormat::AllTrue:
3110 rewriter.replaceOpWithNewOp<vector::StoreOp>(
3111 store, store.valueToStore(), store.base(), store.indices());
3112 return success();
3113 case MaskFormat::AllFalse:
3114 rewriter.eraseOp(store);
3115 return success();
3116 case MaskFormat::Unknown:
3117 return failure();
3118 }
3119 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
3120 }
3121 };
3122 } // namespace
3123
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3124 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3125 MLIRContext *context) {
3126 results.add<MaskedStoreFolder>(context);
3127 }
3128
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3129 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
3130 SmallVectorImpl<OpFoldResult> &results) {
3131 return foldMemRefCast(*this);
3132 }
3133
3134 //===----------------------------------------------------------------------===//
3135 // GatherOp
3136 //===----------------------------------------------------------------------===//
3137
verify(GatherOp op)3138 static LogicalResult verify(GatherOp op) {
3139 VectorType indVType = op.getIndexVectorType();
3140 VectorType maskVType = op.getMaskVectorType();
3141 VectorType resVType = op.getVectorType();
3142 MemRefType memType = op.getMemRefType();
3143
3144 if (resVType.getElementType() != memType.getElementType())
3145 return op.emitOpError("base and result element type should match");
3146 if (llvm::size(op.indices()) != memType.getRank())
3147 return op.emitOpError("requires ") << memType.getRank() << " indices";
3148 if (resVType.getDimSize(0) != indVType.getDimSize(0))
3149 return op.emitOpError("expected result dim to match indices dim");
3150 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3151 return op.emitOpError("expected result dim to match mask dim");
3152 if (resVType != op.getPassThruVectorType())
3153 return op.emitOpError("expected pass_thru of same type as result type");
3154 return success();
3155 }
3156
3157 namespace {
3158 class GatherFolder final : public OpRewritePattern<GatherOp> {
3159 public:
3160 using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const3161 LogicalResult matchAndRewrite(GatherOp gather,
3162 PatternRewriter &rewriter) const override {
3163 switch (get1DMaskFormat(gather.mask())) {
3164 case MaskFormat::AllTrue:
3165 return failure(); // no unmasked equivalent
3166 case MaskFormat::AllFalse:
3167 rewriter.replaceOp(gather, gather.pass_thru());
3168 return success();
3169 case MaskFormat::Unknown:
3170 return failure();
3171 }
3172 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
3173 }
3174 };
3175 } // namespace
3176
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3177 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
3178 MLIRContext *context) {
3179 results.add<GatherFolder>(context);
3180 }
3181
3182 //===----------------------------------------------------------------------===//
3183 // ScatterOp
3184 //===----------------------------------------------------------------------===//
3185
verify(ScatterOp op)3186 static LogicalResult verify(ScatterOp op) {
3187 VectorType indVType = op.getIndexVectorType();
3188 VectorType maskVType = op.getMaskVectorType();
3189 VectorType valueVType = op.getVectorType();
3190 MemRefType memType = op.getMemRefType();
3191
3192 if (valueVType.getElementType() != memType.getElementType())
3193 return op.emitOpError("base and valueToStore element type should match");
3194 if (llvm::size(op.indices()) != memType.getRank())
3195 return op.emitOpError("requires ") << memType.getRank() << " indices";
3196 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
3197 return op.emitOpError("expected valueToStore dim to match indices dim");
3198 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3199 return op.emitOpError("expected valueToStore dim to match mask dim");
3200 return success();
3201 }
3202
3203 namespace {
3204 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
3205 public:
3206 using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const3207 LogicalResult matchAndRewrite(ScatterOp scatter,
3208 PatternRewriter &rewriter) const override {
3209 switch (get1DMaskFormat(scatter.mask())) {
3210 case MaskFormat::AllTrue:
3211 return failure(); // no unmasked equivalent
3212 case MaskFormat::AllFalse:
3213 rewriter.eraseOp(scatter);
3214 return success();
3215 case MaskFormat::Unknown:
3216 return failure();
3217 }
3218 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
3219 }
3220 };
3221 } // namespace
3222
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3223 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
3224 MLIRContext *context) {
3225 results.add<ScatterFolder>(context);
3226 }
3227
3228 //===----------------------------------------------------------------------===//
3229 // ExpandLoadOp
3230 //===----------------------------------------------------------------------===//
3231
verify(ExpandLoadOp op)3232 static LogicalResult verify(ExpandLoadOp op) {
3233 VectorType maskVType = op.getMaskVectorType();
3234 VectorType passVType = op.getPassThruVectorType();
3235 VectorType resVType = op.getVectorType();
3236 MemRefType memType = op.getMemRefType();
3237
3238 if (resVType.getElementType() != memType.getElementType())
3239 return op.emitOpError("base and result element type should match");
3240 if (llvm::size(op.indices()) != memType.getRank())
3241 return op.emitOpError("requires ") << memType.getRank() << " indices";
3242 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3243 return op.emitOpError("expected result dim to match mask dim");
3244 if (resVType != passVType)
3245 return op.emitOpError("expected pass_thru of same type as result type");
3246 return success();
3247 }
3248
3249 namespace {
3250 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
3251 public:
3252 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const3253 LogicalResult matchAndRewrite(ExpandLoadOp expand,
3254 PatternRewriter &rewriter) const override {
3255 switch (get1DMaskFormat(expand.mask())) {
3256 case MaskFormat::AllTrue:
3257 rewriter.replaceOpWithNewOp<vector::LoadOp>(
3258 expand, expand.getType(), expand.base(), expand.indices());
3259 return success();
3260 case MaskFormat::AllFalse:
3261 rewriter.replaceOp(expand, expand.pass_thru());
3262 return success();
3263 case MaskFormat::Unknown:
3264 return failure();
3265 }
3266 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
3267 }
3268 };
3269 } // namespace
3270
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3271 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3272 MLIRContext *context) {
3273 results.add<ExpandLoadFolder>(context);
3274 }
3275
3276 //===----------------------------------------------------------------------===//
3277 // CompressStoreOp
3278 //===----------------------------------------------------------------------===//
3279
verify(CompressStoreOp op)3280 static LogicalResult verify(CompressStoreOp op) {
3281 VectorType maskVType = op.getMaskVectorType();
3282 VectorType valueVType = op.getVectorType();
3283 MemRefType memType = op.getMemRefType();
3284
3285 if (valueVType.getElementType() != memType.getElementType())
3286 return op.emitOpError("base and valueToStore element type should match");
3287 if (llvm::size(op.indices()) != memType.getRank())
3288 return op.emitOpError("requires ") << memType.getRank() << " indices";
3289 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3290 return op.emitOpError("expected valueToStore dim to match mask dim");
3291 return success();
3292 }
3293
3294 namespace {
3295 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
3296 public:
3297 using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const3298 LogicalResult matchAndRewrite(CompressStoreOp compress,
3299 PatternRewriter &rewriter) const override {
3300 switch (get1DMaskFormat(compress.mask())) {
3301 case MaskFormat::AllTrue:
3302 rewriter.replaceOpWithNewOp<vector::StoreOp>(
3303 compress, compress.valueToStore(), compress.base(),
3304 compress.indices());
3305 return success();
3306 case MaskFormat::AllFalse:
3307 rewriter.eraseOp(compress);
3308 return success();
3309 case MaskFormat::Unknown:
3310 return failure();
3311 }
3312 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
3313 }
3314 };
3315 } // namespace
3316
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3317 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3318 MLIRContext *context) {
3319 results.add<CompressStoreFolder>(context);
3320 }
3321
3322 //===----------------------------------------------------------------------===//
3323 // ShapeCastOp
3324 //===----------------------------------------------------------------------===//
3325
3326 /// Returns true if each element of 'a' is equal to the product of a contiguous
3327 /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)3328 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3329 unsigned rankA = a.size();
3330 unsigned rankB = b.size();
3331 assert(rankA < rankB);
3332
3333 unsigned i = 0;
3334 unsigned j = 0;
3335 while (i < rankA && j < rankB) {
3336 int64_t dimA = a[i];
3337 int64_t dimB = 1;
3338 while (dimB < dimA && j < rankB)
3339 dimB *= b[j++];
3340 if (dimA != dimB)
3341 break;
3342 ++i;
3343
3344 // Handle the case when trailing dimensions are of size 1.
3345 // Include them into the contiguous sequence.
3346 auto isOne = [](int64_t v) { return v == 1; };
3347 if (i < rankA && llvm::all_of(a.slice(i), isOne))
3348 i = rankA;
3349 if (j < rankB && llvm::all_of(b.slice(j), isOne))
3350 j = rankB;
3351 }
3352
3353 return i == rankA && j == rankB;
3354 }
3355
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)3356 static LogicalResult verifyVectorShapeCast(Operation *op,
3357 VectorType sourceVectorType,
3358 VectorType resultVectorType) {
3359 // Check that element type is the same.
3360 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
3361 return op->emitOpError("source/result vectors must have same element type");
3362 auto sourceShape = sourceVectorType.getShape();
3363 auto resultShape = resultVectorType.getShape();
3364
3365 // Check that product of source dim sizes matches product of result dim sizes.
3366 int64_t sourceDimProduct = std::accumulate(
3367 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
3368 int64_t resultDimProduct = std::accumulate(
3369 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
3370 if (sourceDimProduct != resultDimProduct)
3371 return op->emitOpError("source/result number of elements must match");
3372
3373 // Check that expanding/contracting rank cases.
3374 unsigned sourceRank = sourceVectorType.getRank();
3375 unsigned resultRank = resultVectorType.getRank();
3376 if (sourceRank < resultRank) {
3377 if (!isValidShapeCast(sourceShape, resultShape))
3378 return op->emitOpError("invalid shape cast");
3379 } else if (sourceRank > resultRank) {
3380 if (!isValidShapeCast(resultShape, sourceShape))
3381 return op->emitOpError("invalid shape cast");
3382 }
3383 return success();
3384 }
3385
verify(ShapeCastOp op)3386 static LogicalResult verify(ShapeCastOp op) {
3387 auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
3388 auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
3389
3390 // Check if source/result are of vector type.
3391 if (sourceVectorType && resultVectorType)
3392 return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
3393
3394 return success();
3395 }
3396
fold(ArrayRef<Attribute> operands)3397 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
3398 // Nop shape cast.
3399 if (source().getType() == result().getType())
3400 return source();
3401
3402 // Canceling shape casts.
3403 if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
3404 if (result().getType() == otherOp.source().getType())
3405 return otherOp.source();
3406 setOperand(otherOp.source());
3407 return getResult();
3408 }
3409 return {};
3410 }
3411
3412 namespace {
3413 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
3414 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
3415 public:
3416 using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
3417
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const3418 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
3419 PatternRewriter &rewriter) const override {
3420 auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
3421 if (!constantOp)
3422 return failure();
3423 // Only handle splat for now.
3424 auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
3425 if (!dense)
3426 return failure();
3427 auto newAttr = DenseElementsAttr::get(
3428 shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
3429 rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
3430 return success();
3431 }
3432 };
3433
3434 } // namespace
3435
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3436 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3437 MLIRContext *context) {
3438 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
3439 results.add<ShapeCastConstantFolder>(context);
3440 }
3441
3442 //===----------------------------------------------------------------------===//
3443 // VectorBitCastOp
3444 //===----------------------------------------------------------------------===//
3445
verify(BitCastOp op)3446 static LogicalResult verify(BitCastOp op) {
3447 auto sourceVectorType = op.getSourceVectorType();
3448 auto resultVectorType = op.getResultVectorType();
3449
3450 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
3451 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
3452 return op.emitOpError("dimension size mismatch at: ") << i;
3453 }
3454
3455 DataLayout dataLayout = DataLayout::closest(op);
3456 if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
3457 sourceVectorType.getShape().back() !=
3458 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
3459 resultVectorType.getShape().back())
3460 return op.emitOpError(
3461 "source/result bitwidth of the minor 1-D vectors must be equal");
3462
3463 return success();
3464 }
3465
fold(ArrayRef<Attribute> operands)3466 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
3467 // Nop cast.
3468 if (source().getType() == result().getType())
3469 return source();
3470
3471 // Canceling bitcasts.
3472 if (auto otherOp = source().getDefiningOp<BitCastOp>())
3473 if (result().getType() == otherOp.source().getType())
3474 return otherOp.source();
3475
3476 Attribute sourceConstant = operands.front();
3477 if (!sourceConstant)
3478 return {};
3479
3480 Type srcElemType = getSourceVectorType().getElementType();
3481 Type dstElemType = getResultVectorType().getElementType();
3482
3483 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
3484 if (floatPack.isSplat()) {
3485 auto splat = floatPack.getSplatValue<FloatAttr>();
3486
3487 // Casting fp16 into fp32.
3488 if (srcElemType.isF16() && dstElemType.isF32()) {
3489 uint32_t bits = static_cast<uint32_t>(
3490 splat.getValue().bitcastToAPInt().getZExtValue());
3491 // Duplicate the 16-bit pattern.
3492 bits = (bits << 16) | (bits & 0xffff);
3493 APInt intBits(32, bits);
3494 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
3495 return DenseElementsAttr::get(getResultVectorType(), floatBits);
3496 }
3497 }
3498 }
3499
3500 return {};
3501 }
3502
3503 //===----------------------------------------------------------------------===//
3504 // TypeCastOp
3505 //===----------------------------------------------------------------------===//
3506
extractShape(MemRefType memRefType)3507 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
3508 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
3509 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
3510 memRefType.getShape().end());
3511 if (vectorType)
3512 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
3513 return res;
3514 }
3515
3516 /// Build the canonical memRefType with a single vector.
3517 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)3518 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
3519 Value source) {
3520 result.addOperands(source);
3521 MemRefType memRefType = source.getType().cast<MemRefType>();
3522 VectorType vectorType =
3523 VectorType::get(extractShape(memRefType),
3524 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
3525 result.addTypes(
3526 MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
3527 }
3528
verify(TypeCastOp op)3529 static LogicalResult verify(TypeCastOp op) {
3530 MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
3531 if (!canonicalType.getAffineMaps().empty())
3532 return op.emitOpError("expects operand to be a memref with no layout");
3533 if (!op.getResultMemRefType().getAffineMaps().empty())
3534 return op.emitOpError("expects result to be a memref with no layout");
3535 if (op.getResultMemRefType().getMemorySpace() !=
3536 op.getMemRefType().getMemorySpace())
3537 return op.emitOpError("expects result in same memory space");
3538
3539 auto sourceType = op.getMemRefType();
3540 auto resultType = op.getResultMemRefType();
3541 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
3542 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
3543 return op.emitOpError(
3544 "expects result and operand with same underlying scalar type: ")
3545 << resultType;
3546 if (extractShape(sourceType) != extractShape(resultType))
3547 return op.emitOpError(
3548 "expects concatenated result and operand shapes to be equal: ")
3549 << resultType;
3550 return success();
3551 }
3552
3553 //===----------------------------------------------------------------------===//
3554 // TransposeOp
3555 //===----------------------------------------------------------------------===//
3556
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)3557 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
3558 Value vector, ArrayRef<int64_t> transp) {
3559 VectorType vt = vector.getType().cast<VectorType>();
3560 SmallVector<int64_t, 4> transposedShape(vt.getRank());
3561 for (unsigned i = 0; i < transp.size(); ++i)
3562 transposedShape[i] = vt.getShape()[transp[i]];
3563
3564 result.addOperands(vector);
3565 result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
3566 result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
3567 }
3568
3569 // Eliminates transpose operations, which produce values identical to their
3570 // input values. This happens when the dimensions of the input vector remain in
3571 // their original order after the transpose operation.
fold(ArrayRef<Attribute> operands)3572 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
3573 SmallVector<int64_t, 4> transp;
3574 getTransp(transp);
3575
3576 // Check if the permutation of the dimensions contains sequential values:
3577 // {0, 1, 2, ...}.
3578 for (int64_t i = 0, e = transp.size(); i < e; i++) {
3579 if (transp[i] != i)
3580 return {};
3581 }
3582
3583 return vector();
3584 }
3585
verify(vector::TransposeOp op)3586 static LogicalResult verify(vector::TransposeOp op) {
3587 VectorType vectorType = op.getVectorType();
3588 VectorType resultType = op.getResultType();
3589 int64_t rank = resultType.getRank();
3590 if (vectorType.getRank() != rank)
3591 return op.emitOpError("vector result rank mismatch: ") << rank;
3592 // Verify transposition array.
3593 auto transpAttr = op.transp().getValue();
3594 int64_t size = transpAttr.size();
3595 if (rank != size)
3596 return op.emitOpError("transposition length mismatch: ") << size;
3597 SmallVector<bool, 8> seen(rank, false);
3598 for (auto ta : llvm::enumerate(transpAttr)) {
3599 int64_t i = ta.value().cast<IntegerAttr>().getInt();
3600 if (i < 0 || i >= rank)
3601 return op.emitOpError("transposition index out of range: ") << i;
3602 if (seen[i])
3603 return op.emitOpError("duplicate position index: ") << i;
3604 seen[i] = true;
3605 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
3606 return op.emitOpError("dimension size mismatch at: ") << i;
3607 }
3608 return success();
3609 }
3610
3611 namespace {
3612
3613 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
3614 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
3615 public:
3616 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
3617
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const3618 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
3619 PatternRewriter &rewriter) const override {
3620 // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
3621 auto getPermutation = [](vector::TransposeOp transpose) {
3622 SmallVector<int64_t, 4> permutation;
3623 transpose.getTransp(permutation);
3624 return permutation;
3625 };
3626
3627 // Composes two permutations: result[i] = permutation1[permutation2[i]].
3628 auto composePermutations = [](ArrayRef<int64_t> permutation1,
3629 ArrayRef<int64_t> permutation2) {
3630 SmallVector<int64_t, 4> result;
3631 for (auto index : permutation2)
3632 result.push_back(permutation1[index]);
3633 return result;
3634 };
3635
3636 // Return if the input of 'transposeOp' is not defined by another transpose.
3637 vector::TransposeOp parentTransposeOp =
3638 transposeOp.vector().getDefiningOp<vector::TransposeOp>();
3639 if (!parentTransposeOp)
3640 return failure();
3641
3642 SmallVector<int64_t, 4> permutation = composePermutations(
3643 getPermutation(parentTransposeOp), getPermutation(transposeOp));
3644 // Replace 'transposeOp' with a new transpose operation.
3645 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
3646 transposeOp, transposeOp.getResult().getType(),
3647 parentTransposeOp.vector(),
3648 vector::getVectorSubscriptAttr(rewriter, permutation));
3649 return success();
3650 }
3651 };
3652
3653 } // end anonymous namespace
3654
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3655 void vector::TransposeOp::getCanonicalizationPatterns(
3656 RewritePatternSet &results, MLIRContext *context) {
3657 results.add<TransposeFolder>(context);
3658 }
3659
getTransp(SmallVectorImpl<int64_t> & results)3660 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
3661 populateFromInt64AttrArray(transp(), results);
3662 }
3663
3664 //===----------------------------------------------------------------------===//
3665 // ConstantMaskOp
3666 //===----------------------------------------------------------------------===//
3667
verify(ConstantMaskOp & op)3668 static LogicalResult verify(ConstantMaskOp &op) {
3669 // Verify that array attr size matches the rank of the vector result.
3670 auto resultType = op.getResult().getType().cast<VectorType>();
3671 if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
3672 return op.emitOpError(
3673 "must specify array attr of size equal vector result rank");
3674 // Verify that each array attr element is in bounds of corresponding vector
3675 // result dimension size.
3676 auto resultShape = resultType.getShape();
3677 SmallVector<int64_t, 4> maskDimSizes;
3678 for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
3679 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
3680 if (attrValue < 0 || attrValue > resultShape[it.index()])
3681 return op.emitOpError(
3682 "array attr of size out of bounds of vector result dimension size");
3683 maskDimSizes.push_back(attrValue);
3684 }
3685 // Verify that if one mask dim size is zero, they all should be zero (because
3686 // the mask region is a conjunction of each mask dimension interval).
3687 bool any_zeros = llvm::is_contained(maskDimSizes, 0);
3688 bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
3689 if (any_zeros && !all_zeros)
3690 return op.emitOpError("expected all mask dim sizes to be zeros, "
3691 "as a result of conjunction with zero mask dim");
3692 return success();
3693 }
3694
3695 //===----------------------------------------------------------------------===//
3696 // CreateMaskOp
3697 //===----------------------------------------------------------------------===//
3698
verify(CreateMaskOp op)3699 static LogicalResult verify(CreateMaskOp op) {
3700 // Verify that an operand was specified for each result vector each dimension.
3701 if (op.getNumOperands() !=
3702 op.getResult().getType().cast<VectorType>().getRank())
3703 return op.emitOpError(
3704 "must specify an operand for each result vector dimension");
3705 return success();
3706 }
3707
3708 namespace {
3709
3710 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
3711 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
3712 public:
3713 using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
3714
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const3715 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
3716 PatternRewriter &rewriter) const override {
3717 // Return if any of 'createMaskOp' operands are not defined by a constant.
3718 auto is_not_def_by_constant = [](Value operand) {
3719 return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
3720 };
3721 if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
3722 return failure();
3723 // Gather constant mask dimension sizes.
3724 SmallVector<int64_t, 4> maskDimSizes;
3725 for (auto operand : createMaskOp.operands()) {
3726 auto defOp = operand.getDefiningOp();
3727 maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
3728 }
3729 // Replace 'createMaskOp' with ConstantMaskOp.
3730 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3731 createMaskOp, createMaskOp.getResult().getType(),
3732 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
3733 return success();
3734 }
3735 };
3736
3737 } // end anonymous namespace
3738
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3739 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
3740 MLIRContext *context) {
3741 results.add<CreateMaskFolder>(context);
3742 }
3743
populateVectorToVectorCanonicalizationPatterns(RewritePatternSet & patterns)3744 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
3745 RewritePatternSet &patterns) {
3746 patterns
3747 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
3748 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
3749 StridedSliceConstantMaskFolder, TransposeFolder>(
3750 patterns.getContext());
3751 }
3752
3753 #define GET_OP_CLASSES
3754 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
3755