1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/VectorUtils.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/TypeUtilities.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Support/MathExtras.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/bit.h"
34 #include <numeric>
35
36 #include "mlir/Dialect/Vector/VectorOpsDialect.cpp.inc"
37 // Pull in all enum type and utility function definitions.
38 #include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
39
40 using namespace mlir;
41 using namespace mlir::vector;
42
43 /// Helper enum to classify mask value.
44 enum class MaskFormat {
45 AllTrue = 0,
46 AllFalse = 1,
47 Unknown = 2,
48 };
49
50 /// Helper method to classify a 1-D mask value. Currently, the method
51 /// looks "under the hood" of a constant value with dense attributes
52 /// and a constant mask operation (since the client may be called at
53 /// various stages during progressive lowering).
get1DMaskFormat(Value mask)54 static MaskFormat get1DMaskFormat(Value mask) {
55 if (auto c = mask.getDefiningOp<ConstantOp>()) {
56 // Inspect constant dense values. We count up for bits that
57 // are set, count down for bits that are cleared, and bail
58 // when a mix is detected.
59 if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
60 int64_t val = 0;
61 for (bool b : denseElts.getValues<bool>())
62 if (b && val >= 0)
63 val++;
64 else if (!b && val <= 0)
65 val--;
66 else
67 return MaskFormat::Unknown;
68 if (val > 0)
69 return MaskFormat::AllTrue;
70 if (val < 0)
71 return MaskFormat::AllFalse;
72 }
73 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
74 // Inspect constant mask index. If the index exceeds the
75 // dimension size, all bits are set. If the index is zero
76 // or less, no bits are set.
77 ArrayAttr masks = m.mask_dim_sizes();
78 assert(masks.size() == 1);
79 int64_t i = masks[0].cast<IntegerAttr>().getInt();
80 int64_t u = m.getType().getDimSize(0);
81 if (i >= u)
82 return MaskFormat::AllTrue;
83 if (i <= 0)
84 return MaskFormat::AllFalse;
85 }
86 return MaskFormat::Unknown;
87 }
88
89 // Helper for verifying combining kinds in contractions and reductions.
isSupportedCombiningKind(CombiningKind combiningKind,Type elementType)90 static bool isSupportedCombiningKind(CombiningKind combiningKind,
91 Type elementType) {
92 switch (combiningKind) {
93 case CombiningKind::ADD:
94 case CombiningKind::MUL:
95 return elementType.isIntOrIndexOrFloat();
96 case CombiningKind::MINUI:
97 case CombiningKind::MINSI:
98 case CombiningKind::MAXUI:
99 case CombiningKind::MAXSI:
100 case CombiningKind::AND:
101 case CombiningKind::OR:
102 case CombiningKind::XOR:
103 return elementType.isIntOrIndex();
104 case CombiningKind::MINF:
105 case CombiningKind::MAXF:
106 return elementType.isa<FloatType>();
107 }
108 return false;
109 }
110
111 /// Return true if the last dimension of the MemRefType has unit stride. Also
112 /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)113 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
114 int64_t offset;
115 SmallVector<int64_t> strides;
116 auto successStrides = getStridesAndOffset(type, strides, offset);
117 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
118 }
119
120 //===----------------------------------------------------------------------===//
121 // CombiningKindAttr
122 //===----------------------------------------------------------------------===//
123
124 namespace mlir {
125 namespace vector {
126 namespace detail {
127 struct BitmaskEnumStorage : public AttributeStorage {
128 using KeyTy = uint64_t;
129
BitmaskEnumStoragemlir::vector::detail::BitmaskEnumStorage130 BitmaskEnumStorage(KeyTy val) : value(val) {}
131
operator ==mlir::vector::detail::BitmaskEnumStorage132 bool operator==(const KeyTy &key) const { return value == key; }
133
constructmlir::vector::detail::BitmaskEnumStorage134 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
135 const KeyTy &key) {
136 return new (allocator.allocate<BitmaskEnumStorage>())
137 BitmaskEnumStorage(key);
138 }
139
140 KeyTy value = 0;
141 };
142 } // namespace detail
143 } // namespace vector
144 } // namespace mlir
145
get(CombiningKind kind,MLIRContext * context)146 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
147 MLIRContext *context) {
148 return Base::get(context, static_cast<uint64_t>(kind));
149 }
150
getKind() const151 CombiningKind CombiningKindAttr::getKind() const {
152 return static_cast<CombiningKind>(getImpl()->value);
153 }
154
155 static constexpr const CombiningKind combiningKindsList[] = {
156 // clang-format off
157 CombiningKind::ADD,
158 CombiningKind::MUL,
159 CombiningKind::MINUI,
160 CombiningKind::MINSI,
161 CombiningKind::MINF,
162 CombiningKind::MAXUI,
163 CombiningKind::MAXSI,
164 CombiningKind::MAXF,
165 CombiningKind::AND,
166 CombiningKind::OR,
167 CombiningKind::XOR,
168 // clang-format on
169 };
170
print(DialectAsmPrinter & printer) const171 void CombiningKindAttr::print(DialectAsmPrinter &printer) const {
172 printer << "kind<";
173 auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
174 return bitEnumContains(this->getKind(), kind);
175 });
176 llvm::interleaveComma(kinds, printer,
177 [&](auto kind) { printer << stringifyEnum(kind); });
178 printer << ">";
179 }
180
parse(DialectAsmParser & parser)181 Attribute CombiningKindAttr::parse(DialectAsmParser &parser) {
182 if (failed(parser.parseLess()))
183 return {};
184
185 StringRef elemName;
186 if (failed(parser.parseKeyword(&elemName)))
187 return {};
188
189 auto kind = symbolizeCombiningKind(elemName);
190 if (!kind) {
191 parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
192 << elemName;
193 return {};
194 }
195
196 if (failed(parser.parseGreater()))
197 return {};
198
199 return CombiningKindAttr::get(kind.getValue(), parser.getContext());
200 }
201
parseAttribute(DialectAsmParser & parser,Type type) const202 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
203 Type type) const {
204 StringRef attrKind;
205 if (parser.parseKeyword(&attrKind))
206 return {};
207
208 if (attrKind == "kind")
209 return CombiningKindAttr::parse(parser);
210
211 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
212 return {};
213 }
214
printAttribute(Attribute attr,DialectAsmPrinter & os) const215 void VectorDialect::printAttribute(Attribute attr,
216 DialectAsmPrinter &os) const {
217 if (auto ck = attr.dyn_cast<CombiningKindAttr>())
218 ck.print(os);
219 else
220 llvm_unreachable("Unknown attribute type");
221 }
222
223 //===----------------------------------------------------------------------===//
224 // VectorDialect
225 //===----------------------------------------------------------------------===//
226
initialize()227 void VectorDialect::initialize() {
228 addAttributes<CombiningKindAttr>();
229
230 addOperations<
231 #define GET_OP_LIST
232 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
233 >();
234 }
235
236 /// Materialize a single constant operation from a given attribute value with
237 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)238 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
239 Attribute value, Type type,
240 Location loc) {
241 return builder.create<ConstantOp>(loc, type, value);
242 }
243
getVectorSubscriptType(Builder & builder)244 IntegerType vector::getVectorSubscriptType(Builder &builder) {
245 return builder.getIntegerType(64);
246 }
247
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)248 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
249 ArrayRef<int64_t> values) {
250 return builder.getI64ArrayAttr(values);
251 }
252
253 //===----------------------------------------------------------------------===//
254 // MultiDimReductionOp
255 //===----------------------------------------------------------------------===//
256
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<bool> reductionMask,CombiningKind kind)257 void vector::MultiDimReductionOp::build(OpBuilder &builder,
258 OperationState &result, Value source,
259 ArrayRef<bool> reductionMask,
260 CombiningKind kind) {
261 result.addOperands(source);
262 auto sourceVectorType = source.getType().cast<VectorType>();
263 auto targetType = MultiDimReductionOp::inferDestType(
264 sourceVectorType.getShape(), reductionMask,
265 sourceVectorType.getElementType());
266 result.addTypes(targetType);
267
268 SmallVector<int64_t> reductionDims;
269 for (auto en : llvm::enumerate(reductionMask))
270 if (en.value())
271 reductionDims.push_back(en.index());
272 result.addAttribute(getReductionDimsAttrName(),
273 builder.getI64ArrayAttr(reductionDims));
274 result.addAttribute(getKindAttrName(),
275 CombiningKindAttr::get(kind, builder.getContext()));
276 }
277
verify(MultiDimReductionOp op)278 static LogicalResult verify(MultiDimReductionOp op) {
279 auto reductionMask = op.getReductionMask();
280 auto targetType = MultiDimReductionOp::inferDestType(
281 op.getSourceVectorType().getShape(), reductionMask,
282 op.getSourceVectorType().getElementType());
283 // TODO: update to support 0-d vectors when available.
284 if (targetType != op.getDestType())
285 return op.emitError("invalid output vector type: ")
286 << op.getDestType() << " (expected: " << targetType << ")";
287 return success();
288 }
289
fold(ArrayRef<Attribute> operands)290 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
291 // Single parallel dim, this is a noop.
292 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
293 return source();
294 return {};
295 }
296
297 //===----------------------------------------------------------------------===//
298 // ReductionOp
299 //===----------------------------------------------------------------------===//
300
verify(ReductionOp op)301 static LogicalResult verify(ReductionOp op) {
302 // Verify for 1-D vector.
303 int64_t rank = op.getVectorType().getRank();
304 if (rank != 1)
305 return op.emitOpError("unsupported reduction rank: ") << rank;
306
307 // Verify supported reduction kind.
308 StringRef strKind = op.kind();
309 auto maybeKind = symbolizeCombiningKind(strKind);
310 if (!maybeKind)
311 return op.emitOpError("unknown reduction kind: ") << strKind;
312
313 Type eltType = op.dest().getType();
314 if (!isSupportedCombiningKind(*maybeKind, eltType))
315 return op.emitOpError("unsupported reduction type '")
316 << eltType << "' for kind '" << op.kind() << "'";
317
318 // Verify optional accumulator.
319 if (!op.acc().empty()) {
320 if (strKind != "add" && strKind != "mul")
321 return op.emitOpError("no accumulator for reduction kind: ") << strKind;
322 if (!eltType.isa<FloatType>())
323 return op.emitOpError("no accumulator for type: ") << eltType;
324 }
325
326 return success();
327 }
328
parseReductionOp(OpAsmParser & parser,OperationState & result)329 static ParseResult parseReductionOp(OpAsmParser &parser,
330 OperationState &result) {
331 SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
332 Type redType;
333 Type resType;
334 Attribute attr;
335 if (parser.parseAttribute(attr, "kind", result.attributes) ||
336 parser.parseComma() || parser.parseOperandList(operandsInfo) ||
337 parser.parseColonType(redType) ||
338 parser.parseKeywordType("into", resType) ||
339 (operandsInfo.size() > 0 &&
340 parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
341 (operandsInfo.size() > 1 &&
342 parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
343 parser.addTypeToList(resType, result.types))
344 return failure();
345 if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
346 return parser.emitError(parser.getNameLoc(),
347 "unsupported number of operands");
348 return success();
349 }
350
print(OpAsmPrinter & p,ReductionOp op)351 static void print(OpAsmPrinter &p, ReductionOp op) {
352 p << " \"" << op.kind() << "\", " << op.vector();
353 if (!op.acc().empty())
354 p << ", " << op.acc();
355 p << " : " << op.vector().getType() << " into " << op.dest().getType();
356 }
357
getVectorReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value vector)358 Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
359 Location loc, Value vector) {
360 Type scalarType = vector.getType().cast<ShapedType>().getElementType();
361 switch (op) {
362 case AtomicRMWKind::addf:
363 case AtomicRMWKind::addi:
364 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
365 builder.getStringAttr("add"),
366 vector, ValueRange{});
367 case AtomicRMWKind::mulf:
368 case AtomicRMWKind::muli:
369 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
370 builder.getStringAttr("mul"),
371 vector, ValueRange{});
372 case AtomicRMWKind::minf:
373 case AtomicRMWKind::mins:
374 case AtomicRMWKind::minu:
375 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
376 builder.getStringAttr("min"),
377 vector, ValueRange{});
378 case AtomicRMWKind::maxf:
379 case AtomicRMWKind::maxs:
380 case AtomicRMWKind::maxu:
381 return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
382 builder.getStringAttr("max"),
383 vector, ValueRange{});
384 // TODO: Add remaining reduction operations.
385 default:
386 (void)emitOptionalError(loc, "Reduction operation type not supported");
387 break;
388 }
389 return nullptr;
390 }
391
392 //===----------------------------------------------------------------------===//
393 // ContractionOp
394 //===----------------------------------------------------------------------===//
395
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)396 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
397 Value lhs, Value rhs, Value acc,
398 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
399 ArrayRef<StringRef> iteratorTypes) {
400 result.addOperands({lhs, rhs, acc});
401 result.addTypes(acc.getType());
402 result.addAttribute(getIndexingMapsAttrName(),
403 builder.getAffineMapArrayAttr(
404 AffineMap::inferFromExprList(indexingExprs)));
405 result.addAttribute(getIteratorTypesAttrName(),
406 builder.getStrArrayAttr(iteratorTypes));
407 }
408
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)409 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
410 Value lhs, Value rhs, Value acc,
411 ArrayAttr indexingMaps,
412 ArrayAttr iteratorTypes) {
413 result.addOperands({lhs, rhs, acc});
414 result.addTypes(acc.getType());
415 result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
416 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
417 result.addAttribute(ContractionOp::getKindAttrName(),
418 CombiningKindAttr::get(ContractionOp::getDefaultKind(),
419 builder.getContext()));
420 }
421
parseContractionOp(OpAsmParser & parser,OperationState & result)422 static ParseResult parseContractionOp(OpAsmParser &parser,
423 OperationState &result) {
424 OpAsmParser::OperandType lhsInfo;
425 OpAsmParser::OperandType rhsInfo;
426 OpAsmParser::OperandType accInfo;
427 SmallVector<OpAsmParser::OperandType, 2> masksInfo;
428 SmallVector<Type, 2> types;
429 Type resultType;
430 auto loc = parser.getCurrentLocation();
431 DictionaryAttr dictAttr;
432 // TODO: Unify linalg op attribute parsing.
433 if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
434 parser.parseOperand(lhsInfo) || parser.parseComma() ||
435 parser.parseOperand(rhsInfo) || parser.parseComma() ||
436 parser.parseOperand(accInfo) ||
437 parser.parseTrailingOperandList(masksInfo) ||
438 parser.parseOptionalAttrDict(result.attributes) ||
439 parser.parseColonTypeList(types) ||
440 parser.parseKeywordType("into", resultType) ||
441 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
442 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
443 parser.resolveOperand(accInfo, resultType, result.operands) ||
444 parser.addTypeToList(resultType, result.types))
445 return failure();
446 result.attributes.assign(dictAttr.getValue().begin(),
447 dictAttr.getValue().end());
448 if (!result.attributes.get(ContractionOp::getKindAttrName())) {
449 result.addAttribute(ContractionOp::getKindAttrName(),
450 CombiningKindAttr::get(ContractionOp::getDefaultKind(),
451 result.getContext()));
452 }
453 if (masksInfo.empty())
454 return success();
455 if (masksInfo.size() != 2)
456 return parser.emitError(parser.getNameLoc(),
457 "expected zero or exactly 2 vector mask operands");
458 auto lhsType = types[0].cast<VectorType>();
459 auto rhsType = types[1].cast<VectorType>();
460 auto maskElementType = parser.getBuilder().getI1Type();
461 std::array<Type, 2> maskTypes = {
462 VectorType::get(lhsType.getShape(), maskElementType),
463 VectorType::get(rhsType.getShape(), maskElementType)};
464 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
465 return failure();
466 return success();
467 }
468
print(OpAsmPrinter & p,ContractionOp op)469 static void print(OpAsmPrinter &p, ContractionOp op) {
470 // TODO: Unify printing code with linalg ops.
471 auto attrNames = op.getTraitAttrNames();
472 llvm::StringSet<> traitAttrsSet;
473 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
474 SmallVector<NamedAttribute, 8> attrs;
475 for (auto attr : op->getAttrs())
476 if (traitAttrsSet.count(attr.first.strref()) > 0)
477 attrs.push_back(attr);
478
479 auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
480 p << " " << dictAttr << " " << op.lhs() << ", ";
481 p << op.rhs() << ", " << op.acc();
482 if (op.masks().size() == 2)
483 p << ", " << op.masks();
484
485 p.printOptionalAttrDict(op->getAttrs(), attrNames);
486 p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
487 << op.getResultType();
488 }
489
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)490 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
491 const std::vector<std::pair<int64_t, int64_t>> &map) {
492 for (auto &dimPair : map) {
493 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
494 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
495 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
496 return false;
497 }
498 return true;
499 }
500
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)501 static LogicalResult verifyOutputShape(
502 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
503 Type resType,
504 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
505 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
506 DenseSet<int64_t> lhsContractingDimSet;
507 DenseSet<int64_t> rhsContractingDimSet;
508 for (auto &dimPair : contractingDimMap) {
509 lhsContractingDimSet.insert(dimPair.first);
510 rhsContractingDimSet.insert(dimPair.second);
511 }
512 DenseSet<int64_t> rhsBatchDimSet;
513 for (auto &dimPair : batchDimMap)
514 rhsBatchDimSet.insert(dimPair.second);
515
516 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
517 SmallVector<int64_t, 4> expectedResultDims;
518 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
519 if (lhsContractingDimSet.count(i) > 0)
520 continue;
521 expectedResultDims.push_back(lhsType.getDimSize(i));
522 }
523
524 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
525 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
526 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
527 continue;
528 expectedResultDims.push_back(rhsType.getDimSize(i));
529 }
530
531 // Verify 'expectedResultDims'.
532 if (expectedResultDims.size() == 0) {
533 // No batch or free dimension implies a scalar result.
534 if (resType.isa<VectorType>() || accType.isa<VectorType>())
535 return op.emitOpError("invalid accumulator/result vector shape");
536 } else {
537 // At least one batch or free dimension implies a vector result.
538 auto resVectorType = resType.dyn_cast<VectorType>();
539 auto accVectorType = accType.dyn_cast<VectorType>();
540 if (!resVectorType || !accVectorType)
541 return op.emitOpError("invalid accumulator/result vector shape");
542
543 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
544 // types fully define the result vector type. This assumes the affine maps
545 // are well-formed, which must have been verified already.
546 MLIRContext *ctx = op.getContext();
547 AffineMap lhsMap = op.getIndexingMaps()[0];
548 AffineMap rhsMap = op.getIndexingMaps()[1];
549 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
550 for (auto pair :
551 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
552 VectorType v = pair.first;
553 auto map = pair.second;
554 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
555 unsigned pos = map.getDimPosition(idx);
556 if (!extents[pos])
557 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
558 }
559 }
560 assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
561 "expected extent along all dimensions.");
562
563 AffineMap resMap = op.getIndexingMaps()[2];
564 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
565 /*symCount=*/0, extents, ctx);
566 // Compose the resMap with the extentsMap, which is a constant map.
567 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
568 assert(llvm::all_of(
569 expectedMap.getResults(),
570 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
571 "expected constant extent along all dimensions.");
572 // Extract the expected shape and build the type.
573 auto expectedShape = llvm::to_vector<4>(
574 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
575 return e.cast<AffineConstantExpr>().getValue();
576 }));
577 auto expected =
578 VectorType::get(expectedShape, resVectorType.getElementType());
579 if (resVectorType != expected || accVectorType != expected)
580 return op.emitOpError(
581 "invalid accumulator/result vector shape, expected: ")
582 << expected;
583 }
584 return success();
585 }
586
verify(ContractionOp op)587 static LogicalResult verify(ContractionOp op) {
588 auto lhsType = op.getLhsType();
589 auto rhsType = op.getRhsType();
590 auto accType = op.getAccType();
591 auto resType = op.getResultType();
592
593 // Verify that an indexing map was specified for each vector operand.
594 if (op.indexing_maps().size() != 3)
595 return op.emitOpError("expected an indexing map for each vector operand");
596
597 // Verify that each index map has 'numIterators' inputs, no symbols, and
598 // that the number of map outputs equals the rank of its associated
599 // vector operand.
600 unsigned numIterators = op.iterator_types().getValue().size();
601 for (auto it : llvm::enumerate(op.indexing_maps())) {
602 auto index = it.index();
603 auto map = it.value().cast<AffineMapAttr>().getValue();
604 if (map.getNumSymbols() != 0)
605 return op.emitOpError("expected indexing map ")
606 << index << " to have no symbols";
607 auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
608 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
609 // Verify that the map has the right number of inputs, outputs, and indices.
610 // This also correctly accounts for (..) -> () for rank-0 results.
611 if (map.getNumDims() != numIterators)
612 return op.emitOpError("expected indexing map ")
613 << index << " to have " << numIterators << " number of inputs";
614 if (map.getNumResults() != rank)
615 return op.emitOpError("expected indexing map ")
616 << index << " to have " << rank << " number of outputs";
617 if (!map.isProjectedPermutation())
618 return op.emitOpError("expected indexing map ")
619 << index << " to be a projected permutation of its inputs";
620 }
621
622 auto contractingDimMap = op.getContractingDimMap();
623 auto batchDimMap = op.getBatchDimMap();
624
625 // Verify at least one contracting dimension pair was specified.
626 if (contractingDimMap.empty())
627 return op.emitOpError("expected at least one contracting dimension pair");
628
629 // Verify contracting dimension map was properly constructed.
630 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
631 return op.emitOpError("invalid contracting dimension map");
632
633 // Verify batch dimension map was properly constructed.
634 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
635 return op.emitOpError("invalid batch dimension map");
636
637 // Verify 'accType' and 'resType' shape.
638 if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
639 contractingDimMap, batchDimMap)))
640 return failure();
641
642 // Verify that either two vector masks are set or none are set.
643 auto lhsMaskType = op.getLHSVectorMaskType();
644 auto rhsMaskType = op.getRHSVectorMaskType();
645 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
646 return op.emitOpError("invalid number of vector masks specified");
647 if (lhsMaskType && rhsMaskType) {
648 // Verify mask rank == argument rank.
649 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
650 rhsMaskType.getShape().size() != rhsType.getShape().size())
651 return op.emitOpError("invalid vector mask rank");
652 }
653
654 // Verify supported combining kind.
655 auto vectorType = resType.dyn_cast<VectorType>();
656 auto elementType = vectorType ? vectorType.getElementType() : resType;
657 if (!isSupportedCombiningKind(op.kind(), elementType))
658 return op.emitOpError("unsupported contraction type");
659
660 return success();
661 }
662
getTraitAttrNames()663 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
664 static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
665 getIteratorTypesAttrName(),
666 ContractionOp::getKindAttrName()};
667 return llvm::makeArrayRef(names);
668 }
669
getResultIndex(AffineMap map,AffineExpr targetExpr)670 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
671 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
672 if (targetExpr == map.getResult(i))
673 return i;
674 return -1;
675 }
676
677 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)678 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
679 StringRef targetIteratorTypeName, MLIRContext *context) {
680 std::vector<std::pair<int64_t, int64_t>> dimMap;
681 for (auto it : llvm::enumerate(iteratorTypes)) {
682 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
683 if (iteratorTypeName != targetIteratorTypeName)
684 continue;
685 // Search lhs/rhs map results for 'targetExpr'.
686 auto targetExpr = getAffineDimExpr(it.index(), context);
687 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
688 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
689 if (lhsDim >= 0 && rhsDim >= 0)
690 dimMap.push_back({lhsDim, rhsDim});
691 }
692 return dimMap;
693 }
694
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)695 void ContractionOp::getIterationBounds(
696 SmallVectorImpl<int64_t> &iterationBounds) {
697 auto lhsShape = getLhsType().getShape();
698 auto resVectorType = getResultType().dyn_cast<VectorType>();
699 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
700 SmallVector<int64_t, 2> iterationShape;
701 for (auto it : llvm::enumerate(iterator_types())) {
702 // Search lhs/rhs map results for 'targetExpr'.
703 auto targetExpr = getAffineDimExpr(it.index(), getContext());
704 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
705 if (iteratorTypeName == getReductionIteratorTypeName()) {
706 // Get reduction dim size from lhs shape (same size in rhsShape).
707 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
708 assert(lhsDimIndex >= 0);
709 iterationBounds.push_back(lhsShape[lhsDimIndex]);
710 continue;
711 }
712 // Get parallel dimension size from result shape.
713 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
714 assert(resDimIndex >= 0);
715 assert(resVectorType != nullptr);
716 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
717 }
718 }
719
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)720 void ContractionOp::getIterationIndexMap(
721 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
722 unsigned numMaps = indexing_maps().getValue().size();
723 iterationIndexMap.resize(numMaps);
724 for (auto it : llvm::enumerate(indexing_maps())) {
725 auto index = it.index();
726 auto map = it.value().cast<AffineMapAttr>().getValue();
727 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
728 auto dim = map.getResult(i).cast<AffineDimExpr>();
729 iterationIndexMap[index][dim.getPosition()] = i;
730 }
731 }
732 }
733
getContractingDimMap()734 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
735 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
736 return getDimMap(indexingMaps, iterator_types(),
737 getReductionIteratorTypeName(), getContext());
738 }
739
getBatchDimMap()740 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
741 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
742 return getDimMap(indexingMaps, iterator_types(),
743 getParallelIteratorTypeName(), getContext());
744 }
745
getIndexingMaps()746 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
747 return llvm::to_vector<4>(
748 llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
749 return mapAttr.cast<AffineMapAttr>().getValue();
750 }));
751 }
752
getShapeForUnroll()753 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
754 SmallVector<int64_t, 4> shape;
755 getIterationBounds(shape);
756 return shape;
757 }
758
759 /// Return a fused vector::ContractionOp which represents a patterns such as:
760 ///
761 /// ```mlir
762 /// %c0 = vector.constant 0: ...
763 /// %c = vector.contract %a, %b, %c0: ...
764 /// %e = add %c, %d: ...
765 /// ```
766 ///
767 /// by:
768 ///
769 /// ```mlir
770 /// %e = vector.contract %a, %b, %d: ...
771 /// ```
772 ///
773 /// Return null if the canonicalization does not apply.
774 // TODO: This should be a folding of Add into Contract in core but while they
775 // live in different dialects, it is not possible without unnatural
776 // dependencies.
777 template <typename AddOpType>
778 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
779 using OpRewritePattern<AddOpType>::OpRewritePattern;
780
matchAndRewriteCanonicalizeContractAdd781 LogicalResult matchAndRewrite(AddOpType addOp,
782 PatternRewriter &rewriter) const override {
783 auto canonicalize = [&](Value maybeContraction,
784 Value otherOperand) -> vector::ContractionOp {
785 vector::ContractionOp contractionOp =
786 dyn_cast_or_null<vector::ContractionOp>(
787 maybeContraction.getDefiningOp());
788 if (!contractionOp)
789 return vector::ContractionOp();
790 if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
791 contractionOp.acc().getDefiningOp())) {
792 if (maybeZero.value() ==
793 rewriter.getZeroAttr(contractionOp.acc().getType())) {
794 BlockAndValueMapping bvm;
795 bvm.map(contractionOp.acc(), otherOperand);
796 auto newContraction =
797 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
798 rewriter.replaceOp(addOp, newContraction.getResult());
799 return newContraction;
800 }
801 }
802 return vector::ContractionOp();
803 };
804
805 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
806 vector::ContractionOp contract = canonicalize(a, b);
807 contract = contract ? contract : canonicalize(b, a);
808 return contract ? success() : failure();
809 }
810 };
811
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)812 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
813 MLIRContext *context) {
814 results.add<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
815 context);
816 }
817
818 //===----------------------------------------------------------------------===//
819 // ExtractElementOp
820 //===----------------------------------------------------------------------===//
821
build(OpBuilder & builder,OperationState & result,Value source,Value position)822 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
823 Value source, Value position) {
824 result.addOperands({source, position});
825 result.addTypes(source.getType().cast<VectorType>().getElementType());
826 }
827
build(OpBuilder & builder,OperationState & result,Value source,int64_t position)828 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
829 Value source, int64_t position) {
830 Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
831 build(builder, result, source, pos);
832 }
833
verify(vector::ExtractElementOp op)834 static LogicalResult verify(vector::ExtractElementOp op) {
835 VectorType vectorType = op.getVectorType();
836 if (vectorType.getRank() != 1)
837 return op.emitOpError("expected 1-D vector");
838 return success();
839 }
840
841 //===----------------------------------------------------------------------===//
842 // ExtractOp
843 //===----------------------------------------------------------------------===//
844
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)845 static Type inferExtractOpResultType(VectorType vectorType,
846 ArrayAttr position) {
847 if (static_cast<int64_t>(position.size()) == vectorType.getRank())
848 return vectorType.getElementType();
849 return VectorType::get(vectorType.getShape().drop_front(position.size()),
850 vectorType.getElementType());
851 }
852
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)853 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
854 Value source, ArrayRef<int64_t> position) {
855 result.addOperands(source);
856 auto positionAttr = getVectorSubscriptAttr(builder, position);
857 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
858 positionAttr));
859 result.addAttribute(getPositionAttrName(), positionAttr);
860 }
861
862 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)863 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
864 Value source, ValueRange position) {
865 SmallVector<int64_t, 4> positionConstants =
866 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
867 return pos.getDefiningOp<ConstantIndexOp>().getValue();
868 }));
869 build(builder, result, source, positionConstants);
870 }
871
print(OpAsmPrinter & p,vector::ExtractOp op)872 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
873 p << " " << op.vector() << op.position();
874 p.printOptionalAttrDict(op->getAttrs(), {"position"});
875 p << " : " << op.vector().getType();
876 }
877
parseExtractOp(OpAsmParser & parser,OperationState & result)878 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
879 llvm::SMLoc attributeLoc, typeLoc;
880 NamedAttrList attrs;
881 OpAsmParser::OperandType vector;
882 Type type;
883 Attribute attr;
884 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
885 parser.parseAttribute(attr, "position", attrs) ||
886 parser.parseOptionalAttrDict(attrs) ||
887 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
888 return failure();
889
890 auto vectorType = type.dyn_cast<VectorType>();
891 if (!vectorType)
892 return parser.emitError(typeLoc, "expected vector type");
893
894 auto positionAttr = attr.dyn_cast<ArrayAttr>();
895 if (!positionAttr ||
896 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
897 return parser.emitError(
898 attributeLoc,
899 "expected position attribute of rank smaller than vector rank");
900
901 Type resType = inferExtractOpResultType(vectorType, positionAttr);
902 result.attributes = attrs;
903 return failure(parser.resolveOperand(vector, type, result.operands) ||
904 parser.addTypeToList(resType, result.types));
905 }
906
verify(vector::ExtractOp op)907 static LogicalResult verify(vector::ExtractOp op) {
908 auto positionAttr = op.position().getValue();
909 if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
910 return op.emitOpError(
911 "expected position attribute of rank smaller than vector rank");
912 for (auto en : llvm::enumerate(positionAttr)) {
913 auto attr = en.value().dyn_cast<IntegerAttr>();
914 if (!attr || attr.getInt() < 0 ||
915 attr.getInt() >= op.getVectorType().getDimSize(en.index()))
916 return op.emitOpError("expected position attribute #")
917 << (en.index() + 1)
918 << " to be a non-negative integer smaller than the corresponding "
919 "vector dimension";
920 }
921 return success();
922 }
923
924 template <typename IntType>
extractVector(ArrayAttr arrayAttr)925 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
926 return llvm::to_vector<4>(llvm::map_range(
927 arrayAttr.getAsRange<IntegerAttr>(),
928 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
929 }
930
931 /// Fold the result of chains of ExtractOp in place by simply concatenating the
932 /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)933 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
934 if (!extractOp.vector().getDefiningOp<ExtractOp>())
935 return failure();
936
937 SmallVector<int64_t, 4> globalPosition;
938 ExtractOp currentOp = extractOp;
939 auto extractedPos = extractVector<int64_t>(currentOp.position());
940 globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
941 while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
942 currentOp = nextOp;
943 auto extractedPos = extractVector<int64_t>(currentOp.position());
944 globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
945 }
946 extractOp.setOperand(currentOp.vector());
947 // OpBuilder is only used as a helper to build an I64ArrayAttr.
948 OpBuilder b(extractOp.getContext());
949 std::reverse(globalPosition.begin(), globalPosition.end());
950 extractOp->setAttr(ExtractOp::getPositionAttrName(),
951 b.getI64ArrayAttr(globalPosition));
952 return success();
953 }
954
955 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
foldExtractOpFromTranspose(ExtractOp extractOp)956 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
957 auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
958 if (!transposeOp)
959 return failure();
960
961 auto permutation = extractVector<unsigned>(transposeOp.transp());
962 auto extractedPos = extractVector<int64_t>(extractOp.position());
963
964 // If transposition permutation is larger than the ExtractOp, all minor
965 // dimensions must be an identity for folding to occur. If not, individual
966 // elements within the extracted value are transposed and this is not just a
967 // simple folding.
968 unsigned minorRank = permutation.size() - extractedPos.size();
969 MLIRContext *ctx = extractOp.getContext();
970 AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
971 AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
972 if (minorMap && !minorMap.isMinorIdentity())
973 return failure();
974
975 // %1 = transpose %0[x, y, z] : vector<axbxcxf32>
976 // %2 = extract %1[u, v] : vector<..xf32>
977 // may turn into:
978 // %2 = extract %0[w, x] : vector<..xf32>
979 // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
980 // -1 denotes the inverse.
981 permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
982 // The major submap has fewer results but the same number of dims. To compose
983 // cleanly, we need to drop dims to form a "square matrix". This is possible
984 // because:
985 // (a) this is a permutation map and
986 // (b) the minor map has already been checked to be identity.
987 // Therefore, the major map cannot contain dims of position greater or equal
988 // than the number of results.
989 assert(llvm::all_of(permutationMap.getResults(),
990 [&](AffineExpr e) {
991 auto dim = e.dyn_cast<AffineDimExpr>();
992 return dim && dim.getPosition() <
993 permutationMap.getNumResults();
994 }) &&
995 "Unexpected map results depend on higher rank positions");
996 // Project on the first domain dimensions to allow composition.
997 permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
998 permutationMap.getResults(), ctx);
999
1000 extractOp.setOperand(transposeOp.vector());
1001 // Compose the inverse permutation map with the extractedPos.
1002 auto newExtractedPos =
1003 inversePermutation(permutationMap).compose(extractedPos);
1004 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1005 OpBuilder b(extractOp.getContext());
1006 extractOp->setAttr(ExtractOp::getPositionAttrName(),
1007 b.getI64ArrayAttr(newExtractedPos));
1008
1009 return success();
1010 }
1011
1012 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
1013 /// result is always the input to some InsertOp.
foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp)1014 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
1015 MLIRContext *context = extractOp.getContext();
1016 AffineMap permutationMap;
1017 auto extractedPos = extractVector<unsigned>(extractOp.position());
1018 // Walk back a chain of InsertOp/TransposeOp until we hit a match.
1019 // Compose TransposeOp permutations as we walk back.
1020 auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
1021 auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
1022 while (insertOp || transposeOp) {
1023 if (transposeOp) {
1024 // If it is transposed, compose the map and iterate.
1025 auto permutation = extractVector<unsigned>(transposeOp.transp());
1026 AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
1027 if (!permutationMap)
1028 permutationMap = newMap;
1029 else if (newMap.getNumInputs() != permutationMap.getNumResults())
1030 return Value();
1031 else
1032 permutationMap = newMap.compose(permutationMap);
1033 // Compute insert/transpose for the next iteration.
1034 Value transposed = transposeOp.vector();
1035 insertOp = transposed.getDefiningOp<vector::InsertOp>();
1036 transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
1037 continue;
1038 }
1039
1040 assert(insertOp);
1041 Value insertionDest = insertOp.dest();
1042 // If it is inserted into, either the position matches and we have a
1043 // successful folding; or we iterate until we run out of
1044 // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
1045 // produces a new vector with 1 modified value/slice in exactly the static
1046 // position we need to match.
1047 auto insertedPos = extractVector<unsigned>(insertOp.position());
1048 // Trivial permutations are solved with position equality checks.
1049 if (!permutationMap || permutationMap.isIdentity()) {
1050 if (extractedPos == insertedPos)
1051 return insertOp.source();
1052 // Fallthrough: if the position does not match, just skip to the next
1053 // producing `vector.insert` / `vector.transpose`.
1054 // Compute insert/transpose for the next iteration.
1055 insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1056 transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1057 continue;
1058 }
1059
1060 // More advanced permutations require application of the permutation.
1061 // However, the rank of `insertedPos` may be different from that of the
1062 // `permutationMap`. To support such case, we need to:
1063 // 1. apply on the `insertedPos.size()` major dimensions
1064 // 2. check the other dimensions of the permutation form a minor identity.
1065 assert(permutationMap.isPermutation() && "expected a permutation");
1066 if (insertedPos.size() == extractedPos.size()) {
1067 bool fold = true;
1068 for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
1069 auto pos = permutationMap.getDimPosition(idx);
1070 if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
1071 fold = false;
1072 break;
1073 }
1074 }
1075 if (fold) {
1076 assert(permutationMap.getNumResults() >= insertedPos.size() &&
1077 "expected map of rank larger than insert indexing");
1078 unsigned minorRank =
1079 permutationMap.getNumResults() - insertedPos.size();
1080 AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
1081 if (!minorMap || minorMap.isMinorIdentity())
1082 return insertOp.source();
1083 }
1084 }
1085
1086 // If we haven't found a match, just continue to the next producing
1087 // `vector.insert` / `vector.transpose`.
1088 // Compute insert/transpose for the next iteration.
1089 insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1090 transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1091 }
1092 return Value();
1093 }
1094
1095 /// Fold extractOp with scalar result coming from BroadcastOp.
foldExtractFromBroadcast(ExtractOp extractOp)1096 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1097 auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
1098 if (!broadcastOp)
1099 return Value();
1100 if (extractOp.getType() == broadcastOp.getSourceType())
1101 return broadcastOp.source();
1102 auto getRank = [](Type type) {
1103 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1104 };
1105 unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
1106 unsigned extractResultRank = getRank(extractOp.getType());
1107 if (extractResultRank < broadcasrSrcRank) {
1108 auto extractPos = extractVector<int64_t>(extractOp.position());
1109 unsigned rankDiff = broadcasrSrcRank - extractResultRank;
1110 extractPos.erase(
1111 extractPos.begin(),
1112 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1113 extractOp.setOperand(broadcastOp.source());
1114 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1115 OpBuilder b(extractOp.getContext());
1116 extractOp->setAttr(ExtractOp::getPositionAttrName(),
1117 b.getI64ArrayAttr(extractPos));
1118 return extractOp.getResult();
1119 }
1120 // TODO: In case the rank of the broadcast source is greater than the rank of
1121 // the extract result this can be combined into a new broadcast op. This needs
1122 // to be added a canonicalization pattern if needed.
1123 return Value();
1124 }
1125
1126 // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)1127 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1128 auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
1129 if (!shapeCastOp)
1130 return Value();
1131 // Get the nth dimension size starting from lowest dimension.
1132 auto getDimReverse = [](VectorType type, int64_t n) {
1133 return type.getShape().take_back(n + 1).front();
1134 };
1135 int64_t destinationRank =
1136 extractOp.getType().isa<VectorType>()
1137 ? extractOp.getType().cast<VectorType>().getRank()
1138 : 0;
1139 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1140 return Value();
1141 if (destinationRank > 0) {
1142 auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1143 for (int64_t i = 0; i < destinationRank; i++) {
1144 // The lowest dimension of of the destination must match the lowest
1145 // dimension of the shapecast op source.
1146 // TODO: This case could be support in a canonicalization pattern.
1147 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1148 getDimReverse(destinationType, i))
1149 return Value();
1150 }
1151 }
1152 // Extract the strides associated with the extract op vector source. Then use
1153 // this to calculate a linearized position for the extract.
1154 auto extractedPos = extractVector<int64_t>(extractOp.position());
1155 std::reverse(extractedPos.begin(), extractedPos.end());
1156 SmallVector<int64_t, 4> strides;
1157 int64_t stride = 1;
1158 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1159 strides.push_back(stride);
1160 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1161 }
1162
1163 int64_t position = linearize(extractedPos, strides);
1164 // Then extract the strides associated to the shapeCast op vector source and
1165 // delinearize the position using those strides.
1166 SmallVector<int64_t, 4> newStrides;
1167 int64_t numDimension =
1168 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1169 stride = 1;
1170 for (int64_t i = 0; i < numDimension; i++) {
1171 newStrides.push_back(stride);
1172 stride *=
1173 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1174 }
1175 std::reverse(newStrides.begin(), newStrides.end());
1176 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1177 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1178 OpBuilder b(extractOp.getContext());
1179 extractOp->setAttr(ExtractOp::getPositionAttrName(),
1180 b.getI64ArrayAttr(newPosition));
1181 extractOp.setOperand(shapeCastOp.source());
1182 return extractOp.getResult();
1183 }
1184
fold(ArrayRef<Attribute>)1185 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1186 if (position().empty())
1187 return vector();
1188 if (succeeded(foldExtractOpFromExtractChain(*this)))
1189 return getResult();
1190 if (succeeded(foldExtractOpFromTranspose(*this)))
1191 return getResult();
1192 if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
1193 return val;
1194 if (auto val = foldExtractFromBroadcast(*this))
1195 return val;
1196 if (auto val = foldExtractFromShapeCast(*this))
1197 return val;
1198 return OpFoldResult();
1199 }
1200
1201 namespace {
1202
1203 // If extractOp is only removing unit dimensions it can be transformed to a
1204 // shapecast.
1205 class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
1206 public:
1207 using OpRewritePattern<ExtractOp>::OpRewritePattern;
1208
matchAndRewrite(ExtractOp extractOp,PatternRewriter & rewriter) const1209 LogicalResult matchAndRewrite(ExtractOp extractOp,
1210 PatternRewriter &rewriter) const override {
1211 auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
1212 if (!dstVecType || extractOp.getVectorType().getNumElements() !=
1213 dstVecType.getNumElements())
1214 return failure();
1215 rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
1216 extractOp.vector());
1217 return success();
1218 }
1219 };
1220
1221 } // namespace
1222
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1223 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1224 MLIRContext *context) {
1225 results.add<ExtractToShapeCast>(context);
1226 }
1227
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)1228 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1229 SmallVectorImpl<int64_t> &results) {
1230 for (auto attr : arrayAttr)
1231 results.push_back(attr.cast<IntegerAttr>().getInt());
1232 }
1233
1234 //===----------------------------------------------------------------------===//
1235 // ExtractMapOp
1236 //===----------------------------------------------------------------------===//
1237
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)1238 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1239 Value vector, ValueRange ids,
1240 ArrayRef<int64_t> multiplicity,
1241 AffineMap permutationMap) {
1242 assert(ids.size() == multiplicity.size() &&
1243 ids.size() == permutationMap.getNumResults());
1244 assert(permutationMap.isProjectedPermutation());
1245 VectorType type = vector.getType().cast<VectorType>();
1246 SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1247 type.getShape().end());
1248 for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1249 AffineExpr expr = permutationMap.getResult(i);
1250 auto dim = expr.cast<AffineDimExpr>();
1251 newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1252 }
1253 VectorType resultType = VectorType::get(newShape, type.getElementType());
1254 ExtractMapOp::build(builder, result, resultType, vector, ids);
1255 }
1256
verify(ExtractMapOp op)1257 static LogicalResult verify(ExtractMapOp op) {
1258 if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1259 return op.emitOpError(
1260 "expected source and destination vectors of same rank");
1261 unsigned numId = 0;
1262 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1263 if (op.getSourceVectorType().getDimSize(i) %
1264 op.getResultType().getDimSize(i) !=
1265 0)
1266 return op.emitOpError("source vector dimensions must be a multiple of "
1267 "destination vector dimensions");
1268 if (op.getSourceVectorType().getDimSize(i) !=
1269 op.getResultType().getDimSize(i))
1270 numId++;
1271 }
1272 if (numId != op.ids().size())
1273 return op.emitOpError("expected number of ids must match the number of "
1274 "dimensions distributed");
1275 return success();
1276 }
1277
fold(ArrayRef<Attribute> operands)1278 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1279 auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1280 if (insert == nullptr || getType() != insert.vector().getType() ||
1281 ids() != insert.ids())
1282 return {};
1283 return insert.vector();
1284 }
1285
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)1286 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1287 assert(multiplicity.empty());
1288 for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1289 if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1290 multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1291 getResultType().getDimSize(i));
1292 }
1293 }
1294
1295 template <typename MapOp>
calculateImplicitMap(MapOp op)1296 AffineMap calculateImplicitMap(MapOp op) {
1297 SmallVector<AffineExpr, 4> perm;
1298 // Check which dimension have a multiplicity greater than 1 and associated
1299 // them to the IDs in order.
1300 for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1301 if (op.getSourceVectorType().getDimSize(i) !=
1302 op.getResultType().getDimSize(i))
1303 perm.push_back(getAffineDimExpr(i, op.getContext()));
1304 }
1305 auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1306 op.getContext());
1307 return map;
1308 }
1309
map()1310 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1311
1312 //===----------------------------------------------------------------------===//
1313 // FmaOp
1314 //===----------------------------------------------------------------------===//
1315
getShapeForUnroll()1316 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1317 return llvm::to_vector<4>(getVectorType().getShape());
1318 }
1319
1320 //===----------------------------------------------------------------------===//
1321 // BroadcastOp
1322 //===----------------------------------------------------------------------===//
1323
1324 BroadcastableToResult
isBroadcastableTo(Type srcType,VectorType dstVectorType,std::pair<int,int> * mismatchingDims)1325 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1326 std::pair<int, int> *mismatchingDims) {
1327 // Broadcast scalar to vector of the same element type.
1328 if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1329 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1330 return BroadcastableToResult::Success;
1331 // From now on, only vectors broadcast.
1332 VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1333 if (!srcVectorType)
1334 return BroadcastableToResult::SourceTypeNotAVector;
1335
1336 int64_t srcRank = srcVectorType.getRank();
1337 int64_t dstRank = dstVectorType.getRank();
1338 if (srcRank > dstRank)
1339 return BroadcastableToResult::SourceRankHigher;
1340 // Source has an exact match or singleton value for all trailing dimensions
1341 // (all leading dimensions are simply duplicated).
1342 int64_t lead = dstRank - srcRank;
1343 for (int64_t r = 0; r < srcRank; ++r) {
1344 int64_t srcDim = srcVectorType.getDimSize(r);
1345 int64_t dstDim = dstVectorType.getDimSize(lead + r);
1346 if (srcDim != 1 && srcDim != dstDim) {
1347 if (mismatchingDims) {
1348 mismatchingDims->first = srcDim;
1349 mismatchingDims->second = dstDim;
1350 }
1351 return BroadcastableToResult::DimensionMismatch;
1352 }
1353 }
1354
1355 return BroadcastableToResult::Success;
1356 }
1357
verify(BroadcastOp op)1358 static LogicalResult verify(BroadcastOp op) {
1359 std::pair<int, int> mismatchingDims;
1360 BroadcastableToResult res = isBroadcastableTo(
1361 op.getSourceType(), op.getVectorType(), &mismatchingDims);
1362 if (res == BroadcastableToResult::Success)
1363 return success();
1364 if (res == BroadcastableToResult::SourceRankHigher)
1365 return op.emitOpError("source rank higher than destination rank");
1366 if (res == BroadcastableToResult::DimensionMismatch)
1367 return op.emitOpError("dimension mismatch (")
1368 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1369 if (res == BroadcastableToResult::SourceTypeNotAVector)
1370 return op.emitOpError("source type is not a vector");
1371 llvm_unreachable("unexpected vector.broadcast op error");
1372 }
1373
fold(ArrayRef<Attribute> operands)1374 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1375 if (getSourceType() == getVectorType())
1376 return source();
1377 if (!operands[0])
1378 return {};
1379 auto vectorType = getVectorType();
1380 if (operands[0].getType().isIntOrIndexOrFloat())
1381 return DenseElementsAttr::get(vectorType, operands[0]);
1382 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1383 return DenseElementsAttr::get(vectorType, attr.getSplatValue());
1384 return {};
1385 }
1386
1387 namespace {
1388
1389 // BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
1390 // the degenerated case where the broadcast only adds dimensions of size 1 it
1391 // can be replaced by a ShapeCastOp. This canonicalization checks if the total
1392 // number of elements is the same before and after the broadcast to detect if
1393 // the only change in the vector type are new dimensions of size 1.
1394 class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
1395 public:
1396 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1397
matchAndRewrite(BroadcastOp broadcastOp,PatternRewriter & rewriter) const1398 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1399 PatternRewriter &rewriter) const override {
1400 auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
1401 if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
1402 srcVecType.getNumElements())
1403 return failure();
1404 rewriter.replaceOpWithNewOp<ShapeCastOp>(
1405 broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
1406 return success();
1407 }
1408 };
1409
1410 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1411 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1412 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1413
matchAndRewrite__anona11d586e0e11::BroadcastFolder1414 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1415 PatternRewriter &rewriter) const override {
1416 auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
1417 if (!srcBroadcast)
1418 return failure();
1419 rewriter.replaceOpWithNewOp<BroadcastOp>(
1420 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
1421 return success();
1422 }
1423 };
1424 } // namespace
1425
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1426 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1427 MLIRContext *context) {
1428 results.add<BroadcastToShapeCast, BroadcastFolder>(context);
1429 }
1430
1431 //===----------------------------------------------------------------------===//
1432 // ShuffleOp
1433 //===----------------------------------------------------------------------===//
1434
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)1435 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1436 Value v2, ArrayRef<int64_t> mask) {
1437 result.addOperands({v1, v2});
1438 auto maskAttr = getVectorSubscriptAttr(builder, mask);
1439 result.addTypes(v1.getType());
1440 result.addAttribute(getMaskAttrName(), maskAttr);
1441 }
1442
print(OpAsmPrinter & p,ShuffleOp op)1443 static void print(OpAsmPrinter &p, ShuffleOp op) {
1444 p << " " << op.v1() << ", " << op.v2() << " " << op.mask();
1445 p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()});
1446 p << " : " << op.v1().getType() << ", " << op.v2().getType();
1447 }
1448
verify(ShuffleOp op)1449 static LogicalResult verify(ShuffleOp op) {
1450 VectorType resultType = op.getVectorType();
1451 VectorType v1Type = op.getV1VectorType();
1452 VectorType v2Type = op.getV2VectorType();
1453 // Verify ranks.
1454 int64_t resRank = resultType.getRank();
1455 int64_t v1Rank = v1Type.getRank();
1456 int64_t v2Rank = v2Type.getRank();
1457 if (resRank != v1Rank || v1Rank != v2Rank)
1458 return op.emitOpError("rank mismatch");
1459 // Verify all but leading dimension sizes.
1460 for (int64_t r = 1; r < v1Rank; ++r) {
1461 int64_t resDim = resultType.getDimSize(r);
1462 int64_t v1Dim = v1Type.getDimSize(r);
1463 int64_t v2Dim = v2Type.getDimSize(r);
1464 if (resDim != v1Dim || v1Dim != v2Dim)
1465 return op.emitOpError("dimension mismatch");
1466 }
1467 // Verify mask length.
1468 auto maskAttr = op.mask().getValue();
1469 int64_t maskLength = maskAttr.size();
1470 if (maskLength != resultType.getDimSize(0))
1471 return op.emitOpError("mask length mismatch");
1472 // Verify all indices.
1473 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1474 for (auto en : llvm::enumerate(maskAttr)) {
1475 auto attr = en.value().dyn_cast<IntegerAttr>();
1476 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1477 return op.emitOpError("mask index #")
1478 << (en.index() + 1) << " out of range";
1479 }
1480 return success();
1481 }
1482
parseShuffleOp(OpAsmParser & parser,OperationState & result)1483 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1484 OpAsmParser::OperandType v1, v2;
1485 Attribute attr;
1486 VectorType v1Type, v2Type;
1487 if (parser.parseOperand(v1) || parser.parseComma() ||
1488 parser.parseOperand(v2) ||
1489 parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1490 result.attributes) ||
1491 parser.parseOptionalAttrDict(result.attributes) ||
1492 parser.parseColonType(v1Type) || parser.parseComma() ||
1493 parser.parseType(v2Type) ||
1494 parser.resolveOperand(v1, v1Type, result.operands) ||
1495 parser.resolveOperand(v2, v2Type, result.operands))
1496 return failure();
1497 // Construct resulting type: leading dimension matches mask length,
1498 // all trailing dimensions match the operands.
1499 auto maskAttr = attr.dyn_cast<ArrayAttr>();
1500 if (!maskAttr)
1501 return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1502 int64_t maskLength = maskAttr.size();
1503 if (maskLength <= 0)
1504 return parser.emitError(parser.getNameLoc(), "invalid mask length");
1505 int64_t v1Rank = v1Type.getRank();
1506 SmallVector<int64_t, 4> shape;
1507 shape.reserve(v1Rank);
1508 shape.push_back(maskLength);
1509 for (int64_t r = 1; r < v1Rank; ++r)
1510 shape.push_back(v1Type.getDimSize(r));
1511 VectorType resType = VectorType::get(shape, v1Type.getElementType());
1512 parser.addTypeToList(resType, result.types);
1513 return success();
1514 }
1515
1516 //===----------------------------------------------------------------------===//
1517 // InsertElementOp
1518 //===----------------------------------------------------------------------===//
1519
build(OpBuilder & builder,OperationState & result,Value source,Value dest,Value position)1520 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1521 Value source, Value dest, Value position) {
1522 result.addOperands({source, dest, position});
1523 result.addTypes(dest.getType());
1524 }
1525
build(OpBuilder & builder,OperationState & result,Value source,Value dest,int64_t position)1526 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1527 Value source, Value dest, int64_t position) {
1528 Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
1529 build(builder, result, source, dest, pos);
1530 }
1531
verify(InsertElementOp op)1532 static LogicalResult verify(InsertElementOp op) {
1533 auto dstVectorType = op.getDestVectorType();
1534 if (dstVectorType.getRank() != 1)
1535 return op.emitOpError("expected 1-D vector");
1536 return success();
1537 }
1538
1539 //===----------------------------------------------------------------------===//
1540 // InsertOp
1541 //===----------------------------------------------------------------------===//
1542
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)1543 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1544 Value dest, ArrayRef<int64_t> position) {
1545 result.addOperands({source, dest});
1546 auto positionAttr = getVectorSubscriptAttr(builder, position);
1547 result.addTypes(dest.getType());
1548 result.addAttribute(getPositionAttrName(), positionAttr);
1549 }
1550
1551 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)1552 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1553 Value dest, ValueRange position) {
1554 SmallVector<int64_t, 4> positionConstants =
1555 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1556 return pos.getDefiningOp<ConstantIndexOp>().getValue();
1557 }));
1558 build(builder, result, source, dest, positionConstants);
1559 }
1560
verify(InsertOp op)1561 static LogicalResult verify(InsertOp op) {
1562 auto positionAttr = op.position().getValue();
1563 auto destVectorType = op.getDestVectorType();
1564 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1565 return op.emitOpError(
1566 "expected position attribute of rank smaller than dest vector rank");
1567 auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1568 if (srcVectorType &&
1569 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1570 static_cast<unsigned>(destVectorType.getRank())))
1571 return op.emitOpError("expected position attribute rank + source rank to "
1572 "match dest vector rank");
1573 else if (!srcVectorType && (positionAttr.size() !=
1574 static_cast<unsigned>(destVectorType.getRank())))
1575 return op.emitOpError(
1576 "expected position attribute rank to match the dest vector rank");
1577 for (auto en : llvm::enumerate(positionAttr)) {
1578 auto attr = en.value().dyn_cast<IntegerAttr>();
1579 if (!attr || attr.getInt() < 0 ||
1580 attr.getInt() >= destVectorType.getDimSize(en.index()))
1581 return op.emitOpError("expected position attribute #")
1582 << (en.index() + 1)
1583 << " to be a non-negative integer smaller than the corresponding "
1584 "dest vector dimension";
1585 }
1586 return success();
1587 }
1588
1589 namespace {
1590
1591 // If insertOp is only inserting unit dimensions it can be transformed to a
1592 // shapecast.
1593 class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
1594 public:
1595 using OpRewritePattern<InsertOp>::OpRewritePattern;
1596
matchAndRewrite(InsertOp insertOp,PatternRewriter & rewriter) const1597 LogicalResult matchAndRewrite(InsertOp insertOp,
1598 PatternRewriter &rewriter) const override {
1599 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1600 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1601 srcVecType.getNumElements())
1602 return failure();
1603 rewriter.replaceOpWithNewOp<ShapeCastOp>(
1604 insertOp, insertOp.getDestVectorType(), insertOp.source());
1605 return success();
1606 }
1607 };
1608
1609 } // namespace
1610
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1611 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1612 MLIRContext *context) {
1613 results.add<InsertToShapeCast>(context);
1614 }
1615
1616 // Eliminates insert operations that produce values identical to their source
1617 // value. This happens when the source and destination vectors have identical
1618 // sizes.
fold(ArrayRef<Attribute> operands)1619 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
1620 if (position().empty())
1621 return source();
1622 return {};
1623 }
1624
1625 //===----------------------------------------------------------------------===//
1626 // InsertMapOp
1627 //===----------------------------------------------------------------------===//
1628
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange ids)1629 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1630 Value vector, Value dest, ValueRange ids) {
1631 InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1632 }
1633
verify(InsertMapOp op)1634 static LogicalResult verify(InsertMapOp op) {
1635 if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1636 return op.emitOpError(
1637 "expected source and destination vectors of same rank");
1638 unsigned numId = 0;
1639 for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1640 if (op.getResultType().getDimSize(i) %
1641 op.getSourceVectorType().getDimSize(i) !=
1642 0)
1643 return op.emitOpError(
1644 "destination vector size must be a multiple of source vector size");
1645 if (op.getResultType().getDimSize(i) !=
1646 op.getSourceVectorType().getDimSize(i))
1647 numId++;
1648 }
1649 if (numId != op.ids().size())
1650 return op.emitOpError("expected number of ids must match the number of "
1651 "dimensions distributed");
1652 return success();
1653 }
1654
map()1655 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1656
1657 //===----------------------------------------------------------------------===//
1658 // InsertStridedSliceOp
1659 //===----------------------------------------------------------------------===//
1660
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)1661 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1662 Value source, Value dest,
1663 ArrayRef<int64_t> offsets,
1664 ArrayRef<int64_t> strides) {
1665 result.addOperands({source, dest});
1666 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1667 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1668 result.addTypes(dest.getType());
1669 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1670 result.addAttribute(getStridesAttrName(), stridesAttr);
1671 }
1672
1673 // TODO: Should be moved to Tablegen Confined attributes.
1674 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)1675 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1676 ArrayAttr arrayAttr,
1677 ArrayRef<int64_t> shape,
1678 StringRef attrName) {
1679 if (arrayAttr.size() > shape.size())
1680 return op.emitOpError("expected ")
1681 << attrName << " attribute of rank smaller than vector rank";
1682 return success();
1683 }
1684
1685 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1686 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1687 // Otherwise, the admissible interval is [min, max].
1688 template <typename OpType>
1689 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)1690 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1691 int64_t max, StringRef attrName,
1692 bool halfOpen = true) {
1693 for (auto attr : arrayAttr) {
1694 auto val = attr.cast<IntegerAttr>().getInt();
1695 auto upper = max;
1696 if (!halfOpen)
1697 upper += 1;
1698 if (val < min || val >= upper)
1699 return op.emitOpError("expected ") << attrName << " to be confined to ["
1700 << min << ", " << upper << ")";
1701 }
1702 return success();
1703 }
1704
1705 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1706 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1707 // Otherwise, the admissible interval is [min, max].
1708 template <typename OpType>
1709 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)1710 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1711 ArrayRef<int64_t> shape, StringRef attrName,
1712 bool halfOpen = true, int64_t min = 0) {
1713 assert(arrayAttr.size() <= shape.size());
1714 unsigned index = 0;
1715 for (auto it : llvm::zip(arrayAttr, shape)) {
1716 auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1717 auto max = std::get<1>(it);
1718 if (!halfOpen)
1719 max += 1;
1720 if (val < min || val >= max)
1721 return op.emitOpError("expected ")
1722 << attrName << " dimension " << index << " to be confined to ["
1723 << min << ", " << max << ")";
1724 ++index;
1725 }
1726 return success();
1727 }
1728
1729 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1730 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1731 // Otherwise, the admissible interval is [min, max].
1732 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)1733 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
1734 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1735 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1736 bool halfOpen = true, int64_t min = 1) {
1737 assert(arrayAttr1.size() <= shape.size());
1738 assert(arrayAttr2.size() <= shape.size());
1739 unsigned index = 0;
1740 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1741 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1742 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1743 auto max = std::get<2>(it);
1744 if (!halfOpen)
1745 max += 1;
1746 if (val1 + val2 < 0 || val1 + val2 >= max)
1747 return op.emitOpError("expected sum(")
1748 << attrName1 << ", " << attrName2 << ") dimension " << index
1749 << " to be confined to [" << min << ", " << max << ")";
1750 ++index;
1751 }
1752 return success();
1753 }
1754
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1755 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1756 MLIRContext *context) {
1757 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1758 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
1759 });
1760 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
1761 }
1762
verify(InsertStridedSliceOp op)1763 static LogicalResult verify(InsertStridedSliceOp op) {
1764 auto sourceVectorType = op.getSourceVectorType();
1765 auto destVectorType = op.getDestVectorType();
1766 auto offsets = op.offsets();
1767 auto strides = op.strides();
1768 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1769 return op.emitOpError(
1770 "expected offsets of same size as destination vector rank");
1771 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1772 return op.emitOpError(
1773 "expected strides of same size as source vector rank");
1774 if (sourceVectorType.getRank() > destVectorType.getRank())
1775 return op.emitOpError(
1776 "expected source rank to be smaller than destination rank");
1777
1778 auto sourceShape = sourceVectorType.getShape();
1779 auto destShape = destVectorType.getShape();
1780 SmallVector<int64_t, 4> sourceShapeAsDestShape(
1781 destShape.size() - sourceShape.size(), 0);
1782 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1783 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1784 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1785 if (failed(
1786 isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1787 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1788 /*halfOpen=*/false)) ||
1789 failed(isSumOfIntegerArrayAttrConfinedToShape(
1790 op, offsets,
1791 makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1792 offName, "source vector shape",
1793 /*halfOpen=*/false, /*min=*/1)))
1794 return failure();
1795
1796 return success();
1797 }
1798
fold(ArrayRef<Attribute> operands)1799 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
1800 if (getSourceVectorType() == getDestVectorType())
1801 return source();
1802 return {};
1803 }
1804
1805 //===----------------------------------------------------------------------===//
1806 // OuterProductOp
1807 //===----------------------------------------------------------------------===//
1808
1809 /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)1810 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
1811 Value lhs, Value rhs, Value acc) {
1812 result.addOperands({lhs, rhs, acc});
1813 result.addTypes(acc.getType());
1814 }
1815
print(OpAsmPrinter & p,OuterProductOp op)1816 static void print(OpAsmPrinter &p, OuterProductOp op) {
1817 p << " " << op.lhs() << ", " << op.rhs();
1818 if (!op.acc().empty()) {
1819 p << ", " << op.acc();
1820 p.printOptionalAttrDict(op->getAttrs());
1821 }
1822 p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1823 }
1824
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1825 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1826 OperationState &result) {
1827 SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1828 Type tLHS, tRHS;
1829 if (parser.parseOperandList(operandsInfo) ||
1830 parser.parseOptionalAttrDict(result.attributes) ||
1831 parser.parseColonType(tLHS) || parser.parseComma() ||
1832 parser.parseType(tRHS))
1833 return failure();
1834 if (operandsInfo.size() < 2)
1835 return parser.emitError(parser.getNameLoc(),
1836 "expected at least 2 operands");
1837 VectorType vLHS = tLHS.dyn_cast<VectorType>();
1838 VectorType vRHS = tRHS.dyn_cast<VectorType>();
1839 if (!vLHS)
1840 return parser.emitError(parser.getNameLoc(),
1841 "expected vector type for operand #1");
1842 VectorType resType =
1843 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1844 vLHS.getElementType())
1845 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
1846
1847 if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
1848 result.attributes.append(
1849 OuterProductOp::getKindAttrName(),
1850 CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
1851 result.getContext()));
1852 }
1853
1854 return failure(
1855 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1856 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1857 (operandsInfo.size() > 2 &&
1858 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1859 parser.addTypeToList(resType, result.types));
1860 }
1861
verify(OuterProductOp op)1862 static LogicalResult verify(OuterProductOp op) {
1863 Type tRHS = op.getOperandTypeRHS();
1864 VectorType vLHS = op.getOperandVectorTypeLHS(),
1865 vRHS = tRHS.dyn_cast<VectorType>(),
1866 vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1867
1868 if (vLHS.getRank() != 1)
1869 return op.emitOpError("expected 1-d vector for operand #1");
1870
1871 if (vRHS) {
1872 // Proper OUTER operation.
1873 if (vRHS.getRank() != 1)
1874 return op.emitOpError("expected 1-d vector for operand #2");
1875 if (vRES.getRank() != 2)
1876 return op.emitOpError("expected 2-d vector result");
1877 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1878 return op.emitOpError("expected #1 operand dim to match result dim #1");
1879 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1880 return op.emitOpError("expected #2 operand dim to match result dim #2");
1881 } else {
1882 // An AXPY operation.
1883 if (vRES.getRank() != 1)
1884 return op.emitOpError("expected 1-d vector result");
1885 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1886 return op.emitOpError("expected #1 operand dim to match result dim #1");
1887 }
1888
1889 if (vACC && vACC != vRES)
1890 return op.emitOpError("expected operand #3 of same type as result type");
1891
1892 // Verify supported combining kind.
1893 if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
1894 return op.emitOpError("unsupported outerproduct type");
1895
1896 return success();
1897 }
1898
1899 //===----------------------------------------------------------------------===//
1900 // ReshapeOp
1901 //===----------------------------------------------------------------------===//
1902
verify(ReshapeOp op)1903 static LogicalResult verify(ReshapeOp op) {
1904 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1905 auto inputVectorType = op.getInputVectorType();
1906 auto outputVectorType = op.getOutputVectorType();
1907 int64_t inputShapeRank = op.getNumInputShapeSizes();
1908 int64_t outputShapeRank = op.getNumOutputShapeSizes();
1909 SmallVector<int64_t, 4> fixedVectorSizes;
1910 op.getFixedVectorSizes(fixedVectorSizes);
1911 int64_t numFixedVectorSizes = fixedVectorSizes.size();
1912
1913 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1914 return op.emitError("invalid input shape for vector type ")
1915 << inputVectorType;
1916
1917 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1918 return op.emitError("invalid output shape for vector type ")
1919 << outputVectorType;
1920
1921 // Verify that the 'fixedVectorSizes' match an input/output vector shape
1922 // suffix.
1923 unsigned inputVectorRank = inputVectorType.getRank();
1924 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1925 unsigned index = inputVectorRank - numFixedVectorSizes - i;
1926 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1927 return op.emitError("fixed vector size must match input vector for dim ")
1928 << i;
1929 }
1930
1931 unsigned outputVectorRank = outputVectorType.getRank();
1932 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1933 unsigned index = outputVectorRank - numFixedVectorSizes - i;
1934 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1935 return op.emitError("fixed vector size must match output vector for dim ")
1936 << i;
1937 }
1938
1939 // If all shape operands are produced by constant ops, verify that product
1940 // of dimensions for input/output shape match.
1941 auto isDefByConstant = [](Value operand) {
1942 return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1943 };
1944 if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1945 llvm::all_of(op.output_shape(), isDefByConstant)) {
1946 int64_t numInputElements = 1;
1947 for (auto operand : op.input_shape())
1948 numInputElements *=
1949 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1950 int64_t numOutputElements = 1;
1951 for (auto operand : op.output_shape())
1952 numOutputElements *=
1953 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1954 if (numInputElements != numOutputElements)
1955 return op.emitError("product of input and output shape sizes must match");
1956 }
1957 return success();
1958 }
1959
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1960 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1961 populateFromInt64AttrArray(fixed_vector_sizes(), results);
1962 }
1963
1964 //===----------------------------------------------------------------------===//
1965 // ExtractStridedSliceOp
1966 //===----------------------------------------------------------------------===//
1967
1968 // Inference works as follows:
1969 // 1. Add 'sizes' from prefix of dims in 'offsets'.
1970 // 2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1971 static Type inferStridedSliceOpResultType(VectorType vectorType,
1972 ArrayAttr offsets, ArrayAttr sizes,
1973 ArrayAttr strides) {
1974 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1975 SmallVector<int64_t, 4> shape;
1976 shape.reserve(vectorType.getRank());
1977 unsigned idx = 0;
1978 for (unsigned e = offsets.size(); idx < e; ++idx)
1979 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
1980 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1981 shape.push_back(vectorType.getShape()[idx]);
1982
1983 return VectorType::get(shape, vectorType.getElementType());
1984 }
1985
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1986 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1987 Value source, ArrayRef<int64_t> offsets,
1988 ArrayRef<int64_t> sizes,
1989 ArrayRef<int64_t> strides) {
1990 result.addOperands(source);
1991 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1992 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
1993 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1994 result.addTypes(
1995 inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1996 offsetsAttr, sizesAttr, stridesAttr));
1997 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1998 result.addAttribute(getSizesAttrName(), sizesAttr);
1999 result.addAttribute(getStridesAttrName(), stridesAttr);
2000 }
2001
verify(ExtractStridedSliceOp op)2002 static LogicalResult verify(ExtractStridedSliceOp op) {
2003 auto type = op.getVectorType();
2004 auto offsets = op.offsets();
2005 auto sizes = op.sizes();
2006 auto strides = op.strides();
2007 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
2008 op.emitOpError(
2009 "expected offsets, sizes and strides attributes of same size");
2010 return failure();
2011 }
2012
2013 auto shape = type.getShape();
2014 auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
2015 auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
2016 auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
2017 if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
2018 failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
2019 failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
2020 stridesName)) ||
2021 failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
2022 failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
2023 /*halfOpen=*/false,
2024 /*min=*/1)) ||
2025 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
2026 /*halfOpen=*/false)) ||
2027 failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
2028 offName, sizesName,
2029 /*halfOpen=*/false)))
2030 return failure();
2031
2032 auto resultType = inferStridedSliceOpResultType(
2033 op.getVectorType(), op.offsets(), op.sizes(), op.strides());
2034 if (op.getResult().getType() != resultType) {
2035 op.emitOpError("expected result type to be ") << resultType;
2036 return failure();
2037 }
2038
2039 return success();
2040 }
2041
2042 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
2043 // to use the source of the InsertStrided ops if we can detect that the
2044 // extracted vector is a subset of one of the vector inserted.
2045 static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)2046 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2047 // Helper to extract integer out of ArrayAttr.
2048 auto getElement = [](ArrayAttr array, int idx) {
2049 return array[idx].cast<IntegerAttr>().getInt();
2050 };
2051 ArrayAttr extractOffsets = op.offsets();
2052 ArrayAttr extractStrides = op.strides();
2053 ArrayAttr extractSizes = op.sizes();
2054 auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
2055 while (insertOp) {
2056 if (op.getVectorType().getRank() !=
2057 insertOp.getSourceVectorType().getRank())
2058 return failure();
2059 ArrayAttr insertOffsets = insertOp.offsets();
2060 ArrayAttr insertStrides = insertOp.strides();
2061 // If the rank of extract is greater than the rank of insert, we are likely
2062 // extracting a partial chunk of the vector inserted.
2063 if (extractOffsets.size() > insertOffsets.size())
2064 return failure();
2065 bool patialoverlap = false;
2066 bool disjoint = false;
2067 SmallVector<int64_t, 4> offsetDiffs;
2068 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2069 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2070 return failure();
2071 int64_t start = getElement(insertOffsets, dim);
2072 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2073 int64_t offset = getElement(extractOffsets, dim);
2074 int64_t size = getElement(extractSizes, dim);
2075 // Check if the start of the extract offset is in the interval inserted.
2076 if (start <= offset && offset < end) {
2077 // If the extract interval overlaps but is not fully included we may
2078 // have a partial overlap that will prevent any folding.
2079 if (offset + size > end)
2080 patialoverlap = true;
2081 offsetDiffs.push_back(offset - start);
2082 continue;
2083 }
2084 disjoint = true;
2085 break;
2086 }
2087 // The extract element chunk is a subset of the insert element.
2088 if (!disjoint && !patialoverlap) {
2089 op.setOperand(insertOp.source());
2090 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2091 OpBuilder b(op.getContext());
2092 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
2093 b.getI64ArrayAttr(offsetDiffs));
2094 return success();
2095 }
2096 // If the chunk extracted is disjoint from the chunk inserted, keep looking
2097 // in the insert chain.
2098 if (disjoint)
2099 insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
2100 else {
2101 // The extracted vector partially overlap the inserted vector, we cannot
2102 // fold.
2103 return failure();
2104 }
2105 }
2106 return failure();
2107 }
2108
fold(ArrayRef<Attribute> operands)2109 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2110 if (getVectorType() == getResult().getType())
2111 return vector();
2112 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2113 return getResult();
2114 return {};
2115 }
2116
getOffsets(SmallVectorImpl<int64_t> & results)2117 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2118 populateFromInt64AttrArray(offsets(), results);
2119 }
2120
2121 namespace {
2122
2123 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2124 // ConstantMaskOp.
2125 class StridedSliceConstantMaskFolder final
2126 : public OpRewritePattern<ExtractStridedSliceOp> {
2127 public:
2128 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2129
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2130 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2131 PatternRewriter &rewriter) const override {
2132 // Return if 'extractStridedSliceOp' operand is not defined by a
2133 // ConstantMaskOp.
2134 auto defOp = extractStridedSliceOp.vector().getDefiningOp();
2135 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2136 if (!constantMaskOp)
2137 return failure();
2138 // Return if 'extractStridedSliceOp' has non-unit strides.
2139 if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
2140 return attr.cast<IntegerAttr>().getInt() != 1;
2141 }))
2142 return failure();
2143 // Gather constant mask dimension sizes.
2144 SmallVector<int64_t, 4> maskDimSizes;
2145 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
2146 // Gather strided slice offsets and sizes.
2147 SmallVector<int64_t, 4> sliceOffsets;
2148 populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
2149 SmallVector<int64_t, 4> sliceSizes;
2150 populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
2151
2152 // Compute slice of vector mask region.
2153 SmallVector<int64_t, 4> sliceMaskDimSizes;
2154 assert(sliceOffsets.size() == maskDimSizes.size());
2155 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2156 int64_t maskDimSize = std::get<0>(it);
2157 int64_t sliceOffset = std::get<1>(it);
2158 int64_t sliceSize = std::get<2>(it);
2159 int64_t sliceMaskDimSize = std::max(
2160 static_cast<int64_t>(0),
2161 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2162 sliceMaskDimSizes.push_back(sliceMaskDimSize);
2163 }
2164 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2165 // region is a conjunction of mask dim intervals).
2166 if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
2167 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2168
2169 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2170 // region.
2171 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2172 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2173 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2174 return success();
2175 }
2176 };
2177
2178 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2179 class StridedSliceConstantFolder final
2180 : public OpRewritePattern<ExtractStridedSliceOp> {
2181 public:
2182 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2183
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2184 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2185 PatternRewriter &rewriter) const override {
2186 // Return if 'extractStridedSliceOp' operand is not defined by a
2187 // ConstantOp.
2188 auto constantOp =
2189 extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
2190 if (!constantOp)
2191 return failure();
2192 auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2193 if (!dense)
2194 return failure();
2195 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2196 dense.getSplatValue());
2197 rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
2198 return success();
2199 }
2200 };
2201
2202 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
getI64SubArray(ArrayAttr arrayAttr,unsigned dropFront=0,unsigned dropBack=0)2203 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
2204 unsigned dropFront = 0,
2205 unsigned dropBack = 0) {
2206 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
2207 auto range = arrayAttr.getAsRange<IntegerAttr>();
2208 SmallVector<int64_t, 4> res;
2209 res.reserve(arrayAttr.size() - dropFront - dropBack);
2210 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
2211 it != eit; ++it)
2212 res.push_back((*it).getValue().getSExtValue());
2213 return res;
2214 }
2215
2216 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2217 // BroadcastOp(ExtractStrideSliceOp).
2218 class StridedSliceBroadcast final
2219 : public OpRewritePattern<ExtractStridedSliceOp> {
2220 public:
2221 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2222
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const2223 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2224 PatternRewriter &rewriter) const override {
2225 auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
2226 if (!broadcast)
2227 return failure();
2228 auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
2229 unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
2230 auto dstVecType = op.getType().cast<VectorType>();
2231 unsigned dstRank = dstVecType.getRank();
2232 unsigned rankDiff = dstRank - srcRrank;
2233 // Check if the most inner dimensions of the source of the broadcast are the
2234 // same as the destination of the extract. If this is the case we can just
2235 // use a broadcast as the original dimensions are untouched.
2236 bool lowerDimMatch = true;
2237 for (unsigned i = 0; i < srcRrank; i++) {
2238 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2239 lowerDimMatch = false;
2240 break;
2241 }
2242 }
2243 Value source = broadcast.source();
2244 if (!lowerDimMatch) {
2245 // The inner dimensions don't match, it means we need to extract from the
2246 // source of the orignal broadcast and then broadcast the extracted value.
2247 source = rewriter.create<ExtractStridedSliceOp>(
2248 op->getLoc(), source,
2249 getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
2250 getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
2251 getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
2252 }
2253 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2254 return success();
2255 }
2256 };
2257
2258 } // end anonymous namespace
2259
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2260 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2261 RewritePatternSet &results, MLIRContext *context) {
2262 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2263 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2264 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2265 StridedSliceBroadcast>(context);
2266 }
2267
2268 //===----------------------------------------------------------------------===//
2269 // TransferReadOp
2270 //===----------------------------------------------------------------------===//
2271
2272 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)2273 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2274 EmitFun emitOpError) {
2275 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2276 for (auto expr : permutationMap.getResults()) {
2277 auto dim = expr.dyn_cast<AffineDimExpr>();
2278 auto zero = expr.dyn_cast<AffineConstantExpr>();
2279 if (zero) {
2280 if (zero.getValue() != 0) {
2281 return emitOpError(
2282 "requires a projected permutation_map (at most one dim or the zero "
2283 "constant can appear in each result)");
2284 }
2285 continue;
2286 }
2287 if (!dim) {
2288 return emitOpError("requires a projected permutation_map (at most one "
2289 "dim or the zero constant can appear in each result)");
2290 }
2291 if (seen[dim.getPosition()]) {
2292 return emitOpError(
2293 "requires a permutation_map that is a permutation (found one dim "
2294 "used more than once)");
2295 }
2296 seen[dim.getPosition()] = true;
2297 }
2298 return success();
2299 }
2300
2301 static LogicalResult
verifyTransferOp(VectorTransferOpInterface op,ShapedType shapedType,VectorType vectorType,VectorType maskType,AffineMap permutationMap,ArrayAttr inBounds)2302 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2303 VectorType vectorType, VectorType maskType,
2304 AffineMap permutationMap, ArrayAttr inBounds) {
2305 if (shapedType.getRank() == 0 && !op.isZeroD())
2306 return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> "
2307 "(0) permutation_map");
2308
2309 if (op->hasAttr("masked")) {
2310 return op->emitOpError("masked attribute has been removed. "
2311 "Use in_bounds instead.");
2312 }
2313
2314 if (!shapedType.isa<MemRefType, RankedTensorType>())
2315 return op->emitOpError(
2316 "requires source to be a memref or ranked tensor type");
2317 auto elementType = shapedType.getElementType();
2318 DataLayout dataLayout = DataLayout::closest(op);
2319 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2320 // Memref or tensor has vector element type.
2321 unsigned sourceVecSize =
2322 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2323 vectorElementType.getShape().back();
2324 unsigned resultVecSize =
2325 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2326 vectorType.getShape().back();
2327 if (resultVecSize % sourceVecSize != 0)
2328 return op->emitOpError(
2329 "requires the bitwidth of the minor 1-D vector to be an integral "
2330 "multiple of the bitwidth of the minor 1-D vector of the source");
2331
2332 unsigned sourceVecEltRank = vectorElementType.getRank();
2333 unsigned resultVecRank = vectorType.getRank();
2334 if (sourceVecEltRank > resultVecRank)
2335 return op->emitOpError(
2336 "requires source vector element and vector result ranks to match.");
2337 unsigned rankOffset = resultVecRank - sourceVecEltRank;
2338 // Check that permutation map results match 'rankOffset' of vector type.
2339 if (permutationMap.getNumResults() != rankOffset)
2340 return op->emitOpError("requires a permutation_map with result dims of "
2341 "the same rank as the vector type");
2342
2343 if (maskType)
2344 return op->emitOpError("does not support masks with vector element type");
2345 } else {
2346 // Memref or tensor has scalar element type.
2347 unsigned resultVecSize =
2348 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2349 vectorType.getShape().back();
2350 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2351 return op->emitOpError(
2352 "requires the bitwidth of the minor 1-D vector to be an integral "
2353 "multiple of the bitwidth of the source element type");
2354
2355 // Check that permutation map results match rank of vector type.
2356 if (permutationMap.getNumResults() != vectorType.getRank())
2357 return op->emitOpError("requires a permutation_map with result dims of "
2358 "the same rank as the vector type");
2359
2360 VectorType expectedMaskType =
2361 vector::detail::transferMaskType(vectorType, permutationMap);
2362 if (maskType && expectedMaskType != maskType)
2363 return op->emitOpError("expects mask type consistent with permutation "
2364 "map: ")
2365 << maskType;
2366 }
2367
2368 if (permutationMap.getNumSymbols() != 0)
2369 return op->emitOpError("requires permutation_map without symbols");
2370 // TODO: implement 0-d vector corner cases.
2371 if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank())
2372 return op->emitOpError("requires a permutation_map with input dims of the "
2373 "same rank as the source type");
2374
2375 if (inBounds) {
2376 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2377 return op->emitOpError("expects the optional in_bounds attr of same rank "
2378 "as permutation_map results: ")
2379 << AffineMapAttr::get(permutationMap);
2380 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2381 if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
2382 !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2383 return op->emitOpError("requires broadcast dimensions to be in-bounds");
2384 }
2385
2386 return success();
2387 }
2388
2389 /// Builder that sets padding to zero.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> inBounds)2390 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2391 VectorType vectorType, Value source,
2392 ValueRange indices, AffineMap permutationMap,
2393 ArrayRef<bool> inBounds) {
2394 Type elemType = source.getType().cast<ShapedType>().getElementType();
2395 Value padding = builder.create<ConstantOp>(result.location, elemType,
2396 builder.getZeroAttr(elemType));
2397 if (inBounds.empty())
2398 return build(builder, result, vectorType, source, indices, permutationMap,
2399 padding, ArrayAttr());
2400 ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2401 build(builder, result, vectorType, source, indices, permutationMap, padding,
2402 inBoundsArrayAttr);
2403 }
2404
2405 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,Value padding,ArrayRef<bool> inBounds)2406 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2407 VectorType vectorType, Value source,
2408 ValueRange indices, Value padding,
2409 ArrayRef<bool> inBounds) {
2410 auto permMap = getTransferMinorIdentityMap(
2411 source.getType().cast<ShapedType>(), vectorType);
2412 if (inBounds.empty())
2413 return build(builder, result, vectorType, source, indices, permMap, padding,
2414 ArrayAttr());
2415 ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2416 build(builder, result, vectorType, source, indices, permMap, padding,
2417 inBoundsArrayAttr);
2418 }
2419
2420 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
2421 /// (resp. zero).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,ArrayRef<bool> inBounds)2422 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2423 VectorType vectorType, Value source,
2424 ValueRange indices, ArrayRef<bool> inBounds) {
2425 auto permMap = getTransferMinorIdentityMap(
2426 source.getType().cast<ShapedType>(), vectorType);
2427 build(builder, result, vectorType, source, indices, permMap, inBounds);
2428 }
2429
2430 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMap permutationMap,Value padding,ArrayAttr inBounds)2431 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2432 Type vectorType, Value source, ValueRange indices,
2433 AffineMap permutationMap, Value padding,
2434 ArrayAttr inBounds) {
2435 build(builder, result, vectorType, source, indices, permutationMap, padding,
2436 /*mask=*/Value(), inBounds);
2437 }
2438
2439 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMapAttr permutationMap,Value padding,ArrayAttr inBounds)2440 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2441 Type vectorType, Value source, ValueRange indices,
2442 AffineMapAttr permutationMap, Value padding,
2443 ArrayAttr inBounds) {
2444 build(builder, result, vectorType, source, indices, permutationMap, padding,
2445 /*mask=*/Value(), inBounds);
2446 }
2447
createScalarOp(OpBuilder & builder,Location loc,Value source,ValueRange indices,ArrayRef<bool> inBounds)2448 Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
2449 Value source, ValueRange indices,
2450 ArrayRef<bool> inBounds) {
2451 Type elemType = source.getType().cast<ShapedType>().getElementType();
2452 auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
2453 AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
2454 getAffineConstantExpr(0, loc.getContext()));
2455 Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
2456 indices, map, inBounds);
2457 return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
2458 }
2459
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)2460 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2461 SmallVector<StringRef, 3> elidedAttrs;
2462 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2463 if (op.permutation_map().isMinorIdentity())
2464 elidedAttrs.push_back(op.getPermutationMapAttrName());
2465 bool elideInBounds = true;
2466 if (auto inBounds = op.in_bounds()) {
2467 for (auto attr : *inBounds) {
2468 if (attr.template cast<BoolAttr>().getValue()) {
2469 elideInBounds = false;
2470 break;
2471 }
2472 }
2473 }
2474 if (elideInBounds)
2475 elidedAttrs.push_back(op.getInBoundsAttrName());
2476 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2477 }
2478
print(OpAsmPrinter & p,TransferReadOp op)2479 static void print(OpAsmPrinter &p, TransferReadOp op) {
2480 p << " " << op.source() << "[" << op.indices() << "], " << op.padding();
2481 if (op.mask())
2482 p << ", " << op.mask();
2483 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2484 p << " : " << op.getShapedType() << ", " << op.getVectorType();
2485 }
2486
parseTransferReadOp(OpAsmParser & parser,OperationState & result)2487 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2488 OperationState &result) {
2489 auto &builder = parser.getBuilder();
2490 llvm::SMLoc typesLoc;
2491 OpAsmParser::OperandType sourceInfo;
2492 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2493 OpAsmParser::OperandType paddingInfo;
2494 SmallVector<Type, 2> types;
2495 OpAsmParser::OperandType maskInfo;
2496 // Parsing with support for paddingValue.
2497 if (parser.parseOperand(sourceInfo) ||
2498 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2499 parser.parseComma() || parser.parseOperand(paddingInfo))
2500 return failure();
2501 ParseResult hasMask = parser.parseOptionalComma();
2502 if (hasMask.succeeded()) {
2503 parser.parseOperand(maskInfo);
2504 }
2505 if (parser.parseOptionalAttrDict(result.attributes) ||
2506 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2507 return failure();
2508 if (types.size() != 2)
2509 return parser.emitError(typesLoc, "requires two types");
2510 auto indexType = builder.getIndexType();
2511 auto shapedType = types[0].dyn_cast<ShapedType>();
2512 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2513 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2514 VectorType vectorType = types[1].dyn_cast<VectorType>();
2515 if (!vectorType)
2516 return parser.emitError(typesLoc, "requires vector type");
2517 auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2518 Attribute mapAttr = result.attributes.get(permutationAttrName);
2519 if (!mapAttr) {
2520 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2521 mapAttr = AffineMapAttr::get(permMap);
2522 result.attributes.set(permutationAttrName, mapAttr);
2523 }
2524 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2525 parser.resolveOperands(indexInfo, indexType, result.operands) ||
2526 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2527 result.operands))
2528 return failure();
2529 if (hasMask.succeeded()) {
2530 if (shapedType.getElementType().dyn_cast<VectorType>())
2531 return parser.emitError(
2532 maskInfo.location, "does not support masks with vector element type");
2533 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2534 // Instead of adding the mask type as an op type, compute it based on the
2535 // vector type and the permutation map (to keep the type signature small).
2536 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2537 if (parser.resolveOperand(maskInfo, maskType, result.operands))
2538 return failure();
2539 }
2540 result.addAttribute(
2541 TransferReadOp::getOperandSegmentSizeAttr(),
2542 builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
2543 static_cast<int32_t>(hasMask.succeeded())}));
2544 return parser.addTypeToList(vectorType, result.types);
2545 }
2546
verify(TransferReadOp op)2547 static LogicalResult verify(TransferReadOp op) {
2548 // Consistency of elemental types in source and vector.
2549 ShapedType shapedType = op.getShapedType();
2550 VectorType vectorType = op.getVectorType();
2551 VectorType maskType = op.getMaskType();
2552 auto paddingType = op.padding().getType();
2553 auto permutationMap = op.permutation_map();
2554 auto sourceElementType = shapedType.getElementType();
2555
2556 if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
2557 return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2558
2559 if (failed(
2560 verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
2561 shapedType, vectorType, maskType, permutationMap,
2562 op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2563 return failure();
2564
2565 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
2566 // Source has vector element type.
2567 // Check that 'sourceVectorElementType' and 'paddingType' types match.
2568 if (sourceVectorElementType != paddingType)
2569 return op.emitOpError(
2570 "requires source element type and padding type to match.");
2571
2572 } else {
2573 // Check that 'paddingType' is valid to store in a vector type.
2574 if (!VectorType::isValidElementType(paddingType))
2575 return op.emitOpError("requires valid padding vector elemental type");
2576
2577 // Check that padding type and vector element types match.
2578 if (paddingType != sourceElementType)
2579 return op.emitOpError(
2580 "requires formal padding and source of the same elemental type");
2581 }
2582
2583 return verifyPermutationMap(permutationMap,
2584 [&op](Twine t) { return op.emitOpError(t); });
2585 }
2586
2587 /// This is a common class used for patterns of the form
2588 /// ```
2589 /// someop(memrefcast) -> someop
2590 /// ```
2591 /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)2592 static LogicalResult foldMemRefCast(Operation *op) {
2593 bool folded = false;
2594 for (OpOperand &operand : op->getOpOperands()) {
2595 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
2596 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
2597 operand.set(castOp.getOperand());
2598 folded = true;
2599 }
2600 }
2601 return success(folded);
2602 }
2603
foldTensorCast(Operation * op)2604 static LogicalResult foldTensorCast(Operation *op) {
2605 bool folded = false;
2606 for (OpOperand &operand : op->getOpOperands()) {
2607 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2608 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2609 operand.set(castOp.getOperand());
2610 folded = true;
2611 }
2612 }
2613 return success(folded);
2614 }
2615
2616 template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)2617 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2618 // TODO: support more aggressive createOrFold on:
2619 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
2620 if (op.getShapedType().isDynamicDim(indicesIdx))
2621 return false;
2622 Value index = op.indices()[indicesIdx];
2623 auto cstOp = index.getDefiningOp<ConstantIndexOp>();
2624 if (!cstOp)
2625 return false;
2626
2627 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
2628 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2629
2630 return cstOp.getValue() + vectorSize <= sourceSize;
2631 }
2632
2633 template <typename TransferOp>
foldTransferInBoundsAttribute(TransferOp op)2634 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
2635 // TODO: Be less conservative once we have 0-d vectors.
2636 if (op.isZeroD())
2637 return failure();
2638 AffineMap permutationMap = op.permutation_map();
2639 bool changed = false;
2640 SmallVector<bool, 4> newInBounds;
2641 newInBounds.reserve(op.getTransferRank());
2642 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
2643 // Already marked as in-bounds, nothing to see here.
2644 if (op.isDimInBounds(i)) {
2645 newInBounds.push_back(true);
2646 continue;
2647 }
2648 // Currently out-of-bounds, check whether we can statically determine it is
2649 // inBounds.
2650 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
2651 assert(dimExpr && "Broadcast dims must be in-bounds");
2652 auto inBounds =
2653 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
2654 newInBounds.push_back(inBounds);
2655 // We commit the pattern if it is "more inbounds".
2656 changed |= inBounds;
2657 }
2658 if (!changed)
2659 return failure();
2660 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2661 OpBuilder b(op.getContext());
2662 op->setAttr(TransferOp::getInBoundsAttrName(),
2663 b.getBoolArrayAttr(newInBounds));
2664 return success();
2665 }
2666
2667 /// ```
2668 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2669 /// : vector<1x4xf32>, tensor<4x4xf32>
2670 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
2671 /// : tensor<4x4xf32>, vector<1x4xf32>
2672 /// ```
2673 /// -> Folds into
2674 /// ```
2675 /// %v0
2676 /// ```
foldRAW(TransferReadOp readOp)2677 static Value foldRAW(TransferReadOp readOp) {
2678 if (!readOp.getShapedType().isa<RankedTensorType>())
2679 return {};
2680 auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
2681 while (defWrite) {
2682 if (checkSameValueRAW(defWrite, readOp))
2683 return defWrite.vector();
2684 if (!isDisjointTransferIndices(
2685 cast<VectorTransferOpInterface>(defWrite.getOperation()),
2686 cast<VectorTransferOpInterface>(readOp.getOperation())))
2687 break;
2688 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2689 }
2690 return {};
2691 }
2692
fold(ArrayRef<Attribute>)2693 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2694 if (Value vec = foldRAW(*this))
2695 return vec;
2696 /// transfer_read(memrefcast) -> transfer_read
2697 if (succeeded(foldTransferInBoundsAttribute(*this)))
2698 return getResult();
2699 if (succeeded(foldMemRefCast(*this)))
2700 return getResult();
2701 if (succeeded(foldTensorCast(*this)))
2702 return getResult();
2703 return OpFoldResult();
2704 }
2705
getShapeForUnroll()2706 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2707 return llvm::to_vector<4>(getVectorType().getShape());
2708 }
2709
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2710 void TransferReadOp::getEffects(
2711 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2712 &effects) {
2713 if (getShapedType().isa<MemRefType>())
2714 effects.emplace_back(MemoryEffects::Read::get(), source(),
2715 SideEffects::DefaultResource::get());
2716 }
2717
2718 namespace {
2719 /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
2720 ///
2721 /// ```
2722 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
2723 /// : tensor<?x?xf32> to tensor<?x?xf32>
2724 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
2725 /// : tensor<?x?xf32>, vector<4x5xf32>
2726 /// ```
2727 /// is rewritten to:
2728 /// ```
2729 /// %p0 = addi %a, %e : index
2730 /// %p1 = addi %b, %f : index
2731 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
2732 /// : tensor<?x?xf32>, vector<4x5xf32>
2733 /// ```
2734 struct FoldExtractSliceIntoTransferRead
2735 : public OpRewritePattern<TransferReadOp> {
2736 public:
2737 using OpRewritePattern<TransferReadOp>::OpRewritePattern;
2738
matchAndRewrite__anona11d586e1811::FoldExtractSliceIntoTransferRead2739 LogicalResult matchAndRewrite(TransferReadOp xferOp,
2740 PatternRewriter &rewriter) const override {
2741 if (xferOp.hasOutOfBoundsDim())
2742 return failure();
2743 if (!xferOp.permutation_map().isIdentity())
2744 return failure();
2745 if (xferOp.mask())
2746 return failure();
2747 auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
2748 if (!extractOp)
2749 return failure();
2750 if (!extractOp.hasUnitStride())
2751 return failure();
2752
2753 int64_t rankReduced =
2754 extractOp.getSourceType().getRank() - extractOp.getType().getRank();
2755 SmallVector<Value> newIndices;
2756 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
2757 // indices first.
2758 for (int64_t i = 0; i < rankReduced; ++i) {
2759 OpFoldResult offset = extractOp.getMixedOffsets()[i];
2760 newIndices.push_back(getValueOrCreateConstantIndexOp(
2761 rewriter, extractOp.getLoc(), offset));
2762 }
2763 for (auto it : llvm::enumerate(xferOp.indices())) {
2764 OpFoldResult offset =
2765 extractOp.getMixedOffsets()[it.index() + rankReduced];
2766 newIndices.push_back(
2767 rewriter.create<AddIOp>(xferOp->getLoc(), it.value(),
2768 getValueOrCreateConstantIndexOp(
2769 rewriter, extractOp.getLoc(), offset)));
2770 }
2771 SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
2772 rewriter.replaceOpWithNewOp<TransferReadOp>(xferOp, xferOp.getVectorType(),
2773 extractOp.source(), newIndices,
2774 xferOp.padding(), inBounds);
2775
2776 return success();
2777 }
2778 };
2779 } // namespace
2780
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2781 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2782 MLIRContext *context) {
2783 results.add<FoldExtractSliceIntoTransferRead>(context);
2784 }
2785
2786 //===----------------------------------------------------------------------===//
2787 // TransferWriteOp
2788 //===----------------------------------------------------------------------===//
2789
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> inBounds)2790 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2791 Value vector, Value dest, ValueRange indices,
2792 AffineMap permutationMap, ArrayRef<bool> inBounds) {
2793 if (inBounds.empty())
2794 return build(builder, result, vector, dest, indices, permutationMap,
2795 /*mask=*/Value(), ArrayAttr());
2796 build(builder, result, vector, dest, indices, permutationMap,
2797 /*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
2798 }
2799
2800 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,ArrayRef<bool> inBounds)2801 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2802 Value vector, Value source, ValueRange indices,
2803 ArrayRef<bool> inBounds) {
2804 auto vectorType = vector.getType().cast<VectorType>();
2805 auto permMap = getTransferMinorIdentityMap(
2806 source.getType().cast<ShapedType>(), vectorType);
2807 if (inBounds.empty())
2808 return build(builder, result, vector, source, indices, permMap,
2809 ArrayAttr());
2810 ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2811 build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
2812 }
2813
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMapAttr permutationMap,ArrayAttr inBounds)2814 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2815 Value vector, Value source, ValueRange indices,
2816 AffineMapAttr permutationMap,
2817 /*optional*/ ArrayAttr inBounds) {
2818 Type resultType = source.getType().dyn_cast<RankedTensorType>();
2819 build(builder, result, resultType, vector, source, indices, permutationMap,
2820 /*mask=*/Value(), inBounds);
2821 }
2822
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,ArrayAttr inBounds)2823 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2824 Value vector, Value source, ValueRange indices,
2825 AffineMap permutationMap,
2826 /*optional*/ ArrayAttr inBounds) {
2827 Type resultType = source.getType().dyn_cast<RankedTensorType>();
2828 build(builder, result, resultType, vector, source, indices, permutationMap,
2829 /*mask=*/Value(), inBounds);
2830 }
2831
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,Value mask,ArrayAttr inBounds)2832 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2833 Value vector, Value source, ValueRange indices,
2834 AffineMap permutationMap, /*optional*/ Value mask,
2835 /*optional*/ ArrayAttr inBounds) {
2836 Type resultType = source.getType().dyn_cast<RankedTensorType>();
2837 build(builder, result, resultType, vector, source, indices, permutationMap,
2838 mask, inBounds);
2839 }
2840
createScalarOp(OpBuilder & builder,Location loc,Value value,Value dest,ValueRange indices,ArrayRef<bool> inBounds)2841 Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
2842 Value value, Value dest,
2843 ValueRange indices,
2844 ArrayRef<bool> inBounds) {
2845 Value vectorOfAScalar = value;
2846 if (!value.getType().isa<VectorType>())
2847 vectorOfAScalar = builder.create<vector::BroadcastOp>(
2848 loc, VectorType::get({1}, value.getType()), value);
2849 AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
2850 getAffineConstantExpr(0, loc.getContext()));
2851 return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
2852 indices, map, inBounds);
2853 }
2854
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)2855 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
2856 OperationState &result) {
2857 auto &builder = parser.getBuilder();
2858 llvm::SMLoc typesLoc;
2859 OpAsmParser::OperandType vectorInfo, sourceInfo;
2860 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2861 SmallVector<Type, 2> types;
2862 OpAsmParser::OperandType maskInfo;
2863 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
2864 parser.parseOperand(sourceInfo) ||
2865 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
2866 return failure();
2867 ParseResult hasMask = parser.parseOptionalComma();
2868 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
2869 return failure();
2870 if (parser.parseOptionalAttrDict(result.attributes) ||
2871 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2872 return failure();
2873 if (types.size() != 2)
2874 return parser.emitError(typesLoc, "requires two types");
2875 auto indexType = builder.getIndexType();
2876 VectorType vectorType = types[0].dyn_cast<VectorType>();
2877 if (!vectorType)
2878 return parser.emitError(typesLoc, "requires vector type");
2879 ShapedType shapedType = types[1].dyn_cast<ShapedType>();
2880 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2881 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2882 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
2883 auto attr = result.attributes.get(permutationAttrName);
2884 if (!attr) {
2885 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2886 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2887 }
2888 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
2889 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2890 parser.resolveOperands(indexInfo, indexType, result.operands))
2891 return failure();
2892 if (hasMask.succeeded()) {
2893 if (shapedType.getElementType().dyn_cast<VectorType>())
2894 return parser.emitError(
2895 maskInfo.location, "does not support masks with vector element type");
2896 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
2897 if (parser.resolveOperand(maskInfo, maskType, result.operands))
2898 return failure();
2899 }
2900 result.addAttribute(
2901 TransferWriteOp::getOperandSegmentSizeAttr(),
2902 builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
2903 static_cast<int32_t>(hasMask.succeeded())}));
2904 return failure(shapedType.isa<RankedTensorType>() &&
2905 parser.addTypeToList(shapedType, result.types));
2906 }
2907
print(OpAsmPrinter & p,TransferWriteOp op)2908 static void print(OpAsmPrinter &p, TransferWriteOp op) {
2909 p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]";
2910 if (op.mask())
2911 p << ", " << op.mask();
2912 printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2913 p << " : " << op.getVectorType() << ", " << op.getShapedType();
2914 }
2915
verify(TransferWriteOp op)2916 static LogicalResult verify(TransferWriteOp op) {
2917 // Consistency of elemental types in shape and vector.
2918 ShapedType shapedType = op.getShapedType();
2919 VectorType vectorType = op.getVectorType();
2920 VectorType maskType = op.getMaskType();
2921 auto permutationMap = op.permutation_map();
2922
2923 if (llvm::size(op.indices()) != shapedType.getRank())
2924 return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2925
2926 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
2927 // as the semantics is unclear. This can be revisited later if necessary.
2928 if (op.hasBroadcastDim())
2929 return op.emitOpError("should not have broadcast dimensions");
2930
2931 if (failed(
2932 verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
2933 shapedType, vectorType, maskType, permutationMap,
2934 op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2935 return failure();
2936
2937 return verifyPermutationMap(permutationMap,
2938 [&op](Twine t) { return op.emitOpError(t); });
2939 }
2940
2941 /// Fold:
2942 /// ```
2943 /// %t1 = ...
2944 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
2945 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
2946 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
2947 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
2948 /// ```
2949 ///
2950 /// into:
2951 ///
2952 /// ```
2953 /// %t0
2954 /// ```
2955 ///
2956 /// The producer of t1 may or may not be DCE'd depending on whether it is a
2957 /// block argument or has side effects.
foldReadInitWrite(TransferWriteOp write,ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)2958 static LogicalResult foldReadInitWrite(TransferWriteOp write,
2959 ArrayRef<Attribute>,
2960 SmallVectorImpl<OpFoldResult> &results) {
2961 auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
2962 // If not operating on tensors, bail.
2963 if (!rankedTensorType)
2964 return failure();
2965 // If no read, bail.
2966 auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
2967 if (!read)
2968 return failure();
2969 // For now, only accept minor identity. Future: composition is minor identity.
2970 if (!read.permutation_map().isMinorIdentity() ||
2971 !write.permutation_map().isMinorIdentity())
2972 return failure();
2973 // Bail on mismatching ranks.
2974 if (read.getTransferRank() != write.getTransferRank())
2975 return failure();
2976 // Bail on potential out-of-bounds accesses.
2977 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
2978 return failure();
2979 // Tensor types must be the same.
2980 if (read.source().getType() != rankedTensorType)
2981 return failure();
2982 // Vector types must be the same.
2983 if (read.getVectorType() != write.getVectorType())
2984 return failure();
2985 // Vector and Tensor shapes must match.
2986 if (read.getVectorType().getShape() != rankedTensorType.getShape())
2987 return failure();
2988 // If any index is nonzero.
2989 auto isNotConstantZero = [](Value v) {
2990 auto cstOp = v.getDefiningOp<ConstantIndexOp>();
2991 return !cstOp || cstOp.getValue() != 0;
2992 };
2993 if (llvm::any_of(read.indices(), isNotConstantZero) ||
2994 llvm::any_of(write.indices(), isNotConstantZero))
2995 return failure();
2996 // Success.
2997 results.push_back(read.source());
2998 return success();
2999 }
3000
checkSameValueWAR(vector::TransferReadOp read,vector::TransferWriteOp write)3001 static bool checkSameValueWAR(vector::TransferReadOp read,
3002 vector::TransferWriteOp write) {
3003 return read.source() == write.source() && read.indices() == write.indices() &&
3004 read.permutation_map() == write.permutation_map() &&
3005 read.getVectorType() == write.getVectorType() && !read.mask() &&
3006 !write.mask();
3007 }
3008 /// Fold transfer_write write after read:
3009 /// ```
3010 /// %t0 = ...
3011 /// %v = vector.transfer_read %t0[%c0...] :
3012 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3013 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
3014 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3015 /// ```
3016 ///
3017 /// into:
3018 ///
3019 /// ```
3020 /// %t0
3021 /// ```
foldWAR(TransferWriteOp write,SmallVectorImpl<OpFoldResult> & results)3022 static LogicalResult foldWAR(TransferWriteOp write,
3023 SmallVectorImpl<OpFoldResult> &results) {
3024 if (!write.source().getType().isa<RankedTensorType>())
3025 return failure();
3026 auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
3027 if (!read)
3028 return failure();
3029
3030 if (!checkSameValueWAR(read, write))
3031 return failure();
3032 results.push_back(read.source());
3033 return success();
3034 }
3035
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3036 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3037 SmallVectorImpl<OpFoldResult> &results) {
3038 if (succeeded(foldReadInitWrite(*this, operands, results)))
3039 return success();
3040 if (succeeded(foldWAR(*this, results)))
3041 return success();
3042 if (succeeded(foldTransferInBoundsAttribute(*this)))
3043 return success();
3044 return foldMemRefCast(*this);
3045 }
3046
getShapeForUnroll()3047 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
3048 return llvm::to_vector<4>(getVectorType().getShape());
3049 }
3050
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)3051 void TransferWriteOp::getEffects(
3052 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3053 &effects) {
3054 if (getShapedType().isa<MemRefType>())
3055 effects.emplace_back(MemoryEffects::Write::get(), source(),
3056 SideEffects::DefaultResource::get());
3057 }
3058
3059 namespace {
3060 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
3061 /// DCE
3062 /// ```
3063 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3064 /// : vector<1x4xf32>, tensor<4x4xf32>
3065 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
3066 /// : vector<1x4xf32>, tensor<4x4xf32>
3067 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3068 /// : vector<1x4xf32>, tensor<4x4xf32>
3069 /// ```
3070 ///
3071 /// into:
3072 ///
3073 /// ```
3074 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3075 /// : vector<1x4xf32>, tensor<4x4xf32>
3076 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
3077 /// : vector<1x4xf32>, tensor<4x4xf32>
3078 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3079 /// : vector<1x4xf32>, tensor<4x4xf32>
3080 /// ```
3081 ///
3082 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
3083 /// any other uses.
3084 class foldWAW final : public OpRewritePattern<TransferWriteOp> {
3085 public:
3086 using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
matchAndRewrite(TransferWriteOp writeOp,PatternRewriter & rewriter) const3087 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
3088 PatternRewriter &rewriter) const override {
3089 if (!writeOp.getShapedType().isa<RankedTensorType>())
3090 return failure();
3091 vector::TransferWriteOp writeToModify = writeOp;
3092
3093 auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
3094 while (defWrite) {
3095 if (checkSameValueWAW(writeOp, defWrite)) {
3096 writeToModify.sourceMutable().assign(defWrite.source());
3097 return success();
3098 }
3099 if (!isDisjointTransferIndices(
3100 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3101 cast<VectorTransferOpInterface>(writeOp.getOperation())))
3102 break;
3103 // If the previous write op doesn't have any other use we an safely look
3104 // at the previous store to see if it can be removed.
3105 if (!defWrite->hasOneUse())
3106 break;
3107 writeToModify = defWrite;
3108 defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
3109 }
3110 return failure();
3111 }
3112 };
3113
3114 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
3115 /// could directly write to the insert_slice's destination. E.g.:
3116 ///
3117 /// ```
3118 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
3119 /// : vector<4x5xf32>, tensor<4x5xf32>
3120 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
3121 /// : tensor<4x5xf32> into tensor<?x?xf32>
3122 /// ```
3123 /// is rewritten to:
3124 /// ```
3125 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
3126 /// : vector<4x5xf32>, tensor<?x?xf32>
3127 /// ```
3128 struct FoldInsertSliceIntoTransferWrite
3129 : public OpRewritePattern<tensor::InsertSliceOp> {
3130 public:
3131 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
3132
matchAndRewrite__anona11d586e1b11::FoldInsertSliceIntoTransferWrite3133 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3134 PatternRewriter &rewriter) const override {
3135 if (!insertOp.hasUnitStride())
3136 return failure();
3137 auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
3138 if (!xferOp)
3139 return failure();
3140 if (xferOp.hasOutOfBoundsDim())
3141 return failure();
3142 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3143 return failure();
3144 if (xferOp.mask())
3145 return failure();
3146 // Fold only if the TransferWriteOp completely overwrites the `source` with
3147 // a vector. I.e., the result of the TransferWriteOp is a new tensor who's
3148 // content is the data of the vector.
3149 if (!llvm::equal(xferOp.getVectorType().getShape(),
3150 xferOp.getShapedType().getShape()))
3151 return failure();
3152 if (!xferOp.permutation_map().isIdentity())
3153 return failure();
3154
3155 SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
3156 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3157 SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3158 rewriter.replaceOpWithNewOp<TransferWriteOp>(
3159 insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds);
3160 return success();
3161 }
3162 };
3163 } // namespace
3164
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3165 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
3166 MLIRContext *context) {
3167 results.add<foldWAW, FoldInsertSliceIntoTransferWrite>(context);
3168 }
3169
3170 //===----------------------------------------------------------------------===//
3171 // LoadOp
3172 //===----------------------------------------------------------------------===//
3173
verifyLoadStoreMemRefLayout(Operation * op,MemRefType memRefTy)3174 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
3175 MemRefType memRefTy) {
3176 if (!isLastMemrefDimUnitStride(memRefTy))
3177 return op->emitOpError("most minor memref dim must have unit stride");
3178 return success();
3179 }
3180
verify(vector::LoadOp op)3181 static LogicalResult verify(vector::LoadOp op) {
3182 VectorType resVecTy = op.getVectorType();
3183 MemRefType memRefTy = op.getMemRefType();
3184
3185 if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3186 return failure();
3187
3188 // Checks for vector memrefs.
3189 Type memElemTy = memRefTy.getElementType();
3190 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3191 if (memVecTy != resVecTy)
3192 return op.emitOpError("base memref and result vector types should match");
3193 memElemTy = memVecTy.getElementType();
3194 }
3195
3196 if (resVecTy.getElementType() != memElemTy)
3197 return op.emitOpError("base and result element types should match");
3198 if (llvm::size(op.indices()) != memRefTy.getRank())
3199 return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3200 return success();
3201 }
3202
fold(ArrayRef<Attribute>)3203 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
3204 if (succeeded(foldMemRefCast(*this)))
3205 return getResult();
3206 return OpFoldResult();
3207 }
3208
3209 //===----------------------------------------------------------------------===//
3210 // StoreOp
3211 //===----------------------------------------------------------------------===//
3212
verify(vector::StoreOp op)3213 static LogicalResult verify(vector::StoreOp op) {
3214 VectorType valueVecTy = op.getVectorType();
3215 MemRefType memRefTy = op.getMemRefType();
3216
3217 if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3218 return failure();
3219
3220 // Checks for vector memrefs.
3221 Type memElemTy = memRefTy.getElementType();
3222 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3223 if (memVecTy != valueVecTy)
3224 return op.emitOpError(
3225 "base memref and valueToStore vector types should match");
3226 memElemTy = memVecTy.getElementType();
3227 }
3228
3229 if (valueVecTy.getElementType() != memElemTy)
3230 return op.emitOpError("base and valueToStore element type should match");
3231 if (llvm::size(op.indices()) != memRefTy.getRank())
3232 return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3233 return success();
3234 }
3235
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3236 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3237 SmallVectorImpl<OpFoldResult> &results) {
3238 return foldMemRefCast(*this);
3239 }
3240
3241 //===----------------------------------------------------------------------===//
3242 // MaskedLoadOp
3243 //===----------------------------------------------------------------------===//
3244
verify(MaskedLoadOp op)3245 static LogicalResult verify(MaskedLoadOp op) {
3246 VectorType maskVType = op.getMaskVectorType();
3247 VectorType passVType = op.getPassThruVectorType();
3248 VectorType resVType = op.getVectorType();
3249 MemRefType memType = op.getMemRefType();
3250
3251 if (resVType.getElementType() != memType.getElementType())
3252 return op.emitOpError("base and result element type should match");
3253 if (llvm::size(op.indices()) != memType.getRank())
3254 return op.emitOpError("requires ") << memType.getRank() << " indices";
3255 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3256 return op.emitOpError("expected result dim to match mask dim");
3257 if (resVType != passVType)
3258 return op.emitOpError("expected pass_thru of same type as result type");
3259 return success();
3260 }
3261
3262 namespace {
3263 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3264 public:
3265 using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const3266 LogicalResult matchAndRewrite(MaskedLoadOp load,
3267 PatternRewriter &rewriter) const override {
3268 switch (get1DMaskFormat(load.mask())) {
3269 case MaskFormat::AllTrue:
3270 rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
3271 load.base(), load.indices());
3272 return success();
3273 case MaskFormat::AllFalse:
3274 rewriter.replaceOp(load, load.pass_thru());
3275 return success();
3276 case MaskFormat::Unknown:
3277 return failure();
3278 }
3279 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3280 }
3281 };
3282 } // namespace
3283
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3284 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3285 MLIRContext *context) {
3286 results.add<MaskedLoadFolder>(context);
3287 }
3288
fold(ArrayRef<Attribute>)3289 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3290 if (succeeded(foldMemRefCast(*this)))
3291 return getResult();
3292 return OpFoldResult();
3293 }
3294
3295 //===----------------------------------------------------------------------===//
3296 // MaskedStoreOp
3297 //===----------------------------------------------------------------------===//
3298
verify(MaskedStoreOp op)3299 static LogicalResult verify(MaskedStoreOp op) {
3300 VectorType maskVType = op.getMaskVectorType();
3301 VectorType valueVType = op.getVectorType();
3302 MemRefType memType = op.getMemRefType();
3303
3304 if (valueVType.getElementType() != memType.getElementType())
3305 return op.emitOpError("base and valueToStore element type should match");
3306 if (llvm::size(op.indices()) != memType.getRank())
3307 return op.emitOpError("requires ") << memType.getRank() << " indices";
3308 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3309 return op.emitOpError("expected valueToStore dim to match mask dim");
3310 return success();
3311 }
3312
3313 namespace {
3314 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
3315 public:
3316 using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const3317 LogicalResult matchAndRewrite(MaskedStoreOp store,
3318 PatternRewriter &rewriter) const override {
3319 switch (get1DMaskFormat(store.mask())) {
3320 case MaskFormat::AllTrue:
3321 rewriter.replaceOpWithNewOp<vector::StoreOp>(
3322 store, store.valueToStore(), store.base(), store.indices());
3323 return success();
3324 case MaskFormat::AllFalse:
3325 rewriter.eraseOp(store);
3326 return success();
3327 case MaskFormat::Unknown:
3328 return failure();
3329 }
3330 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
3331 }
3332 };
3333 } // namespace
3334
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3335 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3336 MLIRContext *context) {
3337 results.add<MaskedStoreFolder>(context);
3338 }
3339
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3340 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
3341 SmallVectorImpl<OpFoldResult> &results) {
3342 return foldMemRefCast(*this);
3343 }
3344
3345 //===----------------------------------------------------------------------===//
3346 // GatherOp
3347 //===----------------------------------------------------------------------===//
3348
verify(GatherOp op)3349 static LogicalResult verify(GatherOp op) {
3350 VectorType indVType = op.getIndexVectorType();
3351 VectorType maskVType = op.getMaskVectorType();
3352 VectorType resVType = op.getVectorType();
3353 MemRefType memType = op.getMemRefType();
3354
3355 if (resVType.getElementType() != memType.getElementType())
3356 return op.emitOpError("base and result element type should match");
3357 if (llvm::size(op.indices()) != memType.getRank())
3358 return op.emitOpError("requires ") << memType.getRank() << " indices";
3359 if (resVType.getDimSize(0) != indVType.getDimSize(0))
3360 return op.emitOpError("expected result dim to match indices dim");
3361 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3362 return op.emitOpError("expected result dim to match mask dim");
3363 if (resVType != op.getPassThruVectorType())
3364 return op.emitOpError("expected pass_thru of same type as result type");
3365 return success();
3366 }
3367
3368 namespace {
3369 class GatherFolder final : public OpRewritePattern<GatherOp> {
3370 public:
3371 using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const3372 LogicalResult matchAndRewrite(GatherOp gather,
3373 PatternRewriter &rewriter) const override {
3374 switch (get1DMaskFormat(gather.mask())) {
3375 case MaskFormat::AllTrue:
3376 return failure(); // no unmasked equivalent
3377 case MaskFormat::AllFalse:
3378 rewriter.replaceOp(gather, gather.pass_thru());
3379 return success();
3380 case MaskFormat::Unknown:
3381 return failure();
3382 }
3383 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
3384 }
3385 };
3386 } // namespace
3387
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3388 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
3389 MLIRContext *context) {
3390 results.add<GatherFolder>(context);
3391 }
3392
3393 //===----------------------------------------------------------------------===//
3394 // ScatterOp
3395 //===----------------------------------------------------------------------===//
3396
verify(ScatterOp op)3397 static LogicalResult verify(ScatterOp op) {
3398 VectorType indVType = op.getIndexVectorType();
3399 VectorType maskVType = op.getMaskVectorType();
3400 VectorType valueVType = op.getVectorType();
3401 MemRefType memType = op.getMemRefType();
3402
3403 if (valueVType.getElementType() != memType.getElementType())
3404 return op.emitOpError("base and valueToStore element type should match");
3405 if (llvm::size(op.indices()) != memType.getRank())
3406 return op.emitOpError("requires ") << memType.getRank() << " indices";
3407 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
3408 return op.emitOpError("expected valueToStore dim to match indices dim");
3409 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3410 return op.emitOpError("expected valueToStore dim to match mask dim");
3411 return success();
3412 }
3413
3414 namespace {
3415 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
3416 public:
3417 using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const3418 LogicalResult matchAndRewrite(ScatterOp scatter,
3419 PatternRewriter &rewriter) const override {
3420 switch (get1DMaskFormat(scatter.mask())) {
3421 case MaskFormat::AllTrue:
3422 return failure(); // no unmasked equivalent
3423 case MaskFormat::AllFalse:
3424 rewriter.eraseOp(scatter);
3425 return success();
3426 case MaskFormat::Unknown:
3427 return failure();
3428 }
3429 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
3430 }
3431 };
3432 } // namespace
3433
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3434 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
3435 MLIRContext *context) {
3436 results.add<ScatterFolder>(context);
3437 }
3438
3439 //===----------------------------------------------------------------------===//
3440 // ExpandLoadOp
3441 //===----------------------------------------------------------------------===//
3442
verify(ExpandLoadOp op)3443 static LogicalResult verify(ExpandLoadOp op) {
3444 VectorType maskVType = op.getMaskVectorType();
3445 VectorType passVType = op.getPassThruVectorType();
3446 VectorType resVType = op.getVectorType();
3447 MemRefType memType = op.getMemRefType();
3448
3449 if (resVType.getElementType() != memType.getElementType())
3450 return op.emitOpError("base and result element type should match");
3451 if (llvm::size(op.indices()) != memType.getRank())
3452 return op.emitOpError("requires ") << memType.getRank() << " indices";
3453 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3454 return op.emitOpError("expected result dim to match mask dim");
3455 if (resVType != passVType)
3456 return op.emitOpError("expected pass_thru of same type as result type");
3457 return success();
3458 }
3459
3460 namespace {
3461 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
3462 public:
3463 using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const3464 LogicalResult matchAndRewrite(ExpandLoadOp expand,
3465 PatternRewriter &rewriter) const override {
3466 switch (get1DMaskFormat(expand.mask())) {
3467 case MaskFormat::AllTrue:
3468 rewriter.replaceOpWithNewOp<vector::LoadOp>(
3469 expand, expand.getType(), expand.base(), expand.indices());
3470 return success();
3471 case MaskFormat::AllFalse:
3472 rewriter.replaceOp(expand, expand.pass_thru());
3473 return success();
3474 case MaskFormat::Unknown:
3475 return failure();
3476 }
3477 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
3478 }
3479 };
3480 } // namespace
3481
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3482 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3483 MLIRContext *context) {
3484 results.add<ExpandLoadFolder>(context);
3485 }
3486
3487 //===----------------------------------------------------------------------===//
3488 // CompressStoreOp
3489 //===----------------------------------------------------------------------===//
3490
verify(CompressStoreOp op)3491 static LogicalResult verify(CompressStoreOp op) {
3492 VectorType maskVType = op.getMaskVectorType();
3493 VectorType valueVType = op.getVectorType();
3494 MemRefType memType = op.getMemRefType();
3495
3496 if (valueVType.getElementType() != memType.getElementType())
3497 return op.emitOpError("base and valueToStore element type should match");
3498 if (llvm::size(op.indices()) != memType.getRank())
3499 return op.emitOpError("requires ") << memType.getRank() << " indices";
3500 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3501 return op.emitOpError("expected valueToStore dim to match mask dim");
3502 return success();
3503 }
3504
3505 namespace {
3506 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
3507 public:
3508 using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const3509 LogicalResult matchAndRewrite(CompressStoreOp compress,
3510 PatternRewriter &rewriter) const override {
3511 switch (get1DMaskFormat(compress.mask())) {
3512 case MaskFormat::AllTrue:
3513 rewriter.replaceOpWithNewOp<vector::StoreOp>(
3514 compress, compress.valueToStore(), compress.base(),
3515 compress.indices());
3516 return success();
3517 case MaskFormat::AllFalse:
3518 rewriter.eraseOp(compress);
3519 return success();
3520 case MaskFormat::Unknown:
3521 return failure();
3522 }
3523 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
3524 }
3525 };
3526 } // namespace
3527
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3528 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3529 MLIRContext *context) {
3530 results.add<CompressStoreFolder>(context);
3531 }
3532
3533 //===----------------------------------------------------------------------===//
3534 // ShapeCastOp
3535 //===----------------------------------------------------------------------===//
3536
3537 /// Returns true if each element of 'a' is equal to the product of a contiguous
3538 /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)3539 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3540 unsigned rankA = a.size();
3541 unsigned rankB = b.size();
3542 assert(rankA < rankB);
3543
3544 unsigned i = 0;
3545 unsigned j = 0;
3546 while (i < rankA && j < rankB) {
3547 int64_t dimA = a[i];
3548 int64_t dimB = 1;
3549 while (dimB < dimA && j < rankB)
3550 dimB *= b[j++];
3551 if (dimA != dimB)
3552 break;
3553 ++i;
3554
3555 // Handle the case when trailing dimensions are of size 1.
3556 // Include them into the contiguous sequence.
3557 auto isOne = [](int64_t v) { return v == 1; };
3558 if (i < rankA && llvm::all_of(a.slice(i), isOne))
3559 i = rankA;
3560 if (j < rankB && llvm::all_of(b.slice(j), isOne))
3561 j = rankB;
3562 }
3563
3564 return i == rankA && j == rankB;
3565 }
3566
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)3567 static LogicalResult verifyVectorShapeCast(Operation *op,
3568 VectorType sourceVectorType,
3569 VectorType resultVectorType) {
3570 // Check that element type is the same.
3571 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
3572 return op->emitOpError("source/result vectors must have same element type");
3573 auto sourceShape = sourceVectorType.getShape();
3574 auto resultShape = resultVectorType.getShape();
3575
3576 // Check that product of source dim sizes matches product of result dim sizes.
3577 int64_t sourceDimProduct = std::accumulate(
3578 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
3579 int64_t resultDimProduct = std::accumulate(
3580 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
3581 if (sourceDimProduct != resultDimProduct)
3582 return op->emitOpError("source/result number of elements must match");
3583
3584 // Check that expanding/contracting rank cases.
3585 unsigned sourceRank = sourceVectorType.getRank();
3586 unsigned resultRank = resultVectorType.getRank();
3587 if (sourceRank < resultRank) {
3588 if (!isValidShapeCast(sourceShape, resultShape))
3589 return op->emitOpError("invalid shape cast");
3590 } else if (sourceRank > resultRank) {
3591 if (!isValidShapeCast(resultShape, sourceShape))
3592 return op->emitOpError("invalid shape cast");
3593 }
3594 return success();
3595 }
3596
verify(ShapeCastOp op)3597 static LogicalResult verify(ShapeCastOp op) {
3598 auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
3599 auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
3600
3601 // Check if source/result are of vector type.
3602 if (sourceVectorType && resultVectorType)
3603 return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
3604
3605 return success();
3606 }
3607
fold(ArrayRef<Attribute> operands)3608 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
3609 // Nop shape cast.
3610 if (source().getType() == result().getType())
3611 return source();
3612
3613 // Canceling shape casts.
3614 if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
3615 if (result().getType() == otherOp.source().getType())
3616 return otherOp.source();
3617 setOperand(otherOp.source());
3618 return getResult();
3619 }
3620 return {};
3621 }
3622
3623 namespace {
3624 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
3625 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
3626 public:
3627 using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
3628
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const3629 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
3630 PatternRewriter &rewriter) const override {
3631 auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
3632 if (!constantOp)
3633 return failure();
3634 // Only handle splat for now.
3635 auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
3636 if (!dense)
3637 return failure();
3638 auto newAttr = DenseElementsAttr::get(
3639 shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
3640 rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
3641 return success();
3642 }
3643 };
3644
3645 } // namespace
3646
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3647 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3648 MLIRContext *context) {
3649 // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
3650 results.add<ShapeCastConstantFolder>(context);
3651 }
3652
3653 //===----------------------------------------------------------------------===//
3654 // VectorBitCastOp
3655 //===----------------------------------------------------------------------===//
3656
verify(BitCastOp op)3657 static LogicalResult verify(BitCastOp op) {
3658 auto sourceVectorType = op.getSourceVectorType();
3659 auto resultVectorType = op.getResultVectorType();
3660
3661 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
3662 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
3663 return op.emitOpError("dimension size mismatch at: ") << i;
3664 }
3665
3666 DataLayout dataLayout = DataLayout::closest(op);
3667 if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
3668 sourceVectorType.getShape().back() !=
3669 dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
3670 resultVectorType.getShape().back())
3671 return op.emitOpError(
3672 "source/result bitwidth of the minor 1-D vectors must be equal");
3673
3674 return success();
3675 }
3676
fold(ArrayRef<Attribute> operands)3677 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
3678 // Nop cast.
3679 if (source().getType() == result().getType())
3680 return source();
3681
3682 // Canceling bitcasts.
3683 if (auto otherOp = source().getDefiningOp<BitCastOp>())
3684 if (result().getType() == otherOp.source().getType())
3685 return otherOp.source();
3686
3687 Attribute sourceConstant = operands.front();
3688 if (!sourceConstant)
3689 return {};
3690
3691 Type srcElemType = getSourceVectorType().getElementType();
3692 Type dstElemType = getResultVectorType().getElementType();
3693
3694 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
3695 if (floatPack.isSplat()) {
3696 auto splat = floatPack.getSplatValue<FloatAttr>();
3697
3698 // Casting fp16 into fp32.
3699 if (srcElemType.isF16() && dstElemType.isF32()) {
3700 uint32_t bits = static_cast<uint32_t>(
3701 splat.getValue().bitcastToAPInt().getZExtValue());
3702 // Duplicate the 16-bit pattern.
3703 bits = (bits << 16) | (bits & 0xffff);
3704 APInt intBits(32, bits);
3705 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
3706 return DenseElementsAttr::get(getResultVectorType(), floatBits);
3707 }
3708 }
3709 }
3710
3711 return {};
3712 }
3713
3714 //===----------------------------------------------------------------------===//
3715 // TypeCastOp
3716 //===----------------------------------------------------------------------===//
3717
extractShape(MemRefType memRefType)3718 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
3719 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
3720 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
3721 memRefType.getShape().end());
3722 if (vectorType)
3723 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
3724 return res;
3725 }
3726
3727 /// Build the canonical memRefType with a single vector.
3728 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)3729 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
3730 Value source) {
3731 result.addOperands(source);
3732 MemRefType memRefType = source.getType().cast<MemRefType>();
3733 VectorType vectorType =
3734 VectorType::get(extractShape(memRefType),
3735 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
3736 result.addTypes(
3737 MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
3738 }
3739
verify(TypeCastOp op)3740 static LogicalResult verify(TypeCastOp op) {
3741 MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
3742 if (!canonicalType.getAffineMaps().empty())
3743 return op.emitOpError("expects operand to be a memref with no layout");
3744 if (!op.getResultMemRefType().getAffineMaps().empty())
3745 return op.emitOpError("expects result to be a memref with no layout");
3746 if (op.getResultMemRefType().getMemorySpace() !=
3747 op.getMemRefType().getMemorySpace())
3748 return op.emitOpError("expects result in same memory space");
3749
3750 auto sourceType = op.getMemRefType();
3751 auto resultType = op.getResultMemRefType();
3752 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
3753 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
3754 return op.emitOpError(
3755 "expects result and operand with same underlying scalar type: ")
3756 << resultType;
3757 if (extractShape(sourceType) != extractShape(resultType))
3758 return op.emitOpError(
3759 "expects concatenated result and operand shapes to be equal: ")
3760 << resultType;
3761 return success();
3762 }
3763
3764 //===----------------------------------------------------------------------===//
3765 // TransposeOp
3766 //===----------------------------------------------------------------------===//
3767
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)3768 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
3769 Value vector, ArrayRef<int64_t> transp) {
3770 VectorType vt = vector.getType().cast<VectorType>();
3771 SmallVector<int64_t, 4> transposedShape(vt.getRank());
3772 for (unsigned i = 0; i < transp.size(); ++i)
3773 transposedShape[i] = vt.getShape()[transp[i]];
3774
3775 result.addOperands(vector);
3776 result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
3777 result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
3778 }
3779
3780 // Eliminates transpose operations, which produce values identical to their
3781 // input values. This happens when the dimensions of the input vector remain in
3782 // their original order after the transpose operation.
fold(ArrayRef<Attribute> operands)3783 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
3784 SmallVector<int64_t, 4> transp;
3785 getTransp(transp);
3786
3787 // Check if the permutation of the dimensions contains sequential values:
3788 // {0, 1, 2, ...}.
3789 for (int64_t i = 0, e = transp.size(); i < e; i++) {
3790 if (transp[i] != i)
3791 return {};
3792 }
3793
3794 return vector();
3795 }
3796
verify(vector::TransposeOp op)3797 static LogicalResult verify(vector::TransposeOp op) {
3798 VectorType vectorType = op.getVectorType();
3799 VectorType resultType = op.getResultType();
3800 int64_t rank = resultType.getRank();
3801 if (vectorType.getRank() != rank)
3802 return op.emitOpError("vector result rank mismatch: ") << rank;
3803 // Verify transposition array.
3804 auto transpAttr = op.transp().getValue();
3805 int64_t size = transpAttr.size();
3806 if (rank != size)
3807 return op.emitOpError("transposition length mismatch: ") << size;
3808 SmallVector<bool, 8> seen(rank, false);
3809 for (auto ta : llvm::enumerate(transpAttr)) {
3810 int64_t i = ta.value().cast<IntegerAttr>().getInt();
3811 if (i < 0 || i >= rank)
3812 return op.emitOpError("transposition index out of range: ") << i;
3813 if (seen[i])
3814 return op.emitOpError("duplicate position index: ") << i;
3815 seen[i] = true;
3816 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
3817 return op.emitOpError("dimension size mismatch at: ") << i;
3818 }
3819 return success();
3820 }
3821
3822 namespace {
3823
3824 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
3825 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
3826 public:
3827 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
3828
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const3829 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
3830 PatternRewriter &rewriter) const override {
3831 // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
3832 auto getPermutation = [](vector::TransposeOp transpose) {
3833 SmallVector<int64_t, 4> permutation;
3834 transpose.getTransp(permutation);
3835 return permutation;
3836 };
3837
3838 // Composes two permutations: result[i] = permutation1[permutation2[i]].
3839 auto composePermutations = [](ArrayRef<int64_t> permutation1,
3840 ArrayRef<int64_t> permutation2) {
3841 SmallVector<int64_t, 4> result;
3842 for (auto index : permutation2)
3843 result.push_back(permutation1[index]);
3844 return result;
3845 };
3846
3847 // Return if the input of 'transposeOp' is not defined by another transpose.
3848 vector::TransposeOp parentTransposeOp =
3849 transposeOp.vector().getDefiningOp<vector::TransposeOp>();
3850 if (!parentTransposeOp)
3851 return failure();
3852
3853 SmallVector<int64_t, 4> permutation = composePermutations(
3854 getPermutation(parentTransposeOp), getPermutation(transposeOp));
3855 // Replace 'transposeOp' with a new transpose operation.
3856 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
3857 transposeOp, transposeOp.getResult().getType(),
3858 parentTransposeOp.vector(),
3859 vector::getVectorSubscriptAttr(rewriter, permutation));
3860 return success();
3861 }
3862 };
3863
3864 } // end anonymous namespace
3865
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3866 void vector::TransposeOp::getCanonicalizationPatterns(
3867 RewritePatternSet &results, MLIRContext *context) {
3868 results.add<TransposeFolder>(context);
3869 }
3870
getTransp(SmallVectorImpl<int64_t> & results)3871 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
3872 populateFromInt64AttrArray(transp(), results);
3873 }
3874
3875 //===----------------------------------------------------------------------===//
3876 // ConstantMaskOp
3877 //===----------------------------------------------------------------------===//
3878
verify(ConstantMaskOp & op)3879 static LogicalResult verify(ConstantMaskOp &op) {
3880 // Verify that array attr size matches the rank of the vector result.
3881 auto resultType = op.getResult().getType().cast<VectorType>();
3882 if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
3883 return op.emitOpError(
3884 "must specify array attr of size equal vector result rank");
3885 // Verify that each array attr element is in bounds of corresponding vector
3886 // result dimension size.
3887 auto resultShape = resultType.getShape();
3888 SmallVector<int64_t, 4> maskDimSizes;
3889 for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
3890 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
3891 if (attrValue < 0 || attrValue > resultShape[it.index()])
3892 return op.emitOpError(
3893 "array attr of size out of bounds of vector result dimension size");
3894 maskDimSizes.push_back(attrValue);
3895 }
3896 // Verify that if one mask dim size is zero, they all should be zero (because
3897 // the mask region is a conjunction of each mask dimension interval).
3898 bool any_zeros = llvm::is_contained(maskDimSizes, 0);
3899 bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
3900 if (any_zeros && !all_zeros)
3901 return op.emitOpError("expected all mask dim sizes to be zeros, "
3902 "as a result of conjunction with zero mask dim");
3903 return success();
3904 }
3905
3906 //===----------------------------------------------------------------------===//
3907 // CreateMaskOp
3908 //===----------------------------------------------------------------------===//
3909
verify(CreateMaskOp op)3910 static LogicalResult verify(CreateMaskOp op) {
3911 // Verify that an operand was specified for each result vector each dimension.
3912 if (op.getNumOperands() !=
3913 op.getResult().getType().cast<VectorType>().getRank())
3914 return op.emitOpError(
3915 "must specify an operand for each result vector dimension");
3916 return success();
3917 }
3918
3919 namespace {
3920
3921 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
3922 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
3923 public:
3924 using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
3925
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const3926 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
3927 PatternRewriter &rewriter) const override {
3928 // Return if any of 'createMaskOp' operands are not defined by a constant.
3929 auto is_not_def_by_constant = [](Value operand) {
3930 return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
3931 };
3932 if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
3933 return failure();
3934 // Gather constant mask dimension sizes.
3935 SmallVector<int64_t, 4> maskDimSizes;
3936 for (auto operand : createMaskOp.operands()) {
3937 auto defOp = operand.getDefiningOp();
3938 maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
3939 }
3940 // Replace 'createMaskOp' with ConstantMaskOp.
3941 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3942 createMaskOp, createMaskOp.getResult().getType(),
3943 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
3944 return success();
3945 }
3946 };
3947
3948 } // end anonymous namespace
3949
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3950 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
3951 MLIRContext *context) {
3952 results.add<CreateMaskFolder>(context);
3953 }
3954
populateVectorToVectorCanonicalizationPatterns(RewritePatternSet & patterns)3955 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
3956 RewritePatternSet &patterns) {
3957 patterns
3958 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
3959 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
3960 StridedSliceConstantMaskFolder, TransposeFolder>(
3961 patterns.getContext());
3962 }
3963
3964 #define GET_OP_CLASSES
3965 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
3966