1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/VectorUtils.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/DialectImplementation.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Support/MathExtras.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/ADT/bit.h"
33 #include <numeric>
34 
35 #include "mlir/Dialect/Vector/VectorOpsDialect.cpp.inc"
36 // Pull in all enum type and utility function definitions.
37 #include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
38 
39 using namespace mlir;
40 using namespace mlir::vector;
41 
42 /// Helper enum to classify mask value.
43 enum class MaskFormat {
44   AllTrue = 0,
45   AllFalse = 1,
46   Unknown = 2,
47 };
48 
49 /// Helper method to classify a 1-D mask value. Currently, the method
50 /// looks "under the hood" of a constant value with dense attributes
51 /// and a constant mask operation (since the client may be called at
52 /// various stages during progressive lowering).
get1DMaskFormat(Value mask)53 static MaskFormat get1DMaskFormat(Value mask) {
54   if (auto c = mask.getDefiningOp<ConstantOp>()) {
55     // Inspect constant dense values. We count up for bits that
56     // are set, count down for bits that are cleared, and bail
57     // when a mix is detected.
58     if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
59       int64_t val = 0;
60       for (bool b : denseElts.getValues<bool>())
61         if (b && val >= 0)
62           val++;
63         else if (!b && val <= 0)
64           val--;
65         else
66           return MaskFormat::Unknown;
67       if (val > 0)
68         return MaskFormat::AllTrue;
69       if (val < 0)
70         return MaskFormat::AllFalse;
71     }
72   } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
73     // Inspect constant mask index. If the index exceeds the
74     // dimension size, all bits are set. If the index is zero
75     // or less, no bits are set.
76     ArrayAttr masks = m.mask_dim_sizes();
77     assert(masks.size() == 1);
78     int64_t i = masks[0].cast<IntegerAttr>().getInt();
79     int64_t u = m.getType().getDimSize(0);
80     if (i >= u)
81       return MaskFormat::AllTrue;
82     if (i <= 0)
83       return MaskFormat::AllFalse;
84   }
85   return MaskFormat::Unknown;
86 }
87 
88 // Helper for verifying combining kinds in contractions and reductions.
isSupportedCombiningKind(CombiningKind combiningKind,Type elementType)89 static bool isSupportedCombiningKind(CombiningKind combiningKind,
90                                      Type elementType) {
91   switch (combiningKind) {
92   case CombiningKind::ADD:
93   case CombiningKind::MUL:
94   case CombiningKind::MIN:
95   case CombiningKind::MAX:
96     return elementType.isIntOrIndexOrFloat();
97   case CombiningKind::AND:
98   case CombiningKind::OR:
99   case CombiningKind::XOR:
100     return elementType.isIntOrIndex();
101   }
102   return false;
103 }
104 
105 /// Return true if the last dimension of the MemRefType has unit stride. Also
106 /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)107 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
108   int64_t offset;
109   SmallVector<int64_t> strides;
110   auto successStrides = getStridesAndOffset(type, strides, offset);
111   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // CombiningKindAttr
116 //===----------------------------------------------------------------------===//
117 
118 namespace mlir {
119 namespace vector {
120 namespace detail {
121 struct BitmaskEnumStorage : public AttributeStorage {
122   using KeyTy = uint64_t;
123 
BitmaskEnumStoragemlir::vector::detail::BitmaskEnumStorage124   BitmaskEnumStorage(KeyTy val) : value(val) {}
125 
operator ==mlir::vector::detail::BitmaskEnumStorage126   bool operator==(const KeyTy &key) const { return value == key; }
127 
constructmlir::vector::detail::BitmaskEnumStorage128   static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
129                                        const KeyTy &key) {
130     return new (allocator.allocate<BitmaskEnumStorage>())
131         BitmaskEnumStorage(key);
132   }
133 
134   KeyTy value = 0;
135 };
136 } // namespace detail
137 } // namespace vector
138 } // namespace mlir
139 
get(CombiningKind kind,MLIRContext * context)140 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
141                                          MLIRContext *context) {
142   return Base::get(context, static_cast<uint64_t>(kind));
143 }
144 
getKind() const145 CombiningKind CombiningKindAttr::getKind() const {
146   return static_cast<CombiningKind>(getImpl()->value);
147 }
148 
149 static constexpr const CombiningKind combiningKindsList[] = {
150     // clang-format off
151     CombiningKind::ADD,
152     CombiningKind::MUL,
153     CombiningKind::MIN,
154     CombiningKind::MAX,
155     CombiningKind::AND,
156     CombiningKind::OR,
157     CombiningKind::XOR,
158     // clang-format on
159 };
160 
print(DialectAsmPrinter & printer) const161 void CombiningKindAttr::print(DialectAsmPrinter &printer) const {
162   printer << "kind<";
163   auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
164     return bitEnumContains(this->getKind(), kind);
165   });
166   llvm::interleaveComma(kinds, printer,
167                         [&](auto kind) { printer << stringifyEnum(kind); });
168   printer << ">";
169 }
170 
parse(DialectAsmParser & parser)171 Attribute CombiningKindAttr::parse(DialectAsmParser &parser) {
172   if (failed(parser.parseLess()))
173     return {};
174 
175   StringRef elemName;
176   if (failed(parser.parseKeyword(&elemName)))
177     return {};
178 
179   auto kind = symbolizeCombiningKind(elemName);
180   if (!kind) {
181     parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
182         << elemName;
183     return {};
184   }
185 
186   if (failed(parser.parseGreater()))
187     return {};
188 
189   return CombiningKindAttr::get(kind.getValue(),
190                                 parser.getBuilder().getContext());
191 }
192 
parseAttribute(DialectAsmParser & parser,Type type) const193 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
194                                         Type type) const {
195   StringRef attrKind;
196   if (parser.parseKeyword(&attrKind))
197     return {};
198 
199   if (attrKind == "kind")
200     return CombiningKindAttr::parse(parser);
201 
202   parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
203   return {};
204 }
205 
printAttribute(Attribute attr,DialectAsmPrinter & os) const206 void VectorDialect::printAttribute(Attribute attr,
207                                    DialectAsmPrinter &os) const {
208   if (auto ck = attr.dyn_cast<CombiningKindAttr>())
209     ck.print(os);
210   else
211     llvm_unreachable("Unknown attribute type");
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // VectorDialect
216 //===----------------------------------------------------------------------===//
217 
initialize()218 void VectorDialect::initialize() {
219   addAttributes<CombiningKindAttr>();
220 
221   addOperations<
222 #define GET_OP_LIST
223 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
224       >();
225 }
226 
227 /// Materialize a single constant operation from a given attribute value with
228 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)229 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
230                                               Attribute value, Type type,
231                                               Location loc) {
232   return builder.create<ConstantOp>(loc, type, value);
233 }
234 
getVectorSubscriptType(Builder & builder)235 IntegerType vector::getVectorSubscriptType(Builder &builder) {
236   return builder.getIntegerType(64);
237 }
238 
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)239 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
240                                          ArrayRef<int64_t> values) {
241   return builder.getI64ArrayAttr(values);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // MultiDimReductionOp
246 //===----------------------------------------------------------------------===//
247 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<bool> reductionMask,CombiningKind kind)248 void vector::MultiDimReductionOp::build(OpBuilder &builder,
249                                         OperationState &result, Value source,
250                                         ArrayRef<bool> reductionMask,
251                                         CombiningKind kind) {
252   result.addOperands(source);
253   auto sourceVectorType = source.getType().cast<VectorType>();
254   auto targetShape = MultiDimReductionOp::inferDestShape(
255       sourceVectorType.getShape(), reductionMask);
256   auto targetVectorType =
257       VectorType::get(targetShape, sourceVectorType.getElementType());
258   result.addTypes(targetVectorType);
259 
260   SmallVector<int64_t> reductionDims;
261   for (auto en : llvm::enumerate(reductionMask))
262     if (en.value())
263       reductionDims.push_back(en.index());
264   result.addAttribute(getReductionDimsAttrName(),
265                       builder.getI64ArrayAttr(reductionDims));
266   result.addAttribute(getKindAttrName(),
267                       CombiningKindAttr::get(kind, builder.getContext()));
268 }
269 
verify(MultiDimReductionOp op)270 static LogicalResult verify(MultiDimReductionOp op) {
271   auto reductionMask = op.getReductionMask();
272   auto targetShape = MultiDimReductionOp::inferDestShape(
273       op.getSourceVectorType().getShape(), reductionMask);
274   auto targetVectorType =
275       VectorType::get(targetShape, op.getSourceVectorType().getElementType());
276   if (targetVectorType != op.getDestVectorType())
277     return op.emitError("invalid output vector type: ")
278            << op.getDestVectorType() << " (expected: " << targetVectorType
279            << ")";
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // ReductionOp
285 //===----------------------------------------------------------------------===//
286 
verify(ReductionOp op)287 static LogicalResult verify(ReductionOp op) {
288   // Verify for 1-D vector.
289   int64_t rank = op.getVectorType().getRank();
290   if (rank != 1)
291     return op.emitOpError("unsupported reduction rank: ") << rank;
292 
293   // Verify supported reduction kind.
294   auto kind = op.kind();
295   Type eltType = op.dest().getType();
296   if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
297     if (!eltType.isIntOrIndexOrFloat())
298       return op.emitOpError("unsupported reduction type");
299   } else if (kind == "and" || kind == "or" || kind == "xor") {
300     if (!eltType.isIntOrIndex())
301       return op.emitOpError("unsupported reduction type");
302   } else {
303     return op.emitOpError("unknown reduction kind: ") << kind;
304   }
305 
306   // Verify optional accumulator.
307   if (!op.acc().empty()) {
308     if (kind != "add" && kind != "mul")
309       return op.emitOpError("no accumulator for reduction kind: ") << kind;
310     if (!eltType.isa<FloatType>())
311       return op.emitOpError("no accumulator for type: ") << eltType;
312   }
313 
314   return success();
315 }
316 
parseReductionOp(OpAsmParser & parser,OperationState & result)317 static ParseResult parseReductionOp(OpAsmParser &parser,
318                                     OperationState &result) {
319   SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
320   Type redType;
321   Type resType;
322   Attribute attr;
323   if (parser.parseAttribute(attr, "kind", result.attributes) ||
324       parser.parseComma() || parser.parseOperandList(operandsInfo) ||
325       parser.parseColonType(redType) ||
326       parser.parseKeywordType("into", resType) ||
327       (operandsInfo.size() > 0 &&
328        parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
329       (operandsInfo.size() > 1 &&
330        parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
331       parser.addTypeToList(resType, result.types))
332     return failure();
333   if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
334     return parser.emitError(parser.getNameLoc(),
335                             "unsupported number of operands");
336   return success();
337 }
338 
print(OpAsmPrinter & p,ReductionOp op)339 static void print(OpAsmPrinter &p, ReductionOp op) {
340   p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector();
341   if (!op.acc().empty())
342     p << ", " << op.acc();
343   p << " : " << op.vector().getType() << " into " << op.dest().getType();
344 }
345 
getVectorReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value vector)346 Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
347                                          Location loc, Value vector) {
348   Type scalarType = vector.getType().cast<ShapedType>().getElementType();
349   switch (op) {
350   case AtomicRMWKind::addf:
351   case AtomicRMWKind::addi:
352     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
353                                                builder.getStringAttr("add"),
354                                                vector, ValueRange{});
355   case AtomicRMWKind::mulf:
356   case AtomicRMWKind::muli:
357     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
358                                                builder.getStringAttr("mul"),
359                                                vector, ValueRange{});
360   // TODO: Add remaining reduction operations.
361   default:
362     (void)emitOptionalError(loc, "Reduction operation type not supported");
363     break;
364   }
365   return nullptr;
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ContractionOp
370 //===----------------------------------------------------------------------===//
371 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)372 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
373                                   Value lhs, Value rhs, Value acc,
374                                   ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
375                                   ArrayRef<StringRef> iteratorTypes) {
376   result.addOperands({lhs, rhs, acc});
377   result.addTypes(acc.getType());
378   result.addAttribute(getIndexingMapsAttrName(),
379                       builder.getAffineMapArrayAttr(
380                           AffineMap::inferFromExprList(indexingExprs)));
381   result.addAttribute(getIteratorTypesAttrName(),
382                       builder.getStrArrayAttr(iteratorTypes));
383 }
384 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)385 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
386                                   Value lhs, Value rhs, Value acc,
387                                   ArrayAttr indexingMaps,
388                                   ArrayAttr iteratorTypes) {
389   result.addOperands({lhs, rhs, acc});
390   result.addTypes(acc.getType());
391   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
392   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
393   result.addAttribute(ContractionOp::getKindAttrName(),
394                       CombiningKindAttr::get(ContractionOp::getDefaultKind(),
395                                              builder.getContext()));
396 }
397 
parseContractionOp(OpAsmParser & parser,OperationState & result)398 static ParseResult parseContractionOp(OpAsmParser &parser,
399                                       OperationState &result) {
400   OpAsmParser::OperandType lhsInfo;
401   OpAsmParser::OperandType rhsInfo;
402   OpAsmParser::OperandType accInfo;
403   SmallVector<OpAsmParser::OperandType, 2> masksInfo;
404   SmallVector<Type, 2> types;
405   Type resultType;
406   auto loc = parser.getCurrentLocation();
407   DictionaryAttr dictAttr;
408   // TODO: Unify linalg op attribute parsing.
409   if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
410       parser.parseOperand(lhsInfo) || parser.parseComma() ||
411       parser.parseOperand(rhsInfo) || parser.parseComma() ||
412       parser.parseOperand(accInfo) ||
413       parser.parseTrailingOperandList(masksInfo) ||
414       parser.parseOptionalAttrDict(result.attributes) ||
415       parser.parseColonTypeList(types) ||
416       parser.parseKeywordType("into", resultType) ||
417       parser.resolveOperand(lhsInfo, types[0], result.operands) ||
418       parser.resolveOperand(rhsInfo, types[1], result.operands) ||
419       parser.resolveOperand(accInfo, resultType, result.operands) ||
420       parser.addTypeToList(resultType, result.types))
421     return failure();
422   result.attributes.assign(dictAttr.getValue().begin(),
423                            dictAttr.getValue().end());
424   if (!result.attributes.get(ContractionOp::getKindAttrName())) {
425     result.addAttribute(ContractionOp::getKindAttrName(),
426                         CombiningKindAttr::get(ContractionOp::getDefaultKind(),
427                                                result.getContext()));
428   }
429   if (masksInfo.empty())
430     return success();
431   if (masksInfo.size() != 2)
432     return parser.emitError(parser.getNameLoc(),
433                             "expected zero or exactly 2 vector mask operands");
434   auto lhsType = types[0].cast<VectorType>();
435   auto rhsType = types[1].cast<VectorType>();
436   auto maskElementType = parser.getBuilder().getI1Type();
437   std::array<Type, 2> maskTypes = {
438       VectorType::get(lhsType.getShape(), maskElementType),
439       VectorType::get(rhsType.getShape(), maskElementType)};
440   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
441     return failure();
442   return success();
443 }
444 
print(OpAsmPrinter & p,ContractionOp op)445 static void print(OpAsmPrinter &p, ContractionOp op) {
446   // TODO: Unify printing code with linalg ops.
447   auto attrNames = op.getTraitAttrNames();
448   llvm::StringSet<> traitAttrsSet;
449   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
450   SmallVector<NamedAttribute, 8> attrs;
451   for (auto attr : op->getAttrs())
452     if (traitAttrsSet.count(attr.first.strref()) > 0)
453       attrs.push_back(attr);
454 
455   auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
456   p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
457   p << op.rhs() << ", " << op.acc();
458   if (op.masks().size() == 2)
459     p << ", " << op.masks();
460 
461   p.printOptionalAttrDict(op->getAttrs(), attrNames);
462   p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
463     << op.getResultType();
464 }
465 
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)466 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
467                          const std::vector<std::pair<int64_t, int64_t>> &map) {
468   for (auto &dimPair : map) {
469     if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
470         dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
471         lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
472       return false;
473   }
474   return true;
475 }
476 
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)477 static LogicalResult verifyOutputShape(
478     ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
479     Type resType,
480     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
481     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
482   DenseSet<int64_t> lhsContractingDimSet;
483   DenseSet<int64_t> rhsContractingDimSet;
484   for (auto &dimPair : contractingDimMap) {
485     lhsContractingDimSet.insert(dimPair.first);
486     rhsContractingDimSet.insert(dimPair.second);
487   }
488   DenseSet<int64_t> rhsBatchDimSet;
489   for (auto &dimPair : batchDimMap)
490     rhsBatchDimSet.insert(dimPair.second);
491 
492   // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
493   SmallVector<int64_t, 4> expectedResultDims;
494   for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
495     if (lhsContractingDimSet.count(i) > 0)
496       continue;
497     expectedResultDims.push_back(lhsType.getDimSize(i));
498   }
499 
500   // Add free dimensions from 'rhsType' to 'expectedResultDims'.
501   for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
502     if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
503       continue;
504     expectedResultDims.push_back(rhsType.getDimSize(i));
505   }
506 
507   // Verify 'expectedResultDims'.
508   if (expectedResultDims.size() == 0) {
509     // No batch or free dimension implies a scalar result.
510     if (resType.isa<VectorType>() || accType.isa<VectorType>())
511       return op.emitOpError("invalid accumulator/result vector shape");
512   } else {
513     // At least one batch or free dimension implies a vector result.
514     auto resVectorType = resType.dyn_cast<VectorType>();
515     auto accVectorType = accType.dyn_cast<VectorType>();
516     if (!resVectorType || !accVectorType)
517       return op.emitOpError("invalid accumulator/result vector shape");
518 
519     // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
520     // types fully define the result vector type. This assumes the affine maps
521     // are well-formed, which must have been verified already.
522     MLIRContext *ctx = op.getContext();
523     AffineMap lhsMap = op.getIndexingMaps()[0];
524     AffineMap rhsMap = op.getIndexingMaps()[1];
525     SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
526     for (auto pair :
527          {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
528       VectorType v = pair.first;
529       auto map = pair.second;
530       for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
531         unsigned pos = map.getDimPosition(idx);
532         if (!extents[pos])
533           extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
534       }
535     }
536     assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
537            "expected extent along all dimensions.");
538 
539     AffineMap resMap = op.getIndexingMaps()[2];
540     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
541                                      /*symCount=*/0, extents, ctx);
542     // Compose the resMap with the extentsMap, which is a constant map.
543     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
544     assert(llvm::all_of(
545                expectedMap.getResults(),
546                [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
547            "expected constant extent along all dimensions.");
548     // Extract the expected shape and build the type.
549     auto expectedShape = llvm::to_vector<4>(
550         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
551           return e.cast<AffineConstantExpr>().getValue();
552         }));
553     auto expected =
554         VectorType::get(expectedShape, resVectorType.getElementType());
555     if (resVectorType != expected || accVectorType != expected)
556       return op.emitOpError(
557                  "invalid accumulator/result vector shape, expected: ")
558              << expected;
559   }
560   return success();
561 }
562 
verify(ContractionOp op)563 static LogicalResult verify(ContractionOp op) {
564   auto lhsType = op.getLhsType();
565   auto rhsType = op.getRhsType();
566   auto accType = op.getAccType();
567   auto resType = op.getResultType();
568 
569   // Verify that an indexing map was specified for each vector operand.
570   if (op.indexing_maps().size() != 3)
571     return op.emitOpError("expected an indexing map for each vector operand");
572 
573   // Verify that each index map has 'numIterators' inputs, no symbols, and
574   // that the number of map outputs equals the rank of its associated
575   // vector operand.
576   unsigned numIterators = op.iterator_types().getValue().size();
577   for (auto it : llvm::enumerate(op.indexing_maps())) {
578     auto index = it.index();
579     auto map = it.value().cast<AffineMapAttr>().getValue();
580     if (map.getNumSymbols() != 0)
581       return op.emitOpError("expected indexing map ")
582              << index << " to have no symbols";
583     auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
584     unsigned rank = vectorType ? vectorType.getShape().size() : 0;
585     // Verify that the map has the right number of inputs, outputs, and indices.
586     // This also correctly accounts for (..) -> () for rank-0 results.
587     if (map.getNumDims() != numIterators)
588       return op.emitOpError("expected indexing map ")
589              << index << " to have " << numIterators << " number of inputs";
590     if (map.getNumResults() != rank)
591       return op.emitOpError("expected indexing map ")
592              << index << " to have " << rank << " number of outputs";
593     if (!map.isProjectedPermutation())
594       return op.emitOpError("expected indexing map ")
595              << index << " to be a projected permutation of its inputs";
596   }
597 
598   auto contractingDimMap = op.getContractingDimMap();
599   auto batchDimMap = op.getBatchDimMap();
600 
601   // Verify at least one contracting dimension pair was specified.
602   if (contractingDimMap.empty())
603     return op.emitOpError("expected at least one contracting dimension pair");
604 
605   // Verify contracting dimension map was properly constructed.
606   if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
607     return op.emitOpError("invalid contracting dimension map");
608 
609   // Verify batch dimension map was properly constructed.
610   if (!verifyDimMap(lhsType, rhsType, batchDimMap))
611     return op.emitOpError("invalid batch dimension map");
612 
613   // Verify 'accType' and 'resType' shape.
614   if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
615                                contractingDimMap, batchDimMap)))
616     return failure();
617 
618   // Verify that either two vector masks are set or none are set.
619   auto lhsMaskType = op.getLHSVectorMaskType();
620   auto rhsMaskType = op.getRHSVectorMaskType();
621   if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
622     return op.emitOpError("invalid number of vector masks specified");
623   if (lhsMaskType && rhsMaskType) {
624     // Verify mask rank == argument rank.
625     if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
626         rhsMaskType.getShape().size() != rhsType.getShape().size())
627       return op.emitOpError("invalid vector mask rank");
628   }
629 
630   // Verify supported combining kind.
631   auto vectorType = resType.dyn_cast<VectorType>();
632   auto elementType = vectorType ? vectorType.getElementType() : resType;
633   if (!isSupportedCombiningKind(op.kind(), elementType))
634     return op.emitOpError("unsupported contraction type");
635 
636   return success();
637 }
638 
getTraitAttrNames()639 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
640   static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
641                                          getIteratorTypesAttrName(),
642                                          ContractionOp::getKindAttrName()};
643   return llvm::makeArrayRef(names);
644 }
645 
getResultIndex(AffineMap map,AffineExpr targetExpr)646 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
647   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
648     if (targetExpr == map.getResult(i))
649       return i;
650   return -1;
651 }
652 
653 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)654 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
655           StringRef targetIteratorTypeName, MLIRContext *context) {
656   std::vector<std::pair<int64_t, int64_t>> dimMap;
657   for (auto it : llvm::enumerate(iteratorTypes)) {
658     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
659     if (iteratorTypeName != targetIteratorTypeName)
660       continue;
661     // Search lhs/rhs map results for 'targetExpr'.
662     auto targetExpr = getAffineDimExpr(it.index(), context);
663     int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
664     int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
665     if (lhsDim >= 0 && rhsDim >= 0)
666       dimMap.push_back({lhsDim, rhsDim});
667   }
668   return dimMap;
669 }
670 
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)671 void ContractionOp::getIterationBounds(
672     SmallVectorImpl<int64_t> &iterationBounds) {
673   auto lhsShape = getLhsType().getShape();
674   auto resVectorType = getResultType().dyn_cast<VectorType>();
675   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
676   SmallVector<int64_t, 2> iterationShape;
677   for (auto it : llvm::enumerate(iterator_types())) {
678     // Search lhs/rhs map results for 'targetExpr'.
679     auto targetExpr = getAffineDimExpr(it.index(), getContext());
680     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
681     if (iteratorTypeName == getReductionIteratorTypeName()) {
682       // Get reduction dim size from lhs shape (same size in rhsShape).
683       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
684       assert(lhsDimIndex >= 0);
685       iterationBounds.push_back(lhsShape[lhsDimIndex]);
686       continue;
687     }
688     // Get parallel dimension size from result shape.
689     int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
690     assert(resDimIndex >= 0);
691     assert(resVectorType != nullptr);
692     iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
693   }
694 }
695 
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)696 void ContractionOp::getIterationIndexMap(
697     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
698   unsigned numMaps = indexing_maps().getValue().size();
699   iterationIndexMap.resize(numMaps);
700   for (auto it : llvm::enumerate(indexing_maps())) {
701     auto index = it.index();
702     auto map = it.value().cast<AffineMapAttr>().getValue();
703     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
704       auto dim = map.getResult(i).cast<AffineDimExpr>();
705       iterationIndexMap[index][dim.getPosition()] = i;
706     }
707   }
708 }
709 
getContractingDimMap()710 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
711   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
712   return getDimMap(indexingMaps, iterator_types(),
713                    getReductionIteratorTypeName(), getContext());
714 }
715 
getBatchDimMap()716 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
717   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
718   return getDimMap(indexingMaps, iterator_types(),
719                    getParallelIteratorTypeName(), getContext());
720 }
721 
getIndexingMaps()722 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
723   return llvm::to_vector<4>(
724       llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
725         return mapAttr.cast<AffineMapAttr>().getValue();
726       }));
727 }
728 
getShapeForUnroll()729 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
730   SmallVector<int64_t, 4> shape;
731   getIterationBounds(shape);
732   return shape;
733 }
734 
735 /// Return a fused vector::ContractionOp which represents a patterns such as:
736 ///
737 /// ```mlir
738 ///    %c0 = vector.constant 0: ...
739 ///    %c = vector.contract %a, %b, %c0: ...
740 ///    %e = add %c, %d: ...
741 /// ```
742 ///
743 /// by:
744 ///
745 /// ```mlir
746 ///    %e = vector.contract %a, %b, %d: ...
747 /// ```
748 ///
749 /// Return null if the canonicalization does not apply.
750 // TODO: This should be a folding of Add into Contract in core but while they
751 // live in different dialects, it is not possible without unnatural
752 // dependencies.
753 template <typename AddOpType>
754 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
755   using OpRewritePattern<AddOpType>::OpRewritePattern;
756 
matchAndRewriteCanonicalizeContractAdd757   LogicalResult matchAndRewrite(AddOpType addOp,
758                                 PatternRewriter &rewriter) const override {
759     auto canonicalize = [&](Value maybeContraction,
760                             Value otherOperand) -> vector::ContractionOp {
761       vector::ContractionOp contractionOp =
762           dyn_cast_or_null<vector::ContractionOp>(
763               maybeContraction.getDefiningOp());
764       if (!contractionOp)
765         return vector::ContractionOp();
766       if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
767               contractionOp.acc().getDefiningOp())) {
768         if (maybeZero.value() ==
769             rewriter.getZeroAttr(contractionOp.acc().getType())) {
770           BlockAndValueMapping bvm;
771           bvm.map(contractionOp.acc(), otherOperand);
772           auto newContraction =
773               cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
774           rewriter.replaceOp(addOp, newContraction.getResult());
775           return newContraction;
776         }
777       }
778       return vector::ContractionOp();
779     };
780 
781     Value a = addOp->getOperand(0), b = addOp->getOperand(1);
782     vector::ContractionOp contract = canonicalize(a, b);
783     contract = contract ? contract : canonicalize(b, a);
784     return success();
785   }
786 };
787 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)788 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
789                                                 MLIRContext *context) {
790   results.add<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
791       context);
792 }
793 
794 //===----------------------------------------------------------------------===//
795 // ExtractElementOp
796 //===----------------------------------------------------------------------===//
797 
build(OpBuilder & builder,OperationState & result,Value source,Value position)798 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
799                                      Value source, Value position) {
800   result.addOperands({source, position});
801   result.addTypes(source.getType().cast<VectorType>().getElementType());
802 }
803 
build(OpBuilder & builder,OperationState & result,Value source,int64_t position)804 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
805                                      Value source, int64_t position) {
806   Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
807   build(builder, result, source, pos);
808 }
809 
verify(vector::ExtractElementOp op)810 static LogicalResult verify(vector::ExtractElementOp op) {
811   VectorType vectorType = op.getVectorType();
812   if (vectorType.getRank() != 1)
813     return op.emitOpError("expected 1-D vector");
814   return success();
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // ExtractOp
819 //===----------------------------------------------------------------------===//
820 
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)821 static Type inferExtractOpResultType(VectorType vectorType,
822                                      ArrayAttr position) {
823   if (static_cast<int64_t>(position.size()) == vectorType.getRank())
824     return vectorType.getElementType();
825   return VectorType::get(vectorType.getShape().drop_front(position.size()),
826                          vectorType.getElementType());
827 }
828 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)829 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
830                               Value source, ArrayRef<int64_t> position) {
831   result.addOperands(source);
832   auto positionAttr = getVectorSubscriptAttr(builder, position);
833   result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
834                                            positionAttr));
835   result.addAttribute(getPositionAttrName(), positionAttr);
836 }
837 
838 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)839 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
840                               Value source, ValueRange position) {
841   SmallVector<int64_t, 4> positionConstants =
842       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
843         return pos.getDefiningOp<ConstantIndexOp>().getValue();
844       }));
845   build(builder, result, source, positionConstants);
846 }
847 
print(OpAsmPrinter & p,vector::ExtractOp op)848 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
849   p << op.getOperationName() << " " << op.vector() << op.position();
850   p.printOptionalAttrDict(op->getAttrs(), {"position"});
851   p << " : " << op.vector().getType();
852 }
853 
parseExtractOp(OpAsmParser & parser,OperationState & result)854 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
855   llvm::SMLoc attributeLoc, typeLoc;
856   NamedAttrList attrs;
857   OpAsmParser::OperandType vector;
858   Type type;
859   Attribute attr;
860   if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
861       parser.parseAttribute(attr, "position", attrs) ||
862       parser.parseOptionalAttrDict(attrs) ||
863       parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
864     return failure();
865 
866   auto vectorType = type.dyn_cast<VectorType>();
867   if (!vectorType)
868     return parser.emitError(typeLoc, "expected vector type");
869 
870   auto positionAttr = attr.dyn_cast<ArrayAttr>();
871   if (!positionAttr ||
872       static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
873     return parser.emitError(
874         attributeLoc,
875         "expected position attribute of rank smaller than vector rank");
876 
877   Type resType = inferExtractOpResultType(vectorType, positionAttr);
878   result.attributes = attrs;
879   return failure(parser.resolveOperand(vector, type, result.operands) ||
880                  parser.addTypeToList(resType, result.types));
881 }
882 
verify(vector::ExtractOp op)883 static LogicalResult verify(vector::ExtractOp op) {
884   auto positionAttr = op.position().getValue();
885   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
886     return op.emitOpError(
887         "expected position attribute of rank smaller than vector rank");
888   for (auto en : llvm::enumerate(positionAttr)) {
889     auto attr = en.value().dyn_cast<IntegerAttr>();
890     if (!attr || attr.getInt() < 0 ||
891         attr.getInt() >= op.getVectorType().getDimSize(en.index()))
892       return op.emitOpError("expected position attribute #")
893              << (en.index() + 1)
894              << " to be a non-negative integer smaller than the corresponding "
895                 "vector dimension";
896   }
897   return success();
898 }
899 
900 template <typename IntType>
extractVector(ArrayAttr arrayAttr)901 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
902   return llvm::to_vector<4>(llvm::map_range(
903       arrayAttr.getAsRange<IntegerAttr>(),
904       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
905 }
906 
907 /// Fold the result of chains of ExtractOp in place by simply concatenating the
908 /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)909 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
910   if (!extractOp.vector().getDefiningOp<ExtractOp>())
911     return failure();
912 
913   SmallVector<int64_t, 4> globalPosition;
914   ExtractOp currentOp = extractOp;
915   auto extractedPos = extractVector<int64_t>(currentOp.position());
916   globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
917   while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
918     currentOp = nextOp;
919     auto extractedPos = extractVector<int64_t>(currentOp.position());
920     globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
921   }
922   extractOp.setOperand(currentOp.vector());
923   // OpBuilder is only used as a helper to build an I64ArrayAttr.
924   OpBuilder b(extractOp.getContext());
925   std::reverse(globalPosition.begin(), globalPosition.end());
926   extractOp->setAttr(ExtractOp::getPositionAttrName(),
927                      b.getI64ArrayAttr(globalPosition));
928   return success();
929 }
930 
931 /// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
foldExtractOpFromTranspose(ExtractOp extractOp)932 static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
933   auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
934   if (!transposeOp)
935     return failure();
936 
937   auto permutation = extractVector<unsigned>(transposeOp.transp());
938   auto extractedPos = extractVector<int64_t>(extractOp.position());
939 
940   // If transposition permutation is larger than the ExtractOp, all minor
941   // dimensions must be an identity for folding to occur. If not, individual
942   // elements within the extracted value are transposed and this is not just a
943   // simple folding.
944   unsigned minorRank = permutation.size() - extractedPos.size();
945   MLIRContext *ctx = extractOp.getContext();
946   AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
947   AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
948   if (minorMap && !minorMap.isMinorIdentity())
949     return failure();
950 
951   //   %1 = transpose %0[x, y, z] : vector<axbxcxf32>
952   //   %2 = extract %1[u, v] : vector<..xf32>
953   // may turn into:
954   //   %2 = extract %0[w, x] : vector<..xf32>
955   // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
956   // -1 denotes the inverse.
957   permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
958   // The major submap has fewer results but the same number of dims. To compose
959   // cleanly, we need to drop dims to form a "square matrix". This is possible
960   // because:
961   //   (a) this is a permutation map and
962   //   (b) the minor map has already been checked to be identity.
963   // Therefore, the major map cannot contain dims of position greater or equal
964   // than the number of results.
965   assert(llvm::all_of(permutationMap.getResults(),
966                       [&](AffineExpr e) {
967                         auto dim = e.dyn_cast<AffineDimExpr>();
968                         return dim && dim.getPosition() <
969                                           permutationMap.getNumResults();
970                       }) &&
971          "Unexpected map results depend on higher rank positions");
972   // Project on the first domain dimensions to allow composition.
973   permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
974                                   permutationMap.getResults(), ctx);
975 
976   extractOp.setOperand(transposeOp.vector());
977   // Compose the inverse permutation map with the extractedPos.
978   auto newExtractedPos =
979       inversePermutation(permutationMap).compose(extractedPos);
980   // OpBuilder is only used as a helper to build an I64ArrayAttr.
981   OpBuilder b(extractOp.getContext());
982   extractOp->setAttr(ExtractOp::getPositionAttrName(),
983                      b.getI64ArrayAttr(newExtractedPos));
984 
985   return success();
986 }
987 
988 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
989 /// result is always the input to some InsertOp.
foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp)990 static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
991   MLIRContext *context = extractOp.getContext();
992   AffineMap permutationMap;
993   auto extractedPos = extractVector<unsigned>(extractOp.position());
994   // Walk back a chain of InsertOp/TransposeOp until we hit a match.
995   // Compose TransposeOp permutations as we walk back.
996   auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
997   auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
998   while (insertOp || transposeOp) {
999     if (transposeOp) {
1000       // If it is transposed, compose the map and iterate.
1001       auto permutation = extractVector<unsigned>(transposeOp.transp());
1002       AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
1003       if (!permutationMap)
1004         permutationMap = newMap;
1005       else if (newMap.getNumInputs() != permutationMap.getNumResults())
1006         return Value();
1007       else
1008         permutationMap = newMap.compose(permutationMap);
1009       // Compute insert/transpose for the next iteration.
1010       Value transposed = transposeOp.vector();
1011       insertOp = transposed.getDefiningOp<vector::InsertOp>();
1012       transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
1013       continue;
1014     }
1015 
1016     assert(insertOp);
1017     Value insertionDest = insertOp.dest();
1018     // If it is inserted into, either the position matches and we have a
1019     // successful folding; or we iterate until we run out of
1020     // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
1021     // produces a new vector with 1 modified value/slice in exactly the static
1022     // position we need to match.
1023     auto insertedPos = extractVector<unsigned>(insertOp.position());
1024     // Trivial permutations are solved with position equality checks.
1025     if (!permutationMap || permutationMap.isIdentity()) {
1026       if (extractedPos == insertedPos)
1027         return insertOp.source();
1028       // Fallthrough: if the position does not match, just skip to the next
1029       // producing `vector.insert` / `vector.transpose`.
1030       // Compute insert/transpose for the next iteration.
1031       insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1032       transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1033       continue;
1034     }
1035 
1036     // More advanced permutations require application of the permutation.
1037     // However, the rank of `insertedPos` may be different from that of the
1038     // `permutationMap`. To support such case, we need to:
1039     //   1. apply on the `insertedPos.size()` major dimensions
1040     //   2. check the other dimensions of the permutation form a minor identity.
1041     assert(permutationMap.isPermutation() && "expected a permutation");
1042     if (insertedPos.size() == extractedPos.size()) {
1043       bool fold = true;
1044       for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
1045         auto pos = permutationMap.getDimPosition(idx);
1046         if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
1047           fold = false;
1048           break;
1049         }
1050       }
1051       if (fold) {
1052         assert(permutationMap.getNumResults() >= insertedPos.size() &&
1053                "expected map of rank larger than insert indexing");
1054         unsigned minorRank =
1055             permutationMap.getNumResults() - insertedPos.size();
1056         AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
1057         if (!minorMap || minorMap.isMinorIdentity())
1058           return insertOp.source();
1059       }
1060     }
1061 
1062     // If we haven't found a match, just continue to the next producing
1063     // `vector.insert` / `vector.transpose`.
1064     // Compute insert/transpose for the next iteration.
1065     insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
1066     transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
1067   }
1068   return Value();
1069 }
1070 
1071 /// Fold extractOp with scalar result coming from BroadcastOp.
foldExtractFromBroadcast(ExtractOp extractOp)1072 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1073   auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
1074   if (!broadcastOp)
1075     return Value();
1076   if (extractOp.getType() == broadcastOp.getSourceType())
1077     return broadcastOp.source();
1078   auto getRank = [](Type type) {
1079     return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1080   };
1081   unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
1082   unsigned extractResultRank = getRank(extractOp.getType());
1083   if (extractResultRank < broadcasrSrcRank) {
1084     auto extractPos = extractVector<int64_t>(extractOp.position());
1085     unsigned rankDiff = broadcasrSrcRank - extractResultRank;
1086     extractPos.erase(
1087         extractPos.begin(),
1088         std::next(extractPos.begin(), extractPos.size() - rankDiff));
1089     extractOp.setOperand(broadcastOp.source());
1090     // OpBuilder is only used as a helper to build an I64ArrayAttr.
1091     OpBuilder b(extractOp.getContext());
1092     extractOp->setAttr(ExtractOp::getPositionAttrName(),
1093                        b.getI64ArrayAttr(extractPos));
1094     return extractOp.getResult();
1095   }
1096   // TODO: In case the rank of the broadcast source is greater than the rank of
1097   // the extract result this can be combined into a new broadcast op. This needs
1098   // to be added a canonicalization pattern if needed.
1099   return Value();
1100 }
1101 
1102 // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)1103 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1104   auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
1105   if (!shapeCastOp)
1106     return Value();
1107   // Get the nth dimension size starting from lowest dimension.
1108   auto getDimReverse = [](VectorType type, int64_t n) {
1109     return type.getShape().take_back(n + 1).front();
1110   };
1111   int64_t destinationRank =
1112       extractOp.getType().isa<VectorType>()
1113           ? extractOp.getType().cast<VectorType>().getRank()
1114           : 0;
1115   if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1116     return Value();
1117   if (destinationRank > 0) {
1118     auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1119     for (int64_t i = 0; i < destinationRank; i++) {
1120       // The lowest dimension of of the destination must match the lowest
1121       // dimension of the shapecast op source.
1122       // TODO: This case could be support in a canonicalization pattern.
1123       if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1124           getDimReverse(destinationType, i))
1125         return Value();
1126     }
1127   }
1128   // Extract the strides associated with the extract op vector source. Then use
1129   // this to calculate a linearized position for the extract.
1130   auto extractedPos = extractVector<int64_t>(extractOp.position());
1131   std::reverse(extractedPos.begin(), extractedPos.end());
1132   SmallVector<int64_t, 4> strides;
1133   int64_t stride = 1;
1134   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1135     strides.push_back(stride);
1136     stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1137   }
1138 
1139   int64_t position = linearize(extractedPos, strides);
1140   // Then extract the strides associated to the shapeCast op vector source and
1141   // delinearize the position using those strides.
1142   SmallVector<int64_t, 4> newStrides;
1143   int64_t numDimension =
1144       shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1145   stride = 1;
1146   for (int64_t i = 0; i < numDimension; i++) {
1147     newStrides.push_back(stride);
1148     stride *=
1149         getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1150   }
1151   std::reverse(newStrides.begin(), newStrides.end());
1152   SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1153   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1154   OpBuilder b(extractOp.getContext());
1155   extractOp->setAttr(ExtractOp::getPositionAttrName(),
1156                      b.getI64ArrayAttr(newPosition));
1157   extractOp.setOperand(shapeCastOp.source());
1158   return extractOp.getResult();
1159 }
1160 
fold(ArrayRef<Attribute>)1161 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1162   if (position().empty())
1163     return vector();
1164   if (succeeded(foldExtractOpFromExtractChain(*this)))
1165     return getResult();
1166   if (succeeded(foldExtractOpFromTranspose(*this)))
1167     return getResult();
1168   if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
1169     return val;
1170   if (auto val = foldExtractFromBroadcast(*this))
1171     return val;
1172   if (auto val = foldExtractFromShapeCast(*this))
1173     return val;
1174   return OpFoldResult();
1175 }
1176 
1177 namespace {
1178 
1179 // If extractOp is only removing unit dimensions it can be transformed to a
1180 // shapecast.
1181 class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
1182 public:
1183   using OpRewritePattern<ExtractOp>::OpRewritePattern;
1184 
matchAndRewrite(ExtractOp extractOp,PatternRewriter & rewriter) const1185   LogicalResult matchAndRewrite(ExtractOp extractOp,
1186                                 PatternRewriter &rewriter) const override {
1187     auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
1188     if (!dstVecType || extractOp.getVectorType().getNumElements() !=
1189                            dstVecType.getNumElements())
1190       return failure();
1191     rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
1192                                              extractOp.vector());
1193     return success();
1194   }
1195 };
1196 
1197 } // namespace
1198 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1199 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1200                                             MLIRContext *context) {
1201   results.add<ExtractToShapeCast>(context);
1202 }
1203 
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)1204 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1205                                        SmallVectorImpl<int64_t> &results) {
1206   for (auto attr : arrayAttr)
1207     results.push_back(attr.cast<IntegerAttr>().getInt());
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // ExtractMapOp
1212 //===----------------------------------------------------------------------===//
1213 
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)1214 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1215                          Value vector, ValueRange ids,
1216                          ArrayRef<int64_t> multiplicity,
1217                          AffineMap permutationMap) {
1218   assert(ids.size() == multiplicity.size() &&
1219          ids.size() == permutationMap.getNumResults());
1220   assert(permutationMap.isProjectedPermutation());
1221   VectorType type = vector.getType().cast<VectorType>();
1222   SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1223                                    type.getShape().end());
1224   for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1225     AffineExpr expr = permutationMap.getResult(i);
1226     auto dim = expr.cast<AffineDimExpr>();
1227     newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1228   }
1229   VectorType resultType = VectorType::get(newShape, type.getElementType());
1230   ExtractMapOp::build(builder, result, resultType, vector, ids);
1231 }
1232 
verify(ExtractMapOp op)1233 static LogicalResult verify(ExtractMapOp op) {
1234   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1235     return op.emitOpError(
1236         "expected source and destination vectors of same rank");
1237   unsigned numId = 0;
1238   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1239     if (op.getSourceVectorType().getDimSize(i) %
1240             op.getResultType().getDimSize(i) !=
1241         0)
1242       return op.emitOpError("source vector dimensions must be a multiple of "
1243                             "destination vector dimensions");
1244     if (op.getSourceVectorType().getDimSize(i) !=
1245         op.getResultType().getDimSize(i))
1246       numId++;
1247   }
1248   if (numId != op.ids().size())
1249     return op.emitOpError("expected number of ids must match the number of "
1250                           "dimensions distributed");
1251   return success();
1252 }
1253 
fold(ArrayRef<Attribute> operands)1254 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1255   auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1256   if (insert == nullptr || getType() != insert.vector().getType() ||
1257       ids() != insert.ids())
1258     return {};
1259   return insert.vector();
1260 }
1261 
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)1262 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1263   assert(multiplicity.empty());
1264   for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1265     if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1266       multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1267                              getResultType().getDimSize(i));
1268   }
1269 }
1270 
1271 template <typename MapOp>
calculateImplicitMap(MapOp op)1272 AffineMap calculateImplicitMap(MapOp op) {
1273   SmallVector<AffineExpr, 4> perm;
1274   // Check which dimension have a multiplicity greater than 1 and associated
1275   // them to the IDs in order.
1276   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1277     if (op.getSourceVectorType().getDimSize(i) !=
1278         op.getResultType().getDimSize(i))
1279       perm.push_back(getAffineDimExpr(i, op.getContext()));
1280   }
1281   auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1282                             op.getContext());
1283   return map;
1284 }
1285 
map()1286 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // FmaOp
1290 //===----------------------------------------------------------------------===//
1291 
getShapeForUnroll()1292 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1293   return llvm::to_vector<4>(getVectorType().getShape());
1294 }
1295 
1296 //===----------------------------------------------------------------------===//
1297 // BroadcastOp
1298 //===----------------------------------------------------------------------===//
1299 
verify(BroadcastOp op)1300 static LogicalResult verify(BroadcastOp op) {
1301   VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1302   VectorType dstVectorType = op.getVectorType();
1303   // Scalar to vector broadcast is always valid. A vector
1304   // to vector broadcast needs some additional checking.
1305   if (srcVectorType) {
1306     int64_t srcRank = srcVectorType.getRank();
1307     int64_t dstRank = dstVectorType.getRank();
1308     if (srcRank > dstRank)
1309       return op.emitOpError("source rank higher than destination rank");
1310     // Source has an exact match or singleton value for all trailing dimensions
1311     // (all leading dimensions are simply duplicated).
1312     int64_t lead = dstRank - srcRank;
1313     for (int64_t r = 0; r < srcRank; ++r) {
1314       int64_t srcDim = srcVectorType.getDimSize(r);
1315       int64_t dstDim = dstVectorType.getDimSize(lead + r);
1316       if (srcDim != 1 && srcDim != dstDim)
1317         return op.emitOpError("dimension mismatch (")
1318                << srcDim << " vs. " << dstDim << ")";
1319     }
1320   }
1321   return success();
1322 }
1323 
fold(ArrayRef<Attribute> operands)1324 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1325   if (!operands[0])
1326     return {};
1327   auto vectorType = getVectorType();
1328   if (operands[0].getType().isIntOrIndexOrFloat())
1329     return DenseElementsAttr::get(vectorType, operands[0]);
1330   if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1331     return DenseElementsAttr::get(vectorType, attr.getSplatValue());
1332   return {};
1333 }
1334 
1335 namespace {
1336 
1337 // BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
1338 // the degenerated case where the broadcast only adds dimensions of size 1 it
1339 // can be replaced by a ShapeCastOp. This canonicalization checks if the total
1340 // number of elements is the same before and after the broadcast to detect if
1341 // the only change in the vector type are new dimensions of size 1.
1342 class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
1343 public:
1344   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1345 
matchAndRewrite(BroadcastOp broadcastOp,PatternRewriter & rewriter) const1346   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1347                                 PatternRewriter &rewriter) const override {
1348     auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
1349     if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
1350                            srcVecType.getNumElements())
1351       return failure();
1352     rewriter.replaceOpWithNewOp<ShapeCastOp>(
1353         broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
1354     return success();
1355   }
1356 };
1357 
1358 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1359 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1360   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1361 
matchAndRewrite__anon80a4b5030e11::BroadcastFolder1362   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1363                                 PatternRewriter &rewriter) const override {
1364     auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
1365     if (!srcBroadcast)
1366       return failure();
1367     rewriter.replaceOpWithNewOp<BroadcastOp>(
1368         broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
1369     return success();
1370   }
1371 };
1372 } // namespace
1373 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1374 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375                                               MLIRContext *context) {
1376   results.add<BroadcastToShapeCast, BroadcastFolder>(context);
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // ShuffleOp
1381 //===----------------------------------------------------------------------===//
1382 
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)1383 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1384                       Value v2, ArrayRef<int64_t> mask) {
1385   result.addOperands({v1, v2});
1386   auto maskAttr = getVectorSubscriptAttr(builder, mask);
1387   result.addTypes(v1.getType());
1388   result.addAttribute(getMaskAttrName(), maskAttr);
1389 }
1390 
print(OpAsmPrinter & p,ShuffleOp op)1391 static void print(OpAsmPrinter &p, ShuffleOp op) {
1392   p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
1393     << op.mask();
1394   p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()});
1395   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1396 }
1397 
verify(ShuffleOp op)1398 static LogicalResult verify(ShuffleOp op) {
1399   VectorType resultType = op.getVectorType();
1400   VectorType v1Type = op.getV1VectorType();
1401   VectorType v2Type = op.getV2VectorType();
1402   // Verify ranks.
1403   int64_t resRank = resultType.getRank();
1404   int64_t v1Rank = v1Type.getRank();
1405   int64_t v2Rank = v2Type.getRank();
1406   if (resRank != v1Rank || v1Rank != v2Rank)
1407     return op.emitOpError("rank mismatch");
1408   // Verify all but leading dimension sizes.
1409   for (int64_t r = 1; r < v1Rank; ++r) {
1410     int64_t resDim = resultType.getDimSize(r);
1411     int64_t v1Dim = v1Type.getDimSize(r);
1412     int64_t v2Dim = v2Type.getDimSize(r);
1413     if (resDim != v1Dim || v1Dim != v2Dim)
1414       return op.emitOpError("dimension mismatch");
1415   }
1416   // Verify mask length.
1417   auto maskAttr = op.mask().getValue();
1418   int64_t maskLength = maskAttr.size();
1419   if (maskLength != resultType.getDimSize(0))
1420     return op.emitOpError("mask length mismatch");
1421   // Verify all indices.
1422   int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1423   for (auto en : llvm::enumerate(maskAttr)) {
1424     auto attr = en.value().dyn_cast<IntegerAttr>();
1425     if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1426       return op.emitOpError("mask index #")
1427              << (en.index() + 1) << " out of range";
1428   }
1429   return success();
1430 }
1431 
parseShuffleOp(OpAsmParser & parser,OperationState & result)1432 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1433   OpAsmParser::OperandType v1, v2;
1434   Attribute attr;
1435   VectorType v1Type, v2Type;
1436   if (parser.parseOperand(v1) || parser.parseComma() ||
1437       parser.parseOperand(v2) ||
1438       parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1439                             result.attributes) ||
1440       parser.parseOptionalAttrDict(result.attributes) ||
1441       parser.parseColonType(v1Type) || parser.parseComma() ||
1442       parser.parseType(v2Type) ||
1443       parser.resolveOperand(v1, v1Type, result.operands) ||
1444       parser.resolveOperand(v2, v2Type, result.operands))
1445     return failure();
1446   // Construct resulting type: leading dimension matches mask length,
1447   // all trailing dimensions match the operands.
1448   auto maskAttr = attr.dyn_cast<ArrayAttr>();
1449   if (!maskAttr)
1450     return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1451   int64_t maskLength = maskAttr.size();
1452   if (maskLength <= 0)
1453     return parser.emitError(parser.getNameLoc(), "invalid mask length");
1454   int64_t v1Rank = v1Type.getRank();
1455   SmallVector<int64_t, 4> shape;
1456   shape.reserve(v1Rank);
1457   shape.push_back(maskLength);
1458   for (int64_t r = 1; r < v1Rank; ++r)
1459     shape.push_back(v1Type.getDimSize(r));
1460   VectorType resType = VectorType::get(shape, v1Type.getElementType());
1461   parser.addTypeToList(resType, result.types);
1462   return success();
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // InsertElementOp
1467 //===----------------------------------------------------------------------===//
1468 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,Value position)1469 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1470                             Value source, Value dest, Value position) {
1471   result.addOperands({source, dest, position});
1472   result.addTypes(dest.getType());
1473 }
1474 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,int64_t position)1475 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1476                             Value source, Value dest, int64_t position) {
1477   Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
1478   build(builder, result, source, dest, pos);
1479 }
1480 
verify(InsertElementOp op)1481 static LogicalResult verify(InsertElementOp op) {
1482   auto dstVectorType = op.getDestVectorType();
1483   if (dstVectorType.getRank() != 1)
1484     return op.emitOpError("expected 1-D vector");
1485   return success();
1486 }
1487 
1488 //===----------------------------------------------------------------------===//
1489 // InsertOp
1490 //===----------------------------------------------------------------------===//
1491 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)1492 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1493                      Value dest, ArrayRef<int64_t> position) {
1494   result.addOperands({source, dest});
1495   auto positionAttr = getVectorSubscriptAttr(builder, position);
1496   result.addTypes(dest.getType());
1497   result.addAttribute(getPositionAttrName(), positionAttr);
1498 }
1499 
1500 // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)1501 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1502                      Value dest, ValueRange position) {
1503   SmallVector<int64_t, 4> positionConstants =
1504       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1505         return pos.getDefiningOp<ConstantIndexOp>().getValue();
1506       }));
1507   build(builder, result, source, dest, positionConstants);
1508 }
1509 
verify(InsertOp op)1510 static LogicalResult verify(InsertOp op) {
1511   auto positionAttr = op.position().getValue();
1512   auto destVectorType = op.getDestVectorType();
1513   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1514     return op.emitOpError(
1515         "expected position attribute of rank smaller than dest vector rank");
1516   auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1517   if (srcVectorType &&
1518       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1519        static_cast<unsigned>(destVectorType.getRank())))
1520     return op.emitOpError("expected position attribute rank + source rank to "
1521                           "match dest vector rank");
1522   else if (!srcVectorType && (positionAttr.size() !=
1523                               static_cast<unsigned>(destVectorType.getRank())))
1524     return op.emitOpError(
1525         "expected position attribute rank to match the dest vector rank");
1526   for (auto en : llvm::enumerate(positionAttr)) {
1527     auto attr = en.value().dyn_cast<IntegerAttr>();
1528     if (!attr || attr.getInt() < 0 ||
1529         attr.getInt() >= destVectorType.getDimSize(en.index()))
1530       return op.emitOpError("expected position attribute #")
1531              << (en.index() + 1)
1532              << " to be a non-negative integer smaller than the corresponding "
1533                 "dest vector dimension";
1534   }
1535   return success();
1536 }
1537 
1538 namespace {
1539 
1540 // If insertOp is only inserting unit dimensions it can be transformed to a
1541 // shapecast.
1542 class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
1543 public:
1544   using OpRewritePattern<InsertOp>::OpRewritePattern;
1545 
matchAndRewrite(InsertOp insertOp,PatternRewriter & rewriter) const1546   LogicalResult matchAndRewrite(InsertOp insertOp,
1547                                 PatternRewriter &rewriter) const override {
1548     auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1549     if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1550                            srcVecType.getNumElements())
1551       return failure();
1552     rewriter.replaceOpWithNewOp<ShapeCastOp>(
1553         insertOp, insertOp.getDestVectorType(), insertOp.source());
1554     return success();
1555   }
1556 };
1557 
1558 } // namespace
1559 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1560 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1561                                            MLIRContext *context) {
1562   results.add<InsertToShapeCast>(context);
1563 }
1564 
1565 // Eliminates insert operations that produce values identical to their source
1566 // value. This happens when the source and destination vectors have identical
1567 // sizes.
fold(ArrayRef<Attribute> operands)1568 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
1569   if (position().empty())
1570     return source();
1571   return {};
1572 }
1573 
1574 //===----------------------------------------------------------------------===//
1575 // InsertMapOp
1576 //===----------------------------------------------------------------------===//
1577 
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange ids)1578 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1579                         Value vector, Value dest, ValueRange ids) {
1580   InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1581 }
1582 
verify(InsertMapOp op)1583 static LogicalResult verify(InsertMapOp op) {
1584   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1585     return op.emitOpError(
1586         "expected source and destination vectors of same rank");
1587   unsigned numId = 0;
1588   for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1589     if (op.getResultType().getDimSize(i) %
1590             op.getSourceVectorType().getDimSize(i) !=
1591         0)
1592       return op.emitOpError(
1593           "destination vector size must be a multiple of source vector size");
1594     if (op.getResultType().getDimSize(i) !=
1595         op.getSourceVectorType().getDimSize(i))
1596       numId++;
1597   }
1598   if (numId != op.ids().size())
1599     return op.emitOpError("expected number of ids must match the number of "
1600                           "dimensions distributed");
1601   return success();
1602 }
1603 
map()1604 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1605 
1606 //===----------------------------------------------------------------------===//
1607 // InsertStridedSliceOp
1608 //===----------------------------------------------------------------------===//
1609 
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)1610 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1611                                  Value source, Value dest,
1612                                  ArrayRef<int64_t> offsets,
1613                                  ArrayRef<int64_t> strides) {
1614   result.addOperands({source, dest});
1615   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1616   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1617   result.addTypes(dest.getType());
1618   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1619   result.addAttribute(getStridesAttrName(), stridesAttr);
1620 }
1621 
1622 // TODO: Should be moved to Tablegen Confined attributes.
1623 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)1624 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1625                                                         ArrayAttr arrayAttr,
1626                                                         ArrayRef<int64_t> shape,
1627                                                         StringRef attrName) {
1628   if (arrayAttr.size() > shape.size())
1629     return op.emitOpError("expected ")
1630            << attrName << " attribute of rank smaller than vector rank";
1631   return success();
1632 }
1633 
1634 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1635 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1636 // Otherwise, the admissible interval is [min, max].
1637 template <typename OpType>
1638 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)1639 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1640                                   int64_t max, StringRef attrName,
1641                                   bool halfOpen = true) {
1642   for (auto attr : arrayAttr) {
1643     auto val = attr.cast<IntegerAttr>().getInt();
1644     auto upper = max;
1645     if (!halfOpen)
1646       upper += 1;
1647     if (val < min || val >= upper)
1648       return op.emitOpError("expected ") << attrName << " to be confined to ["
1649                                          << min << ", " << upper << ")";
1650   }
1651   return success();
1652 }
1653 
1654 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1655 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1656 // Otherwise, the admissible interval is [min, max].
1657 template <typename OpType>
1658 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)1659 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1660                                   ArrayRef<int64_t> shape, StringRef attrName,
1661                                   bool halfOpen = true, int64_t min = 0) {
1662   assert(arrayAttr.size() <= shape.size());
1663   unsigned index = 0;
1664   for (auto it : llvm::zip(arrayAttr, shape)) {
1665     auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1666     auto max = std::get<1>(it);
1667     if (!halfOpen)
1668       max += 1;
1669     if (val < min || val >= max)
1670       return op.emitOpError("expected ")
1671              << attrName << " dimension " << index << " to be confined to ["
1672              << min << ", " << max << ")";
1673     ++index;
1674   }
1675   return success();
1676 }
1677 
1678 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1679 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1680 // Otherwise, the admissible interval is [min, max].
1681 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)1682 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
1683     OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1684     ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1685     bool halfOpen = true, int64_t min = 1) {
1686   assert(arrayAttr1.size() <= shape.size());
1687   assert(arrayAttr2.size() <= shape.size());
1688   unsigned index = 0;
1689   for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1690     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1691     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1692     auto max = std::get<2>(it);
1693     if (!halfOpen)
1694       max += 1;
1695     if (val1 + val2 < 0 || val1 + val2 >= max)
1696       return op.emitOpError("expected sum(")
1697              << attrName1 << ", " << attrName2 << ") dimension " << index
1698              << " to be confined to [" << min << ", " << max << ")";
1699     ++index;
1700   }
1701   return success();
1702 }
1703 
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1704 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1705                                   MLIRContext *context) {
1706   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1707     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
1708   });
1709   return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
1710 }
1711 
verify(InsertStridedSliceOp op)1712 static LogicalResult verify(InsertStridedSliceOp op) {
1713   auto sourceVectorType = op.getSourceVectorType();
1714   auto destVectorType = op.getDestVectorType();
1715   auto offsets = op.offsets();
1716   auto strides = op.strides();
1717   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1718     return op.emitOpError(
1719         "expected offsets of same size as destination vector rank");
1720   if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1721     return op.emitOpError(
1722         "expected strides of same size as source vector rank");
1723   if (sourceVectorType.getRank() > destVectorType.getRank())
1724     return op.emitOpError(
1725         "expected source rank to be smaller than destination rank");
1726 
1727   auto sourceShape = sourceVectorType.getShape();
1728   auto destShape = destVectorType.getShape();
1729   SmallVector<int64_t, 4> sourceShapeAsDestShape(
1730       destShape.size() - sourceShape.size(), 0);
1731   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1732   auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1733   auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1734   if (failed(
1735           isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1736       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1737                                                /*halfOpen=*/false)) ||
1738       failed(isSumOfIntegerArrayAttrConfinedToShape(
1739           op, offsets,
1740           makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1741           offName, "source vector shape",
1742           /*halfOpen=*/false, /*min=*/1)))
1743     return failure();
1744 
1745   return success();
1746 }
1747 
1748 //===----------------------------------------------------------------------===//
1749 // OuterProductOp
1750 //===----------------------------------------------------------------------===//
1751 
1752 /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)1753 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
1754                            Value lhs, Value rhs, Value acc) {
1755   result.addOperands({lhs, rhs, acc});
1756   result.addTypes(acc.getType());
1757 }
1758 
print(OpAsmPrinter & p,OuterProductOp op)1759 static void print(OpAsmPrinter &p, OuterProductOp op) {
1760   p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
1761   if (!op.acc().empty()) {
1762     p << ", " << op.acc();
1763     p.printOptionalAttrDict(op->getAttrs());
1764   }
1765   p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1766 }
1767 
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1768 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1769                                        OperationState &result) {
1770   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1771   Type tLHS, tRHS;
1772   if (parser.parseOperandList(operandsInfo) ||
1773       parser.parseOptionalAttrDict(result.attributes) ||
1774       parser.parseColonType(tLHS) || parser.parseComma() ||
1775       parser.parseType(tRHS))
1776     return failure();
1777   if (operandsInfo.size() < 2)
1778     return parser.emitError(parser.getNameLoc(),
1779                             "expected at least 2 operands");
1780   VectorType vLHS = tLHS.dyn_cast<VectorType>();
1781   VectorType vRHS = tRHS.dyn_cast<VectorType>();
1782   if (!vLHS)
1783     return parser.emitError(parser.getNameLoc(),
1784                             "expected vector type for operand #1");
1785   VectorType resType =
1786       vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1787                              vLHS.getElementType())
1788            : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
1789 
1790   if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
1791     result.attributes.append(
1792         OuterProductOp::getKindAttrName(),
1793         CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
1794                                result.getContext()));
1795   }
1796 
1797   return failure(
1798       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1799       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1800       (operandsInfo.size() > 2 &&
1801        parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1802       parser.addTypeToList(resType, result.types));
1803 }
1804 
verify(OuterProductOp op)1805 static LogicalResult verify(OuterProductOp op) {
1806   Type tRHS = op.getOperandTypeRHS();
1807   VectorType vLHS = op.getOperandVectorTypeLHS(),
1808              vRHS = tRHS.dyn_cast<VectorType>(),
1809              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1810 
1811   if (vLHS.getRank() != 1)
1812     return op.emitOpError("expected 1-d vector for operand #1");
1813 
1814   if (vRHS) {
1815     // Proper OUTER operation.
1816     if (vRHS.getRank() != 1)
1817       return op.emitOpError("expected 1-d vector for operand #2");
1818     if (vRES.getRank() != 2)
1819       return op.emitOpError("expected 2-d vector result");
1820     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1821       return op.emitOpError("expected #1 operand dim to match result dim #1");
1822     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1823       return op.emitOpError("expected #2 operand dim to match result dim #2");
1824   } else {
1825     // An AXPY operation.
1826     if (vRES.getRank() != 1)
1827       return op.emitOpError("expected 1-d vector result");
1828     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1829       return op.emitOpError("expected #1 operand dim to match result dim #1");
1830   }
1831 
1832   if (vACC && vACC != vRES)
1833     return op.emitOpError("expected operand #3 of same type as result type");
1834 
1835   // Verify supported combining kind.
1836   if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
1837     return op.emitOpError("unsupported outerproduct type");
1838 
1839   return success();
1840 }
1841 
1842 //===----------------------------------------------------------------------===//
1843 // ReshapeOp
1844 //===----------------------------------------------------------------------===//
1845 
verify(ReshapeOp op)1846 static LogicalResult verify(ReshapeOp op) {
1847   // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1848   auto inputVectorType = op.getInputVectorType();
1849   auto outputVectorType = op.getOutputVectorType();
1850   int64_t inputShapeRank = op.getNumInputShapeSizes();
1851   int64_t outputShapeRank = op.getNumOutputShapeSizes();
1852   SmallVector<int64_t, 4> fixedVectorSizes;
1853   op.getFixedVectorSizes(fixedVectorSizes);
1854   int64_t numFixedVectorSizes = fixedVectorSizes.size();
1855 
1856   if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1857     return op.emitError("invalid input shape for vector type ")
1858            << inputVectorType;
1859 
1860   if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1861     return op.emitError("invalid output shape for vector type ")
1862            << outputVectorType;
1863 
1864   // Verify that the 'fixedVectorSizes' match an input/output vector shape
1865   // suffix.
1866   unsigned inputVectorRank = inputVectorType.getRank();
1867   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1868     unsigned index = inputVectorRank - numFixedVectorSizes - i;
1869     if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1870       return op.emitError("fixed vector size must match input vector for dim ")
1871              << i;
1872   }
1873 
1874   unsigned outputVectorRank = outputVectorType.getRank();
1875   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1876     unsigned index = outputVectorRank - numFixedVectorSizes - i;
1877     if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1878       return op.emitError("fixed vector size must match output vector for dim ")
1879              << i;
1880   }
1881 
1882   // If all shape operands are produced by constant ops, verify that product
1883   // of dimensions for input/output shape match.
1884   auto isDefByConstant = [](Value operand) {
1885     return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1886   };
1887   if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1888       llvm::all_of(op.output_shape(), isDefByConstant)) {
1889     int64_t numInputElements = 1;
1890     for (auto operand : op.input_shape())
1891       numInputElements *=
1892           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1893     int64_t numOutputElements = 1;
1894     for (auto operand : op.output_shape())
1895       numOutputElements *=
1896           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1897     if (numInputElements != numOutputElements)
1898       return op.emitError("product of input and output shape sizes must match");
1899   }
1900   return success();
1901 }
1902 
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1903 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1904   populateFromInt64AttrArray(fixed_vector_sizes(), results);
1905 }
1906 
1907 //===----------------------------------------------------------------------===//
1908 // ExtractStridedSliceOp
1909 //===----------------------------------------------------------------------===//
1910 
1911 // Inference works as follows:
1912 //   1. Add 'sizes' from prefix of dims in 'offsets'.
1913 //   2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1914 static Type inferStridedSliceOpResultType(VectorType vectorType,
1915                                           ArrayAttr offsets, ArrayAttr sizes,
1916                                           ArrayAttr strides) {
1917   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1918   SmallVector<int64_t, 4> shape;
1919   shape.reserve(vectorType.getRank());
1920   unsigned idx = 0;
1921   for (unsigned e = offsets.size(); idx < e; ++idx)
1922     shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
1923   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1924     shape.push_back(vectorType.getShape()[idx]);
1925 
1926   return VectorType::get(shape, vectorType.getElementType());
1927 }
1928 
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1929 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1930                                   Value source, ArrayRef<int64_t> offsets,
1931                                   ArrayRef<int64_t> sizes,
1932                                   ArrayRef<int64_t> strides) {
1933   result.addOperands(source);
1934   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1935   auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
1936   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1937   result.addTypes(
1938       inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1939                                     offsetsAttr, sizesAttr, stridesAttr));
1940   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1941   result.addAttribute(getSizesAttrName(), sizesAttr);
1942   result.addAttribute(getStridesAttrName(), stridesAttr);
1943 }
1944 
verify(ExtractStridedSliceOp op)1945 static LogicalResult verify(ExtractStridedSliceOp op) {
1946   auto type = op.getVectorType();
1947   auto offsets = op.offsets();
1948   auto sizes = op.sizes();
1949   auto strides = op.strides();
1950   if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
1951     op.emitOpError(
1952         "expected offsets, sizes and strides attributes of same size");
1953     return failure();
1954   }
1955 
1956   auto shape = type.getShape();
1957   auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
1958   auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
1959   auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
1960   if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
1961       failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
1962       failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
1963                                                 stridesName)) ||
1964       failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
1965       failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
1966                                                /*halfOpen=*/false,
1967                                                /*min=*/1)) ||
1968       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1969                                                /*halfOpen=*/false)) ||
1970       failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
1971                                                     offName, sizesName,
1972                                                     /*halfOpen=*/false)))
1973     return failure();
1974 
1975   auto resultType = inferStridedSliceOpResultType(
1976       op.getVectorType(), op.offsets(), op.sizes(), op.strides());
1977   if (op.getResult().getType() != resultType) {
1978     op.emitOpError("expected result type to be ") << resultType;
1979     return failure();
1980   }
1981 
1982   return success();
1983 }
1984 
1985 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
1986 // to use the source of the InsertStrided ops if we can detect that the
1987 // extracted vector is a subset of one of the vector inserted.
1988 static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)1989 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
1990   // Helper to extract integer out of ArrayAttr.
1991   auto getElement = [](ArrayAttr array, int idx) {
1992     return array[idx].cast<IntegerAttr>().getInt();
1993   };
1994   ArrayAttr extractOffsets = op.offsets();
1995   ArrayAttr extractStrides = op.strides();
1996   ArrayAttr extractSizes = op.sizes();
1997   auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1998   while (insertOp) {
1999     if (op.getVectorType().getRank() !=
2000         insertOp.getSourceVectorType().getRank())
2001       return failure();
2002     ArrayAttr insertOffsets = insertOp.offsets();
2003     ArrayAttr insertStrides = insertOp.strides();
2004     // If the rank of extract is greater than the rank of insert, we are likely
2005     // extracting a partial chunk of the vector inserted.
2006     if (extractOffsets.size() > insertOffsets.size())
2007       return failure();
2008     bool patialoverlap = false;
2009     bool disjoint = false;
2010     SmallVector<int64_t, 4> offsetDiffs;
2011     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2012       if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2013         return failure();
2014       int64_t start = getElement(insertOffsets, dim);
2015       int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2016       int64_t offset = getElement(extractOffsets, dim);
2017       int64_t size = getElement(extractSizes, dim);
2018       // Check if the start of the extract offset is in the interval inserted.
2019       if (start <= offset && offset < end) {
2020         // If the extract interval overlaps but is not fully included we may
2021         // have a partial overlap that will prevent any folding.
2022         if (offset + size > end)
2023           patialoverlap = true;
2024         offsetDiffs.push_back(offset - start);
2025         continue;
2026       }
2027       disjoint = true;
2028       break;
2029     }
2030     // The extract element chunk is a subset of the insert element.
2031     if (!disjoint && !patialoverlap) {
2032       op.setOperand(insertOp.source());
2033       // OpBuilder is only used as a helper to build an I64ArrayAttr.
2034       OpBuilder b(op.getContext());
2035       op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
2036                   b.getI64ArrayAttr(offsetDiffs));
2037       return success();
2038     }
2039     // If the chunk extracted is disjoint from the chunk inserted, keep looking
2040     // in the insert chain.
2041     if (disjoint)
2042       insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
2043     else {
2044       // The extracted vector partially overlap the inserted vector, we cannot
2045       // fold.
2046       return failure();
2047     }
2048   }
2049   return failure();
2050 }
2051 
fold(ArrayRef<Attribute> operands)2052 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2053   if (getVectorType() == getResult().getType())
2054     return vector();
2055   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2056     return getResult();
2057   return {};
2058 }
2059 
getOffsets(SmallVectorImpl<int64_t> & results)2060 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2061   populateFromInt64AttrArray(offsets(), results);
2062 }
2063 
2064 namespace {
2065 
2066 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2067 // ConstantMaskOp.
2068 class StridedSliceConstantMaskFolder final
2069     : public OpRewritePattern<ExtractStridedSliceOp> {
2070 public:
2071   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2072 
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2073   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2074                                 PatternRewriter &rewriter) const override {
2075     // Return if 'extractStridedSliceOp' operand is not defined by a
2076     // ConstantMaskOp.
2077     auto defOp = extractStridedSliceOp.vector().getDefiningOp();
2078     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2079     if (!constantMaskOp)
2080       return failure();
2081     // Return if 'extractStridedSliceOp' has non-unit strides.
2082     if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) {
2083           return attr.cast<IntegerAttr>().getInt() != 1;
2084         }))
2085       return failure();
2086     // Gather constant mask dimension sizes.
2087     SmallVector<int64_t, 4> maskDimSizes;
2088     populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
2089     // Gather strided slice offsets and sizes.
2090     SmallVector<int64_t, 4> sliceOffsets;
2091     populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
2092     SmallVector<int64_t, 4> sliceSizes;
2093     populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
2094 
2095     // Compute slice of vector mask region.
2096     SmallVector<int64_t, 4> sliceMaskDimSizes;
2097     assert(sliceOffsets.size() == maskDimSizes.size());
2098     for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2099       int64_t maskDimSize = std::get<0>(it);
2100       int64_t sliceOffset = std::get<1>(it);
2101       int64_t sliceSize = std::get<2>(it);
2102       int64_t sliceMaskDimSize = std::max(
2103           static_cast<int64_t>(0),
2104           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2105       sliceMaskDimSizes.push_back(sliceMaskDimSize);
2106     }
2107     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2108     // region is a conjunction of mask dim intervals).
2109     if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
2110       sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2111 
2112     // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2113     // region.
2114     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2115         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2116         vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2117     return success();
2118   }
2119 };
2120 
2121 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2122 class StridedSliceConstantFolder final
2123     : public OpRewritePattern<ExtractStridedSliceOp> {
2124 public:
2125   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2126 
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const2127   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2128                                 PatternRewriter &rewriter) const override {
2129     // Return if 'extractStridedSliceOp' operand is not defined by a
2130     // ConstantOp.
2131     auto constantOp =
2132         extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
2133     if (!constantOp)
2134       return failure();
2135     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2136     if (!dense)
2137       return failure();
2138     auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2139                                           dense.getSplatValue());
2140     rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
2141     return success();
2142   }
2143 };
2144 
2145 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
getI64SubArray(ArrayAttr arrayAttr,unsigned dropFront=0,unsigned dropBack=0)2146 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
2147                                               unsigned dropFront = 0,
2148                                               unsigned dropBack = 0) {
2149   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
2150   auto range = arrayAttr.getAsRange<IntegerAttr>();
2151   SmallVector<int64_t, 4> res;
2152   res.reserve(arrayAttr.size() - dropFront - dropBack);
2153   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
2154        it != eit; ++it)
2155     res.push_back((*it).getValue().getSExtValue());
2156   return res;
2157 }
2158 
2159 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2160 // BroadcastOp(ExtractStrideSliceOp).
2161 class StridedSliceBroadcast final
2162     : public OpRewritePattern<ExtractStridedSliceOp> {
2163 public:
2164   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2165 
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const2166   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2167                                 PatternRewriter &rewriter) const override {
2168     auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
2169     if (!broadcast)
2170       return failure();
2171     auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
2172     unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
2173     auto dstVecType = op.getType().cast<VectorType>();
2174     unsigned dstRank = dstVecType.getRank();
2175     unsigned rankDiff = dstRank - srcRrank;
2176     // Check if the most inner dimensions of the source of the broadcast are the
2177     // same as the destination of the extract. If this is the case we can just
2178     // use a broadcast as the original dimensions are untouched.
2179     bool lowerDimMatch = true;
2180     for (unsigned i = 0; i < srcRrank; i++) {
2181       if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2182         lowerDimMatch = false;
2183         break;
2184       }
2185     }
2186     Value source = broadcast.source();
2187     if (!lowerDimMatch) {
2188       // The inner dimensions don't match, it means we need to extract from the
2189       // source of the orignal broadcast and then broadcast the extracted value.
2190       source = rewriter.create<ExtractStridedSliceOp>(
2191           op->getLoc(), source,
2192           getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
2193           getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
2194           getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
2195     }
2196     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2197     return success();
2198   }
2199 };
2200 
2201 } // end anonymous namespace
2202 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2203 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2204     RewritePatternSet &results, MLIRContext *context) {
2205   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2206   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2207   results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2208               StridedSliceBroadcast>(context);
2209 }
2210 
2211 //===----------------------------------------------------------------------===//
2212 // TransferReadOp
2213 //===----------------------------------------------------------------------===//
2214 
2215 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)2216 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2217                                           EmitFun emitOpError) {
2218   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2219   for (auto expr : permutationMap.getResults()) {
2220     auto dim = expr.dyn_cast<AffineDimExpr>();
2221     auto zero = expr.dyn_cast<AffineConstantExpr>();
2222     if (zero) {
2223       if (zero.getValue() != 0) {
2224         return emitOpError(
2225             "requires a projected permutation_map (at most one dim or the zero "
2226             "constant can appear in each result)");
2227       }
2228       continue;
2229     }
2230     if (!dim) {
2231       return emitOpError("requires a projected permutation_map (at most one "
2232                          "dim or the zero constant can appear in each result)");
2233     }
2234     if (seen[dim.getPosition()]) {
2235       return emitOpError(
2236           "requires a permutation_map that is a permutation (found one dim "
2237           "used more than once)");
2238     }
2239     seen[dim.getPosition()] = true;
2240   }
2241   return success();
2242 }
2243 
verifyTransferOp(Operation * op,ShapedType shapedType,VectorType vectorType,VectorType maskType,AffineMap permutationMap,ArrayAttr inBounds)2244 static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
2245                                       VectorType vectorType,
2246                                       VectorType maskType,
2247                                       AffineMap permutationMap,
2248                                       ArrayAttr inBounds) {
2249   if (op->hasAttr("masked")) {
2250     return op->emitOpError("masked attribute has been removed. "
2251                            "Use in_bounds instead.");
2252   }
2253 
2254   if (!shapedType.isa<MemRefType, RankedTensorType>())
2255     return op->emitOpError(
2256         "requires source to be a memref or ranked tensor type");
2257   auto elementType = shapedType.getElementType();
2258   DataLayout dataLayout = DataLayout::closest(op);
2259   if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2260     // Memref or tensor has vector element type.
2261     unsigned sourceVecSize =
2262         dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2263         vectorElementType.getShape().back();
2264     unsigned resultVecSize =
2265         dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2266         vectorType.getShape().back();
2267     if (resultVecSize % sourceVecSize != 0)
2268       return op->emitOpError(
2269           "requires the bitwidth of the minor 1-D vector to be an integral "
2270           "multiple of the bitwidth of the minor 1-D vector of the source");
2271 
2272     unsigned sourceVecEltRank = vectorElementType.getRank();
2273     unsigned resultVecRank = vectorType.getRank();
2274     if (sourceVecEltRank > resultVecRank)
2275       return op->emitOpError(
2276           "requires source vector element and vector result ranks to match.");
2277     unsigned rankOffset = resultVecRank - sourceVecEltRank;
2278     // Check that permutation map results match 'rankOffset' of vector type.
2279     if (permutationMap.getNumResults() != rankOffset)
2280       return op->emitOpError("requires a permutation_map with result dims of "
2281                              "the same rank as the vector type");
2282 
2283     if (maskType)
2284       return op->emitOpError("does not support masks with vector element type");
2285   } else {
2286     // Memref or tensor has scalar element type.
2287     unsigned resultVecSize =
2288         dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2289         vectorType.getShape().back();
2290     if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2291       return op->emitOpError(
2292           "requires the bitwidth of the minor 1-D vector to be an integral "
2293           "multiple of the bitwidth of the source element type");
2294 
2295     // Check that permutation map results match rank of vector type.
2296     if (permutationMap.getNumResults() != vectorType.getRank())
2297       return op->emitOpError("requires a permutation_map with result dims of "
2298                              "the same rank as the vector type");
2299 
2300     VectorType expectedMaskType =
2301         vector::detail::transferMaskType(vectorType, permutationMap);
2302     if (maskType && expectedMaskType != maskType)
2303       return op->emitOpError("expects mask type consistent with permutation "
2304                              "map: ")
2305              << maskType;
2306   }
2307 
2308   if (permutationMap.getNumSymbols() != 0)
2309     return op->emitOpError("requires permutation_map without symbols");
2310   if (permutationMap.getNumInputs() != shapedType.getRank())
2311     return op->emitOpError("requires a permutation_map with input dims of the "
2312                            "same rank as the source type");
2313 
2314   if (inBounds) {
2315     if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2316       return op->emitOpError("expects the optional in_bounds attr of same rank "
2317                              "as permutation_map results: ")
2318              << AffineMapAttr::get(permutationMap);
2319     for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2320       if (permutationMap.getResult(i).isa<AffineConstantExpr>()
2321           && !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2322         return op->emitOpError("requires broadcast dimensions to be in-bounds");
2323   }
2324 
2325   return success();
2326 }
2327 
2328 /// Builder that sets padding to zero.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,AffineMap permutationMap,ArrayRef<bool> inBounds)2329 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2330                            VectorType vectorType, Value source,
2331                            ValueRange indices, AffineMap permutationMap,
2332                            ArrayRef<bool> inBounds) {
2333   Type elemType = source.getType().cast<ShapedType>().getElementType();
2334   Value padding = builder.create<ConstantOp>(result.location, elemType,
2335                                              builder.getZeroAttr(elemType));
2336   if (inBounds.empty())
2337     return build(builder, result, vectorType, source, indices, permutationMap,
2338                  padding, ArrayAttr());
2339   ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2340   build(builder, result, vectorType, source, indices, permutationMap, padding,
2341         inBoundsArrayAttr);
2342 }
2343 
2344 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,Value padding,ArrayRef<bool> inBounds)2345 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2346                            VectorType vectorType, Value source,
2347                            ValueRange indices, Value padding,
2348                            ArrayRef<bool> inBounds) {
2349   auto permMap = getTransferMinorIdentityMap(
2350       source.getType().cast<ShapedType>(), vectorType);
2351   if (inBounds.empty())
2352     return build(builder, result, vectorType, source, indices, permMap, padding,
2353                  ArrayAttr());
2354   ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2355   build(builder, result, vectorType, source, indices, permMap, padding,
2356         inBoundsArrayAttr);
2357 }
2358 
2359 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
2360 /// (resp. zero).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,ArrayRef<bool> inBounds)2361 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2362                            VectorType vectorType, Value source,
2363                            ValueRange indices, ArrayRef<bool> inBounds) {
2364   auto permMap = getTransferMinorIdentityMap(
2365       source.getType().cast<ShapedType>(), vectorType);
2366   build(builder, result, vectorType, source, indices, permMap, inBounds);
2367 }
2368 
2369 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMap permutationMap,Value padding,ArrayAttr inBounds)2370 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2371                            Type vectorType, Value source, ValueRange indices,
2372                            AffineMap permutationMap, Value padding,
2373                            ArrayAttr inBounds) {
2374   build(builder, result, vectorType, source, indices, permutationMap, padding,
2375         /*mask=*/Value(), inBounds);
2376 }
2377 
2378 /// Builder that does not provide a mask.
build(OpBuilder & builder,OperationState & result,Type vectorType,Value source,ValueRange indices,AffineMapAttr permutationMap,Value padding,ArrayAttr inBounds)2379 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2380                            Type vectorType, Value source, ValueRange indices,
2381                            AffineMapAttr permutationMap, Value padding,
2382                            ArrayAttr inBounds) {
2383   build(builder, result, vectorType, source, indices, permutationMap, padding,
2384         /*mask=*/Value(), inBounds);
2385 }
2386 
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)2387 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2388   SmallVector<StringRef, 3> elidedAttrs;
2389   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2390   if (op.permutation_map().isMinorIdentity())
2391     elidedAttrs.push_back(op.getPermutationMapAttrName());
2392   bool elideInBounds = true;
2393   if (auto inBounds = op.in_bounds()) {
2394     for (auto attr : *inBounds) {
2395       if (attr.template cast<BoolAttr>().getValue()) {
2396         elideInBounds = false;
2397         break;
2398       }
2399     }
2400   }
2401   if (elideInBounds)
2402     elidedAttrs.push_back(op.getInBoundsAttrName());
2403   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2404 }
2405 
print(OpAsmPrinter & p,TransferReadOp op)2406 static void print(OpAsmPrinter &p, TransferReadOp op) {
2407   p << op.getOperationName() << " " << op.source() << "[" << op.indices()
2408     << "], " << op.padding();
2409   if (op.mask())
2410     p << ", " << op.mask();
2411   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2412   p << " : " << op.getShapedType() << ", " << op.getVectorType();
2413 }
2414 
parseTransferReadOp(OpAsmParser & parser,OperationState & result)2415 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2416                                        OperationState &result) {
2417   auto &builder = parser.getBuilder();
2418   llvm::SMLoc typesLoc;
2419   OpAsmParser::OperandType sourceInfo;
2420   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2421   OpAsmParser::OperandType paddingInfo;
2422   SmallVector<Type, 2> types;
2423   OpAsmParser::OperandType maskInfo;
2424   // Parsing with support for paddingValue.
2425   if (parser.parseOperand(sourceInfo) ||
2426       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2427       parser.parseComma() || parser.parseOperand(paddingInfo))
2428     return failure();
2429   ParseResult hasMask = parser.parseOptionalComma();
2430   if (hasMask.succeeded()) {
2431     parser.parseOperand(maskInfo);
2432   }
2433   if (parser.parseOptionalAttrDict(result.attributes) ||
2434       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2435     return failure();
2436   if (types.size() != 2)
2437     return parser.emitError(typesLoc, "requires two types");
2438   auto indexType = builder.getIndexType();
2439   auto shapedType = types[0].dyn_cast<ShapedType>();
2440   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2441     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2442   VectorType vectorType = types[1].dyn_cast<VectorType>();
2443   if (!vectorType)
2444     return parser.emitError(typesLoc, "requires vector type");
2445   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2446   Attribute mapAttr = result.attributes.get(permutationAttrName);
2447   if (!mapAttr) {
2448     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2449     mapAttr = AffineMapAttr::get(permMap);
2450     result.attributes.set(permutationAttrName, mapAttr);
2451   }
2452   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2453       parser.resolveOperands(indexInfo, indexType, result.operands) ||
2454       parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2455                             result.operands))
2456     return failure();
2457   if (hasMask.succeeded()) {
2458     if (shapedType.getElementType().dyn_cast<VectorType>())
2459       return parser.emitError(
2460           maskInfo.location, "does not support masks with vector element type");
2461     auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2462     // Instead of adding the mask type as an op type, compute it based on the
2463     // vector type and the permutation map (to keep the type signature small).
2464     auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2465     if (parser.resolveOperand(maskInfo, maskType, result.operands))
2466       return failure();
2467   }
2468   result.addAttribute(
2469       TransferReadOp::getOperandSegmentSizeAttr(),
2470       builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
2471                                 static_cast<int32_t>(hasMask.succeeded())}));
2472   return parser.addTypeToList(vectorType, result.types);
2473 }
2474 
verify(TransferReadOp op)2475 static LogicalResult verify(TransferReadOp op) {
2476   // Consistency of elemental types in source and vector.
2477   ShapedType shapedType = op.getShapedType();
2478   VectorType vectorType = op.getVectorType();
2479   VectorType maskType = op.getMaskType();
2480   auto paddingType = op.padding().getType();
2481   auto permutationMap = op.permutation_map();
2482   auto sourceElementType = shapedType.getElementType();
2483 
2484   if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
2485     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2486 
2487   if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
2488                               maskType, permutationMap,
2489                               op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2490     return failure();
2491 
2492   if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
2493     // Source has vector element type.
2494     // Check that 'sourceVectorElementType' and 'paddingType' types match.
2495     if (sourceVectorElementType != paddingType)
2496       return op.emitOpError(
2497           "requires source element type and padding type to match.");
2498 
2499   } else {
2500     // Check that 'paddingType' is valid to store in a vector type.
2501     if (!VectorType::isValidElementType(paddingType))
2502       return op.emitOpError("requires valid padding vector elemental type");
2503 
2504     // Check that padding type and vector element types match.
2505     if (paddingType != sourceElementType)
2506       return op.emitOpError(
2507           "requires formal padding and source of the same elemental type");
2508   }
2509 
2510   return verifyPermutationMap(permutationMap,
2511                               [&op](Twine t) { return op.emitOpError(t); });
2512 }
2513 
2514 /// This is a common class used for patterns of the form
2515 /// ```
2516 ///    someop(memrefcast) -> someop
2517 /// ```
2518 /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)2519 static LogicalResult foldMemRefCast(Operation *op) {
2520   bool folded = false;
2521   for (OpOperand &operand : op->getOpOperands()) {
2522     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
2523     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
2524       operand.set(castOp.getOperand());
2525       folded = true;
2526     }
2527   }
2528   return success(folded);
2529 }
2530 
foldTensorCast(Operation * op)2531 static LogicalResult foldTensorCast(Operation *op) {
2532   bool folded = false;
2533   for (OpOperand &operand : op->getOpOperands()) {
2534     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2535     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2536       operand.set(castOp.getOperand());
2537       folded = true;
2538     }
2539   }
2540   return success(folded);
2541 }
2542 
2543 template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)2544 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2545   // TODO: support more aggressive createOrFold on:
2546   // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
2547   if (op.getShapedType().isDynamicDim(indicesIdx))
2548     return false;
2549   Value index = op.indices()[indicesIdx];
2550   auto cstOp = index.getDefiningOp<ConstantIndexOp>();
2551   if (!cstOp)
2552     return false;
2553 
2554   int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
2555   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2556 
2557   return cstOp.getValue() + vectorSize <= sourceSize;
2558 }
2559 
2560 template <typename TransferOp>
foldTransferInBoundsAttribute(TransferOp op)2561 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
2562   AffineMap permutationMap = op.permutation_map();
2563   bool changed = false;
2564   SmallVector<bool, 4> newInBounds;
2565   newInBounds.reserve(op.getTransferRank());
2566   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
2567     // Already marked as in-bounds, nothing to see here.
2568     if (op.isDimInBounds(i)) {
2569       newInBounds.push_back(true);
2570       continue;
2571     }
2572     // Currently out-of-bounds, check whether we can statically determine it is
2573     // inBounds.
2574     auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
2575     assert(dimExpr && "Broadcast dims must be in-bounds");
2576     auto inBounds = isInBounds(
2577         op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
2578     newInBounds.push_back(inBounds);
2579     // We commit the pattern if it is "more inbounds".
2580     changed |= inBounds;
2581   }
2582   if (!changed)
2583     return failure();
2584   // OpBuilder is only used as a helper to build an I64ArrayAttr.
2585   OpBuilder b(op.getContext());
2586   op->setAttr(TransferOp::getInBoundsAttrName(),
2587               b.getBoolArrayAttr(newInBounds));
2588   return success();
2589 }
2590 
2591 ///  ```
2592 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2593 ///    : vector<1x4xf32>, tensor<4x4xf32>
2594 ///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
2595 ///    : tensor<4x4xf32>, vector<1x4xf32>
2596 ///  ```
2597 ///  -> Folds into
2598 ///  ```
2599 ///  %v0
2600 ///  ```
foldRAW(TransferReadOp readOp)2601 static Value foldRAW(TransferReadOp readOp) {
2602   if (!readOp.getShapedType().isa<RankedTensorType>())
2603     return {};
2604   auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
2605   while (defWrite) {
2606     if (checkSameValueRAW(defWrite, readOp))
2607       return defWrite.vector();
2608     if (!isDisjointTransferIndices(
2609             cast<VectorTransferOpInterface>(defWrite.getOperation()),
2610             cast<VectorTransferOpInterface>(readOp.getOperation())))
2611       break;
2612     defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2613   }
2614   return {};
2615 }
2616 
fold(ArrayRef<Attribute>)2617 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2618   if (Value vec = foldRAW(*this))
2619     return vec;
2620   /// transfer_read(memrefcast) -> transfer_read
2621   if (succeeded(foldTransferInBoundsAttribute(*this)))
2622     return getResult();
2623   if (succeeded(foldMemRefCast(*this)))
2624     return getResult();
2625   if (succeeded(foldTensorCast(*this)))
2626     return getResult();
2627   return OpFoldResult();
2628 }
2629 
getShapeForUnroll()2630 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2631   return llvm::to_vector<4>(getVectorType().getShape());
2632 }
2633 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2634 void TransferReadOp::getEffects(
2635     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2636         &effects) {
2637   if (getShapedType().isa<MemRefType>())
2638     effects.emplace_back(MemoryEffects::Read::get(), source(),
2639                          SideEffects::DefaultResource::get());
2640 }
2641 
2642 //===----------------------------------------------------------------------===//
2643 // TransferWriteOp
2644 //===----------------------------------------------------------------------===//
2645 
2646 /// Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,ArrayRef<bool> inBounds)2647 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2648                             Value vector, Value source, ValueRange indices,
2649                             ArrayRef<bool> inBounds) {
2650   auto vectorType = vector.getType().cast<VectorType>();
2651   auto permMap = getTransferMinorIdentityMap(
2652       source.getType().cast<ShapedType>(), vectorType);
2653   if (inBounds.empty())
2654     return build(builder, result, vector, source, indices, permMap,
2655                  ArrayAttr());
2656   ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
2657   build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
2658 }
2659 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap)2660 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2661                             Value vector, Value source, ValueRange indices,
2662                             AffineMap permutationMap) {
2663   build(builder, result, vector, source, indices, permutationMap,
2664         /*inBounds=*/ArrayAttr());
2665 }
2666 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMapAttr permutationMap,ArrayAttr inBounds)2667 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2668                             Value vector, Value source, ValueRange indices,
2669                             AffineMapAttr permutationMap,
2670                             /*optional*/ ArrayAttr inBounds) {
2671   Type resultType = source.getType().dyn_cast<RankedTensorType>();
2672   build(builder, result, resultType, vector, source, indices, permutationMap,
2673         /*mask=*/Value(), inBounds);
2674 }
2675 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,ArrayAttr inBounds)2676 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2677                             Value vector, Value source, ValueRange indices,
2678                             AffineMap permutationMap,
2679                             /*optional*/ ArrayAttr inBounds) {
2680   Type resultType = source.getType().dyn_cast<RankedTensorType>();
2681   build(builder, result, resultType, vector, source, indices, permutationMap,
2682         /*mask=*/Value(), inBounds);
2683 }
2684 
build(OpBuilder & builder,OperationState & result,Value vector,Value source,ValueRange indices,AffineMap permutationMap,Value mask,ArrayAttr inBounds)2685 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
2686                             Value vector, Value source, ValueRange indices,
2687                             AffineMap permutationMap, /*optional*/ Value mask,
2688                             /*optional*/ ArrayAttr inBounds) {
2689   Type resultType = source.getType().dyn_cast<RankedTensorType>();
2690   build(builder, result, resultType, vector, source, indices, permutationMap,
2691         mask, inBounds);
2692 }
2693 
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)2694 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
2695                                         OperationState &result) {
2696   auto &builder = parser.getBuilder();
2697   llvm::SMLoc typesLoc;
2698   OpAsmParser::OperandType vectorInfo, sourceInfo;
2699   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2700   SmallVector<Type, 2> types;
2701   OpAsmParser::OperandType maskInfo;
2702   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
2703       parser.parseOperand(sourceInfo) ||
2704       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
2705     return failure();
2706   ParseResult hasMask = parser.parseOptionalComma();
2707   if (hasMask.succeeded() && parser.parseOperand(maskInfo))
2708     return failure();
2709   if (parser.parseOptionalAttrDict(result.attributes) ||
2710       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2711     return failure();
2712   if (types.size() != 2)
2713     return parser.emitError(typesLoc, "requires two types");
2714   auto indexType = builder.getIndexType();
2715   VectorType vectorType = types[0].dyn_cast<VectorType>();
2716   if (!vectorType)
2717     return parser.emitError(typesLoc, "requires vector type");
2718   ShapedType shapedType = types[1].dyn_cast<ShapedType>();
2719   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2720     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2721   auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
2722   auto attr = result.attributes.get(permutationAttrName);
2723   if (!attr) {
2724     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2725     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
2726   }
2727   if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
2728       parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2729       parser.resolveOperands(indexInfo, indexType, result.operands))
2730     return failure();
2731   if (hasMask.succeeded()) {
2732     if (shapedType.getElementType().dyn_cast<VectorType>())
2733       return parser.emitError(
2734           maskInfo.location, "does not support masks with vector element type");
2735     auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
2736     if (parser.resolveOperand(maskInfo, maskType, result.operands))
2737       return failure();
2738   }
2739   result.addAttribute(
2740       TransferWriteOp::getOperandSegmentSizeAttr(),
2741       builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
2742                                 static_cast<int32_t>(hasMask.succeeded())}));
2743   return failure(shapedType.isa<RankedTensorType>() &&
2744                  parser.addTypeToList(shapedType, result.types));
2745 }
2746 
print(OpAsmPrinter & p,TransferWriteOp op)2747 static void print(OpAsmPrinter &p, TransferWriteOp op) {
2748   p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
2749     << op.indices() << "]";
2750   if (op.mask())
2751     p << ", " << op.mask();
2752   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2753   p << " : " << op.getVectorType() << ", " << op.getShapedType();
2754 }
2755 
verify(TransferWriteOp op)2756 static LogicalResult verify(TransferWriteOp op) {
2757   // Consistency of elemental types in shape and vector.
2758   ShapedType shapedType = op.getShapedType();
2759   VectorType vectorType = op.getVectorType();
2760   VectorType maskType = op.getMaskType();
2761   auto permutationMap = op.permutation_map();
2762 
2763   if (llvm::size(op.indices()) != shapedType.getRank())
2764     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2765 
2766   // We do not allow broadcast dimensions on TransferWriteOps for the moment,
2767   // as the semantics is unclear. This can be revisited later if necessary.
2768   if (op.hasBroadcastDim())
2769     return op.emitOpError("should not have broadcast dimensions");
2770 
2771   if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
2772                               maskType, permutationMap,
2773                               op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2774     return failure();
2775 
2776   return verifyPermutationMap(permutationMap,
2777                               [&op](Twine t) { return op.emitOpError(t); });
2778 }
2779 
2780 /// Fold:
2781 /// ```
2782 ///    %t1 = ...
2783 ///    %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
2784 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
2785 ///    %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
2786 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
2787 /// ```
2788 ///
2789 /// into:
2790 ///
2791 /// ```
2792 ///    %t0
2793 /// ```
2794 ///
2795 /// The producer of t1 may or may not be DCE'd depending on whether it is a
2796 /// block argument or has side effects.
foldReadInitWrite(TransferWriteOp write,ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)2797 static LogicalResult foldReadInitWrite(TransferWriteOp write,
2798                                        ArrayRef<Attribute>,
2799                                        SmallVectorImpl<OpFoldResult> &results) {
2800   auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
2801   // If not operating on tensors, bail.
2802   if (!rankedTensorType)
2803     return failure();
2804   // If no read, bail.
2805   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
2806   if (!read)
2807     return failure();
2808   // For now, only accept minor identity. Future: composition is minor identity.
2809   if (!read.permutation_map().isMinorIdentity() ||
2810       !write.permutation_map().isMinorIdentity())
2811     return failure();
2812   // Bail on mismatching ranks.
2813   if (read.getTransferRank() != write.getTransferRank())
2814     return failure();
2815   // Bail on potential out-of-bounds accesses.
2816   if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
2817     return failure();
2818   // Tensor types must be the same.
2819   if (read.source().getType() != rankedTensorType)
2820     return failure();
2821   // Vector types must be the same.
2822   if (read.getVectorType() != write.getVectorType())
2823     return failure();
2824   // Vector and Tensor shapes must match.
2825   if (read.getVectorType().getShape() != rankedTensorType.getShape())
2826     return failure();
2827   // If any index is nonzero.
2828   auto isNotConstantZero = [](Value v) {
2829     auto cstOp = v.getDefiningOp<ConstantIndexOp>();
2830     return !cstOp || cstOp.getValue() != 0;
2831   };
2832   if (llvm::any_of(read.indices(), isNotConstantZero) ||
2833       llvm::any_of(write.indices(), isNotConstantZero))
2834     return failure();
2835   // Success.
2836   results.push_back(read.source());
2837   return success();
2838 }
2839 
checkSameValueWAR(vector::TransferReadOp read,vector::TransferWriteOp write)2840 static bool checkSameValueWAR(vector::TransferReadOp read,
2841                               vector::TransferWriteOp write) {
2842   return read.source() == write.source() && read.indices() == write.indices() &&
2843          read.permutation_map() == write.permutation_map() &&
2844          read.getVectorType() == write.getVectorType() && !read.mask() &&
2845          !write.mask();
2846 }
2847 /// Fold transfer_write write after read:
2848 /// ```
2849 ///    %t0 = ...
2850 ///    %v = vector.transfer_read %t0[%c0...] :
2851 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
2852 ///    %t1 = vector.transfer_write %v, %t0[%c0...] :
2853 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
2854 /// ```
2855 ///
2856 /// into:
2857 ///
2858 /// ```
2859 ///    %t0
2860 /// ```
foldWAR(TransferWriteOp write,SmallVectorImpl<OpFoldResult> & results)2861 static LogicalResult foldWAR(TransferWriteOp write,
2862                              SmallVectorImpl<OpFoldResult> &results) {
2863   if (!write.source().getType().isa<RankedTensorType>())
2864     return failure();
2865   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
2866   if (!read)
2867     return failure();
2868 
2869   if (!checkSameValueWAR(read, write))
2870     return failure();
2871   results.push_back(read.source());
2872   return success();
2873 }
2874 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2875 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
2876                                     SmallVectorImpl<OpFoldResult> &results) {
2877   if (succeeded(foldReadInitWrite(*this, operands, results)))
2878     return success();
2879   if (succeeded(foldWAR(*this, results)))
2880     return success();
2881   if (succeeded(foldTransferInBoundsAttribute(*this)))
2882     return success();
2883   return foldMemRefCast(*this);
2884 }
2885 
getShapeForUnroll()2886 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
2887   return llvm::to_vector<4>(getVectorType().getShape());
2888 }
2889 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2890 void TransferWriteOp::getEffects(
2891     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2892         &effects) {
2893   if (getShapedType().isa<MemRefType>())
2894     effects.emplace_back(MemoryEffects::Write::get(), source(),
2895                          SideEffects::DefaultResource::get());
2896 }
2897 
2898 namespace {
2899 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
2900 /// DCE
2901 /// ```
2902 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2903 ///    : vector<1x4xf32>, tensor<4x4xf32>
2904 ///  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
2905 ///    : vector<1x4xf32>, tensor<4x4xf32>
2906 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
2907 ///    : vector<1x4xf32>, tensor<4x4xf32>
2908 /// ```
2909 ///
2910 /// into:
2911 ///
2912 /// ```
2913 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2914 ///    : vector<1x4xf32>, tensor<4x4xf32>
2915 ///  %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
2916 ///    : vector<1x4xf32>, tensor<4x4xf32>
2917 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
2918 ///    : vector<1x4xf32>, tensor<4x4xf32>
2919 /// ```
2920 ///
2921 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
2922 /// any other uses.
2923 class foldWAW final : public OpRewritePattern<TransferWriteOp> {
2924 public:
2925   using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
matchAndRewrite(TransferWriteOp writeOp,PatternRewriter & rewriter) const2926   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
2927                                 PatternRewriter &rewriter) const override {
2928     if (!writeOp.getShapedType().isa<RankedTensorType>())
2929       return failure();
2930     vector::TransferWriteOp writeToModify = writeOp;
2931 
2932     auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
2933     while (defWrite) {
2934       if (checkSameValueWAW(writeOp, defWrite)) {
2935         writeToModify.sourceMutable().assign(defWrite.source());
2936         return success();
2937       }
2938       if (!isDisjointTransferIndices(
2939               cast<VectorTransferOpInterface>(defWrite.getOperation()),
2940               cast<VectorTransferOpInterface>(writeOp.getOperation())))
2941         break;
2942       // If the previous write op doesn't have any other use we an safely look
2943       // at the previous store to see if it can be removed.
2944       if (!defWrite->hasOneUse())
2945         break;
2946       writeToModify = defWrite;
2947       defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2948     }
2949     return failure();
2950   }
2951 };
2952 } // namespace
2953 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2954 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
2955                                                   MLIRContext *context) {
2956   results.add<foldWAW>(context);
2957 }
2958 
2959 //===----------------------------------------------------------------------===//
2960 // LoadOp
2961 //===----------------------------------------------------------------------===//
2962 
verifyLoadStoreMemRefLayout(Operation * op,MemRefType memRefTy)2963 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
2964                                                  MemRefType memRefTy) {
2965   if (!isLastMemrefDimUnitStride(memRefTy))
2966     return op->emitOpError("most minor memref dim must have unit stride");
2967   return success();
2968 }
2969 
verify(vector::LoadOp op)2970 static LogicalResult verify(vector::LoadOp op) {
2971   VectorType resVecTy = op.getVectorType();
2972   MemRefType memRefTy = op.getMemRefType();
2973 
2974   if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
2975     return failure();
2976 
2977   // Checks for vector memrefs.
2978   Type memElemTy = memRefTy.getElementType();
2979   if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
2980     if (memVecTy != resVecTy)
2981       return op.emitOpError("base memref and result vector types should match");
2982     memElemTy = memVecTy.getElementType();
2983   }
2984 
2985   if (resVecTy.getElementType() != memElemTy)
2986     return op.emitOpError("base and result element types should match");
2987   if (llvm::size(op.indices()) != memRefTy.getRank())
2988     return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
2989   return success();
2990 }
2991 
fold(ArrayRef<Attribute>)2992 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
2993   if (succeeded(foldMemRefCast(*this)))
2994     return getResult();
2995   return OpFoldResult();
2996 }
2997 
2998 //===----------------------------------------------------------------------===//
2999 // StoreOp
3000 //===----------------------------------------------------------------------===//
3001 
verify(vector::StoreOp op)3002 static LogicalResult verify(vector::StoreOp op) {
3003   VectorType valueVecTy = op.getVectorType();
3004   MemRefType memRefTy = op.getMemRefType();
3005 
3006   if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3007     return failure();
3008 
3009   // Checks for vector memrefs.
3010   Type memElemTy = memRefTy.getElementType();
3011   if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3012     if (memVecTy != valueVecTy)
3013       return op.emitOpError(
3014           "base memref and valueToStore vector types should match");
3015     memElemTy = memVecTy.getElementType();
3016   }
3017 
3018   if (valueVecTy.getElementType() != memElemTy)
3019     return op.emitOpError("base and valueToStore element type should match");
3020   if (llvm::size(op.indices()) != memRefTy.getRank())
3021     return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3022   return success();
3023 }
3024 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3025 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3026                             SmallVectorImpl<OpFoldResult> &results) {
3027   return foldMemRefCast(*this);
3028 }
3029 
3030 //===----------------------------------------------------------------------===//
3031 // MaskedLoadOp
3032 //===----------------------------------------------------------------------===//
3033 
verify(MaskedLoadOp op)3034 static LogicalResult verify(MaskedLoadOp op) {
3035   VectorType maskVType = op.getMaskVectorType();
3036   VectorType passVType = op.getPassThruVectorType();
3037   VectorType resVType = op.getVectorType();
3038   MemRefType memType = op.getMemRefType();
3039 
3040   if (resVType.getElementType() != memType.getElementType())
3041     return op.emitOpError("base and result element type should match");
3042   if (llvm::size(op.indices()) != memType.getRank())
3043     return op.emitOpError("requires ") << memType.getRank() << " indices";
3044   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3045     return op.emitOpError("expected result dim to match mask dim");
3046   if (resVType != passVType)
3047     return op.emitOpError("expected pass_thru of same type as result type");
3048   return success();
3049 }
3050 
3051 namespace {
3052 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3053 public:
3054   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const3055   LogicalResult matchAndRewrite(MaskedLoadOp load,
3056                                 PatternRewriter &rewriter) const override {
3057     switch (get1DMaskFormat(load.mask())) {
3058     case MaskFormat::AllTrue:
3059       rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
3060                                                   load.base(), load.indices());
3061       return success();
3062     case MaskFormat::AllFalse:
3063       rewriter.replaceOp(load, load.pass_thru());
3064       return success();
3065     case MaskFormat::Unknown:
3066       return failure();
3067     }
3068     llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3069   }
3070 };
3071 } // namespace
3072 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3073 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3074                                                MLIRContext *context) {
3075   results.add<MaskedLoadFolder>(context);
3076 }
3077 
fold(ArrayRef<Attribute>)3078 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3079   if (succeeded(foldMemRefCast(*this)))
3080     return getResult();
3081   return OpFoldResult();
3082 }
3083 
3084 //===----------------------------------------------------------------------===//
3085 // MaskedStoreOp
3086 //===----------------------------------------------------------------------===//
3087 
verify(MaskedStoreOp op)3088 static LogicalResult verify(MaskedStoreOp op) {
3089   VectorType maskVType = op.getMaskVectorType();
3090   VectorType valueVType = op.getVectorType();
3091   MemRefType memType = op.getMemRefType();
3092 
3093   if (valueVType.getElementType() != memType.getElementType())
3094     return op.emitOpError("base and valueToStore element type should match");
3095   if (llvm::size(op.indices()) != memType.getRank())
3096     return op.emitOpError("requires ") << memType.getRank() << " indices";
3097   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3098     return op.emitOpError("expected valueToStore dim to match mask dim");
3099   return success();
3100 }
3101 
3102 namespace {
3103 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
3104 public:
3105   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const3106   LogicalResult matchAndRewrite(MaskedStoreOp store,
3107                                 PatternRewriter &rewriter) const override {
3108     switch (get1DMaskFormat(store.mask())) {
3109     case MaskFormat::AllTrue:
3110       rewriter.replaceOpWithNewOp<vector::StoreOp>(
3111           store, store.valueToStore(), store.base(), store.indices());
3112       return success();
3113     case MaskFormat::AllFalse:
3114       rewriter.eraseOp(store);
3115       return success();
3116     case MaskFormat::Unknown:
3117       return failure();
3118     }
3119     llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
3120   }
3121 };
3122 } // namespace
3123 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3124 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3125                                                 MLIRContext *context) {
3126   results.add<MaskedStoreFolder>(context);
3127 }
3128 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3129 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
3130                                   SmallVectorImpl<OpFoldResult> &results) {
3131   return foldMemRefCast(*this);
3132 }
3133 
3134 //===----------------------------------------------------------------------===//
3135 // GatherOp
3136 //===----------------------------------------------------------------------===//
3137 
verify(GatherOp op)3138 static LogicalResult verify(GatherOp op) {
3139   VectorType indVType = op.getIndexVectorType();
3140   VectorType maskVType = op.getMaskVectorType();
3141   VectorType resVType = op.getVectorType();
3142   MemRefType memType = op.getMemRefType();
3143 
3144   if (resVType.getElementType() != memType.getElementType())
3145     return op.emitOpError("base and result element type should match");
3146   if (llvm::size(op.indices()) != memType.getRank())
3147     return op.emitOpError("requires ") << memType.getRank() << " indices";
3148   if (resVType.getDimSize(0) != indVType.getDimSize(0))
3149     return op.emitOpError("expected result dim to match indices dim");
3150   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3151     return op.emitOpError("expected result dim to match mask dim");
3152   if (resVType != op.getPassThruVectorType())
3153     return op.emitOpError("expected pass_thru of same type as result type");
3154   return success();
3155 }
3156 
3157 namespace {
3158 class GatherFolder final : public OpRewritePattern<GatherOp> {
3159 public:
3160   using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const3161   LogicalResult matchAndRewrite(GatherOp gather,
3162                                 PatternRewriter &rewriter) const override {
3163     switch (get1DMaskFormat(gather.mask())) {
3164     case MaskFormat::AllTrue:
3165       return failure(); // no unmasked equivalent
3166     case MaskFormat::AllFalse:
3167       rewriter.replaceOp(gather, gather.pass_thru());
3168       return success();
3169     case MaskFormat::Unknown:
3170       return failure();
3171     }
3172     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
3173   }
3174 };
3175 } // namespace
3176 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3177 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
3178                                            MLIRContext *context) {
3179   results.add<GatherFolder>(context);
3180 }
3181 
3182 //===----------------------------------------------------------------------===//
3183 // ScatterOp
3184 //===----------------------------------------------------------------------===//
3185 
verify(ScatterOp op)3186 static LogicalResult verify(ScatterOp op) {
3187   VectorType indVType = op.getIndexVectorType();
3188   VectorType maskVType = op.getMaskVectorType();
3189   VectorType valueVType = op.getVectorType();
3190   MemRefType memType = op.getMemRefType();
3191 
3192   if (valueVType.getElementType() != memType.getElementType())
3193     return op.emitOpError("base and valueToStore element type should match");
3194   if (llvm::size(op.indices()) != memType.getRank())
3195     return op.emitOpError("requires ") << memType.getRank() << " indices";
3196   if (valueVType.getDimSize(0) != indVType.getDimSize(0))
3197     return op.emitOpError("expected valueToStore dim to match indices dim");
3198   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3199     return op.emitOpError("expected valueToStore dim to match mask dim");
3200   return success();
3201 }
3202 
3203 namespace {
3204 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
3205 public:
3206   using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const3207   LogicalResult matchAndRewrite(ScatterOp scatter,
3208                                 PatternRewriter &rewriter) const override {
3209     switch (get1DMaskFormat(scatter.mask())) {
3210     case MaskFormat::AllTrue:
3211       return failure(); // no unmasked equivalent
3212     case MaskFormat::AllFalse:
3213       rewriter.eraseOp(scatter);
3214       return success();
3215     case MaskFormat::Unknown:
3216       return failure();
3217     }
3218     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
3219   }
3220 };
3221 } // namespace
3222 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3223 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
3224                                             MLIRContext *context) {
3225   results.add<ScatterFolder>(context);
3226 }
3227 
3228 //===----------------------------------------------------------------------===//
3229 // ExpandLoadOp
3230 //===----------------------------------------------------------------------===//
3231 
verify(ExpandLoadOp op)3232 static LogicalResult verify(ExpandLoadOp op) {
3233   VectorType maskVType = op.getMaskVectorType();
3234   VectorType passVType = op.getPassThruVectorType();
3235   VectorType resVType = op.getVectorType();
3236   MemRefType memType = op.getMemRefType();
3237 
3238   if (resVType.getElementType() != memType.getElementType())
3239     return op.emitOpError("base and result element type should match");
3240   if (llvm::size(op.indices()) != memType.getRank())
3241     return op.emitOpError("requires ") << memType.getRank() << " indices";
3242   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3243     return op.emitOpError("expected result dim to match mask dim");
3244   if (resVType != passVType)
3245     return op.emitOpError("expected pass_thru of same type as result type");
3246   return success();
3247 }
3248 
3249 namespace {
3250 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
3251 public:
3252   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const3253   LogicalResult matchAndRewrite(ExpandLoadOp expand,
3254                                 PatternRewriter &rewriter) const override {
3255     switch (get1DMaskFormat(expand.mask())) {
3256     case MaskFormat::AllTrue:
3257       rewriter.replaceOpWithNewOp<vector::LoadOp>(
3258           expand, expand.getType(), expand.base(), expand.indices());
3259       return success();
3260     case MaskFormat::AllFalse:
3261       rewriter.replaceOp(expand, expand.pass_thru());
3262       return success();
3263     case MaskFormat::Unknown:
3264       return failure();
3265     }
3266     llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
3267   }
3268 };
3269 } // namespace
3270 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3271 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3272                                                MLIRContext *context) {
3273   results.add<ExpandLoadFolder>(context);
3274 }
3275 
3276 //===----------------------------------------------------------------------===//
3277 // CompressStoreOp
3278 //===----------------------------------------------------------------------===//
3279 
verify(CompressStoreOp op)3280 static LogicalResult verify(CompressStoreOp op) {
3281   VectorType maskVType = op.getMaskVectorType();
3282   VectorType valueVType = op.getVectorType();
3283   MemRefType memType = op.getMemRefType();
3284 
3285   if (valueVType.getElementType() != memType.getElementType())
3286     return op.emitOpError("base and valueToStore element type should match");
3287   if (llvm::size(op.indices()) != memType.getRank())
3288     return op.emitOpError("requires ") << memType.getRank() << " indices";
3289   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3290     return op.emitOpError("expected valueToStore dim to match mask dim");
3291   return success();
3292 }
3293 
3294 namespace {
3295 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
3296 public:
3297   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const3298   LogicalResult matchAndRewrite(CompressStoreOp compress,
3299                                 PatternRewriter &rewriter) const override {
3300     switch (get1DMaskFormat(compress.mask())) {
3301     case MaskFormat::AllTrue:
3302       rewriter.replaceOpWithNewOp<vector::StoreOp>(
3303           compress, compress.valueToStore(), compress.base(),
3304           compress.indices());
3305       return success();
3306     case MaskFormat::AllFalse:
3307       rewriter.eraseOp(compress);
3308       return success();
3309     case MaskFormat::Unknown:
3310       return failure();
3311     }
3312     llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
3313   }
3314 };
3315 } // namespace
3316 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3317 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3318                                                   MLIRContext *context) {
3319   results.add<CompressStoreFolder>(context);
3320 }
3321 
3322 //===----------------------------------------------------------------------===//
3323 // ShapeCastOp
3324 //===----------------------------------------------------------------------===//
3325 
3326 /// Returns true if each element of 'a' is equal to the product of a contiguous
3327 /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)3328 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3329   unsigned rankA = a.size();
3330   unsigned rankB = b.size();
3331   assert(rankA < rankB);
3332 
3333   unsigned i = 0;
3334   unsigned j = 0;
3335   while (i < rankA && j < rankB) {
3336     int64_t dimA = a[i];
3337     int64_t dimB = 1;
3338     while (dimB < dimA && j < rankB)
3339       dimB *= b[j++];
3340     if (dimA != dimB)
3341       break;
3342     ++i;
3343 
3344     // Handle the case when trailing dimensions are of size 1.
3345     // Include them into the contiguous sequence.
3346     auto isOne = [](int64_t v) { return v == 1; };
3347     if (i < rankA && llvm::all_of(a.slice(i), isOne))
3348       i = rankA;
3349     if (j < rankB && llvm::all_of(b.slice(j), isOne))
3350       j = rankB;
3351   }
3352 
3353   return i == rankA && j == rankB;
3354 }
3355 
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)3356 static LogicalResult verifyVectorShapeCast(Operation *op,
3357                                            VectorType sourceVectorType,
3358                                            VectorType resultVectorType) {
3359   // Check that element type is the same.
3360   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
3361     return op->emitOpError("source/result vectors must have same element type");
3362   auto sourceShape = sourceVectorType.getShape();
3363   auto resultShape = resultVectorType.getShape();
3364 
3365   // Check that product of source dim sizes matches product of result dim sizes.
3366   int64_t sourceDimProduct = std::accumulate(
3367       sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
3368   int64_t resultDimProduct = std::accumulate(
3369       resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
3370   if (sourceDimProduct != resultDimProduct)
3371     return op->emitOpError("source/result number of elements must match");
3372 
3373   // Check that expanding/contracting rank cases.
3374   unsigned sourceRank = sourceVectorType.getRank();
3375   unsigned resultRank = resultVectorType.getRank();
3376   if (sourceRank < resultRank) {
3377     if (!isValidShapeCast(sourceShape, resultShape))
3378       return op->emitOpError("invalid shape cast");
3379   } else if (sourceRank > resultRank) {
3380     if (!isValidShapeCast(resultShape, sourceShape))
3381       return op->emitOpError("invalid shape cast");
3382   }
3383   return success();
3384 }
3385 
verify(ShapeCastOp op)3386 static LogicalResult verify(ShapeCastOp op) {
3387   auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
3388   auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
3389 
3390   // Check if source/result are of vector type.
3391   if (sourceVectorType && resultVectorType)
3392     return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
3393 
3394   return success();
3395 }
3396 
fold(ArrayRef<Attribute> operands)3397 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
3398   // Nop shape cast.
3399   if (source().getType() == result().getType())
3400     return source();
3401 
3402   // Canceling shape casts.
3403   if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
3404     if (result().getType() == otherOp.source().getType())
3405       return otherOp.source();
3406     setOperand(otherOp.source());
3407     return getResult();
3408   }
3409   return {};
3410 }
3411 
3412 namespace {
3413 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
3414 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
3415 public:
3416   using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
3417 
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const3418   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
3419                                 PatternRewriter &rewriter) const override {
3420     auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
3421     if (!constantOp)
3422       return failure();
3423     // Only handle splat for now.
3424     auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
3425     if (!dense)
3426       return failure();
3427     auto newAttr = DenseElementsAttr::get(
3428         shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
3429     rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
3430     return success();
3431   }
3432 };
3433 
3434 } // namespace
3435 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3436 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3437                                               MLIRContext *context) {
3438   // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
3439   results.add<ShapeCastConstantFolder>(context);
3440 }
3441 
3442 //===----------------------------------------------------------------------===//
3443 // VectorBitCastOp
3444 //===----------------------------------------------------------------------===//
3445 
verify(BitCastOp op)3446 static LogicalResult verify(BitCastOp op) {
3447   auto sourceVectorType = op.getSourceVectorType();
3448   auto resultVectorType = op.getResultVectorType();
3449 
3450   for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
3451     if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
3452       return op.emitOpError("dimension size mismatch at: ") << i;
3453   }
3454 
3455   DataLayout dataLayout = DataLayout::closest(op);
3456   if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
3457           sourceVectorType.getShape().back() !=
3458       dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
3459           resultVectorType.getShape().back())
3460     return op.emitOpError(
3461         "source/result bitwidth of the minor 1-D vectors must be equal");
3462 
3463   return success();
3464 }
3465 
fold(ArrayRef<Attribute> operands)3466 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
3467   // Nop cast.
3468   if (source().getType() == result().getType())
3469     return source();
3470 
3471   // Canceling bitcasts.
3472   if (auto otherOp = source().getDefiningOp<BitCastOp>())
3473     if (result().getType() == otherOp.source().getType())
3474       return otherOp.source();
3475 
3476   Attribute sourceConstant = operands.front();
3477   if (!sourceConstant)
3478     return {};
3479 
3480   Type srcElemType = getSourceVectorType().getElementType();
3481   Type dstElemType = getResultVectorType().getElementType();
3482 
3483   if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
3484     if (floatPack.isSplat()) {
3485       auto splat = floatPack.getSplatValue<FloatAttr>();
3486 
3487       // Casting fp16 into fp32.
3488       if (srcElemType.isF16() && dstElemType.isF32()) {
3489         uint32_t bits = static_cast<uint32_t>(
3490             splat.getValue().bitcastToAPInt().getZExtValue());
3491         // Duplicate the 16-bit pattern.
3492         bits = (bits << 16) | (bits & 0xffff);
3493         APInt intBits(32, bits);
3494         APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
3495         return DenseElementsAttr::get(getResultVectorType(), floatBits);
3496       }
3497     }
3498   }
3499 
3500   return {};
3501 }
3502 
3503 //===----------------------------------------------------------------------===//
3504 // TypeCastOp
3505 //===----------------------------------------------------------------------===//
3506 
extractShape(MemRefType memRefType)3507 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
3508   auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
3509   SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
3510                               memRefType.getShape().end());
3511   if (vectorType)
3512     res.append(vectorType.getShape().begin(), vectorType.getShape().end());
3513   return res;
3514 }
3515 
3516 /// Build the canonical memRefType with a single vector.
3517 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)3518 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
3519                        Value source) {
3520   result.addOperands(source);
3521   MemRefType memRefType = source.getType().cast<MemRefType>();
3522   VectorType vectorType =
3523       VectorType::get(extractShape(memRefType),
3524                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
3525   result.addTypes(
3526       MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace()));
3527 }
3528 
verify(TypeCastOp op)3529 static LogicalResult verify(TypeCastOp op) {
3530   MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
3531   if (!canonicalType.getAffineMaps().empty())
3532     return op.emitOpError("expects operand to be a memref with no layout");
3533   if (!op.getResultMemRefType().getAffineMaps().empty())
3534     return op.emitOpError("expects result to be a memref with no layout");
3535   if (op.getResultMemRefType().getMemorySpace() !=
3536       op.getMemRefType().getMemorySpace())
3537     return op.emitOpError("expects result in same memory space");
3538 
3539   auto sourceType = op.getMemRefType();
3540   auto resultType = op.getResultMemRefType();
3541   if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
3542       getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
3543     return op.emitOpError(
3544                "expects result and operand with same underlying scalar type: ")
3545            << resultType;
3546   if (extractShape(sourceType) != extractShape(resultType))
3547     return op.emitOpError(
3548                "expects concatenated result and operand shapes to be equal: ")
3549            << resultType;
3550   return success();
3551 }
3552 
3553 //===----------------------------------------------------------------------===//
3554 // TransposeOp
3555 //===----------------------------------------------------------------------===//
3556 
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)3557 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
3558                                 Value vector, ArrayRef<int64_t> transp) {
3559   VectorType vt = vector.getType().cast<VectorType>();
3560   SmallVector<int64_t, 4> transposedShape(vt.getRank());
3561   for (unsigned i = 0; i < transp.size(); ++i)
3562     transposedShape[i] = vt.getShape()[transp[i]];
3563 
3564   result.addOperands(vector);
3565   result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
3566   result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
3567 }
3568 
3569 // Eliminates transpose operations, which produce values identical to their
3570 // input values. This happens when the dimensions of the input vector remain in
3571 // their original order after the transpose operation.
fold(ArrayRef<Attribute> operands)3572 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
3573   SmallVector<int64_t, 4> transp;
3574   getTransp(transp);
3575 
3576   // Check if the permutation of the dimensions contains sequential values:
3577   // {0, 1, 2, ...}.
3578   for (int64_t i = 0, e = transp.size(); i < e; i++) {
3579     if (transp[i] != i)
3580       return {};
3581   }
3582 
3583   return vector();
3584 }
3585 
verify(vector::TransposeOp op)3586 static LogicalResult verify(vector::TransposeOp op) {
3587   VectorType vectorType = op.getVectorType();
3588   VectorType resultType = op.getResultType();
3589   int64_t rank = resultType.getRank();
3590   if (vectorType.getRank() != rank)
3591     return op.emitOpError("vector result rank mismatch: ") << rank;
3592   // Verify transposition array.
3593   auto transpAttr = op.transp().getValue();
3594   int64_t size = transpAttr.size();
3595   if (rank != size)
3596     return op.emitOpError("transposition length mismatch: ") << size;
3597   SmallVector<bool, 8> seen(rank, false);
3598   for (auto ta : llvm::enumerate(transpAttr)) {
3599     int64_t i = ta.value().cast<IntegerAttr>().getInt();
3600     if (i < 0 || i >= rank)
3601       return op.emitOpError("transposition index out of range: ") << i;
3602     if (seen[i])
3603       return op.emitOpError("duplicate position index: ") << i;
3604     seen[i] = true;
3605     if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
3606       return op.emitOpError("dimension size mismatch at: ") << i;
3607   }
3608   return success();
3609 }
3610 
3611 namespace {
3612 
3613 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
3614 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
3615 public:
3616   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
3617 
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const3618   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
3619                                 PatternRewriter &rewriter) const override {
3620     // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
3621     auto getPermutation = [](vector::TransposeOp transpose) {
3622       SmallVector<int64_t, 4> permutation;
3623       transpose.getTransp(permutation);
3624       return permutation;
3625     };
3626 
3627     // Composes two permutations: result[i] = permutation1[permutation2[i]].
3628     auto composePermutations = [](ArrayRef<int64_t> permutation1,
3629                                   ArrayRef<int64_t> permutation2) {
3630       SmallVector<int64_t, 4> result;
3631       for (auto index : permutation2)
3632         result.push_back(permutation1[index]);
3633       return result;
3634     };
3635 
3636     // Return if the input of 'transposeOp' is not defined by another transpose.
3637     vector::TransposeOp parentTransposeOp =
3638         transposeOp.vector().getDefiningOp<vector::TransposeOp>();
3639     if (!parentTransposeOp)
3640       return failure();
3641 
3642     SmallVector<int64_t, 4> permutation = composePermutations(
3643         getPermutation(parentTransposeOp), getPermutation(transposeOp));
3644     // Replace 'transposeOp' with a new transpose operation.
3645     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
3646         transposeOp, transposeOp.getResult().getType(),
3647         parentTransposeOp.vector(),
3648         vector::getVectorSubscriptAttr(rewriter, permutation));
3649     return success();
3650   }
3651 };
3652 
3653 } // end anonymous namespace
3654 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3655 void vector::TransposeOp::getCanonicalizationPatterns(
3656     RewritePatternSet &results, MLIRContext *context) {
3657   results.add<TransposeFolder>(context);
3658 }
3659 
getTransp(SmallVectorImpl<int64_t> & results)3660 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
3661   populateFromInt64AttrArray(transp(), results);
3662 }
3663 
3664 //===----------------------------------------------------------------------===//
3665 // ConstantMaskOp
3666 //===----------------------------------------------------------------------===//
3667 
verify(ConstantMaskOp & op)3668 static LogicalResult verify(ConstantMaskOp &op) {
3669   // Verify that array attr size matches the rank of the vector result.
3670   auto resultType = op.getResult().getType().cast<VectorType>();
3671   if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
3672     return op.emitOpError(
3673         "must specify array attr of size equal vector result rank");
3674   // Verify that each array attr element is in bounds of corresponding vector
3675   // result dimension size.
3676   auto resultShape = resultType.getShape();
3677   SmallVector<int64_t, 4> maskDimSizes;
3678   for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
3679     int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
3680     if (attrValue < 0 || attrValue > resultShape[it.index()])
3681       return op.emitOpError(
3682           "array attr of size out of bounds of vector result dimension size");
3683     maskDimSizes.push_back(attrValue);
3684   }
3685   // Verify that if one mask dim size is zero, they all should be zero (because
3686   // the mask region is a conjunction of each mask dimension interval).
3687   bool any_zeros = llvm::is_contained(maskDimSizes, 0);
3688   bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
3689   if (any_zeros && !all_zeros)
3690     return op.emitOpError("expected all mask dim sizes to be zeros, "
3691                           "as a result of conjunction with zero mask dim");
3692   return success();
3693 }
3694 
3695 //===----------------------------------------------------------------------===//
3696 // CreateMaskOp
3697 //===----------------------------------------------------------------------===//
3698 
verify(CreateMaskOp op)3699 static LogicalResult verify(CreateMaskOp op) {
3700   // Verify that an operand was specified for each result vector each dimension.
3701   if (op.getNumOperands() !=
3702       op.getResult().getType().cast<VectorType>().getRank())
3703     return op.emitOpError(
3704         "must specify an operand for each result vector dimension");
3705   return success();
3706 }
3707 
3708 namespace {
3709 
3710 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
3711 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
3712 public:
3713   using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
3714 
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const3715   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
3716                                 PatternRewriter &rewriter) const override {
3717     // Return if any of 'createMaskOp' operands are not defined by a constant.
3718     auto is_not_def_by_constant = [](Value operand) {
3719       return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
3720     };
3721     if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
3722       return failure();
3723     // Gather constant mask dimension sizes.
3724     SmallVector<int64_t, 4> maskDimSizes;
3725     for (auto operand : createMaskOp.operands()) {
3726       auto defOp = operand.getDefiningOp();
3727       maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
3728     }
3729     // Replace 'createMaskOp' with ConstantMaskOp.
3730     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3731         createMaskOp, createMaskOp.getResult().getType(),
3732         vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
3733     return success();
3734   }
3735 };
3736 
3737 } // end anonymous namespace
3738 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3739 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
3740                                                MLIRContext *context) {
3741   results.add<CreateMaskFolder>(context);
3742 }
3743 
populateVectorToVectorCanonicalizationPatterns(RewritePatternSet & patterns)3744 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
3745     RewritePatternSet &patterns) {
3746   patterns
3747       .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
3748            ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
3749            StridedSliceConstantMaskFolder, TransposeFolder>(
3750           patterns.getContext());
3751 }
3752 
3753 #define GET_OP_CLASSES
3754 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
3755