1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/VectorUtils.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/TypeUtilities.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Support/MathExtras.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/bit.h"
34 #include <numeric>
35 
36 #include "mlir/Dialect/Vector/VectorOpsDialect.cpp.inc"
37 // Pull in all enum type and utility function definitions.
38 #include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
39 
40 using namespace mlir;
41 using namespace mlir::vector;
42 
43 /// Helper enum to classify mask value.
44 enum class MaskFormat {
45   AllTrue = 0,
46   AllFalse = 1,
47   Unknown = 2,
48 };
49 
50 /// Helper method to classify a 1-D mask value. Currently, the method
51 /// looks "under the hood" of a constant value with dense attributes
52 /// and a constant mask operation (since the client may be called at
53 /// various stages during progressive lowering).
get1DMaskFormat(Value mask)54 static MaskFormat get1DMaskFormat(Value mask) {
55   if (auto c = mask.getDefiningOp<ConstantOp>()) {
56     // Inspect constant dense values. We count up for bits that
57     // are set, count down for bits that are cleared, and bail
58     // when a mix is detected.
59     if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
60       int64_t val = 0;
61       for (bool b : denseElts.getValues<bool>())
62         if (b && val >= 0)
63           val++;
64         else if (!b && val <= 0)
65           val--;
66         else
67           return MaskFormat::Unknown;
68       if (val > 0)
69         return MaskFormat::AllTrue;
70       if (val < 0)
71         return MaskFormat::AllFalse;
72     }
73   } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
74     // Inspect constant mask index. If the index exceeds the
75     // dimension size, all bits are set. If the index is zero
76     // or less, no bits are set.
77     ArrayAttr masks = m.mask_dim_sizes();
78     assert(masks.size() == 1);
79     int64_t i = masks[0].cast<IntegerAttr>().getInt();
80     int64_t u = m.getType().getDimSize(0);
81     if (i >= u)
82       return MaskFormat::AllTrue;
83     if (i <= 0)
84       return MaskFormat::AllFalse;
85   }
86   return MaskFormat::Unknown;
87 }
88 
89 // Helper for verifying combining kinds in contractions and reductions.
isSupportedCombiningKind(CombiningKind combiningKind,Type elementType)90 static bool isSupportedCombiningKind(CombiningKind combiningKind,
91                                      Type elementType) {
92   switch (combiningKind) {
93   case CombiningKind::ADD:
94   case CombiningKind::MUL:
95     return elementType.isIntOrIndexOrFloat();
96   case CombiningKind::MINUI:
97   case CombiningKind::MINSI:
98   case CombiningKind::MAXUI:
99   case CombiningKind::MAXSI:
100   case CombiningKind::AND:
101   case CombiningKind::OR:
102   case CombiningKind::XOR:
103     return elementType.isIntOrIndex();
104   case CombiningKind::MINF:
105   case CombiningKind::MAXF:
106     return elementType.isa<FloatType>();
107   }
108   return false;
109 }
110 
111 /// Return true if the last dimension of the MemRefType has unit stride. Also
112 /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)113 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
114   int64_t offset;
115   SmallVector<int64_t> strides;
116   auto successStrides = getStridesAndOffset(type, strides, offset);
117   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // CombiningKindAttr
122 //===----------------------------------------------------------------------===//
123 
124 namespace mlir {
125 namespace vector {
126 namespace detail {
127 struct BitmaskEnumStorage : public AttributeStorage {
128   using KeyTy = uint64_t;
129 
BitmaskEnumStoragemlir::vector::detail::BitmaskEnumStorage130   BitmaskEnumStorage(KeyTy val) : value(val) {}
131 
operator ==mlir::vector::detail::BitmaskEnumStorage132   bool operator==(const KeyTy &key) const { return value == key; }
133 
constructmlir::vector::detail::BitmaskEnumStorage134   static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
135                                        const KeyTy &key) {
136     return new (allocator.allocate<BitmaskEnumStorage>())
137         BitmaskEnumStorage(key);
138   }
139 
140   KeyTy value = 0;
141 };
142 } // namespace detail
143 } // namespace vector
144 } // namespace mlir
145 
get(CombiningKind kind,MLIRContext * context)146 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
147                                          MLIRContext *context) {
148   return Base::get(context, static_cast<uint64_t>(kind));
149 }
150 
getKind() const151 CombiningKind CombiningKindAttr::getKind() const {
152   return static_cast<CombiningKind>(getImpl()->value);
153 }
154 
155 static constexpr const CombiningKind combiningKindsList[] = {
156     // clang-format off
157     CombiningKind::ADD,
158     CombiningKind::MUL,
159     CombiningKind::MINUI,
160     CombiningKind::MINSI,
161     CombiningKind::MINF,
162     CombiningKind::MAXUI,
163     CombiningKind::MAXSI,
164     CombiningKind::MAXF,
165     CombiningKind::AND,
166     CombiningKind::OR,
167     CombiningKind::XOR,
168     // clang-format on
169 };
170 
print(DialectAsmPrinter & printer) const171 void CombiningKindAttr::print(DialectAsmPrinter &printer) const {
172   printer << "kind<";
173   auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
174     return bitEnumContains(this->getKind(), kind);
175   });
176   llvm::interleaveComma(kinds, printer,
177                         [&](auto kind) { printer << stringifyEnum(kind); });
178   printer << ">";
179 }
180 
parse(DialectAsmParser & parser)181 Attribute CombiningKindAttr::parse(DialectAsmParser &parser) {
182   if (failed(parser.parseLess()))
183     return {};
184 
185   StringRef elemName;
186   if (failed(parser.parseKeyword(&elemName)))
187     return {};
188 
189   auto kind = symbolizeCombiningKind(elemName);
190   if (!kind) {
191     parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
192         << elemName;
193     return {};
194   }
195 
196   if (failed(parser.parseGreater()))
197     return {};
198 
199   return CombiningKindAttr::get(kind.getValue(), parser.getContext());
200 }
201 
parseAttribute(DialectAsmParser & parser,Type type) const202 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
203                                         Type type) const {
204   StringRef attrKind;
205   if (parser.parseKeyword(&attrKind))
206     return {};
207 
208   if (attrKind == "kind")
209     return CombiningKindAttr::parse(parser);
210 
211   parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
212   return {};
213 }
214 
printAttribute(Attribute attr,DialectAsmPrinter & os) const215 void VectorDialect::printAttribute(Attribute attr,
216                                    DialectAsmPrinter &os) const {
217   if (auto ck = attr.dyn_cast<CombiningKindAttr>())
218     ck.print(os);
219   else
220     llvm_unreachable("Unknown attribute type");
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // VectorDialect
225 //===----------------------------------------------------------------------===//
226 
initialize()227 void VectorDialect::initialize() {
228   addAttributes<CombiningKindAttr>();
229 
230   addOperations<
231 #define GET_OP_LIST
232 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
233       >();
234 }
235 
236 /// Materialize a single constant operation from a given attribute value with
237 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)238 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
239                                               Attribute value, Type type,
240                                               Location loc) {
241   return builder.create<ConstantOp>(loc, type, value);
242 }
243 
getVectorSubscriptType(Builder & builder)244 IntegerType vector::getVectorSubscriptType(Builder &builder) {
245   return builder.getIntegerType(64);
246 }
247 
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)248 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
249                                          ArrayRef<int64_t> values) {
250   return builder.getI64ArrayAttr(values);
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // MultiDimReductionOp
255 //===----------------------------------------------------------------------===//
256 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<bool> reductionMask,CombiningKind kind)257 void vector::MultiDimReductionOp::build(OpBuilder &builder,
258                                         OperationState &result, Value source,
259                                         ArrayRef<bool> reductionMask,
260                                         CombiningKind kind) {
261   result.addOperands(source);
262   auto sourceVectorType = source.getType().cast<VectorType>();
263   auto targetType = MultiDimReductionOp::inferDestType(
264       sourceVectorType.getShape(), reductionMask,
265       sourceVectorType.getElementType());
266   result.addTypes(targetType);
267 
268   SmallVector<int64_t> reductionDims;
269   for (auto en : llvm::enumerate(reductionMask))
270     if (en.value())
271       reductionDims.push_back(en.index());
272   result.addAttribute(getReductionDimsAttrName(),
273                       builder.getI64ArrayAttr(reductionDims));
274   result.addAttribute(getKindAttrName(),
275                       CombiningKindAttr::get(kind, builder.getContext()));
276 }
277 
verify(MultiDimReductionOp op)278 static LogicalResult verify(MultiDimReductionOp op) {
279   auto reductionMask = op.getReductionMask();
280   auto targetType = MultiDimReductionOp::inferDestType(
281       op.getSourceVectorType().getShape(), reductionMask,
282       op.getSourceVectorType().getElementType());
283   // TODO: update to support 0-d vectors when available.
284   if (targetType != op.getDestType())
285     return op.emitError("invalid output vector type: ")
286            << op.getDestType() << " (expected: " << targetType << ")";
287   return success();
288 }
289 
fold(ArrayRef<Attribute> operands)290 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
291   // Single parallel dim, this is a noop.
292   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
293     return source();
294   return {};
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // ReductionOp
299 //===----------------------------------------------------------------------===//
300 
verify(ReductionOp op)301 static LogicalResult verify(ReductionOp op) {
302   // Verify for 1-D vector.
303   int64_t rank = op.getVectorType().getRank();
304   if (rank != 1)
305     return op.emitOpError("unsupported reduction rank: ") << rank;
306 
307   // Verify supported reduction kind.
308   StringRef strKind = op.kind();
309   auto maybeKind = symbolizeCombiningKind(strKind);
310   if (!maybeKind)
311     return op.emitOpError("unknown reduction kind: ") << strKind;
312 
313   Type eltType = op.dest().getType();
314   if (!isSupportedCombiningKind(*maybeKind, eltType))
315     return op.emitOpError("unsupported reduction type '")
316            << eltType << "' for kind '" << op.kind() << "'";
317 
318   // Verify optional accumulator.
319   if (!op.acc().empty()) {
320     if (strKind != "add" && strKind != "mul")
321       return op.emitOpError("no accumulator for reduction kind: ") << strKind;
322     if (!eltType.isa<FloatType>())
323       return op.emitOpError("no accumulator for type: ") << eltType;
324   }
325 
326   return success();
327 }
328 
parseReductionOp(OpAsmParser & parser,OperationState & result)329 static ParseResult parseReductionOp(OpAsmParser &parser,
330                                     OperationState &result) {
331   SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
332   Type redType;
333   Type resType;
334   Attribute attr;
335   if (parser.parseAttribute(attr, "kind", result.attributes) ||
336       parser.parseComma() || parser.parseOperandList(operandsInfo) ||
337       parser.parseColonType(redType) ||
338       parser.parseKeywordType("into", resType) ||
339       (operandsInfo.size() > 0 &&
340        parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
341       (operandsInfo.size() > 1 &&
342        parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
343       parser.addTypeToList(resType, result.types))
344     return failure();
345   if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
346     return parser.emitError(parser.getNameLoc(),
347                             "unsupported number of operands");
348   return success();
349 }
350 
print(OpAsmPrinter & p,ReductionOp op)351 static void print(OpAsmPrinter &p, ReductionOp op) {
352   p << " \"" << op.kind() << "\", " << op.vector();
353   if (!op.acc().empty())
354     p << ", " << op.acc();
355   p << " : " << op.vector().getType() << " into " << op.dest().getType();
356 }
357 
getVectorReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value vector)358 Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
359                                          Location loc, Value vector) {
360   Type scalarType = vector.getType().cast<ShapedType>().getElementType();
361   switch (op) {
362   case AtomicRMWKind::addf:
363   case AtomicRMWKind::addi:
364     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
365                                                builder.getStringAttr("add"),
366                                                vector, ValueRange{});
367   case AtomicRMWKind::mulf:
368   case AtomicRMWKind::muli:
369     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
370                                                builder.getStringAttr("mul"),
371                                                vector, ValueRange{});
372   case AtomicRMWKind::minf:
373   case AtomicRMWKind::mins:
374   case AtomicRMWKind::minu:
375     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
376                                                builder.getStringAttr("min"),
377                                                vector, ValueRange{});
378   case AtomicRMWKind::maxf:
379   case AtomicRMWKind::maxs:
380   case AtomicRMWKind::maxu:
381     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
382                                                builder.getStringAttr("max"),
383                                                vector, ValueRange{});
384   // TODO: Add remaining reduction operations.
385   default:
386     (void)emitOptionalError(loc, "Reduction operation type not supported");
387     break;
388   }
389   return nullptr;
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // ContractionOp
394 //===----------------------------------------------------------------------===//
395 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)396 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
397                                   Value lhs, Value rhs, Value acc,
398                                   ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
399                                   ArrayRef<StringRef> iteratorTypes) {
400   result.addOperands({lhs, rhs, acc});
401   result.addTypes(acc.getType());
402   result.addAttribute(getIndexingMapsAttrName(),
403                       builder.getAffineMapArrayAttr(
404                           AffineMap::inferFromExprList(indexingExprs)));
405   result.addAttribute(getIteratorTypesAttrName(),
406                       builder.getStrArrayAttr(iteratorTypes));
407 }
408 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)409 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
410                                   Value lhs, Value rhs, Value acc,
411                                   ArrayAttr indexingMaps,
412                                   ArrayAttr iteratorTypes) {
413   result.addOperands({lhs, rhs, acc});
414   result.addTypes(acc.getType());
415   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
416   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
417   result.addAttribute(ContractionOp::getKindAttrName(),
418                       CombiningKindAttr::get(ContractionOp::getDefaultKind(),
419                                              builder.getContext()));
420 }
421 
parseContractionOp(OpAsmParser & parser,OperationState & result)422 static ParseResult parseContractionOp(OpAsmParser &parser,
423                                       OperationState &result) {
424   OpAsmParser::OperandType lhsInfo;
425   OpAsmParser::OperandType rhsInfo;
426   OpAsmParser::OperandType accInfo;
427   SmallVector<OpAsmParser::OperandType, 2> masksInfo;
428   SmallVector<Type, 2> types;
429   Type resultType;
430   auto loc = parser.getCurrentLocation();
431   DictionaryAttr dictAttr;
432   // TODO: Unify linalg op attribute parsing.
433   if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
434       parser.parseOperand(lhsInfo) || parser.parseComma() ||
435       parser.parseOperand(rhsInfo) || parser.parseComma() ||
436       parser.parseOperand(accInfo) ||
437       parser.parseTrailingOperandList(masksInfo) ||
438       parser.parseOptionalAttrDict(result.attributes) ||
439       parser.parseColonTypeList(types) ||
440       parser.parseKeywordType("into", resultType) ||
441       parser.resolveOperand(lhsInfo, types[0], result.operands) ||
442       parser.resolveOperand(rhsInfo, types[1], result.operands) ||
443       parser.resolveOperand(accInfo, resultType, result.operands) ||
444       parser.addTypeToList(resultType, result.types))
445     return failure();
446   result.attributes.assign(dictAttr.getValue().begin(),
447                            dictAttr.getValue().end());
448   if (!result.attributes.get(ContractionOp::getKindAttrName())) {
449     result.addAttribute(ContractionOp::getKindAttrName(),
450                         CombiningKindAttr::get(ContractionOp::getDefaultKind(),
451                                                result.getContext()));
452   }
453   if (masksInfo.empty())
454     return success();
455   if (masksInfo.size() != 2)
456     return parser.emitError(parser.getNameLoc(),
457                             "expected zero or exactly 2 vector mask operands");
458   auto lhsType = types[0].cast<VectorType>();
459   auto rhsType = types[1].cast<VectorType>();
460   auto maskElementType = parser.getBuilder().getI1Type();
461   std::array<Type, 2> maskTypes = {
462       VectorType::get(lhsType.getShape(), maskElementType),
463       VectorType::get(rhsType.getShape(), maskElementType)};
464   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
465     return failure();
466   return success();
467 }
468 
print(OpAsmPrinter & p,ContractionOp op)469 static void print(OpAsmPrinter &p, ContractionOp op) {
470   // TODO: Unify printing code with linalg ops.
471   auto attrNames = op.getTraitAttrNames();
472   llvm::StringSet<> traitAttrsSet;
473   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
474   SmallVector<NamedAttribute, 8> attrs;
475   for (auto attr : op->getAttrs())
476     if (traitAttrsSet.count(attr.first.strref()) > 0)
477       attrs.push_back(attr);
478 
479   auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
480   p << " " << dictAttr << " " << op.lhs() << ", ";
481   p << op.rhs() << ", " << op.acc();
482   if (op.masks().size() == 2)
483     p << ", " << op.masks();
484 
485   p.printOptionalAttrDict(op->getAttrs(), attrNames);
486   p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
487     << op.getResultType();
488 }
489 
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)490 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
491                          const std::vector<std::pair<int64_t, int64_t>> &map) {
492   for (auto &dimPair : map) {
493     if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
494         dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
495         lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
496       return false;
497   }
498   return true;
499 }
500 
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)501 static LogicalResult verifyOutputShape(
502     ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
503     Type resType,
504     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
505     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
506   DenseSet<int64_t> lhsContractingDimSet;
507   DenseSet<int64_t> rhsContractingDimSet;
508   for (auto &dimPair : contractingDimMap) {
509     lhsContractingDimSet.insert(dimPair.first);
510     rhsContractingDimSet.insert(dimPair.second);
511   }
512   DenseSet<int64_t> rhsBatchDimSet;
513   for (auto &dimPair : batchDimMap)
514     rhsBatchDimSet.insert(dimPair.second);
515 
516   // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
517   SmallVector<int64_t, 4> expectedResultDims;
518   for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
519     if (lhsContractingDimSet.count(i) > 0)
520       continue;
521     expectedResultDims.push_back(lhsType.getDimSize(i));
522   }
523 
524   // Add free dimensions from 'rhsType' to 'expectedResultDims'.
525   for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
526     if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
527       continue;
528     expectedResultDims.push_back(rhsType.getDimSize(i));
529   }
530 
531   // Verify 'expectedResultDims'.
532   if (expectedResultDims.size() == 0) {
533     // No batch or free dimension implies a scalar result.
534     if (resType.isa<VectorType>() || accType.isa<VectorType>())
535       return op.emitOpError("invalid accumulator/result vector shape");
536   } else {
537     // At least one batch or free dimension implies a vector result.
538     auto resVectorType = resType.dyn_cast<VectorType>();
539     auto accVectorType = accType.dyn_cast<VectorType>();
540     if (!resVectorType || !accVectorType)
541       return op.emitOpError("invalid accumulator/result vector shape");
542 
543     // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
544     // types fully define the result vector type. This assumes the affine maps
545     // are well-formed, which must have been verified already.
546     MLIRContext *ctx = op.getContext();
547     AffineMap lhsMap = op.getIndexingMaps()[0];
548     AffineMap rhsMap = op.getIndexingMaps()[1];
549     SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
550     for (auto pair :
551          {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
552       VectorType v = pair.first;
553       auto map = pair.second;
554       for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
555         unsigned pos = map.getDimPosition(idx);
556         if (!extents[pos])
557           extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
558       }
559     }
560     assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
561            "expected extent along all dimensions.");
562 
563     AffineMap resMap = op.getIndexingMaps()[2];
564     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
565                                      /*symCount=*/0, extents, ctx);
566     // Compose the resMap with the extentsMap, which is a constant map.
567     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
568     assert(llvm::all_of(
569                expectedMap.getResults(),
570                [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
571            "expected constant extent along all dimensions.");
572     // Extract the expected shape and build the type.
573     auto expectedShape = llvm::to_vector<4>(
574         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
575           return e.cast<AffineConstantExpr>().getValue();
576         }));
577     auto expected =
578         VectorType::get(expectedShape, resVectorType.getElementType());
579     if (resVectorType != expected || accVectorType != expected)
580       return op.emitOpError(
581                  "invalid accumulator/result vector shape, expected: ")
582              << expected;
583   }
584   return success();
585 }
586 
verify(ContractionOp op)587 static LogicalResult verify(ContractionOp op) {
588   auto lhsType = op.getLhsType();
589   auto rhsType = op.getRhsType();
590   auto accType = op.getAccType();
591   auto resType = op.getResultType();
592 
593   // Verify that an indexing map was specified for each vector operand.
594   if (op.indexing_maps().size() != 3)
595     return op.emitOpError("expected an indexing map for each vector operand");
596 
597   // Verify that each index map has 'numIterators' inputs, no symbols, and
598   // that the number of map outputs equals the rank of its associated
599   // vector operand.
600   unsigned numIterators = op.iterator_types().getValue().size();
601   for (auto it : llvm::enumerate(op.indexing_maps())) {
602     auto index = it.index();
603     auto map = it.value().cast<AffineMapAttr>().getValue();
604     if (map.getNumSymbols() != 0)
605       return op.emitOpError("expected indexing map ")
606              << index << " to have no symbols";
607     auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
608     unsigned rank = vectorType ? vectorType.getShape().size() : 0;
609     // Verify that the map has the right number of inputs, outputs, and indices.
610     // This also correctly accounts for (..) -> () for rank-0 results.
611     if (map.getNumDims() != numIterators)
612       return op.emitOpError("expected indexing map ")
613              << index << " to have " << numIterators << " number of inputs";
614     if (map.getNumResults() != rank)
615       return op.emitOpError("expected indexing map ")
616              << index << " to have " << rank << " number of outputs";
617     if (!map.isProjectedPermutation())
618       return op.emitOpError("expected indexing map ")
619              << index << " to be a projected permutation of its inputs";
620   }
621 
622   auto contractingDimMap = op.getContractingDimMap();
623   auto batchDimMap = op.getBatchDimMap();
624 
625   // Verify at least one contracting dimension pair was specified.
626   if (contractingDimMap.empty())
627     return op.emitOpError("expected at least one contracting dimension pair");
628 
629   // Verify contracting dimension map was properly constructed.
630   if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
631     return op.emitOpError("invalid contracting dimension map");
632 
633   // Verify batch dimension map was properly constructed.
634   if (!verifyDimMap(lhsType, rhsType, batchDimMap))
635     return op.emitOpError("invalid batch dimension map");
636 
637   // Verify 'accType' and 'resType' shape.
638   if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
639                                contractingDimMap, batchDimMap)))
640     return failure();
641 
642   // Verify that either two vector masks are set or none are set.
643   auto lhsMaskType = op.getLHSVectorMaskType();
644   auto rhsMaskType = op.getRHSVectorMaskType();
645   if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
646     return op.emitOpError("invalid number of vector masks specified");
647   if (lhsMaskType && rhsMaskType) {
648     // Verify mask rank == argument rank.
649     if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
650         rhsMaskType.getShape().size() != rhsType.getShape().size())
651       return op.emitOpError("invalid vector mask rank");
652   }
653 
654   // Verify supported combining kind.
655   auto vectorType = resType.dyn_cast<VectorType>();
656   auto elementType = vectorType ? vectorType.getElementType() : resType;
657   if (!isSupportedCombiningKind(op.kind(), elementType))
658     return op.emitOpError("unsupported contraction type");
659 
660   return success();
661 }
662 
getTraitAttrNames()663 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
664   static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
665                                          getIteratorTypesAttrName(),
666                                          ContractionOp::getKindAttrName()};
667   return llvm::makeArrayRef(names);
668 }
669 
getResultIndex(AffineMap map,AffineExpr targetExpr)670 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
671   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
672     if (targetExpr == map.getResult(i))
673       return i;
674   return -1;
675 }
676 
677 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)678 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
679           StringRef targetIteratorTypeName, MLIRContext *context) {
680   std::vector<std::pair<int64_t, int64_t>> dimMap;
681   for (auto it : llvm::enumerate(iteratorTypes)) {
682     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
683     if (iteratorTypeName != targetIteratorTypeName)
684       continue;
685     // Search lhs/rhs map results for 'targetExpr'.
686     auto targetExpr = getAffineDimExpr(it.index(), context);
687     int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
688     int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
689     if (lhsDim >= 0 && rhsDim >= 0)
690       dimMap.push_back({lhsDim, rhsDim});
691   }
692   return dimMap;
693 }
694 
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)695 void ContractionOp::getIterationBounds(
696     SmallVectorImpl<int64_t> &iterationBounds) {
697   auto lhsShape = getLhsType().getShape();
698   auto resVectorType = getResultType().dyn_cast<VectorType>();
699   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
700   SmallVector<int64_t, 2> iterationShape;
701   for (auto it : llvm::enumerate(iterator_types())) {
702     // Search lhs/rhs map results for 'targetExpr'.
703     auto targetExpr = getAffineDimExpr(it.index(), getContext());
704     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
705     if (iteratorTypeName == getReductionIteratorTypeName()) {
706       // Get reduction dim size from lhs shape (same size in rhsShape).
707       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
708       assert(lhsDimIndex >= 0);
709       iterationBounds.push_back(lhsShape[lhsDimIndex]);
710       continue;
711     }
712     // Get parallel dimension size from result shape.
713     int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
714     assert(resDimIndex >= 0);
715     assert(resVectorType != nullptr);
716     iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
717   }
718 }
719 
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)720 void ContractionOp::getIterationIndexMap(
721     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
722   unsigned numMaps = indexing_maps().getValue().size();
723   iterationIndexMap.resize(numMaps);
724   for (auto it : llvm::enumerate(indexing_maps())) {
725     auto index = it.index();
726     auto map = it.value().cast<AffineMapAttr>().getValue();
727     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
728       auto dim = map.getResult(i).cast<AffineDimExpr>();
729       iterationIndexMap[index][dim.getPosition()] = i;
730     }
731   }
732 }
733 
getContractingDimMap()734 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
735   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
736   return getDimMap(indexingMaps, iterator_types(),
737                    getReductionIteratorTypeName(), getContext());
738 }
739 
getBatchDimMap()740 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
741   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
742   return getDimMap(indexingMaps, iterator_types(),
743                    getParallelIteratorTypeName(), getContext());
744 }
745 
getIndexingMaps()746 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
747   return llvm::to_vector<4>(
748       llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
749         return mapAttr.cast<AffineMapAttr>().getValue();
750       }));
751 }
752 
getShapeForUnroll()753 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
754   SmallVector<int64_t, 4> shape;
755   getIterationBounds(shape);
756   return shape;
757 }
758 
759 /// Return a fused vector::ContractionOp which represents a patterns such as:
760 ///
761 /// ```mlir
762 ///    %c0 = vector.constant 0: ...
763 ///    %c = vector.contract %a, %b, %c0: ...
764 ///    %e = add %c, %d: ...
765 /// ```
766 ///
767 /// by:
768 ///
769 /// ```mlir
770 ///    %e = vector.contract %a, %b, %d: ...
771 /// ```
772 ///
773 /// Return null if the canonicalization does not apply.
774 // TODO: This should be a folding of Add into Contract in core but while they
775 // live in different dialects, it is not possible without unnatural
776 // dependencies.
777 template <typename AddOpType>
778 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
779   using OpRewritePattern<AddOpType>::OpRewritePattern;
780 
matchAndRewriteCanonicalizeContractAdd781   LogicalResult matchAndRewrite(AddOpType addOp,
782                                 PatternRewriter &rewriter) const override {
783     auto canonicalize = [&](Value maybeContraction,
784                             Value otherOperand) -> vector::ContractionOp {
785       vector::ContractionOp contractionOp =
786           dyn_cast_or_null<vector::ContractionOp>(
787               maybeContraction.getDefiningOp());
788       if (!contractionOp)
789         return vector::ContractionOp();
790       if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
791               contractionOp.acc().getDefiningOp())) {
792         if (maybeZero.value() ==
793             rewriter.getZeroAttr(contractionOp.acc().getType())) {
794           BlockAndValueMapping bvm;
795           bvm.map(contractionOp.acc(), otherOperand);
796           auto newContraction =
797               cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
798           rewriter.replaceOp(addOp, newContraction.getResult());
799           return newContraction;
800         }
801       }
802       return vector::ContractionOp();
803     };
804 
805     Value a = addOp->getOperand(0), b = addOp->getOperand(1);
806     vector::ContractionOp contract = canonicalize(a, b);
807     contract = contract ? contract : canonicalize(b, a);
808     return contract ? success() : failure();
809   }
810 };
811 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)812 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
813                                                 MLIRContext *context) {
814   results.add<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
815       context);
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // ExtractElementOp
820 //===----------------------------------------------------------------------===//
821 
build(OpBuilder & builder,OperationState & result,Value source,Value position)822 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
823                                      Value source, Value position) {
824   result.addOperands({source, position});
825   result.addTypes(source.getType().cast<VectorType>().getElementType());
826 }
827 
build(OpBuilder & builder,OperationState & result,Value source,int64_t position)828 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
829                                      Value source, int64_t position) {
830   Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
831   build(builder, result, source, pos);
832 }
833 
verify(vector::ExtractElementOp op)834 static LogicalResult verify(vector::ExtractElementOp op) {
835   VectorType vectorType = op.getVectorType();
836   if (vectorType.getRank() != 1)
837     return op.emitOpError("expected 1-D vector");
838   return success();
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // ExtractOp
843 //===----------------------------------------------------------------------===//
844 
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)845 static Type inferExtractOpResultType(VectorType vectorType,
846                                      ArrayAttr position) {
847   if (static_cast<int64_t>(position.size()) == vectorType.getRank())
848     return vectorType.getElementType();
849   return VectorType::get(vectorType.getShape().drop_front(position.size()),
850                          vectorType.getElementType());
851 }
852 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)853 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
854                               Value source, ArrayRef<int64_t> position) {
855   result.addOperands(source);
856   auto positionAttr = getVectorSubscriptAttr(builder, position);
857   result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
858                                            positionAttr));
859   result.addAttribute(getPositionAttrName(), positionAttr);
860 }
861 
862 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)863 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
864                               Value source, ValueRange position) {
865   SmallVector<int64_t, 4> positionConstants =
866       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
867         return pos.getDefiningOp<ConstantIndexOp>().getValue();
868       }));
869   build(builder, result, source, positionConstants);
870 }
871 
print(OpAsmPrinter & p,vector::ExtractOp op)872 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
873   p << " " << op.vector() << op.position();
874   p.printOptionalAttrDict(op->getAttrs(), {"position"});
875   p << " : " << op.vector().getType();
876 }
877 
parseExtractOp(OpAsmParser & parser,OperationState & result)878 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
879   llvm::SMLoc attributeLoc, typeLoc;
880   NamedAttrList attrs;
881   OpAsmParser::OperandType vector;
882   Type type;
883   Attribute attr;
884   if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
885       parser.parseAttribute(attr, "position", attrs) ||
886       parser.parseOptionalAttrDict(attrs) ||
887       parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
888     return failure();
889 
890   auto vectorType = type.dyn_cast<VectorType>();
891   if (!vectorType)
892     return parser.emitError(typeLoc, "expected vector type");
893 
894   auto positionAttr = attr.dyn_cast<ArrayAttr>();
895   if (!positionAttr ||
896       static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
897     return parser.emitError(
898         attributeLoc,
899         "expected position attribute of rank smaller than vector rank");
900 
901   Type resType = inferExtractOpResultType(vectorType, positionAttr);
902   result.attributes = attrs;
903   return failure(parser.resolveOperand(vector, type, result.operands) ||
904                  parser.addTypeToList(resType, result.types));
905 }
906 
verify(vector::ExtractOp op)907 static LogicalResult verify(vector::ExtractOp op) {
908   auto positionAttr = op.position().getValue();
909   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
910     return op.emitOpError(
911         "expected position attribute of rank smaller than vector rank");
912   for (auto en : llvm::enumerate(positionAttr)) {
913     auto attr = en.value().dyn_cast<IntegerAttr>();
914     if (!attr || attr.getInt() < 0 ||
915         attr.getInt() >= op.getVectorType().getDimSize(en.index()))
916       return op.emitOpError("expected position attribute #")
917              << (en.index() + 1)
918              << " to be a non-negative integer smaller than the corresponding "
919                 "vector dimension";
920   }
921   return success();
922 }
923 
924 template <typename IntType>
extractVector(ArrayAttr arrayAttr)925 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
926   return llvm::to_vector<4>(llvm::map_range(
927       arrayAttr.getAsRange<IntegerAttr>(),
928       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
929 }
930 
931 /// Fold the result of chains of ExtractOp in place by simply concatenating the
932 /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)933 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
934   if (!extractOp.vector().getDefiningOp<ExtractOp>())
935     return failure();
936 
937   SmallVector<int64_t, 4> globalPosition;
938   ExtractOp currentOp = extractOp;
939   auto extractedPos = extractVector<int64_t>(currentOp.position());
940   globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
941   while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
942     currentOp = nextOp;
943     auto extractedPos = extractVector<int64_t>(currentOp.position());
944     globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
945   }
946   extractOp.setOperand(currentOp.vector());
947   // OpBuilder is only used as a helper to build an I64ArrayAttr.
948   OpBuilder b(extractOp.getContext());
949   std::reverse(globalPosition.begin(), globalPosition.end());
950   extractOp->setAttr(ExtractOp::getPositionAttrName(),
951                      b.getI64ArrayAttr(globalPosition));
952   return success();
953 }
954 
955 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
foldExtractOpFromTranspose(ExtractOp extractOp)956 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
957   auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
958   if (!transposeOp)
959     return failure();
960 
961   auto permutation = extractVector<unsigned>(transposeOp.transp());
962   auto extractedPos = extractVector<int64_t>(extractOp.position());
963 
964   // If transposition permutation is larger than the ExtractOp, all minor
965   // dimensions must be an identity for folding to occur. If not, individual
966   // elements within the extracted value are transposed and this is not just a
967   // simple folding.
968   unsigned minorRank = permutation.size() - extractedPos.size();
969   MLIRContext *ctx = extractOp.getContext();
970   AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
971   AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
972   if (minorMap && !minorMap.isMinorIdentity())
973     return failure();
974 
975   //   %1 = transpose %0[x, y, z] : vector<axbxcxf32>
976   //   %2 = extract %1[u, v] : vector<..xf32>
977   // may turn into:
978   //   %2 = extract %0[w, x] : vector<..xf32>
979   // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
980   // -1 denotes the inverse.
981   permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
982   // The major submap has fewer results but the same number of dims. To compose
983   // cleanly, we need to drop dims to form a "square matrix". This is possible
984   // because:
985   //   (a) this is a permutation map and
986   //   (b) the minor map has already been checked to be identity.
987   // Therefore, the major map cannot contain dims of position greater or equal
988   // than the number of results.
989   assert(llvm::all_of(permutationMap.getResults(),
990                       [&](AffineExpr e) {
991                         auto dim = e.dyn_cast<AffineDimExpr>();
992                         return dim && dim.getPosition() <
993                                           permutationMap.getNumResults();
994                       }) &&
995          "Unexpected map results depend on higher rank positions");
996   // Project on the first domain dimensions to allow composition.
997   permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
998                                   permutationMap.getResults(), ctx);
999 
1000   extractOp.setOperand(transposeOp.vector());
1001   // Compose the inverse permutation map with the extractedPos.
1002   auto newExtractedPos =
1003       inversePermutation(permutationMap).compose(extractedPos);
1004   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1005   OpBuilder b(extractOp.getContext());
1006   extractOp->setAttr(ExtractOp::getPositionAttrName(),
1007                      b.getI64ArrayAttr(newExtractedPos));
1008 
1009   return success();
1010 }
1011 
1012 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
1013 /// result is always the input to some InsertOp.
foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp)1014 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
1015   MLIRContext *context = extractOp.getContext();
1016   AffineMap permutationMap;
1017   auto extractedPos = extractVector<unsigned>(extractOp.position());
1018   // Walk back a chain of InsertOp/TransposeOp until we hit a match.
1019   // Compose TransposeOp permutations as we walk back.
1020   auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
1021   auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
1022   while (insertOp || transposeOp) {
1023     if (transposeOp) {
1024       // If it is transposed, compose the map and iterate.
1025       auto permutation = extractVector<unsigned>(transposeOp.transp());
1026       AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
1027       if (!permutationMap)
1028         permutationMap = newMap;
1029       else if (newMap.getNumInputs() != permutationMap.getNumResults())
1030         return Value();
1031       else
1032         permutationMap = newMap.compose(permutationMap);
1033       // Compute insert/transpose for the next iteration.
1034       Value transposed = transposeOp.vector();
1035       insertOp = transposed.getDefiningOp<vector::InsertOp>();
1036       transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
1037       continue;
1038     }
1039 
1040     assert(insertOp);
1041     Value insertionDest = insertOp.dest();
1042     // If it is inserted into, either the position matches and we have a
1043     // successful folding; or we iterate until we run out of
1044     // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
1045     // produces a new vector with 1 modified value/slice in exactly the static
1046     // position we need to match.
1047     auto insertedPos = extractVector<unsigned>(insertOp.position());
1048     // Trivial permutations are solved with position equality checks.
1049     if (!permutationMap || permutationMap.isIdentity()) {
1050       if (extractedPos == insertedPos)
1051         return insertOp.source();
1052       // Fallthrough: if the position does not match, just skip to the next
1053       // producing `vector.insert` / `vector.transpose`.
1054       // Compute insert/transpose for the next iteration.
1055       insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1056       transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1057       continue;
1058     }
1059 
1060     // More advanced permutations require application of the permutation.
1061     // However, the rank of `insertedPos` may be different from that of the
1062     // `permutationMap`. To support such case, we need to:
1063     //   1. apply on the `insertedPos.size()` major dimensions
1064     //   2. check the other dimensions of the permutation form a minor identity.
1065     assert(permutationMap.isPermutation() && "expected a permutation");
1066     if (insertedPos.size() == extractedPos.size()) {
1067       bool fold = true;
1068       for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
1069         auto pos = permutationMap.getDimPosition(idx);
1070         if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
1071           fold = false;
1072           break;
1073         }
1074       }
1075       if (fold) {
1076         assert(permutationMap.getNumResults() >= insertedPos.size() &&
1077                "expected map of rank larger than insert indexing");
1078         unsigned minorRank =
1079             permutationMap.getNumResults() - insertedPos.size();
1080         AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
1081         if (!minorMap || minorMap.isMinorIdentity())
1082           return insertOp.source();
1083       }
1084     }
1085 
1086     // If we haven't found a match, just continue to the next producing
1087     // `vector.insert` / `vector.transpose`.
1088     // Compute insert/transpose for the next iteration.
1089     insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1090     transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1091   }
1092   return Value();
1093 }
1094 
1095 /// Fold extractOp with scalar result coming from BroadcastOp.
foldExtractFromBroadcast(ExtractOp extractOp)1096 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1097   auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
1098   if (!broadcastOp)
1099     return Value();
1100   if (extractOp.getType() == broadcastOp.getSourceType())
1101     return broadcastOp.source();
1102   auto getRank = [](Type type) {
1103     return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1104   };
1105   unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
1106   unsigned extractResultRank = getRank(extractOp.getType());
1107   if (extractResultRank < broadcasrSrcRank) {
1108     auto extractPos = extractVector<int64_t>(extractOp.position());
1109     unsigned rankDiff = broadcasrSrcRank - extractResultRank;
1110     extractPos.erase(
1111         extractPos.begin(),
1112         std::next(extractPos.begin(), extractPos.size() - rankDiff));
1113     extractOp.setOperand(broadcastOp.source());
1114     // OpBuilder is only used as a helper to build an I64ArrayAttr.
1115     OpBuilder b(extractOp.getContext());
1116     extractOp->setAttr(ExtractOp::getPositionAttrName(),
1117                        b.getI64ArrayAttr(extractPos));
1118     return extractOp.getResult();
1119   }
1120   // TODO: In case the rank of the broadcast source is greater than the rank of
1121   // the extract result this can be combined into a new broadcast op. This needs
1122   // to be added a canonicalization pattern if needed.
1123   return Value();
1124 }
1125 
1126 // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)1127 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1128   auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
1129   if (!shapeCastOp)
1130     return Value();
1131   // Get the nth dimension size starting from lowest dimension.
1132   auto getDimReverse = [](VectorType type, int64_t n) {
1133     return type.getShape().take_back(n + 1).front();
1134   };
1135   int64_t destinationRank =
1136       extractOp.getType().isa<VectorType>()
1137           ? extractOp.getType().cast<VectorType>().getRank()
1138           : 0;
1139   if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1140     return Value();
1141   if (destinationRank > 0) {
1142     auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1143     for (int64_t i = 0; i < destinationRank; i++) {
1144       // The lowest dimension of of the destination must match the lowest
1145       // dimension of the shapecast op source.
1146       // TODO: This case could be support in a canonicalization pattern.
1147       if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1148           getDimReverse(destinationType, i))
1149         return Value();
1150     }
1151   }
1152   // Extract the strides associated with the extract op vector source. Then use
1153   // this to calculate a linearized position for the extract.
1154   auto extractedPos = extractVector<int64_t>(extractOp.position());
1155   std::reverse(extractedPos.begin(), extractedPos.end());
1156   SmallVector<int64_t, 4> strides;
1157   int64_t stride = 1;
1158   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1159     strides.push_back(stride);
1160     stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1161   }
1162 
1163   int64_t position = linearize(extractedPos, strides);
1164   // Then extract the strides associated to the shapeCast op vector source and
1165   // delinearize the position using those strides.
1166   SmallVector<int64_t, 4> newStrides;
1167   int64_t numDimension =
1168       shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1169   stride = 1;
1170   for (int64_t i = 0; i < numDimension; i++) {
1171     newStrides.push_back(stride);
1172     stride *=
1173         getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1174   }
1175   std::reverse(newStrides.begin(), newStrides.end());
1176   SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1177   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1178   OpBuilder b(extractOp.getContext());
1179   extractOp->setAttr(ExtractOp::getPositionAttrName(),
1180                      b.getI64ArrayAttr(newPosition));
1181   extractOp.setOperand(shapeCastOp.source());
1182   return extractOp.getResult();
1183 }
1184 
fold(ArrayRef<Attribute>)1185 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1186   if (position().empty())
1187     return vector();
1188   if (succeeded(foldExtractOpFromExtractChain(*this)))
1189     return getResult();
1190   if (succeeded(foldExtractOpFromTranspose(*this)))
1191     return getResult();
1192   if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
1193     return val;
1194   if (auto val = foldExtractFromBroadcast(*this))
1195     return val;
1196   if (auto val = foldExtractFromShapeCast(*this))
1197     return val;
1198   return OpFoldResult();
1199 }
1200 
1201 namespace {
1202 
1203 // If extractOp is only removing unit dimensions it can be transformed to a
1204 // shapecast.
1205 class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
1206 public:
1207   using OpRewritePattern<ExtractOp>::OpRewritePattern;
1208 
matchAndRewrite(ExtractOp extractOp,PatternRewriter & rewriter) const1209   LogicalResult matchAndRewrite(ExtractOp extractOp,
1210                                 PatternRewriter &rewriter) const override {
1211     auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
1212     if (!dstVecType || extractOp.getVectorType().getNumElements() !=
1213                            dstVecType.getNumElements())
1214       return failure();
1215     rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
1216                                              extractOp.vector());
1217     return success();
1218   }
1219 };
1220 
1221 } // namespace
1222 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1223 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1224                                             MLIRContext *context) {
1225   results.add<ExtractToShapeCast>(context);
1226 }
1227 
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)1228 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1229                                        SmallVectorImpl<int64_t> &results) {
1230   for (auto attr : arrayAttr)
1231     results.push_back(attr.cast<IntegerAttr>().getInt());
1232 }
1233 
1234 //===----------------------------------------------------------------------===//
1235 // ExtractMapOp
1236 //===----------------------------------------------------------------------===//
1237 
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)1238 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1239                          Value vector, ValueRange ids,
1240                          ArrayRef<int64_t> multiplicity,
1241                          AffineMap permutationMap) {
1242   assert(ids.size() == multiplicity.size() &&
1243          ids.size() == permutationMap.getNumResults());
1244   assert(permutationMap.isProjectedPermutation());
1245   VectorType type = vector.getType().cast<VectorType>();
1246   SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1247                                    type.getShape().end());
1248   for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1249     AffineExpr expr = permutationMap.getResult(i);
1250     auto dim = expr.cast<AffineDimExpr>();
1251     newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1252   }
1253   VectorType resultType = VectorType::get(newShape, type.getElementType());
1254   ExtractMapOp::build(builder, result, resultType, vector, ids);
1255 }
1256 
verify(ExtractMapOp op)1257 static LogicalResult verify(ExtractMapOp op) {
1258   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1259     return op.emitOpError(
1260         "expected source and destination vectors of same rank");
1261   unsigned numId = 0;
1262   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1263     if (op.getSourceVectorType().getDimSize(i) %
1264             op.getResultType().getDimSize(i) !=
1265         0)
1266       return op.emitOpError("source vector dimensions must be a multiple of "
1267                             "destination vector dimensions");
1268     if (op.getSourceVectorType().getDimSize(i) !=
1269         op.getResultType().getDimSize(i))
1270       numId++;
1271   }
1272   if (numId != op.ids().size())
1273     return op.emitOpError("expected number of ids must match the number of "
1274                           "dimensions distributed");
1275   return success();
1276 }
1277 
fold(ArrayRef<Attribute> operands)1278 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1279   auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1280   if (insert == nullptr || getType() != insert.vector().getType() ||
1281       ids() != insert.ids())
1282     return {};
1283   return insert.vector();
1284 }
1285 
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)1286 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1287   assert(multiplicity.empty());
1288   for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1289     if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1290       multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1291                              getResultType().getDimSize(i));
1292   }
1293 }
1294 
1295 template <typename MapOp>
calculateImplicitMap(MapOp op)1296 AffineMap calculateImplicitMap(MapOp op) {
1297   SmallVector<AffineExpr, 4> perm;
1298   // Check which dimension have a multiplicity greater than 1 and associated
1299   // them to the IDs in order.
1300   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1301     if (op.getSourceVectorType().getDimSize(i) !=
1302         op.getResultType().getDimSize(i))
1303       perm.push_back(getAffineDimExpr(i, op.getContext()));
1304   }
1305   auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1306                             op.getContext());
1307   return map;
1308 }
1309 
map()1310 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // FmaOp
1314 //===----------------------------------------------------------------------===//
1315 
getShapeForUnroll()1316 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1317   return llvm::to_vector<4>(getVectorType().getShape());
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // BroadcastOp
1322 //===----------------------------------------------------------------------===//
1323 
1324 BroadcastableToResult
isBroadcastableTo(Type srcType,VectorType dstVectorType,std::pair<int,int> * mismatchingDims)1325 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1326                                 std::pair<int, int> *mismatchingDims) {
1327   // Broadcast scalar to vector of the same element type.
1328   if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1329       getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1330     return BroadcastableToResult::Success;
1331   // From now on, only vectors broadcast.
1332   VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1333   if (!srcVectorType)
1334     return BroadcastableToResult::SourceTypeNotAVector;
1335 
1336   int64_t srcRank = srcVectorType.getRank();
1337   int64_t dstRank = dstVectorType.getRank();
1338   if (srcRank > dstRank)
1339     return BroadcastableToResult::SourceRankHigher;
1340   // Source has an exact match or singleton value for all trailing dimensions
1341   // (all leading dimensions are simply duplicated).
1342   int64_t lead = dstRank - srcRank;
1343   for (int64_t r = 0; r < srcRank; ++r) {
1344     int64_t srcDim = srcVectorType.getDimSize(r);
1345     int64_t dstDim = dstVectorType.getDimSize(lead + r);
1346     if (srcDim != 1 && srcDim != dstDim) {
1347       if (mismatchingDims) {
1348         mismatchingDims->first = srcDim;
1349         mismatchingDims->second = dstDim;
1350       }
1351       return BroadcastableToResult::DimensionMismatch;
1352     }
1353   }
1354 
1355   return BroadcastableToResult::Success;
1356 }
1357 
verify(BroadcastOp op)1358 static LogicalResult verify(BroadcastOp op) {
1359   std::pair<int, int> mismatchingDims;
1360   BroadcastableToResult res = isBroadcastableTo(
1361       op.getSourceType(), op.getVectorType(), &mismatchingDims);
1362   if (res == BroadcastableToResult::Success)
1363     return success();
1364   if (res == BroadcastableToResult::SourceRankHigher)
1365     return op.emitOpError("source rank higher than destination rank");
1366   if (res == BroadcastableToResult::DimensionMismatch)
1367     return op.emitOpError("dimension mismatch (")
1368            << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1369   if (res == BroadcastableToResult::SourceTypeNotAVector)
1370     return op.emitOpError("source type is not a vector");
1371   llvm_unreachable("unexpected vector.broadcast op error");
1372 }
1373 
fold(ArrayRef<Attribute> operands)1374 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1375   if (getSourceType() == getVectorType())
1376     return source();
1377   if (!operands[0])
1378     return {};
1379   auto vectorType = getVectorType();
1380   if (operands[0].getType().isIntOrIndexOrFloat())
1381     return DenseElementsAttr::get(vectorType, operands[0]);
1382   if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1383     return DenseElementsAttr::get(vectorType, attr.getSplatValue());
1384   return {};
1385 }
1386 
1387 namespace {
1388 
1389 // BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
1390 // the degenerated case where the broadcast only adds dimensions of size 1 it
1391 // can be replaced by a ShapeCastOp. This canonicalization checks if the total
1392 // number of elements is the same before and after the broadcast to detect if
1393 // the only change in the vector type are new dimensions of size 1.
1394 class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
1395 public:
1396   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1397 
matchAndRewrite(BroadcastOp broadcastOp,PatternRewriter & rewriter) const1398   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1399                                 PatternRewriter &rewriter) const override {
1400     auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
1401     if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
1402                            srcVecType.getNumElements())
1403       return failure();
1404     rewriter.replaceOpWithNewOp<ShapeCastOp>(
1405         broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
1406     return success();
1407   }
1408 };
1409 
1410 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1411 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1412   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1413 
matchAndRewrite__anona11d586e0e11::BroadcastFolder1414   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1415                                 PatternRewriter &rewriter) const override {
1416     auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
1417     if (!srcBroadcast)
1418       return failure();
1419     rewriter.replaceOpWithNewOp<BroadcastOp>(
1420         broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
1421     return success();
1422   }
1423 };
1424 } // namespace
1425 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1426 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1427                                               MLIRContext *context) {
1428   results.add<BroadcastToShapeCast, BroadcastFolder>(context);
1429 }
1430 
1431 //===----------------------------------------------------------------------===//
1432 // ShuffleOp
1433 //===----------------------------------------------------------------------===//
1434 
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)1435 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1436                       Value v2, ArrayRef<int64_t> mask) {
1437   result.addOperands({v1, v2});
1438   auto maskAttr = getVectorSubscriptAttr(builder, mask);
1439   result.addTypes(v1.getType());
1440   result.addAttribute(getMaskAttrName(), maskAttr);
1441 }
1442 
print(OpAsmPrinter & p,ShuffleOp op)1443 static void print(OpAsmPrinter &p, ShuffleOp op) {
1444   p << " " << op.v1() << ", " << op.v2() << " " << op.mask();
1445   p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()});
1446   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1447 }
1448 
verify(ShuffleOp op)1449 static LogicalResult verify(ShuffleOp op) {
1450   VectorType resultType = op.getVectorType();
1451   VectorType v1Type = op.getV1VectorType();
1452   VectorType v2Type = op.getV2VectorType();
1453   // Verify ranks.
1454   int64_t resRank = resultType.getRank();
1455   int64_t v1Rank = v1Type.getRank();
1456   int64_t v2Rank = v2Type.getRank();
1457   if (resRank != v1Rank || v1Rank != v2Rank)
1458     return op.emitOpError("rank mismatch");
1459   // Verify all but leading dimension sizes.
1460   for (int64_t r = 1; r < v1Rank; ++r) {
1461     int64_t resDim = resultType.getDimSize(r);
1462     int64_t v1Dim = v1Type.getDimSize(r);
1463     int64_t v2Dim = v2Type.getDimSize(r);
1464     if (resDim != v1Dim || v1Dim != v2Dim)
1465       return op.emitOpError("dimension mismatch");
1466   }
1467   // Verify mask length.
1468   auto maskAttr = op.mask().getValue();
1469   int64_t maskLength = maskAttr.size();
1470   if (maskLength != resultType.getDimSize(0))
1471     return op.emitOpError("mask length mismatch");
1472   // Verify all indices.
1473   int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1474   for (auto en : llvm::enumerate(maskAttr)) {
1475     auto attr = en.value().dyn_cast<IntegerAttr>();
1476     if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1477       return op.emitOpError("mask index #")
1478              << (en.index() + 1) << " out of range";
1479   }
1480   return success();
1481 }
1482 
parseShuffleOp(OpAsmParser & parser,OperationState & result)1483 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1484   OpAsmParser::OperandType v1, v2;
1485   Attribute attr;
1486   VectorType v1Type, v2Type;
1487   if (parser.parseOperand(v1) || parser.parseComma() ||
1488       parser.parseOperand(v2) ||
1489       parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1490                             result.attributes) ||
1491       parser.parseOptionalAttrDict(result.attributes) ||
1492       parser.parseColonType(v1Type) || parser.parseComma() ||
1493       parser.parseType(v2Type) ||
1494       parser.resolveOperand(v1, v1Type, result.operands) ||
1495       parser.resolveOperand(v2, v2Type, result.operands))
1496     return failure();
1497   // Construct resulting type: leading dimension matches mask length,
1498   // all trailing dimensions match the operands.
1499   auto maskAttr = attr.dyn_cast<ArrayAttr>();
1500   if (!maskAttr)
1501     return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1502   int64_t maskLength = maskAttr.size();
1503   if (maskLength <= 0)
1504     return parser.emitError(parser.getNameLoc(), "invalid mask length");
1505   int64_t v1Rank = v1Type.getRank();
1506   SmallVector<int64_t, 4> shape;
1507   shape.reserve(v1Rank);
1508   shape.push_back(maskLength);
1509   for (int64_t r = 1; r < v1Rank; ++r)
1510     shape.push_back(v1Type.getDimSize(r));
1511   VectorType resType = VectorType::get(shape, v1Type.getElementType());
1512   parser.addTypeToList(resType, result.types);
1513   return success();
1514 }
1515 
1516 //===----------------------------------------------------------------------===//
1517 // InsertElementOp
1518 //===----------------------------------------------------------------------===//
1519 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,Value position)1520 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1521                             Value source, Value dest, Value position) {
1522   result.addOperands({source, dest, position});
1523   result.addTypes(dest.getType());
1524 }
1525 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,int64_t position)1526 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1527                             Value source, Value dest, int64_t position) {
1528   Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
1529   build(builder, result, source, dest, pos);
1530 }
1531 
verify(InsertElementOp op)1532 static LogicalResult verify(InsertElementOp op) {
1533   auto dstVectorType = op.getDestVectorType();
1534   if (dstVectorType.getRank() != 1)
1535     return op.emitOpError("expected 1-D vector");
1536   return success();
1537 }
1538 
1539 //===----------------------------------------------------------------------===//
1540 // InsertOp
1541 //===----------------------------------------------------------------------===//
1542 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)1543 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1544                      Value dest, ArrayRef<int64_t> position) {
1545   result.addOperands({source, dest});
1546   auto positionAttr = getVectorSubscriptAttr(builder, position);
1547   result.addTypes(dest.getType());
1548   result.addAttribute(getPositionAttrName(), positionAttr);
1549 }
1550 
1551 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)1552 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1553                      Value dest, ValueRange position) {
1554   SmallVector<int64_t, 4> positionConstants =
1555       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1556         return pos.getDefiningOp<ConstantIndexOp>().getValue();
1557       }));
1558   build(builder, result, source, dest, positionConstants);
1559 }
1560 
verify(InsertOp op)1561 static LogicalResult verify(InsertOp op) {
1562   auto positionAttr = op.position().getValue();
1563   auto destVectorType = op.getDestVectorType();
1564   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1565     return op.emitOpError(
1566         "expected position attribute of rank smaller than dest vector rank");
1567   auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1568   if (srcVectorType &&
1569       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1570        static_cast<unsigned>(destVectorType.getRank())))
1571     return op.emitOpError("expected position attribute rank + source rank to "
1572                           "match dest vector rank");
1573   else if (!srcVectorType && (positionAttr.size() !=
1574                               static_cast<unsigned>(destVectorType.getRank())))
1575     return op.emitOpError(
1576         "expected position attribute rank to match the dest vector rank");
1577   for (auto en : llvm::enumerate(positionAttr)) {
1578     auto attr = en.value().dyn_cast<IntegerAttr>();
1579     if (!attr || attr.getInt() < 0 ||
1580         attr.getInt() >= destVectorType.getDimSize(en.index()))
1581       return op.emitOpError("expected position attribute #")
1582              << (en.index() + 1)
1583              << " to be a non-negative integer smaller than the corresponding "
1584                 "dest vector dimension";
1585   }
1586   return success();
1587 }
1588 
1589 namespace {
1590 
1591 // If insertOp is only inserting unit dimensions it can be transformed to a
1592 // shapecast.
1593 class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
1594 public:
1595   using OpRewritePattern<InsertOp>::OpRewritePattern;
1596 
matchAndRewrite(InsertOp insertOp,PatternRewriter & rewriter) const1597   LogicalResult matchAndRewrite(InsertOp insertOp,
1598                                 PatternRewriter &rewriter) const override {
1599     auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1600     if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1601                            srcVecType.getNumElements())
1602       return failure();
1603     rewriter.replaceOpWithNewOp<ShapeCastOp>(
1604         insertOp, insertOp.getDestVectorType(), insertOp.source());
1605     return success();
1606   }
1607 };
1608 
1609 } // namespace
1610 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1611 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1612                                            MLIRContext *context) {
1613   results.add<InsertToShapeCast>(context);
1614 }
1615 
1616 // Eliminates insert operations that produce values identical to their source
1617 // value. This happens when the source and destination vectors have identical
1618 // sizes.
fold(ArrayRef<Attribute> operands)1619 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
1620   if (position().empty())
1621     return source();
1622   return {};
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // InsertMapOp
1627 //===----------------------------------------------------------------------===//
1628 
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange ids)1629 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1630                         Value vector, Value dest, ValueRange ids) {
1631   InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1632 }
1633 
verify(InsertMapOp op)1634 static LogicalResult verify(InsertMapOp op) {
1635   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1636     return op.emitOpError(
1637         "expected source and destination vectors of same rank");
1638   unsigned numId = 0;
1639   for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1640     if (op.getResultType().getDimSize(i) %
1641             op.getSourceVectorType().getDimSize(i) !=
1642         0)
1643       return op.emitOpError(
1644           "destination vector size must be a multiple of source vector size");
1645     if (op.getResultType().getDimSize(i) !=
1646         op.getSourceVectorType().getDimSize(i))
1647       numId++;
1648   }
1649   if (numId != op.ids().size())
1650     return op.emitOpError("expected number of ids must match the number of "
1651                           "dimensions distributed");
1652   return success();
1653 }
1654 
map()1655 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1656 
1657 //===----------------------------------------------------------------------===//
1658 // InsertStridedSliceOp
1659 //===----------------------------------------------------------------------===//
1660 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)1661 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1662                                  Value source, Value dest,
1663                                  ArrayRef<int64_t> offsets,
1664                                  ArrayRef<int64_t> strides) {
1665   result.addOperands({source, dest});
1666   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1667   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1668   result.addTypes(dest.getType());
1669   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1670   result.addAttribute(getStridesAttrName(), stridesAttr);
1671 }
1672 
1673 // TODO: Should be moved to Tablegen Confined attributes.
1674 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)1675 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1676                                                         ArrayAttr arrayAttr,
1677                                                         ArrayRef<int64_t> shape,
1678                                                         StringRef attrName) {
1679   if (arrayAttr.size() > shape.size())
1680     return op.emitOpError("expected ")
1681            << attrName << " attribute of rank smaller than vector rank";
1682   return success();
1683 }
1684 
1685 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1686 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1687 // Otherwise, the admissible interval is [min, max].
1688 template <typename OpType>
1689 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)1690 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1691                                   int64_t max, StringRef attrName,
1692                                   bool halfOpen = true) {
1693   for (auto attr : arrayAttr) {
1694     auto val = attr.cast<IntegerAttr>().getInt();
1695     auto upper = max;
1696     if (!halfOpen)
1697       upper += 1;
1698     if (val < min || val >= upper)
1699       return op.emitOpError("expected ") << attrName << " to be confined to ["
1700                                          << min << ", " << upper << ")";
1701   }
1702   return success();
1703 }
1704 
1705 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1706 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1707 // Otherwise, the admissible interval is [min, max].
1708 template <typename OpType>
1709 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)1710 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1711                                   ArrayRef<int64_t> shape, StringRef attrName,
1712                                   bool halfOpen = true, int64_t min = 0) {
1713   assert(arrayAttr.size() <= shape.size());
1714   unsigned index = 0;
1715   for (auto it : llvm::zip(arrayAttr, shape)) {
1716     auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1717     auto max = std::get<1>(it);
1718     if (!halfOpen)
1719       max += 1;
1720     if (val < min || val >= max)
1721       return op.emitOpError("expected ")
1722              << attrName << " dimension " << index << " to be confined to ["
1723              << min << ", " << max << ")";
1724     ++index;
1725   }
1726   return success();
1727 }
1728 
1729 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1730 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1731 // Otherwise, the admissible interval is [min, max].
1732 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)1733 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
1734     OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1735     ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1736     bool halfOpen = true, int64_t min = 1) {
1737   assert(arrayAttr1.size() <= shape.size());
1738   assert(arrayAttr2.size() <= shape.size());
1739   unsigned index = 0;
1740   for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1741     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1742     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1743     auto max = std::get<2>(it);
1744     if (!halfOpen)
1745       max += 1;
1746     if (val1 + val2 < 0 || val1 + val2 >= max)
1747       return op.emitOpError("expected sum(")
1748              << attrName1 << ", " << attrName2 << ") dimension " << index
1749              << " to be confined to [" << min << ", " << max << ")";
1750     ++index;
1751   }
1752   return success();
1753 }
1754 
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1755 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1756                                   MLIRContext *context) {
1757   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1758     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
1759   });
1760   return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
1761 }
1762 
verify(InsertStridedSliceOp op)1763 static LogicalResult verify(InsertStridedSliceOp op) {
1764   auto sourceVectorType = op.getSourceVectorType();
1765   auto destVectorType = op.getDestVectorType();
1766   auto offsets = op.offsets();
1767   auto strides = op.strides();
1768   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1769     return op.emitOpError(
1770         "expected offsets of same size as destination vector rank");
1771   if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1772     return op.emitOpError(
1773         "expected strides of same size as source vector rank");
1774   if (sourceVectorType.getRank() > destVectorType.getRank())
1775     return op.emitOpError(
1776         "expected source rank to be smaller than destination rank");
1777 
1778   auto sourceShape = sourceVectorType.getShape();
1779   auto destShape = destVectorType.getShape();
1780   SmallVector<int64_t, 4> sourceShapeAsDestShape(
1781       destShape.size() - sourceShape.size(), 0);
1782   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1783   auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1784   auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1785   if (failed(
1786           isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1787       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1788                                                /*halfOpen=*/false)) ||
1789       failed(isSumOfIntegerArrayAttrConfinedToShape(
1790           op, offsets,
1791           makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1792           offName, "source vector shape",
1793           /*halfOpen=*/false, /*min=*/1)))
1794     return failure();
1795 
1796   return success();
1797 }
1798 
fold(ArrayRef<Attribute> operands)1799 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
1800   if (getSourceVectorType() == getDestVectorType())
1801     return source();
1802   return {};
1803 }
1804 
1805 //===----------------------------------------------------------------------===//
1806 // OuterProductOp
1807 //===----------------------------------------------------------------------===//
1808 
1809 /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)1810 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
1811                            Value lhs, Value rhs, Value acc) {
1812   result.addOperands({lhs, rhs, acc});
1813   result.addTypes(acc.getType());
1814 }
1815 
print(OpAsmPrinter & p,OuterProductOp op)1816 static void print(OpAsmPrinter &p, OuterProductOp op) {
1817   p << " " << op.lhs() << ", " << op.rhs();
1818   if (!op.acc().empty()) {
1819     p << ", " << op.acc();
1820     p.printOptionalAttrDict(op->getAttrs());
1821   }
1822   p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1823 }
1824 
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1825 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1826                                        OperationState &result) {
1827   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1828   Type tLHS, tRHS;
1829   if (parser.parseOperandList(operandsInfo) ||
1830       parser.parseOptionalAttrDict(result.attributes) ||
1831       parser.parseColonType(tLHS) || parser.parseComma() ||
1832       parser.parseType(tRHS))
1833     return failure();
1834   if (operandsInfo.size() < 2)
1835     return parser.emitError(parser.getNameLoc(),
1836                             "expected at least 2 operands");
1837   VectorType vLHS = tLHS.dyn_cast<VectorType>();
1838   VectorType vRHS = tRHS.dyn_cast<VectorType>();
1839   if (!vLHS)
1840     return parser.emitError(parser.getNameLoc(),
1841                             "expected vector type for operand #1");
1842   VectorType resType =
1843       vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1844                              vLHS.getElementType())
1845            : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
1846 
1847   if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
1848     result.attributes.append(
1849         OuterProductOp::getKindAttrName(),
1850         CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
1851                                result.getContext()));
1852   }
1853 
1854   return failure(
1855       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1856       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1857       (operandsInfo.size() > 2 &&
1858        parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1859       parser.addTypeToList(resType, result.types));
1860 }
1861 
verify(OuterProductOp op)1862 static LogicalResult verify(OuterProductOp op) {
1863   Type tRHS = op.getOperandTypeRHS();
1864   VectorType vLHS = op.getOperandVectorTypeLHS(),
1865              vRHS = tRHS.dyn_cast<VectorType>(),
1866              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1867 
1868   if (vLHS.getRank() != 1)
1869     return op.emitOpError("expected 1-d vector for operand #1");
1870 
1871   if (vRHS) {
1872     // Proper OUTER operation.
1873     if (vRHS.getRank() != 1)
1874       return op.emitOpError("expected 1-d vector for operand #2");
1875     if (vRES.getRank() != 2)
1876       return op.emitOpError("expected 2-d vector result");
1877     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1878       return op.emitOpError("expected #1 operand dim to match result dim #1");
1879     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1880       return op.emitOpError("expected #2 operand dim to match result dim #2");
1881   } else {
1882     // An AXPY operation.
1883     if (vRES.getRank() != 1)
1884       return op.emitOpError("expected 1-d vector result");
1885     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1886       return op.emitOpError("expected #1 operand dim to match result dim #1");
1887   }
1888 
1889   if (vACC && vACC != vRES)
1890     return op.emitOpError("expected operand #3 of same type as result type");
1891 
1892   // Verify supported combining kind.
1893   if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
1894     return op.emitOpError("unsupported outerproduct type");
1895 
1896   return success();
1897 }
1898 
1899 //===----------------------------------------------------------------------===//
1900 // ReshapeOp
1901 //===----------------------------------------------------------------------===//
1902 
verify(ReshapeOp op)1903 static LogicalResult verify(ReshapeOp op) {
1904   // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1905   auto inputVectorType = op.getInputVectorType();
1906   auto outputVectorType = op.getOutputVectorType();
1907   int64_t inputShapeRank = op.getNumInputShapeSizes();
1908   int64_t outputShapeRank = op.getNumOutputShapeSizes();
1909   SmallVector<int64_t, 4> fixedVectorSizes;
1910   op.getFixedVectorSizes(fixedVectorSizes);
1911   int64_t numFixedVectorSizes = fixedVectorSizes.size();
1912 
1913   if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1914     return op.emitError("invalid input shape for vector type ")
1915            << inputVectorType;
1916 
1917   if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1918     return op.emitError("invalid output shape for vector type ")
1919            << outputVectorType;
1920 
1921   // Verify that the 'fixedVectorSizes' match an input/output vector shape
1922   // suffix.
1923   unsigned inputVectorRank = inputVectorType.getRank();
1924   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1925     unsigned index = inputVectorRank - numFixedVectorSizes - i;
1926     if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1927       return op.emitError("fixed vector size must match input vector for dim ")
1928              << i;
1929   }
1930 
1931   unsigned outputVectorRank = outputVectorType.getRank();
1932   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1933     unsigned index = outputVectorRank - numFixedVectorSizes - i;
1934     if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1935       return op.emitError("fixed vector size must match output vector for dim ")
1936              << i;
1937   }
1938 
1939   // If all shape operands are produced by constant ops, verify that product
1940   // of dimensions for input/output shape match.
1941   auto isDefByConstant = [](Value operand) {
1942     return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1943   };
1944   if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1945       llvm::all_of(op.output_shape(), isDefByConstant)) {
1946     int64_t numInputElements = 1;
1947     for (auto operand : op.input_shape())
1948       numInputElements *=
1949           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1950     int64_t numOutputElements = 1;
1951     for (auto operand : op.output_shape())
1952       numOutputElements *=
1953           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1954     if (numInputElements != numOutputElements)
1955       return op.emitError("product of input and output shape sizes must match");
1956   }
1957   return success();
1958 }
1959 
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1960 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1961   populateFromInt64AttrArray(fixed_vector_sizes(), results);
1962 }
1963 
1964 //===----------------------------------------------------------------------===//
1965 // ExtractStridedSliceOp
1966 //===----------------------------------------------------------------------===//
1967 
1968 // Inference works as follows:
1969 //   1. Add 'sizes' from prefix of dims in 'offsets'.
1970 //   2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1971 static Type inferStridedSliceOpResultType(VectorType vectorType,
1972                                           ArrayAttr offsets, ArrayAttr sizes,
1973                                           ArrayAttr strides) {
1974   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1975   SmallVector<int64_t, 4> shape;
1976   shape.reserve(vectorType.getRank());
1977   unsigned idx = 0;
1978   for (unsigned e = offsets.size(); idx < e; ++idx)
1979     shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
1980   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1981     shape.push_back(vectorType.getShape()[idx]);
1982 
1983   return VectorType::get(shape, vectorType.getElementType());
1984 }
1985 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1986 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1987                                   Value source, ArrayRef<int64_t> offsets,
1988                                   ArrayRef<int64_t> sizes,
1989                                   ArrayRef<int64_t> strides) {
1990   result.addOperands(source);
1991   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1992   auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
1993   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1994   result.addTypes(
1995       inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1996                                     offsetsAttr, sizesAttr, stridesAttr));
1997   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1998   result.addAttribute(getSizesAttrName(), sizesAttr);
1999   result.addAttribute(getStridesAttrName(), stridesAttr);
2000 }
2001 
verify(ExtractStridedSliceOp op)2002 static LogicalResult verify(ExtractStridedSliceOp op) {
2003   auto type = op.getVectorType();
2004   auto offsets = op.offsets();
2005   auto sizes = op.sizes();
2006   auto strides = op.strides();
2007   if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
2008     op.emitOpError(
2009         "expected offsets, sizes and strides attributes of same size");
2010     return failure();
2011   }
2012 
2013   auto shape = type.getShape();
2014   auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
2015   auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
2016   auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
2017   if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
2018       failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
2019       failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
2020                                                 stridesName)) ||
2021       failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
2022       failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
2023                                                /*halfOpen=*/false,
2024                                                /*min=*/1)) ||
2025       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
2026                                                /*halfOpen=*/false)) ||
2027       failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
2028                                                     offName, sizesName,
2029                                                     /*halfOpen=*/false)))
2030     return failure();
2031 
2032   auto resultType = inferStridedSliceOpResultType(
2033       op.getVectorType(), op.offsets(), op.sizes(), op.strides());
2034   if (op.getResult().getType() != resultType) {
2035     op.emitOpError("expected result type to be ") << resultType;
2036     return failure();
2037   }
2038 
2039   return success();
2040 }
2041 
2042 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
2043 // to use the source of the InsertStrided ops if we can detect that the
2044 // extracted vector is a subset of one of the vector inserted.
2045 static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)2046 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2047   // Helper to extract integer out of ArrayAttr.
2048   auto getElement = [](ArrayAttr array, int idx) {
2049     return array[idx].cast<IntegerAttr>().getInt();
2050   };
2051   ArrayAttr extractOffsets = op.offsets();
2052   ArrayAttr extractStrides = op.strides();
2053   ArrayAttr extractSizes = op.sizes();
2054   auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
2055   while (insertOp) {
2056     if (op.getVectorType().getRank() !=
2057         insertOp.getSourceVectorType().getRank())
2058       return failure();
2059     ArrayAttr insertOffsets = insertOp.offsets();
2060     ArrayAttr insertStrides = insertOp.strides();
2061     // If the rank of extract is greater than the rank of insert, we are likely
2062     // extracting a partial chunk of the vector inserted.
2063     if (extractOffsets.size() > insertOffsets.size())
2064       return failure();
2065     bool patialoverlap = false;
2066     bool disjoint = false;
2067     SmallVector<int64_t, 4> offsetDiffs;
2068     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2069       if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2070         return failure();
2071       int64_t start = getElement(insertOffsets, dim);
2072       int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2073       int64_t offset = getElement(extractOffsets, dim);
2074       int64_t size = getElement(extractSizes, dim);
2075       // Check if the start of the extract offset is in the interval inserted.
2076       if (start <= offset && offset < end) {
2077         // If the extract interval overlaps but is not fully included we may
2078         // have a partial overlap that will prevent any folding.
2079         if (offset + size > end)
2080           patialoverlap = true;
2081         offsetDiffs.push_back(offset - start);
2082         continue;
2083       }
2084       disjoint = true;
2085       break;
2086     }
2087     // The extract element chunk is a subset of the insert element.
2088     if (!disjoint && !patialoverlap) {
2089       op.setOperand(insertOp.source());
2090       // OpBuilder is only used as a helper to build an I64ArrayAttr.
2091       OpBuilder b(op.getContext());
2092       op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
2093                   b.getI64ArrayAttr(offsetDiffs));
2094       return success();
2095     }
2096     // If the chunk extracted is disjoint from the chunk inserted, keep looking
2097     // in the insert chain.
2098     if (disjoint)
2099       insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
2100     else {
2101       // The extracted vector partially overlap the inserted vector, we cannot
2102       // fold.
2103       return failure();
2104     }
2105   }
2106   return failure();
2107 }
2108 
fold(ArrayRef<Attribute> operands)2109 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2110   if (getVectorType() == getResult().getType())
2111     return vector();
2112   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2113     return getResult();
2114   return {};
2115 }
2116 
getOffsets(SmallVectorImpl<int64_t> & results)2117 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2118   populateFromInt64AttrArray(offsets(), results);
2119 }
2120 
2121 namespace {
2122 
2123 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2124 // ConstantMaskOp.
2125 class StridedSliceConstantMaskFolder final
2126     : public OpRewritePattern<ExtractStridedSliceOp> {
2127 public:
2128   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2129 
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2130   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2131                                 PatternRewriter &rewriter) const override {
2132     // Return if 'extractStridedSliceOp' operand is not defined by a
2133     // ConstantMaskOp.
2134     auto defOp = extractStridedSliceOp.vector().getDefiningOp();
2135     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2136     if (!constantMaskOp)
2137       return failure();
2138     // Return if 'extractStridedSliceOp' has non-unit strides.
2139     if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
2140           return attr.cast<IntegerAttr>().getInt() != 1;
2141         }))
2142       return failure();
2143     // Gather constant mask dimension sizes.
2144     SmallVector<int64_t, 4> maskDimSizes;
2145     populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
2146     // Gather strided slice offsets and sizes.
2147     SmallVector<int64_t, 4> sliceOffsets;
2148     populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
2149     SmallVector<int64_t, 4> sliceSizes;
2150     populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
2151 
2152     // Compute slice of vector mask region.
2153     SmallVector<int64_t, 4> sliceMaskDimSizes;
2154     assert(sliceOffsets.size() == maskDimSizes.size());
2155     for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2156       int64_t maskDimSize = std::get<0>(it);
2157       int64_t sliceOffset = std::get<1>(it);
2158       int64_t sliceSize = std::get<2>(it);
2159       int64_t sliceMaskDimSize = std::max(
2160           static_cast<int64_t>(0),
2161           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2162       sliceMaskDimSizes.push_back(sliceMaskDimSize);
2163     }
2164     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2165     // region is a conjunction of mask dim intervals).
2166     if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
2167       sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2168 
2169     // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2170     // region.
2171     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2172         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2173         vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2174     return success();
2175   }
2176 };
2177 
2178 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2179 class StridedSliceConstantFolder final
2180     : public OpRewritePattern<ExtractStridedSliceOp> {
2181 public:
2182   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2183 
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2184   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2185                                 PatternRewriter &rewriter) const override {
2186     // Return if 'extractStridedSliceOp' operand is not defined by a
2187     // ConstantOp.
2188     auto constantOp =
2189         extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
2190     if (!constantOp)
2191       return failure();
2192     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2193     if (!dense)
2194       return failure();
2195     auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2196                                           dense.getSplatValue());
2197     rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
2198     return success();
2199   }
2200 };
2201 
2202 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
getI64SubArray(ArrayAttr arrayAttr,unsigned dropFront=0,unsigned dropBack=0)2203 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
2204                                               unsigned dropFront = 0,
2205                                               unsigned dropBack = 0) {
2206   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
2207   auto range = arrayAttr.getAsRange<IntegerAttr>();
2208   SmallVector<int64_t, 4> res;
2209   res.reserve(arrayAttr.size() - dropFront - dropBack);
2210   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
2211        it != eit; ++it)
2212     res.push_back((*it).getValue().getSExtValue());
2213   return res;
2214 }
2215 
2216 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2217 // BroadcastOp(ExtractStrideSliceOp).
2218 class StridedSliceBroadcast final
2219     : public OpRewritePattern<ExtractStridedSliceOp> {
2220 public:
2221   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2222 
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const2223   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2224                                 PatternRewriter &rewriter) const override {
2225     auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
2226     if (!broadcast)
2227       return failure();
2228     auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
2229     unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
2230     auto dstVecType = op.getType().cast<VectorType>();
2231     unsigned dstRank = dstVecType.getRank();
2232     unsigned rankDiff = dstRank - srcRrank;
2233     // Check if the most inner dimensions of the source of the broadcast are the
2234     // same as the destination of the extract. If this is the case we can just
2235     // use a broadcast as the original dimensions are untouched.
2236     bool lowerDimMatch = true;
2237     for (unsigned i = 0; i < srcRrank; i++) {
2238       if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2239         lowerDimMatch = false;
2240         break;
2241       }
2242     }
2243     Value source = broadcast.source();
2244     if (!lowerDimMatch) {
2245       // The inner dimensions don't match, it means we need to extract from the
2246       // source of the orignal broadcast and then broadcast the extracted value.
2247       source = rewriter.create<ExtractStridedSliceOp>(
2248           op->getLoc(), source,
2249           getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
2250           getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
2251           getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
2252     }
2253     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2254     return success();
2255   }
2256 };
2257 
2258 } // end anonymous namespace
2259 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2260 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2261     RewritePatternSet &results, MLIRContext *context) {
2262   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2263   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2264   results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2265               StridedSliceBroadcast>(context);
2266 }
2267 
2268 //===----------------------------------------------------------------------===//
2269 // TransferReadOp
2270 //===----------------------------------------------------------------------===//
2271 
2272 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)2273 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2274                                           EmitFun emitOpError) {
2275   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2276   for (auto expr : permutationMap.getResults()) {
2277     auto dim = expr.dyn_cast<AffineDimExpr>();
2278     auto zero = expr.dyn_cast<AffineConstantExpr>();
2279     if (zero) {
2280       if (zero.getValue() != 0) {
2281         return emitOpError(
2282             "requires a projected permutation_map (at most one dim or the zero "
2283             "constant can appear in each result)");
2284       }
2285       continue;
2286     }
2287     if (!dim) {
2288       return emitOpError("requires a projected permutation_map (at most one "
2289                          "dim or the zero constant can appear in each result)");
2290     }
2291     if (seen[dim.getPosition()]) {
2292       return emitOpError(
2293           "requires a permutation_map that is a permutation (found one dim "
2294           "used more than once)");
2295     }
2296     seen[dim.getPosition()] = true;
2297   }
2298   return success();
2299 }
2300 
2301 static LogicalResult
verifyTransferOp(VectorTransferOpInterface op,ShapedType shapedType,VectorType vectorType,VectorType maskType,AffineMap permutationMap,ArrayAttr inBounds)2302 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2303                  VectorType vectorType, VectorType maskType,
2304                  AffineMap permutationMap, ArrayAttr inBounds) {
2305   if (shapedType.getRank() == 0 && !op.isZeroD())
2306     return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> "
2307                            "(0) permutation_map");
2308 
2309   if (op->hasAttr("masked")) {
2310     return op->emitOpError("masked attribute has been removed. "
2311                            "Use in_bounds instead.");
2312   }
2313 
2314   if (!shapedType.isa<MemRefType, RankedTensorType>())
2315     return op->emitOpError(
2316         "requires source to be a memref or ranked tensor type");
2317   auto elementType = shapedType.getElementType();
2318   DataLayout dataLayout = DataLayout::closest(op);
2319   if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2320     // Memref or tensor has vector element type.
2321     unsigned sourceVecSize =
2322         dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2323         vectorElementType.getShape().back();
2324     unsigned resultVecSize =
2325         dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2326         vectorType.getShape().back();
2327     if (resultVecSize % sourceVecSize != 0)
2328       return op->emitOpError(
2329           "requires the bitwidth of the minor 1-D vector to be an integral "
2330           "multiple of the bitwidth of the minor 1-D vector of the source");
2331 
2332     unsigned sourceVecEltRank = vectorElementType.getRank();
2333     unsigned resultVecRank = vectorType.getRank();
2334     if (sourceVecEltRank > resultVecRank)
2335       return op->emitOpError(
2336           "requires source vector element and vector result ranks to match.");
2337     unsigned rankOffset = resultVecRank - sourceVecEltRank;
2338     // Check that permutation map results match 'rankOffset' of vector type.
2339     if (permutationMap.getNumResults() != rankOffset)
2340       return op->emitOpError("requires a permutation_map with result dims of "
2341                              "the same rank as the vector type");
2342 
2343     if (maskType)
2344       return op->emitOpError("does not support masks with vector element type");
2345   } else {
2346     // Memref or tensor has scalar element type.
2347     unsigned resultVecSize =
2348         dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2349         vectorType.getShape().back();
2350     if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2351       return op->emitOpError(
2352           "requires the bitwidth of the minor 1-D vector to be an integral "
2353           "multiple of the bitwidth of the source element type");
2354 
2355     // Check that permutation map results match rank of vector type.
2356     if (permutationMap.getNumResults() != vectorType.getRank())
2357       return op->emitOpError("requires a permutation_map with result dims of "
2358                              "the same rank as the vector type");
2359 
2360     VectorType expectedMaskType =
2361         vector::detail::transferMaskType(vectorType, permutationMap);
2362     if (maskType && expectedMaskType != maskType)
2363       return op->emitOpError("expects mask type consistent with permutation "
2364                              "map: ")
2365              << maskType;
2366   }
2367 
2368   if (permutationMap.getNumSymbols() != 0)
2369     return op->emitOpError("requires permutation_map without symbols");
2370   // TODO: implement 0-d vector corner cases.
2371   if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank())
2372     return op->emitOpError("requires a permutation_map with input dims of the "
2373                            "same rank as the source type");
2374 
2375   if (inBounds) {
2376     if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2377       return op->emitOpError("expects the optional in_bounds attr of same rank "
2378                              "as permutation_map results: ")
2379              << AffineMapAttr::get(permutationMap);
2380     for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2381       if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
2382           !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2383         return op->emitOpError("requires broadcast dimensions to be in-bounds");
2384   }
2385 
2386   return success();
2387 }
2388 
2389 /// Builder that sets padding to zero.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> inBounds)2390 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2391                            VectorType vectorType, Value source,
2392                            ValueRange indices, AffineMap permutationMap,
2393                            ArrayRef<bool> inBounds) {
2394   Type elemType = source.getType().cast<ShapedType>().getElementType();
2395   Value padding = builder.create<ConstantOp>(result.location, elemType,
2396                                              builder.getZeroAttr(elemType));
2397   if (inBounds.empty())
2398     return build(builder, result, vectorType, source, indices, permutationMap,
2399                  padding, ArrayAttr());
2400   ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2401   build(builder, result, vectorType, source, indices, permutationMap, padding,
2402         inBoundsArrayAttr);
2403 }
2404 
2405 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,Value padding,ArrayRef<bool> inBounds)2406 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2407                            VectorType vectorType, Value source,
2408                            ValueRange indices, Value padding,
2409                            ArrayRef<bool> inBounds) {
2410   auto permMap = getTransferMinorIdentityMap(
2411       source.getType().cast<ShapedType>(), vectorType);
2412   if (inBounds.empty())
2413     return build(builder, result, vectorType, source, indices, permMap, padding,
2414                  ArrayAttr());
2415   ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2416   build(builder, result, vectorType, source, indices, permMap, padding,
2417         inBoundsArrayAttr);
2418 }
2419 
2420 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
2421 /// (resp. zero).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,ArrayRef<bool> inBounds)2422 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2423                            VectorType vectorType, Value source,
2424                            ValueRange indices, ArrayRef<bool> inBounds) {
2425   auto permMap = getTransferMinorIdentityMap(
2426       source.getType().cast<ShapedType>(), vectorType);
2427   build(builder, result, vectorType, source, indices, permMap, inBounds);
2428 }
2429 
2430 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMap permutationMap,Value padding,ArrayAttr inBounds)2431 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2432                            Type vectorType, Value source, ValueRange indices,
2433                            AffineMap permutationMap, Value padding,
2434                            ArrayAttr inBounds) {
2435   build(builder, result, vectorType, source, indices, permutationMap, padding,
2436         /*mask=*/Value(), inBounds);
2437 }
2438 
2439 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMapAttr permutationMap,Value padding,ArrayAttr inBounds)2440 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2441                            Type vectorType, Value source, ValueRange indices,
2442                            AffineMapAttr permutationMap, Value padding,
2443                            ArrayAttr inBounds) {
2444   build(builder, result, vectorType, source, indices, permutationMap, padding,
2445         /*mask=*/Value(), inBounds);
2446 }
2447 
createScalarOp(OpBuilder & builder,Location loc,Value source,ValueRange indices,ArrayRef<bool> inBounds)2448 Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
2449                                      Value source, ValueRange indices,
2450                                      ArrayRef<bool> inBounds) {
2451   Type elemType = source.getType().cast<ShapedType>().getElementType();
2452   auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
2453   AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
2454                                  getAffineConstantExpr(0, loc.getContext()));
2455   Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
2456                                                       indices, map, inBounds);
2457   return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
2458 }
2459 
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)2460 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2461   SmallVector<StringRef, 3> elidedAttrs;
2462   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2463   if (op.permutation_map().isMinorIdentity())
2464     elidedAttrs.push_back(op.getPermutationMapAttrName());
2465   bool elideInBounds = true;
2466   if (auto inBounds = op.in_bounds()) {
2467     for (auto attr : *inBounds) {
2468       if (attr.template cast<BoolAttr>().getValue()) {
2469         elideInBounds = false;
2470         break;
2471       }
2472     }
2473   }
2474   if (elideInBounds)
2475     elidedAttrs.push_back(op.getInBoundsAttrName());
2476   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2477 }
2478 
print(OpAsmPrinter & p,TransferReadOp op)2479 static void print(OpAsmPrinter &p, TransferReadOp op) {
2480   p << " " << op.source() << "[" << op.indices() << "], " << op.padding();
2481   if (op.mask())
2482     p << ", " << op.mask();
2483   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2484   p << " : " << op.getShapedType() << ", " << op.getVectorType();
2485 }
2486 
parseTransferReadOp(OpAsmParser & parser,OperationState & result)2487 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2488                                        OperationState &result) {
2489   auto &builder = parser.getBuilder();
2490   llvm::SMLoc typesLoc;
2491   OpAsmParser::OperandType sourceInfo;
2492   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2493   OpAsmParser::OperandType paddingInfo;
2494   SmallVector<Type, 2> types;
2495   OpAsmParser::OperandType maskInfo;
2496   // Parsing with support for paddingValue.
2497   if (parser.parseOperand(sourceInfo) ||
2498       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2499       parser.parseComma() || parser.parseOperand(paddingInfo))
2500     return failure();
2501   ParseResult hasMask = parser.parseOptionalComma();
2502   if (hasMask.succeeded()) {
2503     parser.parseOperand(maskInfo);
2504   }
2505   if (parser.parseOptionalAttrDict(result.attributes) ||
2506       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2507     return failure();
2508   if (types.size() != 2)
2509     return parser.emitError(typesLoc, "requires two types");
2510   auto indexType = builder.getIndexType();
2511   auto shapedType = types[0].dyn_cast<ShapedType>();
2512   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2513     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2514   VectorType vectorType = types[1].dyn_cast<VectorType>();
2515   if (!vectorType)
2516     return parser.emitError(typesLoc, "requires vector type");
2517   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2518   Attribute mapAttr = result.attributes.get(permutationAttrName);
2519   if (!mapAttr) {
2520     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2521     mapAttr = AffineMapAttr::get(permMap);
2522     result.attributes.set(permutationAttrName, mapAttr);
2523   }
2524   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2525       parser.resolveOperands(indexInfo, indexType, result.operands) ||
2526       parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2527                             result.operands))
2528     return failure();
2529   if (hasMask.succeeded()) {
2530     if (shapedType.getElementType().dyn_cast<VectorType>())
2531       return parser.emitError(
2532           maskInfo.location, "does not support masks with vector element type");
2533     auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2534     // Instead of adding the mask type as an op type, compute it based on the
2535     // vector type and the permutation map (to keep the type signature small).
2536     auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2537     if (parser.resolveOperand(maskInfo, maskType, result.operands))
2538       return failure();
2539   }
2540   result.addAttribute(
2541       TransferReadOp::getOperandSegmentSizeAttr(),
2542       builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
2543                                 static_cast<int32_t>(hasMask.succeeded())}));
2544   return parser.addTypeToList(vectorType, result.types);
2545 }
2546 
verify(TransferReadOp op)2547 static LogicalResult verify(TransferReadOp op) {
2548   // Consistency of elemental types in source and vector.
2549   ShapedType shapedType = op.getShapedType();
2550   VectorType vectorType = op.getVectorType();
2551   VectorType maskType = op.getMaskType();
2552   auto paddingType = op.padding().getType();
2553   auto permutationMap = op.permutation_map();
2554   auto sourceElementType = shapedType.getElementType();
2555 
2556   if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
2557     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2558 
2559   if (failed(
2560           verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
2561                            shapedType, vectorType, maskType, permutationMap,
2562                            op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2563     return failure();
2564 
2565   if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
2566     // Source has vector element type.
2567     // Check that 'sourceVectorElementType' and 'paddingType' types match.
2568     if (sourceVectorElementType != paddingType)
2569       return op.emitOpError(
2570           "requires source element type and padding type to match.");
2571 
2572   } else {
2573     // Check that 'paddingType' is valid to store in a vector type.
2574     if (!VectorType::isValidElementType(paddingType))
2575       return op.emitOpError("requires valid padding vector elemental type");
2576 
2577     // Check that padding type and vector element types match.
2578     if (paddingType != sourceElementType)
2579       return op.emitOpError(
2580           "requires formal padding and source of the same elemental type");
2581   }
2582 
2583   return verifyPermutationMap(permutationMap,
2584                               [&op](Twine t) { return op.emitOpError(t); });
2585 }
2586 
2587 /// This is a common class used for patterns of the form
2588 /// ```
2589 ///    someop(memrefcast) -> someop
2590 /// ```
2591 /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)2592 static LogicalResult foldMemRefCast(Operation *op) {
2593   bool folded = false;
2594   for (OpOperand &operand : op->getOpOperands()) {
2595     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
2596     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
2597       operand.set(castOp.getOperand());
2598       folded = true;
2599     }
2600   }
2601   return success(folded);
2602 }
2603 
foldTensorCast(Operation * op)2604 static LogicalResult foldTensorCast(Operation *op) {
2605   bool folded = false;
2606   for (OpOperand &operand : op->getOpOperands()) {
2607     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2608     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2609       operand.set(castOp.getOperand());
2610       folded = true;
2611     }
2612   }
2613   return success(folded);
2614 }
2615 
2616 template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)2617 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2618   // TODO: support more aggressive createOrFold on:
2619   // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
2620   if (op.getShapedType().isDynamicDim(indicesIdx))
2621     return false;
2622   Value index = op.indices()[indicesIdx];
2623   auto cstOp = index.getDefiningOp<ConstantIndexOp>();
2624   if (!cstOp)
2625     return false;
2626 
2627   int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
2628   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2629 
2630   return cstOp.getValue() + vectorSize <= sourceSize;
2631 }
2632 
2633 template <typename TransferOp>
foldTransferInBoundsAttribute(TransferOp op)2634 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
2635   // TODO: Be less conservative once we have 0-d vectors.
2636   if (op.isZeroD())
2637     return failure();
2638   AffineMap permutationMap = op.permutation_map();
2639   bool changed = false;
2640   SmallVector<bool, 4> newInBounds;
2641   newInBounds.reserve(op.getTransferRank());
2642   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
2643     // Already marked as in-bounds, nothing to see here.
2644     if (op.isDimInBounds(i)) {
2645       newInBounds.push_back(true);
2646       continue;
2647     }
2648     // Currently out-of-bounds, check whether we can statically determine it is
2649     // inBounds.
2650     auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
2651     assert(dimExpr && "Broadcast dims must be in-bounds");
2652     auto inBounds =
2653         isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
2654     newInBounds.push_back(inBounds);
2655     // We commit the pattern if it is "more inbounds".
2656     changed |= inBounds;
2657   }
2658   if (!changed)
2659     return failure();
2660   // OpBuilder is only used as a helper to build an I64ArrayAttr.
2661   OpBuilder b(op.getContext());
2662   op->setAttr(TransferOp::getInBoundsAttrName(),
2663               b.getBoolArrayAttr(newInBounds));
2664   return success();
2665 }
2666 
2667 ///  ```
2668 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2669 ///    : vector<1x4xf32>, tensor<4x4xf32>
2670 ///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
2671 ///    : tensor<4x4xf32>, vector<1x4xf32>
2672 ///  ```
2673 ///  -> Folds into
2674 ///  ```
2675 ///  %v0
2676 ///  ```
foldRAW(TransferReadOp readOp)2677 static Value foldRAW(TransferReadOp readOp) {
2678   if (!readOp.getShapedType().isa<RankedTensorType>())
2679     return {};
2680   auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
2681   while (defWrite) {
2682     if (checkSameValueRAW(defWrite, readOp))
2683       return defWrite.vector();
2684     if (!isDisjointTransferIndices(
2685             cast<VectorTransferOpInterface>(defWrite.getOperation()),
2686             cast<VectorTransferOpInterface>(readOp.getOperation())))
2687       break;
2688     defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2689   }
2690   return {};
2691 }
2692 
fold(ArrayRef<Attribute>)2693 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2694   if (Value vec = foldRAW(*this))
2695     return vec;
2696   /// transfer_read(memrefcast) -> transfer_read
2697   if (succeeded(foldTransferInBoundsAttribute(*this)))
2698     return getResult();
2699   if (succeeded(foldMemRefCast(*this)))
2700     return getResult();
2701   if (succeeded(foldTensorCast(*this)))
2702     return getResult();
2703   return OpFoldResult();
2704 }
2705 
getShapeForUnroll()2706 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2707   return llvm::to_vector<4>(getVectorType().getShape());
2708 }
2709 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2710 void TransferReadOp::getEffects(
2711     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2712         &effects) {
2713   if (getShapedType().isa<MemRefType>())
2714     effects.emplace_back(MemoryEffects::Read::get(), source(),
2715                          SideEffects::DefaultResource::get());
2716 }
2717 
2718 namespace {
2719 /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
2720 ///
2721 /// ```
2722 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
2723 ///     : tensor<?x?xf32> to tensor<?x?xf32>
2724 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
2725 ///     : tensor<?x?xf32>, vector<4x5xf32>
2726 /// ```
2727 /// is rewritten to:
2728 /// ```
2729 /// %p0 = addi %a, %e : index
2730 /// %p1 = addi %b, %f : index
2731 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
2732 ///     : tensor<?x?xf32>, vector<4x5xf32>
2733 /// ```
2734 struct FoldExtractSliceIntoTransferRead
2735     : public OpRewritePattern<TransferReadOp> {
2736 public:
2737   using OpRewritePattern<TransferReadOp>::OpRewritePattern;
2738 
matchAndRewrite__anona11d586e1811::FoldExtractSliceIntoTransferRead2739   LogicalResult matchAndRewrite(TransferReadOp xferOp,
2740                                 PatternRewriter &rewriter) const override {
2741     if (xferOp.hasOutOfBoundsDim())
2742       return failure();
2743     if (!xferOp.permutation_map().isIdentity())
2744       return failure();
2745     if (xferOp.mask())
2746       return failure();
2747     auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
2748     if (!extractOp)
2749       return failure();
2750     if (!extractOp.hasUnitStride())
2751       return failure();
2752 
2753     int64_t rankReduced =
2754         extractOp.getSourceType().getRank() - extractOp.getType().getRank();
2755     SmallVector<Value> newIndices;
2756     // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
2757     // indices first.
2758     for (int64_t i = 0; i < rankReduced; ++i) {
2759       OpFoldResult offset = extractOp.getMixedOffsets()[i];
2760       newIndices.push_back(getValueOrCreateConstantIndexOp(
2761           rewriter, extractOp.getLoc(), offset));
2762     }
2763     for (auto it : llvm::enumerate(xferOp.indices())) {
2764       OpFoldResult offset =
2765           extractOp.getMixedOffsets()[it.index() + rankReduced];
2766       newIndices.push_back(
2767           rewriter.create<AddIOp>(xferOp->getLoc(), it.value(),
2768                                   getValueOrCreateConstantIndexOp(
2769                                       rewriter, extractOp.getLoc(), offset)));
2770     }
2771     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
2772     rewriter.replaceOpWithNewOp<TransferReadOp>(xferOp, xferOp.getVectorType(),
2773                                                 extractOp.source(), newIndices,
2774                                                 xferOp.padding(), inBounds);
2775 
2776     return success();
2777   }
2778 };
2779 } // namespace
2780 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2781 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2782                                                  MLIRContext *context) {
2783   results.add<FoldExtractSliceIntoTransferRead>(context);
2784 }
2785 
2786 //===----------------------------------------------------------------------===//
2787 // TransferWriteOp
2788 //===----------------------------------------------------------------------===//
2789 
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> inBounds)2790 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2791                             Value vector, Value dest, ValueRange indices,
2792                             AffineMap permutationMap, ArrayRef<bool> inBounds) {
2793   if (inBounds.empty())
2794     return build(builder, result, vector, dest, indices, permutationMap,
2795                  /*mask=*/Value(), ArrayAttr());
2796   build(builder, result, vector, dest, indices, permutationMap,
2797         /*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
2798 }
2799 
2800 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,ArrayRef<bool> inBounds)2801 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2802                             Value vector, Value source, ValueRange indices,
2803                             ArrayRef<bool> inBounds) {
2804   auto vectorType = vector.getType().cast<VectorType>();
2805   auto permMap = getTransferMinorIdentityMap(
2806       source.getType().cast<ShapedType>(), vectorType);
2807   if (inBounds.empty())
2808     return build(builder, result, vector, source, indices, permMap,
2809                  ArrayAttr());
2810   ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2811   build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
2812 }
2813 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMapAttr permutationMap,ArrayAttr inBounds)2814 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2815                             Value vector, Value source, ValueRange indices,
2816                             AffineMapAttr permutationMap,
2817                             /*optional*/ ArrayAttr inBounds) {
2818   Type resultType = source.getType().dyn_cast<RankedTensorType>();
2819   build(builder, result, resultType, vector, source, indices, permutationMap,
2820         /*mask=*/Value(), inBounds);
2821 }
2822 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,ArrayAttr inBounds)2823 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2824                             Value vector, Value source, ValueRange indices,
2825                             AffineMap permutationMap,
2826                             /*optional*/ ArrayAttr inBounds) {
2827   Type resultType = source.getType().dyn_cast<RankedTensorType>();
2828   build(builder, result, resultType, vector, source, indices, permutationMap,
2829         /*mask=*/Value(), inBounds);
2830 }
2831 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,Value mask,ArrayAttr inBounds)2832 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2833                             Value vector, Value source, ValueRange indices,
2834                             AffineMap permutationMap, /*optional*/ Value mask,
2835                             /*optional*/ ArrayAttr inBounds) {
2836   Type resultType = source.getType().dyn_cast<RankedTensorType>();
2837   build(builder, result, resultType, vector, source, indices, permutationMap,
2838         mask, inBounds);
2839 }
2840 
createScalarOp(OpBuilder & builder,Location loc,Value value,Value dest,ValueRange indices,ArrayRef<bool> inBounds)2841 Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
2842                                            Value value, Value dest,
2843                                            ValueRange indices,
2844                                            ArrayRef<bool> inBounds) {
2845   Value vectorOfAScalar = value;
2846   if (!value.getType().isa<VectorType>())
2847     vectorOfAScalar = builder.create<vector::BroadcastOp>(
2848         loc, VectorType::get({1}, value.getType()), value);
2849   AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
2850                                  getAffineConstantExpr(0, loc.getContext()));
2851   return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
2852                                                  indices, map, inBounds);
2853 }
2854 
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)2855 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
2856                                         OperationState &result) {
2857   auto &builder = parser.getBuilder();
2858   llvm::SMLoc typesLoc;
2859   OpAsmParser::OperandType vectorInfo, sourceInfo;
2860   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2861   SmallVector<Type, 2> types;
2862   OpAsmParser::OperandType maskInfo;
2863   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
2864       parser.parseOperand(sourceInfo) ||
2865       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
2866     return failure();
2867   ParseResult hasMask = parser.parseOptionalComma();
2868   if (hasMask.succeeded() && parser.parseOperand(maskInfo))
2869     return failure();
2870   if (parser.parseOptionalAttrDict(result.attributes) ||
2871       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2872     return failure();
2873   if (types.size() != 2)
2874     return parser.emitError(typesLoc, "requires two types");
2875   auto indexType = builder.getIndexType();
2876   VectorType vectorType = types[0].dyn_cast<VectorType>();
2877   if (!vectorType)
2878     return parser.emitError(typesLoc, "requires vector type");
2879   ShapedType shapedType = types[1].dyn_cast<ShapedType>();
2880   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2881     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2882   auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
2883   auto attr = result.attributes.get(permutationAttrName);
2884   if (!attr) {
2885     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2886     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2887   }
2888   if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
2889       parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2890       parser.resolveOperands(indexInfo, indexType, result.operands))
2891     return failure();
2892   if (hasMask.succeeded()) {
2893     if (shapedType.getElementType().dyn_cast<VectorType>())
2894       return parser.emitError(
2895           maskInfo.location, "does not support masks with vector element type");
2896     auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
2897     if (parser.resolveOperand(maskInfo, maskType, result.operands))
2898       return failure();
2899   }
2900   result.addAttribute(
2901       TransferWriteOp::getOperandSegmentSizeAttr(),
2902       builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
2903                                 static_cast<int32_t>(hasMask.succeeded())}));
2904   return failure(shapedType.isa<RankedTensorType>() &&
2905                  parser.addTypeToList(shapedType, result.types));
2906 }
2907 
print(OpAsmPrinter & p,TransferWriteOp op)2908 static void print(OpAsmPrinter &p, TransferWriteOp op) {
2909   p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]";
2910   if (op.mask())
2911     p << ", " << op.mask();
2912   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2913   p << " : " << op.getVectorType() << ", " << op.getShapedType();
2914 }
2915 
verify(TransferWriteOp op)2916 static LogicalResult verify(TransferWriteOp op) {
2917   // Consistency of elemental types in shape and vector.
2918   ShapedType shapedType = op.getShapedType();
2919   VectorType vectorType = op.getVectorType();
2920   VectorType maskType = op.getMaskType();
2921   auto permutationMap = op.permutation_map();
2922 
2923   if (llvm::size(op.indices()) != shapedType.getRank())
2924     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2925 
2926   // We do not allow broadcast dimensions on TransferWriteOps for the moment,
2927   // as the semantics is unclear. This can be revisited later if necessary.
2928   if (op.hasBroadcastDim())
2929     return op.emitOpError("should not have broadcast dimensions");
2930 
2931   if (failed(
2932           verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
2933                            shapedType, vectorType, maskType, permutationMap,
2934                            op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2935     return failure();
2936 
2937   return verifyPermutationMap(permutationMap,
2938                               [&op](Twine t) { return op.emitOpError(t); });
2939 }
2940 
2941 /// Fold:
2942 /// ```
2943 ///    %t1 = ...
2944 ///    %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
2945 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
2946 ///    %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
2947 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
2948 /// ```
2949 ///
2950 /// into:
2951 ///
2952 /// ```
2953 ///    %t0
2954 /// ```
2955 ///
2956 /// The producer of t1 may or may not be DCE'd depending on whether it is a
2957 /// block argument or has side effects.
foldReadInitWrite(TransferWriteOp write,ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)2958 static LogicalResult foldReadInitWrite(TransferWriteOp write,
2959                                        ArrayRef<Attribute>,
2960                                        SmallVectorImpl<OpFoldResult> &results) {
2961   auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
2962   // If not operating on tensors, bail.
2963   if (!rankedTensorType)
2964     return failure();
2965   // If no read, bail.
2966   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
2967   if (!read)
2968     return failure();
2969   // For now, only accept minor identity. Future: composition is minor identity.
2970   if (!read.permutation_map().isMinorIdentity() ||
2971       !write.permutation_map().isMinorIdentity())
2972     return failure();
2973   // Bail on mismatching ranks.
2974   if (read.getTransferRank() != write.getTransferRank())
2975     return failure();
2976   // Bail on potential out-of-bounds accesses.
2977   if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
2978     return failure();
2979   // Tensor types must be the same.
2980   if (read.source().getType() != rankedTensorType)
2981     return failure();
2982   // Vector types must be the same.
2983   if (read.getVectorType() != write.getVectorType())
2984     return failure();
2985   // Vector and Tensor shapes must match.
2986   if (read.getVectorType().getShape() != rankedTensorType.getShape())
2987     return failure();
2988   // If any index is nonzero.
2989   auto isNotConstantZero = [](Value v) {
2990     auto cstOp = v.getDefiningOp<ConstantIndexOp>();
2991     return !cstOp || cstOp.getValue() != 0;
2992   };
2993   if (llvm::any_of(read.indices(), isNotConstantZero) ||
2994       llvm::any_of(write.indices(), isNotConstantZero))
2995     return failure();
2996   // Success.
2997   results.push_back(read.source());
2998   return success();
2999 }
3000 
checkSameValueWAR(vector::TransferReadOp read,vector::TransferWriteOp write)3001 static bool checkSameValueWAR(vector::TransferReadOp read,
3002                               vector::TransferWriteOp write) {
3003   return read.source() == write.source() && read.indices() == write.indices() &&
3004          read.permutation_map() == write.permutation_map() &&
3005          read.getVectorType() == write.getVectorType() && !read.mask() &&
3006          !write.mask();
3007 }
3008 /// Fold transfer_write write after read:
3009 /// ```
3010 ///    %t0 = ...
3011 ///    %v = vector.transfer_read %t0[%c0...] :
3012 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
3013 ///    %t1 = vector.transfer_write %v, %t0[%c0...] :
3014 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
3015 /// ```
3016 ///
3017 /// into:
3018 ///
3019 /// ```
3020 ///    %t0
3021 /// ```
foldWAR(TransferWriteOp write,SmallVectorImpl<OpFoldResult> & results)3022 static LogicalResult foldWAR(TransferWriteOp write,
3023                              SmallVectorImpl<OpFoldResult> &results) {
3024   if (!write.source().getType().isa<RankedTensorType>())
3025     return failure();
3026   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
3027   if (!read)
3028     return failure();
3029 
3030   if (!checkSameValueWAR(read, write))
3031     return failure();
3032   results.push_back(read.source());
3033   return success();
3034 }
3035 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3036 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3037                                     SmallVectorImpl<OpFoldResult> &results) {
3038   if (succeeded(foldReadInitWrite(*this, operands, results)))
3039     return success();
3040   if (succeeded(foldWAR(*this, results)))
3041     return success();
3042   if (succeeded(foldTransferInBoundsAttribute(*this)))
3043     return success();
3044   return foldMemRefCast(*this);
3045 }
3046 
getShapeForUnroll()3047 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
3048   return llvm::to_vector<4>(getVectorType().getShape());
3049 }
3050 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)3051 void TransferWriteOp::getEffects(
3052     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3053         &effects) {
3054   if (getShapedType().isa<MemRefType>())
3055     effects.emplace_back(MemoryEffects::Write::get(), source(),
3056                          SideEffects::DefaultResource::get());
3057 }
3058 
3059 namespace {
3060 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
3061 /// DCE
3062 /// ```
3063 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3064 ///    : vector<1x4xf32>, tensor<4x4xf32>
3065 ///  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
3066 ///    : vector<1x4xf32>, tensor<4x4xf32>
3067 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3068 ///    : vector<1x4xf32>, tensor<4x4xf32>
3069 /// ```
3070 ///
3071 /// into:
3072 ///
3073 /// ```
3074 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3075 ///    : vector<1x4xf32>, tensor<4x4xf32>
3076 ///  %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
3077 ///    : vector<1x4xf32>, tensor<4x4xf32>
3078 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3079 ///    : vector<1x4xf32>, tensor<4x4xf32>
3080 /// ```
3081 ///
3082 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
3083 /// any other uses.
3084 class foldWAW final : public OpRewritePattern<TransferWriteOp> {
3085 public:
3086   using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
matchAndRewrite(TransferWriteOp writeOp,PatternRewriter & rewriter) const3087   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
3088                                 PatternRewriter &rewriter) const override {
3089     if (!writeOp.getShapedType().isa<RankedTensorType>())
3090       return failure();
3091     vector::TransferWriteOp writeToModify = writeOp;
3092 
3093     auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
3094     while (defWrite) {
3095       if (checkSameValueWAW(writeOp, defWrite)) {
3096         writeToModify.sourceMutable().assign(defWrite.source());
3097         return success();
3098       }
3099       if (!isDisjointTransferIndices(
3100               cast<VectorTransferOpInterface>(defWrite.getOperation()),
3101               cast<VectorTransferOpInterface>(writeOp.getOperation())))
3102         break;
3103       // If the previous write op doesn't have any other use we an safely look
3104       // at the previous store to see if it can be removed.
3105       if (!defWrite->hasOneUse())
3106         break;
3107       writeToModify = defWrite;
3108       defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
3109     }
3110     return failure();
3111   }
3112 };
3113 
3114 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
3115 /// could directly write to the insert_slice's destination. E.g.:
3116 ///
3117 /// ```
3118 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
3119 ///     : vector<4x5xf32>, tensor<4x5xf32>
3120 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
3121 ///     : tensor<4x5xf32> into tensor<?x?xf32>
3122 /// ```
3123 /// is rewritten to:
3124 /// ```
3125 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
3126 ///     : vector<4x5xf32>, tensor<?x?xf32>
3127 /// ```
3128 struct FoldInsertSliceIntoTransferWrite
3129     : public OpRewritePattern<tensor::InsertSliceOp> {
3130 public:
3131   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
3132 
matchAndRewrite__anona11d586e1b11::FoldInsertSliceIntoTransferWrite3133   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3134                                 PatternRewriter &rewriter) const override {
3135     if (!insertOp.hasUnitStride())
3136       return failure();
3137     auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
3138     if (!xferOp)
3139       return failure();
3140     if (xferOp.hasOutOfBoundsDim())
3141       return failure();
3142     if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3143       return failure();
3144     if (xferOp.mask())
3145       return failure();
3146     // Fold only if the TransferWriteOp completely overwrites the `source` with
3147     // a vector. I.e., the result of the TransferWriteOp is a new tensor who's
3148     // content is the data of the vector.
3149     if (!llvm::equal(xferOp.getVectorType().getShape(),
3150                      xferOp.getShapedType().getShape()))
3151       return failure();
3152     if (!xferOp.permutation_map().isIdentity())
3153       return failure();
3154 
3155     SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
3156         rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3157     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3158     rewriter.replaceOpWithNewOp<TransferWriteOp>(
3159         insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds);
3160     return success();
3161   }
3162 };
3163 } // namespace
3164 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3165 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
3166                                                   MLIRContext *context) {
3167   results.add<foldWAW, FoldInsertSliceIntoTransferWrite>(context);
3168 }
3169 
3170 //===----------------------------------------------------------------------===//
3171 // LoadOp
3172 //===----------------------------------------------------------------------===//
3173 
verifyLoadStoreMemRefLayout(Operation * op,MemRefType memRefTy)3174 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
3175                                                  MemRefType memRefTy) {
3176   if (!isLastMemrefDimUnitStride(memRefTy))
3177     return op->emitOpError("most minor memref dim must have unit stride");
3178   return success();
3179 }
3180 
verify(vector::LoadOp op)3181 static LogicalResult verify(vector::LoadOp op) {
3182   VectorType resVecTy = op.getVectorType();
3183   MemRefType memRefTy = op.getMemRefType();
3184 
3185   if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3186     return failure();
3187 
3188   // Checks for vector memrefs.
3189   Type memElemTy = memRefTy.getElementType();
3190   if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3191     if (memVecTy != resVecTy)
3192       return op.emitOpError("base memref and result vector types should match");
3193     memElemTy = memVecTy.getElementType();
3194   }
3195 
3196   if (resVecTy.getElementType() != memElemTy)
3197     return op.emitOpError("base and result element types should match");
3198   if (llvm::size(op.indices()) != memRefTy.getRank())
3199     return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3200   return success();
3201 }
3202 
fold(ArrayRef<Attribute>)3203 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
3204   if (succeeded(foldMemRefCast(*this)))
3205     return getResult();
3206   return OpFoldResult();
3207 }
3208 
3209 //===----------------------------------------------------------------------===//
3210 // StoreOp
3211 //===----------------------------------------------------------------------===//
3212 
verify(vector::StoreOp op)3213 static LogicalResult verify(vector::StoreOp op) {
3214   VectorType valueVecTy = op.getVectorType();
3215   MemRefType memRefTy = op.getMemRefType();
3216 
3217   if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3218     return failure();
3219 
3220   // Checks for vector memrefs.
3221   Type memElemTy = memRefTy.getElementType();
3222   if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3223     if (memVecTy != valueVecTy)
3224       return op.emitOpError(
3225           "base memref and valueToStore vector types should match");
3226     memElemTy = memVecTy.getElementType();
3227   }
3228 
3229   if (valueVecTy.getElementType() != memElemTy)
3230     return op.emitOpError("base and valueToStore element type should match");
3231   if (llvm::size(op.indices()) != memRefTy.getRank())
3232     return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3233   return success();
3234 }
3235 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3236 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3237                             SmallVectorImpl<OpFoldResult> &results) {
3238   return foldMemRefCast(*this);
3239 }
3240 
3241 //===----------------------------------------------------------------------===//
3242 // MaskedLoadOp
3243 //===----------------------------------------------------------------------===//
3244 
verify(MaskedLoadOp op)3245 static LogicalResult verify(MaskedLoadOp op) {
3246   VectorType maskVType = op.getMaskVectorType();
3247   VectorType passVType = op.getPassThruVectorType();
3248   VectorType resVType = op.getVectorType();
3249   MemRefType memType = op.getMemRefType();
3250 
3251   if (resVType.getElementType() != memType.getElementType())
3252     return op.emitOpError("base and result element type should match");
3253   if (llvm::size(op.indices()) != memType.getRank())
3254     return op.emitOpError("requires ") << memType.getRank() << " indices";
3255   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3256     return op.emitOpError("expected result dim to match mask dim");
3257   if (resVType != passVType)
3258     return op.emitOpError("expected pass_thru of same type as result type");
3259   return success();
3260 }
3261 
3262 namespace {
3263 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3264 public:
3265   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const3266   LogicalResult matchAndRewrite(MaskedLoadOp load,
3267                                 PatternRewriter &rewriter) const override {
3268     switch (get1DMaskFormat(load.mask())) {
3269     case MaskFormat::AllTrue:
3270       rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
3271                                                   load.base(), load.indices());
3272       return success();
3273     case MaskFormat::AllFalse:
3274       rewriter.replaceOp(load, load.pass_thru());
3275       return success();
3276     case MaskFormat::Unknown:
3277       return failure();
3278     }
3279     llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3280   }
3281 };
3282 } // namespace
3283 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3284 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3285                                                MLIRContext *context) {
3286   results.add<MaskedLoadFolder>(context);
3287 }
3288 
fold(ArrayRef<Attribute>)3289 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3290   if (succeeded(foldMemRefCast(*this)))
3291     return getResult();
3292   return OpFoldResult();
3293 }
3294 
3295 //===----------------------------------------------------------------------===//
3296 // MaskedStoreOp
3297 //===----------------------------------------------------------------------===//
3298 
verify(MaskedStoreOp op)3299 static LogicalResult verify(MaskedStoreOp op) {
3300   VectorType maskVType = op.getMaskVectorType();
3301   VectorType valueVType = op.getVectorType();
3302   MemRefType memType = op.getMemRefType();
3303 
3304   if (valueVType.getElementType() != memType.getElementType())
3305     return op.emitOpError("base and valueToStore element type should match");
3306   if (llvm::size(op.indices()) != memType.getRank())
3307     return op.emitOpError("requires ") << memType.getRank() << " indices";
3308   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3309     return op.emitOpError("expected valueToStore dim to match mask dim");
3310   return success();
3311 }
3312 
3313 namespace {
3314 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
3315 public:
3316   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const3317   LogicalResult matchAndRewrite(MaskedStoreOp store,
3318                                 PatternRewriter &rewriter) const override {
3319     switch (get1DMaskFormat(store.mask())) {
3320     case MaskFormat::AllTrue:
3321       rewriter.replaceOpWithNewOp<vector::StoreOp>(
3322           store, store.valueToStore(), store.base(), store.indices());
3323       return success();
3324     case MaskFormat::AllFalse:
3325       rewriter.eraseOp(store);
3326       return success();
3327     case MaskFormat::Unknown:
3328       return failure();
3329     }
3330     llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
3331   }
3332 };
3333 } // namespace
3334 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3335 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3336                                                 MLIRContext *context) {
3337   results.add<MaskedStoreFolder>(context);
3338 }
3339 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3340 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
3341                                   SmallVectorImpl<OpFoldResult> &results) {
3342   return foldMemRefCast(*this);
3343 }
3344 
3345 //===----------------------------------------------------------------------===//
3346 // GatherOp
3347 //===----------------------------------------------------------------------===//
3348 
verify(GatherOp op)3349 static LogicalResult verify(GatherOp op) {
3350   VectorType indVType = op.getIndexVectorType();
3351   VectorType maskVType = op.getMaskVectorType();
3352   VectorType resVType = op.getVectorType();
3353   MemRefType memType = op.getMemRefType();
3354 
3355   if (resVType.getElementType() != memType.getElementType())
3356     return op.emitOpError("base and result element type should match");
3357   if (llvm::size(op.indices()) != memType.getRank())
3358     return op.emitOpError("requires ") << memType.getRank() << " indices";
3359   if (resVType.getDimSize(0) != indVType.getDimSize(0))
3360     return op.emitOpError("expected result dim to match indices dim");
3361   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3362     return op.emitOpError("expected result dim to match mask dim");
3363   if (resVType != op.getPassThruVectorType())
3364     return op.emitOpError("expected pass_thru of same type as result type");
3365   return success();
3366 }
3367 
3368 namespace {
3369 class GatherFolder final : public OpRewritePattern<GatherOp> {
3370 public:
3371   using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const3372   LogicalResult matchAndRewrite(GatherOp gather,
3373                                 PatternRewriter &rewriter) const override {
3374     switch (get1DMaskFormat(gather.mask())) {
3375     case MaskFormat::AllTrue:
3376       return failure(); // no unmasked equivalent
3377     case MaskFormat::AllFalse:
3378       rewriter.replaceOp(gather, gather.pass_thru());
3379       return success();
3380     case MaskFormat::Unknown:
3381       return failure();
3382     }
3383     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
3384   }
3385 };
3386 } // namespace
3387 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3388 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
3389                                            MLIRContext *context) {
3390   results.add<GatherFolder>(context);
3391 }
3392 
3393 //===----------------------------------------------------------------------===//
3394 // ScatterOp
3395 //===----------------------------------------------------------------------===//
3396 
verify(ScatterOp op)3397 static LogicalResult verify(ScatterOp op) {
3398   VectorType indVType = op.getIndexVectorType();
3399   VectorType maskVType = op.getMaskVectorType();
3400   VectorType valueVType = op.getVectorType();
3401   MemRefType memType = op.getMemRefType();
3402 
3403   if (valueVType.getElementType() != memType.getElementType())
3404     return op.emitOpError("base and valueToStore element type should match");
3405   if (llvm::size(op.indices()) != memType.getRank())
3406     return op.emitOpError("requires ") << memType.getRank() << " indices";
3407   if (valueVType.getDimSize(0) != indVType.getDimSize(0))
3408     return op.emitOpError("expected valueToStore dim to match indices dim");
3409   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3410     return op.emitOpError("expected valueToStore dim to match mask dim");
3411   return success();
3412 }
3413 
3414 namespace {
3415 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
3416 public:
3417   using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const3418   LogicalResult matchAndRewrite(ScatterOp scatter,
3419                                 PatternRewriter &rewriter) const override {
3420     switch (get1DMaskFormat(scatter.mask())) {
3421     case MaskFormat::AllTrue:
3422       return failure(); // no unmasked equivalent
3423     case MaskFormat::AllFalse:
3424       rewriter.eraseOp(scatter);
3425       return success();
3426     case MaskFormat::Unknown:
3427       return failure();
3428     }
3429     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
3430   }
3431 };
3432 } // namespace
3433 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3434 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
3435                                             MLIRContext *context) {
3436   results.add<ScatterFolder>(context);
3437 }
3438 
3439 //===----------------------------------------------------------------------===//
3440 // ExpandLoadOp
3441 //===----------------------------------------------------------------------===//
3442 
verify(ExpandLoadOp op)3443 static LogicalResult verify(ExpandLoadOp op) {
3444   VectorType maskVType = op.getMaskVectorType();
3445   VectorType passVType = op.getPassThruVectorType();
3446   VectorType resVType = op.getVectorType();
3447   MemRefType memType = op.getMemRefType();
3448 
3449   if (resVType.getElementType() != memType.getElementType())
3450     return op.emitOpError("base and result element type should match");
3451   if (llvm::size(op.indices()) != memType.getRank())
3452     return op.emitOpError("requires ") << memType.getRank() << " indices";
3453   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3454     return op.emitOpError("expected result dim to match mask dim");
3455   if (resVType != passVType)
3456     return op.emitOpError("expected pass_thru of same type as result type");
3457   return success();
3458 }
3459 
3460 namespace {
3461 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
3462 public:
3463   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const3464   LogicalResult matchAndRewrite(ExpandLoadOp expand,
3465                                 PatternRewriter &rewriter) const override {
3466     switch (get1DMaskFormat(expand.mask())) {
3467     case MaskFormat::AllTrue:
3468       rewriter.replaceOpWithNewOp<vector::LoadOp>(
3469           expand, expand.getType(), expand.base(), expand.indices());
3470       return success();
3471     case MaskFormat::AllFalse:
3472       rewriter.replaceOp(expand, expand.pass_thru());
3473       return success();
3474     case MaskFormat::Unknown:
3475       return failure();
3476     }
3477     llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
3478   }
3479 };
3480 } // namespace
3481 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3482 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3483                                                MLIRContext *context) {
3484   results.add<ExpandLoadFolder>(context);
3485 }
3486 
3487 //===----------------------------------------------------------------------===//
3488 // CompressStoreOp
3489 //===----------------------------------------------------------------------===//
3490 
verify(CompressStoreOp op)3491 static LogicalResult verify(CompressStoreOp op) {
3492   VectorType maskVType = op.getMaskVectorType();
3493   VectorType valueVType = op.getVectorType();
3494   MemRefType memType = op.getMemRefType();
3495 
3496   if (valueVType.getElementType() != memType.getElementType())
3497     return op.emitOpError("base and valueToStore element type should match");
3498   if (llvm::size(op.indices()) != memType.getRank())
3499     return op.emitOpError("requires ") << memType.getRank() << " indices";
3500   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3501     return op.emitOpError("expected valueToStore dim to match mask dim");
3502   return success();
3503 }
3504 
3505 namespace {
3506 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
3507 public:
3508   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const3509   LogicalResult matchAndRewrite(CompressStoreOp compress,
3510                                 PatternRewriter &rewriter) const override {
3511     switch (get1DMaskFormat(compress.mask())) {
3512     case MaskFormat::AllTrue:
3513       rewriter.replaceOpWithNewOp<vector::StoreOp>(
3514           compress, compress.valueToStore(), compress.base(),
3515           compress.indices());
3516       return success();
3517     case MaskFormat::AllFalse:
3518       rewriter.eraseOp(compress);
3519       return success();
3520     case MaskFormat::Unknown:
3521       return failure();
3522     }
3523     llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
3524   }
3525 };
3526 } // namespace
3527 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3528 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3529                                                   MLIRContext *context) {
3530   results.add<CompressStoreFolder>(context);
3531 }
3532 
3533 //===----------------------------------------------------------------------===//
3534 // ShapeCastOp
3535 //===----------------------------------------------------------------------===//
3536 
3537 /// Returns true if each element of 'a' is equal to the product of a contiguous
3538 /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)3539 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3540   unsigned rankA = a.size();
3541   unsigned rankB = b.size();
3542   assert(rankA < rankB);
3543 
3544   unsigned i = 0;
3545   unsigned j = 0;
3546   while (i < rankA && j < rankB) {
3547     int64_t dimA = a[i];
3548     int64_t dimB = 1;
3549     while (dimB < dimA && j < rankB)
3550       dimB *= b[j++];
3551     if (dimA != dimB)
3552       break;
3553     ++i;
3554 
3555     // Handle the case when trailing dimensions are of size 1.
3556     // Include them into the contiguous sequence.
3557     auto isOne = [](int64_t v) { return v == 1; };
3558     if (i < rankA && llvm::all_of(a.slice(i), isOne))
3559       i = rankA;
3560     if (j < rankB && llvm::all_of(b.slice(j), isOne))
3561       j = rankB;
3562   }
3563 
3564   return i == rankA && j == rankB;
3565 }
3566 
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)3567 static LogicalResult verifyVectorShapeCast(Operation *op,
3568                                            VectorType sourceVectorType,
3569                                            VectorType resultVectorType) {
3570   // Check that element type is the same.
3571   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
3572     return op->emitOpError("source/result vectors must have same element type");
3573   auto sourceShape = sourceVectorType.getShape();
3574   auto resultShape = resultVectorType.getShape();
3575 
3576   // Check that product of source dim sizes matches product of result dim sizes.
3577   int64_t sourceDimProduct = std::accumulate(
3578       sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
3579   int64_t resultDimProduct = std::accumulate(
3580       resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
3581   if (sourceDimProduct != resultDimProduct)
3582     return op->emitOpError("source/result number of elements must match");
3583 
3584   // Check that expanding/contracting rank cases.
3585   unsigned sourceRank = sourceVectorType.getRank();
3586   unsigned resultRank = resultVectorType.getRank();
3587   if (sourceRank < resultRank) {
3588     if (!isValidShapeCast(sourceShape, resultShape))
3589       return op->emitOpError("invalid shape cast");
3590   } else if (sourceRank > resultRank) {
3591     if (!isValidShapeCast(resultShape, sourceShape))
3592       return op->emitOpError("invalid shape cast");
3593   }
3594   return success();
3595 }
3596 
verify(ShapeCastOp op)3597 static LogicalResult verify(ShapeCastOp op) {
3598   auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
3599   auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
3600 
3601   // Check if source/result are of vector type.
3602   if (sourceVectorType && resultVectorType)
3603     return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
3604 
3605   return success();
3606 }
3607 
fold(ArrayRef<Attribute> operands)3608 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
3609   // Nop shape cast.
3610   if (source().getType() == result().getType())
3611     return source();
3612 
3613   // Canceling shape casts.
3614   if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
3615     if (result().getType() == otherOp.source().getType())
3616       return otherOp.source();
3617     setOperand(otherOp.source());
3618     return getResult();
3619   }
3620   return {};
3621 }
3622 
3623 namespace {
3624 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
3625 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
3626 public:
3627   using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
3628 
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const3629   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
3630                                 PatternRewriter &rewriter) const override {
3631     auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
3632     if (!constantOp)
3633       return failure();
3634     // Only handle splat for now.
3635     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
3636     if (!dense)
3637       return failure();
3638     auto newAttr = DenseElementsAttr::get(
3639         shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
3640     rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
3641     return success();
3642   }
3643 };
3644 
3645 } // namespace
3646 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3647 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3648                                               MLIRContext *context) {
3649   // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
3650   results.add<ShapeCastConstantFolder>(context);
3651 }
3652 
3653 //===----------------------------------------------------------------------===//
3654 // VectorBitCastOp
3655 //===----------------------------------------------------------------------===//
3656 
verify(BitCastOp op)3657 static LogicalResult verify(BitCastOp op) {
3658   auto sourceVectorType = op.getSourceVectorType();
3659   auto resultVectorType = op.getResultVectorType();
3660 
3661   for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
3662     if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
3663       return op.emitOpError("dimension size mismatch at: ") << i;
3664   }
3665 
3666   DataLayout dataLayout = DataLayout::closest(op);
3667   if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
3668           sourceVectorType.getShape().back() !=
3669       dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
3670           resultVectorType.getShape().back())
3671     return op.emitOpError(
3672         "source/result bitwidth of the minor 1-D vectors must be equal");
3673 
3674   return success();
3675 }
3676 
fold(ArrayRef<Attribute> operands)3677 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
3678   // Nop cast.
3679   if (source().getType() == result().getType())
3680     return source();
3681 
3682   // Canceling bitcasts.
3683   if (auto otherOp = source().getDefiningOp<BitCastOp>())
3684     if (result().getType() == otherOp.source().getType())
3685       return otherOp.source();
3686 
3687   Attribute sourceConstant = operands.front();
3688   if (!sourceConstant)
3689     return {};
3690 
3691   Type srcElemType = getSourceVectorType().getElementType();
3692   Type dstElemType = getResultVectorType().getElementType();
3693 
3694   if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
3695     if (floatPack.isSplat()) {
3696       auto splat = floatPack.getSplatValue<FloatAttr>();
3697 
3698       // Casting fp16 into fp32.
3699       if (srcElemType.isF16() && dstElemType.isF32()) {
3700         uint32_t bits = static_cast<uint32_t>(
3701             splat.getValue().bitcastToAPInt().getZExtValue());
3702         // Duplicate the 16-bit pattern.
3703         bits = (bits << 16) | (bits & 0xffff);
3704         APInt intBits(32, bits);
3705         APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
3706         return DenseElementsAttr::get(getResultVectorType(), floatBits);
3707       }
3708     }
3709   }
3710 
3711   return {};
3712 }
3713 
3714 //===----------------------------------------------------------------------===//
3715 // TypeCastOp
3716 //===----------------------------------------------------------------------===//
3717 
extractShape(MemRefType memRefType)3718 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
3719   auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
3720   SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
3721                               memRefType.getShape().end());
3722   if (vectorType)
3723     res.append(vectorType.getShape().begin(), vectorType.getShape().end());
3724   return res;
3725 }
3726 
3727 /// Build the canonical memRefType with a single vector.
3728 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)3729 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
3730                        Value source) {
3731   result.addOperands(source);
3732   MemRefType memRefType = source.getType().cast<MemRefType>();
3733   VectorType vectorType =
3734       VectorType::get(extractShape(memRefType),
3735                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
3736   result.addTypes(
3737       MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
3738 }
3739 
verify(TypeCastOp op)3740 static LogicalResult verify(TypeCastOp op) {
3741   MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
3742   if (!canonicalType.getAffineMaps().empty())
3743     return op.emitOpError("expects operand to be a memref with no layout");
3744   if (!op.getResultMemRefType().getAffineMaps().empty())
3745     return op.emitOpError("expects result to be a memref with no layout");
3746   if (op.getResultMemRefType().getMemorySpace() !=
3747       op.getMemRefType().getMemorySpace())
3748     return op.emitOpError("expects result in same memory space");
3749 
3750   auto sourceType = op.getMemRefType();
3751   auto resultType = op.getResultMemRefType();
3752   if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
3753       getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
3754     return op.emitOpError(
3755                "expects result and operand with same underlying scalar type: ")
3756            << resultType;
3757   if (extractShape(sourceType) != extractShape(resultType))
3758     return op.emitOpError(
3759                "expects concatenated result and operand shapes to be equal: ")
3760            << resultType;
3761   return success();
3762 }
3763 
3764 //===----------------------------------------------------------------------===//
3765 // TransposeOp
3766 //===----------------------------------------------------------------------===//
3767 
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)3768 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
3769                                 Value vector, ArrayRef<int64_t> transp) {
3770   VectorType vt = vector.getType().cast<VectorType>();
3771   SmallVector<int64_t, 4> transposedShape(vt.getRank());
3772   for (unsigned i = 0; i < transp.size(); ++i)
3773     transposedShape[i] = vt.getShape()[transp[i]];
3774 
3775   result.addOperands(vector);
3776   result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
3777   result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
3778 }
3779 
3780 // Eliminates transpose operations, which produce values identical to their
3781 // input values. This happens when the dimensions of the input vector remain in
3782 // their original order after the transpose operation.
fold(ArrayRef<Attribute> operands)3783 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
3784   SmallVector<int64_t, 4> transp;
3785   getTransp(transp);
3786 
3787   // Check if the permutation of the dimensions contains sequential values:
3788   // {0, 1, 2, ...}.
3789   for (int64_t i = 0, e = transp.size(); i < e; i++) {
3790     if (transp[i] != i)
3791       return {};
3792   }
3793 
3794   return vector();
3795 }
3796 
verify(vector::TransposeOp op)3797 static LogicalResult verify(vector::TransposeOp op) {
3798   VectorType vectorType = op.getVectorType();
3799   VectorType resultType = op.getResultType();
3800   int64_t rank = resultType.getRank();
3801   if (vectorType.getRank() != rank)
3802     return op.emitOpError("vector result rank mismatch: ") << rank;
3803   // Verify transposition array.
3804   auto transpAttr = op.transp().getValue();
3805   int64_t size = transpAttr.size();
3806   if (rank != size)
3807     return op.emitOpError("transposition length mismatch: ") << size;
3808   SmallVector<bool, 8> seen(rank, false);
3809   for (auto ta : llvm::enumerate(transpAttr)) {
3810     int64_t i = ta.value().cast<IntegerAttr>().getInt();
3811     if (i < 0 || i >= rank)
3812       return op.emitOpError("transposition index out of range: ") << i;
3813     if (seen[i])
3814       return op.emitOpError("duplicate position index: ") << i;
3815     seen[i] = true;
3816     if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
3817       return op.emitOpError("dimension size mismatch at: ") << i;
3818   }
3819   return success();
3820 }
3821 
3822 namespace {
3823 
3824 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
3825 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
3826 public:
3827   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
3828 
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const3829   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
3830                                 PatternRewriter &rewriter) const override {
3831     // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
3832     auto getPermutation = [](vector::TransposeOp transpose) {
3833       SmallVector<int64_t, 4> permutation;
3834       transpose.getTransp(permutation);
3835       return permutation;
3836     };
3837 
3838     // Composes two permutations: result[i] = permutation1[permutation2[i]].
3839     auto composePermutations = [](ArrayRef<int64_t> permutation1,
3840                                   ArrayRef<int64_t> permutation2) {
3841       SmallVector<int64_t, 4> result;
3842       for (auto index : permutation2)
3843         result.push_back(permutation1[index]);
3844       return result;
3845     };
3846 
3847     // Return if the input of 'transposeOp' is not defined by another transpose.
3848     vector::TransposeOp parentTransposeOp =
3849         transposeOp.vector().getDefiningOp<vector::TransposeOp>();
3850     if (!parentTransposeOp)
3851       return failure();
3852 
3853     SmallVector<int64_t, 4> permutation = composePermutations(
3854         getPermutation(parentTransposeOp), getPermutation(transposeOp));
3855     // Replace 'transposeOp' with a new transpose operation.
3856     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
3857         transposeOp, transposeOp.getResult().getType(),
3858         parentTransposeOp.vector(),
3859         vector::getVectorSubscriptAttr(rewriter, permutation));
3860     return success();
3861   }
3862 };
3863 
3864 } // end anonymous namespace
3865 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3866 void vector::TransposeOp::getCanonicalizationPatterns(
3867     RewritePatternSet &results, MLIRContext *context) {
3868   results.add<TransposeFolder>(context);
3869 }
3870 
getTransp(SmallVectorImpl<int64_t> & results)3871 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
3872   populateFromInt64AttrArray(transp(), results);
3873 }
3874 
3875 //===----------------------------------------------------------------------===//
3876 // ConstantMaskOp
3877 //===----------------------------------------------------------------------===//
3878 
verify(ConstantMaskOp & op)3879 static LogicalResult verify(ConstantMaskOp &op) {
3880   // Verify that array attr size matches the rank of the vector result.
3881   auto resultType = op.getResult().getType().cast<VectorType>();
3882   if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
3883     return op.emitOpError(
3884         "must specify array attr of size equal vector result rank");
3885   // Verify that each array attr element is in bounds of corresponding vector
3886   // result dimension size.
3887   auto resultShape = resultType.getShape();
3888   SmallVector<int64_t, 4> maskDimSizes;
3889   for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
3890     int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
3891     if (attrValue < 0 || attrValue > resultShape[it.index()])
3892       return op.emitOpError(
3893           "array attr of size out of bounds of vector result dimension size");
3894     maskDimSizes.push_back(attrValue);
3895   }
3896   // Verify that if one mask dim size is zero, they all should be zero (because
3897   // the mask region is a conjunction of each mask dimension interval).
3898   bool any_zeros = llvm::is_contained(maskDimSizes, 0);
3899   bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
3900   if (any_zeros && !all_zeros)
3901     return op.emitOpError("expected all mask dim sizes to be zeros, "
3902                           "as a result of conjunction with zero mask dim");
3903   return success();
3904 }
3905 
3906 //===----------------------------------------------------------------------===//
3907 // CreateMaskOp
3908 //===----------------------------------------------------------------------===//
3909 
verify(CreateMaskOp op)3910 static LogicalResult verify(CreateMaskOp op) {
3911   // Verify that an operand was specified for each result vector each dimension.
3912   if (op.getNumOperands() !=
3913       op.getResult().getType().cast<VectorType>().getRank())
3914     return op.emitOpError(
3915         "must specify an operand for each result vector dimension");
3916   return success();
3917 }
3918 
3919 namespace {
3920 
3921 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
3922 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
3923 public:
3924   using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
3925 
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const3926   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
3927                                 PatternRewriter &rewriter) const override {
3928     // Return if any of 'createMaskOp' operands are not defined by a constant.
3929     auto is_not_def_by_constant = [](Value operand) {
3930       return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
3931     };
3932     if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
3933       return failure();
3934     // Gather constant mask dimension sizes.
3935     SmallVector<int64_t, 4> maskDimSizes;
3936     for (auto operand : createMaskOp.operands()) {
3937       auto defOp = operand.getDefiningOp();
3938       maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
3939     }
3940     // Replace 'createMaskOp' with ConstantMaskOp.
3941     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3942         createMaskOp, createMaskOp.getResult().getType(),
3943         vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
3944     return success();
3945   }
3946 };
3947 
3948 } // end anonymous namespace
3949 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3950 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
3951                                                MLIRContext *context) {
3952   results.add<CreateMaskFolder>(context);
3953 }
3954 
populateVectorToVectorCanonicalizationPatterns(RewritePatternSet & patterns)3955 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
3956     RewritePatternSet &patterns) {
3957   patterns
3958       .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
3959            ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
3960            StridedSliceConstantMaskFolder, TransposeFolder>(
3961           patterns.getContext());
3962 }
3963 
3964 #define GET_OP_CLASSES
3965 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
3966