1 //===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/VectorOps/VectorOps.h"
15 #include "mlir/Dialect/StandardOps/Ops.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Support/Functional.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Support/MathExtras.h"
26 #include "mlir/Support/STLExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 
29 using namespace mlir;
30 using namespace mlir::vector;
31 
32 //===----------------------------------------------------------------------===//
33 // VectorOpsDialect
34 //===----------------------------------------------------------------------===//
35 
VectorOpsDialect(MLIRContext * context)36 VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
37     : Dialect(getDialectNamespace(), context) {
38   addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
41       >();
42 }
43 
44 /// Materialize a single constant operation from a given attribute value with
45 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)46 Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
47                                                  Attribute value, Type type,
48                                                  Location loc) {
49   return builder.create<ConstantOp>(loc, type, value);
50 }
51 
getVectorSubscriptType(Builder & builder)52 IntegerType vector::getVectorSubscriptType(Builder &builder) {
53   return builder.getIntegerType(64);
54 }
55 
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)56 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
57                                          ArrayRef<int64_t> values) {
58   return builder.getI64ArrayAttr(values);
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ContractionOp
63 //===----------------------------------------------------------------------===//
64 
build(Builder * builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)65 void vector::ContractionOp::build(Builder *builder, OperationState &result,
66                                   Value lhs, Value rhs, Value acc,
67                                   ArrayAttr indexingMaps,
68                                   ArrayAttr iteratorTypes) {
69   result.addOperands({lhs, rhs, acc});
70   result.addTypes(acc.getType());
71   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
72   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
73 }
74 
parseContractionOp(OpAsmParser & parser,OperationState & result)75 static ParseResult parseContractionOp(OpAsmParser &parser,
76                                       OperationState &result) {
77   OpAsmParser::OperandType lhsInfo;
78   OpAsmParser::OperandType rhsInfo;
79   OpAsmParser::OperandType accInfo;
80   SmallVector<OpAsmParser::OperandType, 2> masksInfo;
81   SmallVector<Type, 2> types;
82   Type resultVectorType;
83   auto loc = parser.getCurrentLocation();
84   DictionaryAttr dictAttr;
85   // TODO(andydavis, ntv) Unify linalg op attribute parsing.
86   if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
87       parser.parseOperand(lhsInfo) || parser.parseComma() ||
88       parser.parseOperand(rhsInfo) || parser.parseComma() ||
89       parser.parseOperand(accInfo) ||
90       parser.parseTrailingOperandList(masksInfo) ||
91       parser.parseOptionalAttrDict(result.attributes) ||
92       parser.parseColonTypeList(types) ||
93       parser.parseKeywordType("into", resultVectorType) ||
94       parser.resolveOperand(lhsInfo, types[0], result.operands) ||
95       parser.resolveOperand(rhsInfo, types[1], result.operands) ||
96       parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
97       parser.addTypeToList(resultVectorType, result.types))
98     return failure();
99   result.attributes.assign(dictAttr.getValue().begin(),
100                            dictAttr.getValue().end());
101   if (masksInfo.empty())
102     return success();
103   if (masksInfo.size() != 2)
104     return parser.emitError(parser.getNameLoc(),
105                             "expected zero or exactly 2 vector mask operands");
106   auto lhsType = types[0].cast<VectorType>();
107   auto rhsType = types[1].cast<VectorType>();
108   auto maskElementType = parser.getBuilder().getI1Type();
109   SmallVector<Type, 2> maskTypes;
110   maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
111   maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
112   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
113     return failure();
114   return success();
115 }
116 
print(OpAsmPrinter & p,ContractionOp op)117 static void print(OpAsmPrinter &p, ContractionOp op) {
118   // TODO(andydavis, ntv) Unify printing code with linalg ops.
119   auto attrNames = op.getTraitAttrNames();
120   llvm::StringSet<> traitAttrsSet;
121   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
122   SmallVector<NamedAttribute, 8> attrs;
123   for (auto attr : op.getAttrs())
124     if (traitAttrsSet.count(attr.first.strref()) > 0)
125       attrs.push_back(attr);
126 
127   auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
128   p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
129   p << op.rhs() << ", " << op.acc();
130   if (op.masks().size() == 2)
131     p << ", " << op.masks();
132 
133   p.printOptionalAttrDict(op.getAttrs(), attrNames);
134   p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
135     << op.getResultType();
136 }
137 
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)138 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
139                          const std::vector<std::pair<int64_t, int64_t>> &map) {
140   for (auto &dimPair : map) {
141     if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
142         dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
143         lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
144       return false;
145   }
146   return true;
147 }
148 
verifyOutputShape(VectorType lhsType,VectorType rhsType,VectorType accType,VectorType resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)149 static bool verifyOutputShape(
150     VectorType lhsType, VectorType rhsType, VectorType accType,
151     VectorType resType,
152     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
153     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
154   DenseSet<int64_t> lhsContractingDimSet;
155   DenseSet<int64_t> rhsContractingDimSet;
156   for (auto &dimPair : contractingDimMap) {
157     lhsContractingDimSet.insert(dimPair.first);
158     rhsContractingDimSet.insert(dimPair.second);
159   }
160   DenseSet<int64_t> rhsBatchDimSet;
161   for (auto &dimPair : batchDimMap)
162     rhsBatchDimSet.insert(dimPair.second);
163 
164   // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
165   SmallVector<int64_t, 4> expectedResultDims;
166   for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
167     if (lhsContractingDimSet.count(i) > 0)
168       continue;
169     expectedResultDims.push_back(lhsType.getDimSize(i));
170   }
171 
172   // Add free dimensions from 'rhsType' to 'expectedResultDims'.
173   for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
174     if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
175       continue;
176     expectedResultDims.push_back(rhsType.getDimSize(i));
177   }
178 
179   // Verify dimension from 'resType' against 'expectedResultDims'.
180   if (resType.getShape().size() != expectedResultDims.size() ||
181       accType.getShape().size() != expectedResultDims.size())
182     return false;
183   for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
184     if (resType.getDimSize(i) != expectedResultDims[i] ||
185         accType.getDimSize(i) != expectedResultDims[i])
186       return false;
187   }
188   return true;
189 }
190 
verify(ContractionOp op)191 static LogicalResult verify(ContractionOp op) {
192   auto lhsType = op.getLhsType();
193   auto rhsType = op.getRhsType();
194   auto accType = op.getAccType();
195   auto resType = op.getResultType();
196 
197   // Verify that an indexing map was specified for each vector operand.
198   if (op.indexing_maps().size() != 3)
199     return op.emitOpError("expected an indexing map for each vector operand");
200 
201   // Verify that each index map has 'numIterators' inputs, no symbols, and
202   // that the number of map outputs equals the rank of its associated
203   // vector operand.
204   unsigned numIterators = op.iterator_types().getValue().size();
205   for (auto it : llvm::enumerate(op.indexing_maps())) {
206     auto index = it.index();
207     auto map = it.value().cast<AffineMapAttr>().getValue();
208     if (map.getNumSymbols() != 0)
209       return op.emitOpError("expected indexing map ")
210              << index << " to have no symbols";
211     if (map.getNumDims() != numIterators)
212       return op.emitOpError("expected indexing map ")
213              << index << " to have " << numIterators << " number of inputs";
214     auto operandType = op.getOperand(index).getType().cast<VectorType>();
215     unsigned rank = operandType.getShape().size();
216     if (map.getNumResults() != rank)
217       return op.emitOpError("expected indexing map ")
218              << index << " to have " << rank << " number of outputs";
219     if (!map.isProjectedPermutation())
220       return op.emitOpError("expected indexing map ")
221              << index << " to be a projected permutation of its inputs";
222   }
223 
224   auto contractingDimMap = op.getContractingDimMap();
225   auto batchDimMap = op.getBatchDimMap();
226 
227   // Verify at least one contracting dimension pair was specified.
228   if (contractingDimMap.empty())
229     return op.emitOpError("expected at least one contracting dimension pair");
230 
231   // Verify contracting dimension map was properly constructed.
232   if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
233     return op.emitOpError("invalid contracting dimension map");
234 
235   // Verify batch dimension map was properly constructed.
236   if (!verifyDimMap(lhsType, rhsType, batchDimMap))
237     return op.emitOpError("invalid batch dimension map");
238 
239   // Verify 'accType' and 'resType' shape.
240   if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
241                          batchDimMap))
242     return op.emitOpError("invalid accumulator/result vector shape");
243 
244   // Verify that either two vector masks are set or none are set.
245   auto lhsMaskType = op.getLHSVectorMaskType();
246   auto rhsMaskType = op.getRHSVectorMaskType();
247   if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
248     return op.emitOpError("invalid number of vector masks specified");
249   if (lhsMaskType && rhsMaskType) {
250     // Verify mask rank == argument rank.
251     if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
252         rhsMaskType.getShape().size() != rhsType.getShape().size())
253       return op.emitOpError("invalid vector mask rank");
254   }
255   return success();
256 }
257 
getTraitAttrNames()258 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
259   static constexpr StringLiteral names[2] = {getIndexingMapsAttrName(),
260                                              getIteratorTypesAttrName()};
261   ArrayRef<StringLiteral> res{names};
262   return ArrayRef<StringRef>{res.begin(), res.end()};
263 }
264 
getResultIndex(AffineMap map,AffineExpr targetExpr)265 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
266   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
267     if (targetExpr == map.getResult(i))
268       return i;
269   return -1;
270 }
271 
272 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)273 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
274           StringRef targetIteratorTypeName, MLIRContext *context) {
275   std::vector<std::pair<int64_t, int64_t>> dimMap;
276   for (auto it : llvm::enumerate(iteratorTypes)) {
277     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
278     if (iteratorTypeName != targetIteratorTypeName)
279       continue;
280     // Search lhs/rhs map results for 'targetExpr'.
281     auto targetExpr = getAffineDimExpr(it.index(), context);
282     int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
283     int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
284     if (lhsDim >= 0 && rhsDim >= 0)
285       dimMap.push_back({lhsDim, rhsDim});
286   }
287   return dimMap;
288 }
289 
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)290 void ContractionOp::getIterationBounds(
291     SmallVectorImpl<int64_t> &iterationBounds) {
292   auto lhsShape = getLhsType().getShape();
293   auto resShape = getResultType().getShape();
294   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
295   SmallVector<int64_t, 2> iterationShape;
296   for (auto it : llvm::enumerate(iterator_types())) {
297     // Search lhs/rhs map results for 'targetExpr'.
298     auto targetExpr = getAffineDimExpr(it.index(), getContext());
299     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
300     if (iteratorTypeName == getReductionIteratorTypeName()) {
301       // Get reduction dim size from lhs shape (same size in rhsShape).
302       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
303       assert(lhsDimIndex >= 0);
304       iterationBounds.push_back(lhsShape[lhsDimIndex]);
305       continue;
306     }
307     // Get parallel dimension size from result shape.
308     int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
309     assert(resDimIndex >= 0);
310     iterationBounds.push_back(resShape[resDimIndex]);
311   }
312 }
313 
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)314 void ContractionOp::getIterationIndexMap(
315     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
316   unsigned numMaps = indexing_maps().getValue().size();
317   iterationIndexMap.resize(numMaps);
318   for (auto it : llvm::enumerate(indexing_maps())) {
319     auto index = it.index();
320     auto map = it.value().cast<AffineMapAttr>().getValue();
321     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
322       auto dim = map.getResult(i).cast<AffineDimExpr>();
323       iterationIndexMap[index][dim.getPosition()] = i;
324     }
325   }
326 }
327 
getContractingDimMap()328 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
329   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
330   return getDimMap(indexingMaps, iterator_types(),
331                    getReductionIteratorTypeName(), getContext());
332 }
333 
getBatchDimMap()334 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
335   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
336   return getDimMap(indexingMaps, iterator_types(),
337                    getParallelIteratorTypeName(), getContext());
338 }
339 
getIndexingMaps()340 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
341   SmallVector<AffineMap, 4> res;
342   auto mapAttrs = indexing_maps().getValue();
343   res.reserve(mapAttrs.size());
344   for (auto mapAttr : mapAttrs)
345     res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
346   return res;
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // ExtractElementOp
351 //===----------------------------------------------------------------------===//
352 
print(OpAsmPrinter & p,vector::ExtractElementOp op)353 static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
354   p << op.getOperationName() << " " << op.vector() << "[" << op.position()
355     << " : " << op.position().getType() << "]";
356   p.printOptionalAttrDict(op.getAttrs());
357   p << " : " << op.vector().getType();
358 }
359 
parseExtractElementOp(OpAsmParser & parser,OperationState & result)360 static ParseResult parseExtractElementOp(OpAsmParser &parser,
361                                          OperationState &result) {
362   OpAsmParser::OperandType vector, position;
363   Type positionType;
364   VectorType vectorType;
365   if (parser.parseOperand(vector) || parser.parseLSquare() ||
366       parser.parseOperand(position) || parser.parseColonType(positionType) ||
367       parser.parseRSquare() ||
368       parser.parseOptionalAttrDict(result.attributes) ||
369       parser.parseColonType(vectorType))
370     return failure();
371   Type resultType = vectorType.getElementType();
372   return failure(
373       parser.resolveOperand(vector, vectorType, result.operands) ||
374       parser.resolveOperand(position, positionType, result.operands) ||
375       parser.addTypeToList(resultType, result.types));
376 }
377 
verify(vector::ExtractElementOp op)378 static LogicalResult verify(vector::ExtractElementOp op) {
379   VectorType vectorType = op.getVectorType();
380   if (vectorType.getRank() != 1)
381     return op.emitOpError("expected 1-D vector");
382   return success();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // ExtractOp
387 //===----------------------------------------------------------------------===//
388 
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)389 static Type inferExtractOpResultType(VectorType vectorType,
390                                      ArrayAttr position) {
391   if (static_cast<int64_t>(position.size()) == vectorType.getRank())
392     return vectorType.getElementType();
393   return VectorType::get(vectorType.getShape().drop_front(position.size()),
394                          vectorType.getElementType());
395 }
396 
build(Builder * builder,OperationState & result,Value source,ArrayRef<int64_t> position)397 void vector::ExtractOp::build(Builder *builder, OperationState &result,
398                               Value source, ArrayRef<int64_t> position) {
399   result.addOperands(source);
400   auto positionAttr = getVectorSubscriptAttr(*builder, position);
401   result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
402                                            positionAttr));
403   result.addAttribute(getPositionAttrName(), positionAttr);
404 }
405 
print(OpAsmPrinter & p,vector::ExtractOp op)406 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
407   p << op.getOperationName() << " " << op.vector() << op.position();
408   p.printOptionalAttrDict(op.getAttrs(), {"position"});
409   p << " : " << op.vector().getType();
410 }
411 
parseExtractOp(OpAsmParser & parser,OperationState & result)412 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
413   llvm::SMLoc attributeLoc, typeLoc;
414   SmallVector<NamedAttribute, 4> attrs;
415   OpAsmParser::OperandType vector;
416   Type type;
417   Attribute attr;
418   if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
419       parser.parseAttribute(attr, "position", attrs) ||
420       parser.parseOptionalAttrDict(attrs) ||
421       parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
422     return failure();
423 
424   auto vectorType = type.dyn_cast<VectorType>();
425   if (!vectorType)
426     return parser.emitError(typeLoc, "expected vector type");
427 
428   auto positionAttr = attr.dyn_cast<ArrayAttr>();
429   if (!positionAttr ||
430       static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
431     return parser.emitError(
432         attributeLoc,
433         "expected position attribute of rank smaller than vector rank");
434 
435   Type resType = inferExtractOpResultType(vectorType, positionAttr);
436   result.attributes = attrs;
437   return failure(parser.resolveOperand(vector, type, result.operands) ||
438                  parser.addTypeToList(resType, result.types));
439 }
440 
verify(vector::ExtractOp op)441 static LogicalResult verify(vector::ExtractOp op) {
442   auto positionAttr = op.position().getValue();
443   if (positionAttr.empty())
444     return op.emitOpError("expected non-empty position attribute");
445   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
446     return op.emitOpError(
447         "expected position attribute of rank smaller than vector rank");
448   for (auto en : llvm::enumerate(positionAttr)) {
449     auto attr = en.value().dyn_cast<IntegerAttr>();
450     if (!attr || attr.getInt() < 0 ||
451         attr.getInt() >= op.getVectorType().getDimSize(en.index()))
452       return op.emitOpError("expected position attribute #")
453              << (en.index() + 1)
454              << " to be a non-negative integer smaller than the corresponding "
455                 "vector dimension";
456   }
457   return success();
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // ExtractSlicesOp
462 //===----------------------------------------------------------------------===//
463 
build(Builder * builder,OperationState & result,TupleType tupleType,Value vector,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)464 void ExtractSlicesOp::build(Builder *builder, OperationState &result,
465                             TupleType tupleType, Value vector,
466                             ArrayRef<int64_t> sizes,
467                             ArrayRef<int64_t> strides) {
468   result.addOperands(vector);
469   auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
470   auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
471   result.addTypes(tupleType);
472   result.addAttribute(getSizesAttrName(), sizesAttr);
473   result.addAttribute(getStridesAttrName(), stridesAttr);
474 }
475 
parseExtractSlicesOp(OpAsmParser & parser,OperationState & result)476 static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
477                                         OperationState &result) {
478   OpAsmParser::OperandType operandInfo;
479   ArrayAttr sizesAttr;
480   StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
481   ArrayAttr stridesAttr;
482   StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
483   VectorType vectorType;
484   TupleType resultTupleType;
485   return failure(
486       parser.parseOperand(operandInfo) || parser.parseComma() ||
487       parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
488       parser.parseComma() ||
489       parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
490       parser.parseOptionalAttrDict(result.attributes) ||
491       parser.parseColonType(vectorType) ||
492       parser.parseKeywordType("into", resultTupleType) ||
493       parser.resolveOperand(operandInfo, vectorType, result.operands) ||
494       parser.addTypeToList(resultTupleType, result.types));
495 }
496 
print(OpAsmPrinter & p,ExtractSlicesOp op)497 static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
498   p << op.getOperationName() << ' ' << op.vector() << ", ";
499   p << op.sizes() << ", " << op.strides();
500   p.printOptionalAttrDict(
501       op.getAttrs(),
502       /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
503                        ExtractSlicesOp::getStridesAttrName()});
504   p << " : " << op.vector().getType();
505   p << " into " << op.getResultTupleType();
506 }
507 
508 static LogicalResult
isValidExtractOrInsertSlicesType(Operation * op,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)509 isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
510                                  TupleType tupleType, ArrayRef<int64_t> sizes,
511                                  ArrayRef<int64_t> strides) {
512   // Check for non-unit strides.
513   // TODO(b/144845578) Support non-1 strides.
514   if (llvm::any_of(strides, [](int64_t s) { return s != 1; }))
515     return op->emitError("requires unit strides");
516   // Check that 'vectorType' rank matches rank of tuple element vectors.
517   unsigned rank = vectorType.getRank();
518   auto is_vector_type_of_rank = [&](Type t) {
519     return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank;
520   };
521   if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank))
522     return op->emitError("requires vector tuple elements of rank ") << rank;
523   // Check that 'sizes' and 'strides' are of size == 'rank'.
524   if (sizes.size() != rank || strides.size() != rank)
525     return op->emitError("requires sizes and strides of rank ") << rank;
526 
527   // Compute the number of slices in each dimension.
528   // TODO(andydavis) Move this into a slice generation helper function.
529   auto shape = vectorType.getShape();
530   SmallVector<int64_t, 4> dimSliceCounts(rank);
531   for (unsigned i = 0; i < rank; ++i)
532     dimSliceCounts[i] = ceilDiv(shape[i], sizes[i]);
533   // Compute the strides between slices in each dimension.
534   SmallVector<int64_t, 4> sliceStrides(rank);
535   sliceStrides[rank - 1] = 1;
536   for (int i = rank - 2; i >= 0; --i)
537     sliceStrides[i] = sliceStrides[i + 1] * dimSliceCounts[i + 1];
538 
539   // Generate each slice shape based on 'sizes', 'strides' and 'vectorType',
540   // and verify that the same matches the corresponding tuple element 'i'.
541   for (int64_t i = 0, e = tupleType.size(); i < e; ++i) {
542     // De-linearize w.r.t. 'sliceStrides'.
543     SmallVector<int64_t, 4> vectorOffsets(rank);
544     int64_t linearIndex = i;
545     for (unsigned j = 0; j < rank; ++j) {
546       vectorOffsets[j] = linearIndex / sliceStrides[j];
547       linearIndex %= sliceStrides[j];
548     }
549     // Convert from unrolled vector-space offsets to element-space offsets.
550     auto offsets = mlir::functional::zipMap(
551         [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
552     // Initialize 'sliceSizes' to target 'sizes'
553     SmallVector<int64_t, 4> sliceSizes(sizes.begin(), sizes.end());
554     for (unsigned j = 0; j < rank; ++j) {
555       // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles.
556       sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]);
557     }
558     // Create slice VectorType type.
559     auto sliceVectorType =
560         VectorType::get(sliceSizes, vectorType.getElementType());
561     // Verify that 'sliceVectorType' matches tupleType.getTypes(i)
562     if (sliceVectorType != tupleType.getType(i))
563       return op->emitError("invalid tuple element type ") << sliceVectorType;
564   }
565   return success();
566 }
567 
verify(ExtractSlicesOp op)568 static LogicalResult verify(ExtractSlicesOp op) {
569   SmallVector<int64_t, 4> sizes;
570   op.getSizes(sizes);
571   SmallVector<int64_t, 4> strides;
572   op.getStrides(strides);
573   return isValidExtractOrInsertSlicesType(
574       op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(),
575       sizes, strides);
576 }
577 
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)578 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
579                                        SmallVectorImpl<int64_t> &results) {
580   for (auto attr : arrayAttr)
581     results.push_back(attr.cast<IntegerAttr>().getInt());
582 }
583 
getSizes(SmallVectorImpl<int64_t> & results)584 void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
585   populateFromInt64AttrArray(sizes(), results);
586 }
587 
getStrides(SmallVectorImpl<int64_t> & results)588 void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
589   populateFromInt64AttrArray(strides(), results);
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // BroadcastOp
594 //===----------------------------------------------------------------------===//
595 
print(OpAsmPrinter & p,BroadcastOp op)596 static void print(OpAsmPrinter &p, BroadcastOp op) {
597   p << op.getOperationName() << " " << op.source() << " : "
598     << op.getSourceType() << " to " << op.getVectorType();
599 }
600 
verify(BroadcastOp op)601 static LogicalResult verify(BroadcastOp op) {
602   VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
603   VectorType dstVectorType = op.getVectorType();
604   // Scalar to vector broadcast is always valid. A vector
605   // to vector broadcast needs some additional checking.
606   if (srcVectorType) {
607     int64_t srcRank = srcVectorType.getRank();
608     int64_t dstRank = dstVectorType.getRank();
609     if (srcRank > dstRank)
610       return op.emitOpError("source rank higher than destination rank");
611     // Source has an exact match or singleton value for all trailing dimensions
612     // (all leading dimensions are simply duplicated).
613     int64_t lead = dstRank - srcRank;
614     for (int64_t r = 0; r < srcRank; ++r) {
615       int64_t srcDim = srcVectorType.getDimSize(r);
616       int64_t dstDim = dstVectorType.getDimSize(lead + r);
617       if (srcDim != 1 && srcDim != dstDim)
618         return op.emitOpError("dimension mismatch (")
619                << srcDim << " vs. " << dstDim << ")";
620     }
621   }
622   return success();
623 }
624 
parseBroadcastOp(OpAsmParser & parser,OperationState & result)625 static ParseResult parseBroadcastOp(OpAsmParser &parser,
626                                     OperationState &result) {
627   OpAsmParser::OperandType source;
628   Type sourceType;
629   VectorType vectorType;
630   return failure(parser.parseOperand(source) ||
631                  parser.parseColonType(sourceType) ||
632                  parser.parseKeywordType("to", vectorType) ||
633                  parser.resolveOperand(source, sourceType, result.operands) ||
634                  parser.addTypeToList(vectorType, result.types));
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // ShuffleOp
639 //===----------------------------------------------------------------------===//
640 
build(Builder * builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)641 void ShuffleOp::build(Builder *builder, OperationState &result, Value v1,
642                       Value v2, ArrayRef<int64_t> mask) {
643   result.addOperands({v1, v2});
644   auto maskAttr = getVectorSubscriptAttr(*builder, mask);
645   result.addTypes(v1.getType());
646   result.addAttribute(getMaskAttrName(), maskAttr);
647 }
648 
print(OpAsmPrinter & p,ShuffleOp op)649 static void print(OpAsmPrinter &p, ShuffleOp op) {
650   p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
651     << op.mask();
652   p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
653   p << " : " << op.v1().getType() << ", " << op.v2().getType();
654 }
655 
verify(ShuffleOp op)656 static LogicalResult verify(ShuffleOp op) {
657   VectorType resultType = op.getVectorType();
658   VectorType v1Type = op.getV1VectorType();
659   VectorType v2Type = op.getV2VectorType();
660   // Verify ranks.
661   int64_t resRank = resultType.getRank();
662   int64_t v1Rank = v1Type.getRank();
663   int64_t v2Rank = v2Type.getRank();
664   if (resRank != v1Rank || v1Rank != v2Rank)
665     return op.emitOpError("rank mismatch");
666   // Verify all but leading dimension sizes.
667   for (int64_t r = 1; r < v1Rank; ++r) {
668     int64_t resDim = resultType.getDimSize(r);
669     int64_t v1Dim = v1Type.getDimSize(r);
670     int64_t v2Dim = v2Type.getDimSize(r);
671     if (resDim != v1Dim || v1Dim != v2Dim)
672       return op.emitOpError("dimension mismatch");
673   }
674   // Verify mask length.
675   auto maskAttr = op.mask().getValue();
676   int64_t maskLength = maskAttr.size();
677   if (maskLength != resultType.getDimSize(0))
678     return op.emitOpError("mask length mismatch");
679   // Verify all indices.
680   int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
681   for (auto en : llvm::enumerate(maskAttr)) {
682     auto attr = en.value().dyn_cast<IntegerAttr>();
683     if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
684       return op.emitOpError("mask index #")
685              << (en.index() + 1) << " out of range";
686   }
687   return success();
688 }
689 
parseShuffleOp(OpAsmParser & parser,OperationState & result)690 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
691   OpAsmParser::OperandType v1, v2;
692   Attribute attr;
693   VectorType v1Type, v2Type;
694   if (parser.parseOperand(v1) || parser.parseComma() ||
695       parser.parseOperand(v2) ||
696       parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
697                             result.attributes) ||
698       parser.parseOptionalAttrDict(result.attributes) ||
699       parser.parseColonType(v1Type) || parser.parseComma() ||
700       parser.parseType(v2Type) ||
701       parser.resolveOperand(v1, v1Type, result.operands) ||
702       parser.resolveOperand(v2, v2Type, result.operands))
703     return failure();
704   // Construct resulting type: leading dimension matches mask length,
705   // all trailing dimensions match the operands.
706   auto maskAttr = attr.dyn_cast<ArrayAttr>();
707   if (!maskAttr)
708     return parser.emitError(parser.getNameLoc(), "missing mask attribute");
709   int64_t maskLength = maskAttr.size();
710   if (maskLength <= 0)
711     return parser.emitError(parser.getNameLoc(), "invalid mask length");
712   int64_t v1Rank = v1Type.getRank();
713   SmallVector<int64_t, 4> shape;
714   shape.reserve(v1Rank);
715   shape.push_back(maskLength);
716   for (int64_t r = 1; r < v1Rank; ++r)
717     shape.push_back(v1Type.getDimSize(r));
718   VectorType resType = VectorType::get(shape, v1Type.getElementType());
719   parser.addTypeToList(resType, result.types);
720   return success();
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // InsertElementOp
725 //===----------------------------------------------------------------------===//
726 
print(OpAsmPrinter & p,InsertElementOp op)727 static void print(OpAsmPrinter &p, InsertElementOp op) {
728   p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
729     << op.position() << " : " << op.position().getType() << "]";
730   p.printOptionalAttrDict(op.getAttrs());
731   p << " : " << op.dest().getType();
732 }
733 
parseInsertElementOp(OpAsmParser & parser,OperationState & result)734 static ParseResult parseInsertElementOp(OpAsmParser &parser,
735                                         OperationState &result) {
736   OpAsmParser::OperandType source, dest, position;
737   Type positionType;
738   VectorType destType;
739   if (parser.parseOperand(source) || parser.parseComma() ||
740       parser.parseOperand(dest) || parser.parseLSquare() ||
741       parser.parseOperand(position) || parser.parseColonType(positionType) ||
742       parser.parseRSquare() ||
743       parser.parseOptionalAttrDict(result.attributes) ||
744       parser.parseColonType(destType))
745     return failure();
746   Type sourceType = destType.getElementType();
747   return failure(
748       parser.resolveOperand(source, sourceType, result.operands) ||
749       parser.resolveOperand(dest, destType, result.operands) ||
750       parser.resolveOperand(position, positionType, result.operands) ||
751       parser.addTypeToList(destType, result.types));
752 }
753 
verify(InsertElementOp op)754 static LogicalResult verify(InsertElementOp op) {
755   auto dstVectorType = op.getDestVectorType();
756   if (dstVectorType.getRank() != 1)
757     return op.emitOpError("expected 1-D vector");
758   return success();
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // InsertOp
763 //===----------------------------------------------------------------------===//
764 
build(Builder * builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)765 void InsertOp::build(Builder *builder, OperationState &result, Value source,
766                      Value dest, ArrayRef<int64_t> position) {
767   result.addOperands({source, dest});
768   auto positionAttr = getVectorSubscriptAttr(*builder, position);
769   result.addTypes(dest.getType());
770   result.addAttribute(getPositionAttrName(), positionAttr);
771 }
772 
print(OpAsmPrinter & p,InsertOp op)773 static void print(OpAsmPrinter &p, InsertOp op) {
774   p << op.getOperationName() << " " << op.source() << ", " << op.dest()
775     << op.position();
776   p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
777   p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
778 }
779 
parseInsertOp(OpAsmParser & parser,OperationState & result)780 static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
781   SmallVector<NamedAttribute, 4> attrs;
782   OpAsmParser::OperandType source, dest;
783   Type sourceType;
784   VectorType destType;
785   Attribute attr;
786   return failure(parser.parseOperand(source) || parser.parseComma() ||
787                  parser.parseOperand(dest) ||
788                  parser.parseAttribute(attr, InsertOp::getPositionAttrName(),
789                                        result.attributes) ||
790                  parser.parseOptionalAttrDict(attrs) ||
791                  parser.parseColonType(sourceType) ||
792                  parser.parseKeywordType("into", destType) ||
793                  parser.resolveOperand(source, sourceType, result.operands) ||
794                  parser.resolveOperand(dest, destType, result.operands) ||
795                  parser.addTypeToList(destType, result.types));
796 }
797 
verify(InsertOp op)798 static LogicalResult verify(InsertOp op) {
799   auto positionAttr = op.position().getValue();
800   if (positionAttr.empty())
801     return op.emitOpError("expected non-empty position attribute");
802   auto destVectorType = op.getDestVectorType();
803   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
804     return op.emitOpError(
805         "expected position attribute of rank smaller than dest vector rank");
806   auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
807   if (srcVectorType &&
808       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
809        static_cast<unsigned>(destVectorType.getRank())))
810     return op.emitOpError("expected position attribute rank + source rank to "
811                           "match dest vector rank");
812   else if (!srcVectorType && (positionAttr.size() !=
813                               static_cast<unsigned>(destVectorType.getRank())))
814     return op.emitOpError(
815         "expected position attribute rank to match the dest vector rank");
816   for (auto en : llvm::enumerate(positionAttr)) {
817     auto attr = en.value().dyn_cast<IntegerAttr>();
818     if (!attr || attr.getInt() < 0 ||
819         attr.getInt() >= destVectorType.getDimSize(en.index()))
820       return op.emitOpError("expected position attribute #")
821              << (en.index() + 1)
822              << " to be a non-negative integer smaller than the corresponding "
823                 "dest vector dimension";
824   }
825   return success();
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // InsertSlicesOp
830 //===----------------------------------------------------------------------===//
831 
parseInsertSlicesOp(OpAsmParser & parser,OperationState & result)832 static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
833                                        OperationState &result) {
834   OpAsmParser::OperandType operandInfo;
835   ArrayAttr sizesAttr;
836   StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
837   ArrayAttr stridesAttr;
838   StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
839   TupleType tupleType;
840   VectorType resultVectorType;
841   return failure(
842       parser.parseOperand(operandInfo) || parser.parseComma() ||
843       parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
844       parser.parseComma() ||
845       parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
846       parser.parseOptionalAttrDict(result.attributes) ||
847       parser.parseColonType(tupleType) ||
848       parser.parseKeywordType("into", resultVectorType) ||
849       parser.resolveOperand(operandInfo, tupleType, result.operands) ||
850       parser.addTypeToList(resultVectorType, result.types));
851 }
852 
print(OpAsmPrinter & p,InsertSlicesOp op)853 static void print(OpAsmPrinter &p, InsertSlicesOp op) {
854   p << op.getOperationName() << ' ' << op.vectors() << ", ";
855   p << op.sizes() << ", " << op.strides();
856   p.printOptionalAttrDict(
857       op.getAttrs(),
858       /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
859                        InsertSlicesOp::getStridesAttrName()});
860   p << " : " << op.vectors().getType();
861   p << " into " << op.getResultVectorType();
862 }
863 
verify(InsertSlicesOp op)864 static LogicalResult verify(InsertSlicesOp op) {
865   SmallVector<int64_t, 4> sizes;
866   op.getSizes(sizes);
867   SmallVector<int64_t, 4> strides;
868   op.getStrides(strides);
869   return isValidExtractOrInsertSlicesType(
870       op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
871       sizes, strides);
872 }
873 
getSizes(SmallVectorImpl<int64_t> & results)874 void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
875   populateFromInt64AttrArray(sizes(), results);
876 }
877 
getStrides(SmallVectorImpl<int64_t> & results)878 void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
879   populateFromInt64AttrArray(strides(), results);
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // InsertStridedSliceOp
884 //===----------------------------------------------------------------------===//
885 
build(Builder * builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)886 void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
887                                  Value source, Value dest,
888                                  ArrayRef<int64_t> offsets,
889                                  ArrayRef<int64_t> strides) {
890   result.addOperands({source, dest});
891   auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
892   auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
893   result.addTypes(dest.getType());
894   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
895   result.addAttribute(getStridesAttrName(), stridesAttr);
896 }
897 
print(OpAsmPrinter & p,InsertStridedSliceOp op)898 static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
899   p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " ";
900   p.printOptionalAttrDict(op.getAttrs());
901   p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
902 }
903 
parseInsertStridedSliceOp(OpAsmParser & parser,OperationState & result)904 static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
905                                              OperationState &result) {
906   OpAsmParser::OperandType source, dest;
907   VectorType sourceVectorType, destVectorType;
908   return failure(
909       parser.parseOperand(source) || parser.parseComma() ||
910       parser.parseOperand(dest) ||
911       parser.parseOptionalAttrDict(result.attributes) ||
912       parser.parseColonType(sourceVectorType) ||
913       parser.parseKeywordType("into", destVectorType) ||
914       parser.resolveOperand(source, sourceVectorType, result.operands) ||
915       parser.resolveOperand(dest, destVectorType, result.operands) ||
916       parser.addTypeToList(destVectorType, result.types));
917 }
918 
919 // TODO(ntv) Should be moved to Tablegen Confined attributes.
920 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)921 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
922                                                         ArrayAttr arrayAttr,
923                                                         ArrayRef<int64_t> shape,
924                                                         StringRef attrName) {
925   if (arrayAttr.size() > shape.size())
926     return op.emitOpError("expected ")
927            << attrName << " attribute of rank smaller than vector rank";
928   return success();
929 }
930 
931 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
932 // interval. If `halfOpen` is true then the admissible interval is [min, max).
933 // Otherwise, the admissible interval is [min, max].
934 template <typename OpType>
935 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)936 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
937                                   int64_t max, StringRef attrName,
938                                   bool halfOpen = true) {
939   for (auto attr : arrayAttr) {
940     auto val = attr.cast<IntegerAttr>().getInt();
941     auto upper = max;
942     if (!halfOpen)
943       upper += 1;
944     if (val < min || val >= upper)
945       return op.emitOpError("expected ") << attrName << " to be confined to ["
946                                          << min << ", " << upper << ")";
947   }
948   return success();
949 }
950 
951 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
952 // interval. If `halfOpen` is true then the admissible interval is [min, max).
953 // Otherwise, the admissible interval is [min, max].
954 template <typename OpType>
955 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)956 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
957                                   ArrayRef<int64_t> shape, StringRef attrName,
958                                   bool halfOpen = true, int64_t min = 0) {
959   assert(arrayAttr.size() <= shape.size());
960   unsigned index = 0;
961   for (auto it : llvm::zip(arrayAttr, shape)) {
962     auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
963     auto max = std::get<1>(it);
964     if (!halfOpen)
965       max += 1;
966     if (val < min || val >= max)
967       return op.emitOpError("expected ")
968              << attrName << " dimension " << index << " to be confined to ["
969              << min << ", " << max << ")";
970     ++index;
971   }
972   return success();
973 }
974 
975 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
976 // interval. If `halfOpen` is true then the admissible interval is [min, max).
977 // Otherwise, the admissible interval is [min, max].
978 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)979 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
980     OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
981     ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
982     bool halfOpen = true, int64_t min = 1) {
983   assert(arrayAttr1.size() <= shape.size());
984   assert(arrayAttr2.size() <= shape.size());
985   unsigned index = 0;
986   for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
987     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
988     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
989     auto max = std::get<2>(it);
990     if (!halfOpen)
991       max += 1;
992     if (val1 + val2 < 0 || val1 + val2 >= max)
993       return op.emitOpError("expected sum(")
994              << attrName1 << ", " << attrName2 << ") dimension " << index
995              << " to be confined to [" << min << ", " << max << ")";
996     ++index;
997   }
998   return success();
999 }
1000 
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1001 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1002                                   MLIRContext *context) {
1003   auto attrs = functional::map(
1004       [context](int64_t v) -> Attribute {
1005         return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
1006       },
1007       values);
1008   return ArrayAttr::get(attrs, context);
1009 }
1010 
verify(InsertStridedSliceOp op)1011 static LogicalResult verify(InsertStridedSliceOp op) {
1012   auto sourceVectorType = op.getSourceVectorType();
1013   auto destVectorType = op.getDestVectorType();
1014   auto offsets = op.offsets();
1015   auto strides = op.strides();
1016   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1017     return op.emitOpError(
1018         "expected offsets of same size as destination vector rank");
1019   if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1020     return op.emitOpError(
1021         "expected strides of same size as source vector rank");
1022   if (sourceVectorType.getRank() > destVectorType.getRank())
1023     return op.emitOpError(
1024         "expected source rank to be smaller than destination rank");
1025 
1026   auto sourceShape = sourceVectorType.getShape();
1027   auto destShape = destVectorType.getShape();
1028   SmallVector<int64_t, 4> sourceShapeAsDestShape(
1029       destShape.size() - sourceShape.size(), 0);
1030   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1031   auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1032   auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1033   if (failed(
1034           isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1035       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1036                                                /*halfOpen=*/false)) ||
1037       failed(isSumOfIntegerArrayAttrConfinedToShape(
1038           op, offsets,
1039           makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1040           offName, "source vector shape",
1041           /*halfOpen=*/false, /*min=*/1)))
1042     return failure();
1043 
1044   return success();
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // OuterProductOp
1049 //===----------------------------------------------------------------------===//
1050 
print(OpAsmPrinter & p,OuterProductOp op)1051 static void print(OpAsmPrinter &p, OuterProductOp op) {
1052   p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
1053   if (!op.acc().empty())
1054     p << ", " << op.acc();
1055   p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1056 }
1057 
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1058 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1059                                        OperationState &result) {
1060   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1061   Type tLHS, tRHS;
1062   if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
1063       parser.parseComma() || parser.parseType(tRHS))
1064     return failure();
1065   if (operandsInfo.size() < 2)
1066     return parser.emitError(parser.getNameLoc(),
1067                             "expected at least 2 operands");
1068   VectorType vLHS = tLHS.dyn_cast<VectorType>();
1069   VectorType vRHS = tRHS.dyn_cast<VectorType>();
1070   if (!vLHS || !vRHS)
1071     return parser.emitError(parser.getNameLoc(), "expected 2 vector types");
1072   VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1073                                        vLHS.getElementType());
1074   return failure(
1075       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1076       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1077       (operandsInfo.size() > 2 &&
1078        parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1079       parser.addTypeToList(resType, result.types));
1080 }
1081 
verify(OuterProductOp op)1082 static LogicalResult verify(OuterProductOp op) {
1083   VectorType vLHS = op.getOperandVectorTypeLHS(),
1084              vRHS = op.getOperandVectorTypeRHS(),
1085              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1086   if (vLHS.getRank() != 1)
1087     return op.emitOpError("expected 1-d vector for operand #1");
1088   if (vRHS.getRank() != 1)
1089     return op.emitOpError("expected 1-d vector for operand #2");
1090   if (vRES.getRank() != 2)
1091     return op.emitOpError("expected 2-d vector result");
1092   if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1093     return op.emitOpError("expected #1 operand dim to match result dim #1");
1094   if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1095     return op.emitOpError("expected #2 operand dim to match result dim #2");
1096   if (vACC && vACC != vRES)
1097     return op.emitOpError("expected operand #3 of same type as result type");
1098   return success();
1099 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // ReshapeOp
1103 //===----------------------------------------------------------------------===//
1104 
print(OpAsmPrinter & p,ReshapeOp op)1105 static void print(OpAsmPrinter &p, ReshapeOp op) {
1106   p << op.getOperationName() << " " << op.vector() << ", [" << op.input_shape()
1107     << "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
1108   SmallVector<StringRef, 2> elidedAttrs = {
1109       ReshapeOp::getOperandSegmentSizeAttr(),
1110       ReshapeOp::getFixedVectorSizesAttrName()};
1111   p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
1112   p << " : " << op.getInputVectorType() << " to " << op.getOutputVectorType();
1113 }
1114 
1115 // TODO(b/146516564) Consider passing number of inner vector dimensions that
1116 // are fixed, instead of their values in 'fixesVectorSizes' array attr.
1117 //
1118 // operation ::= ssa-id `=` `vector.reshape` ssa-use, `[` ssa-use-list `]`,
1119 //                          `[` ssa-use-list `]`, `[` array-attribute `]`
1120 //                          `:` vector-type 'to' vector-type
1121 //
parseReshapeOp(OpAsmParser & parser,OperationState & result)1122 static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
1123   OpAsmParser::OperandType inputInfo;
1124   SmallVector<OpAsmParser::OperandType, 4> inputShapeInfo;
1125   SmallVector<OpAsmParser::OperandType, 4> outputShapeInfo;
1126   ArrayAttr fixedVectorSizesAttr;
1127   StringRef attrName = ReshapeOp::getFixedVectorSizesAttrName();
1128   auto indexType = parser.getBuilder().getIndexType();
1129   if (parser.parseOperand(inputInfo) || parser.parseComma() ||
1130       parser.parseOperandList(inputShapeInfo, OpAsmParser::Delimiter::Square) ||
1131       parser.parseComma() ||
1132       parser.parseOperandList(outputShapeInfo,
1133                               OpAsmParser::Delimiter::Square) ||
1134       parser.parseComma()) {
1135     return failure();
1136   }
1137 
1138   auto builder = parser.getBuilder();
1139   result.addAttribute(
1140       ReshapeOp::getOperandSegmentSizeAttr(),
1141       builder.getI32VectorAttr({1, static_cast<int32_t>(inputShapeInfo.size()),
1142                                 static_cast<int32_t>(outputShapeInfo.size())}));
1143   Type inputType;
1144   Type outputType;
1145   return failure(
1146       parser.parseAttribute(fixedVectorSizesAttr, attrName,
1147                             result.attributes) ||
1148       parser.parseOptionalAttrDict(result.attributes) ||
1149       parser.parseColonType(inputType) ||
1150       parser.resolveOperand(inputInfo, inputType, result.operands) ||
1151       parser.resolveOperands(inputShapeInfo, indexType, result.operands) ||
1152       parser.resolveOperands(outputShapeInfo, indexType, result.operands) ||
1153       parser.parseKeywordType("to", outputType) ||
1154       parser.addTypeToList(outputType, result.types));
1155 }
1156 
verify(ReshapeOp op)1157 static LogicalResult verify(ReshapeOp op) {
1158   // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1159   auto inputVectorType = op.getInputVectorType();
1160   auto outputVectorType = op.getOutputVectorType();
1161   int64_t inputShapeRank = op.getNumInputShapeSizes();
1162   int64_t outputShapeRank = op.getNumOutputShapeSizes();
1163   SmallVector<int64_t, 4> fixedVectorSizes;
1164   op.getFixedVectorSizes(fixedVectorSizes);
1165   int64_t numFixedVectorSizes = fixedVectorSizes.size();
1166 
1167   if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1168     return op.emitError("invalid input shape for vector type ")
1169            << inputVectorType;
1170 
1171   if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1172     return op.emitError("invalid output shape for vector type ")
1173            << outputVectorType;
1174 
1175   // Verify that the 'fixedVectorSizes' match a input/output vector shape
1176   // suffix.
1177   unsigned inputVectorRank = inputVectorType.getRank();
1178   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1179     unsigned index = inputVectorRank - numFixedVectorSizes - i;
1180     if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1181       return op.emitError("fixed vector size must match input vector for dim ")
1182              << i;
1183   }
1184 
1185   unsigned outputVectorRank = outputVectorType.getRank();
1186   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1187     unsigned index = outputVectorRank - numFixedVectorSizes - i;
1188     if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1189       return op.emitError("fixed vector size must match output vector for dim ")
1190              << i;
1191   }
1192 
1193   // If all shape operands are produced by constant ops, verify that product
1194   // of dimensions for input/output shape match.
1195   auto isDefByConstant = [](Value operand) {
1196     return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1197   };
1198   if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1199       llvm::all_of(op.output_shape(), isDefByConstant)) {
1200     int64_t numInputElements = 1;
1201     for (auto operand : op.input_shape())
1202       numInputElements *=
1203           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1204     int64_t numOutputElements = 1;
1205     for (auto operand : op.output_shape())
1206       numOutputElements *=
1207           cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1208     if (numInputElements != numOutputElements)
1209       return op.emitError("product of input and output shape sizes must match");
1210   }
1211   return success();
1212 }
1213 
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1214 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1215   populateFromInt64AttrArray(fixed_vector_sizes(), results);
1216 }
1217 
1218 //===----------------------------------------------------------------------===//
1219 // StridedSliceOp
1220 //===----------------------------------------------------------------------===//
1221 
1222 // Inference works as follows:
1223 //   1. Add 'sizes' from prefix of dims in 'offsets'.
1224 //   2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1225 static Type inferStridedSliceOpResultType(VectorType vectorType,
1226                                           ArrayAttr offsets, ArrayAttr sizes,
1227                                           ArrayAttr strides) {
1228   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1229   SmallVector<int64_t, 4> shape;
1230   shape.reserve(vectorType.getRank());
1231   unsigned idx = 0;
1232   for (unsigned e = offsets.size(); idx < e; ++idx)
1233     shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
1234   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1235     shape.push_back(vectorType.getShape()[idx]);
1236 
1237   return VectorType::get(shape, vectorType.getElementType());
1238 }
1239 
build(Builder * builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1240 void StridedSliceOp::build(Builder *builder, OperationState &result,
1241                            Value source, ArrayRef<int64_t> offsets,
1242                            ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
1243   result.addOperands(source);
1244   auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
1245   auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
1246   auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
1247   result.addTypes(
1248       inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1249                                     offsetsAttr, sizesAttr, stridesAttr));
1250   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1251   result.addAttribute(getSizesAttrName(), sizesAttr);
1252   result.addAttribute(getStridesAttrName(), stridesAttr);
1253 }
1254 
print(OpAsmPrinter & p,StridedSliceOp op)1255 static void print(OpAsmPrinter &p, StridedSliceOp op) {
1256   p << op.getOperationName() << " " << op.vector();
1257   p.printOptionalAttrDict(op.getAttrs());
1258   p << " : " << op.vector().getType() << " to " << op.getResult().getType();
1259 }
1260 
parseStridedSliceOp(OpAsmParser & parser,OperationState & result)1261 static ParseResult parseStridedSliceOp(OpAsmParser &parser,
1262                                        OperationState &result) {
1263   llvm::SMLoc attributeLoc, typeLoc;
1264   OpAsmParser::OperandType vector;
1265   VectorType vectorType, resultVectorType;
1266   return failure(parser.parseOperand(vector) ||
1267                  parser.getCurrentLocation(&attributeLoc) ||
1268                  parser.parseOptionalAttrDict(result.attributes) ||
1269                  parser.getCurrentLocation(&typeLoc) ||
1270                  parser.parseColonType(vectorType) ||
1271                  parser.parseKeywordType("to", resultVectorType) ||
1272                  parser.resolveOperand(vector, vectorType, result.operands) ||
1273                  parser.addTypeToList(resultVectorType, result.types));
1274 }
1275 
verify(StridedSliceOp op)1276 static LogicalResult verify(StridedSliceOp op) {
1277   auto type = op.getVectorType();
1278   auto offsets = op.offsets();
1279   auto sizes = op.sizes();
1280   auto strides = op.strides();
1281   if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
1282     op.emitOpError(
1283         "expected offsets, sizes and strides attributes of same size");
1284     return failure();
1285   }
1286 
1287   auto shape = type.getShape();
1288   auto offName = StridedSliceOp::getOffsetsAttrName();
1289   auto sizesName = StridedSliceOp::getSizesAttrName();
1290   auto stridesName = StridedSliceOp::getStridesAttrName();
1291   if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
1292       failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
1293       failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
1294                                                 stridesName)) ||
1295       failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
1296       failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
1297                                                /*halfOpen=*/false,
1298                                                /*min=*/1)) ||
1299       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1300                                                /*halfOpen=*/false)) ||
1301       failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
1302                                                     offName, sizesName,
1303                                                     /*halfOpen=*/false)))
1304     return failure();
1305 
1306   auto resultType = inferStridedSliceOpResultType(
1307       op.getVectorType(), op.offsets(), op.sizes(), op.strides());
1308   if (op.getResult().getType() != resultType) {
1309     op.emitOpError("expected result type to be ") << resultType;
1310     return failure();
1311   }
1312 
1313   return success();
1314 }
1315 
getOffsets(SmallVectorImpl<int64_t> & results)1316 void StridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
1317   populateFromInt64AttrArray(offsets(), results);
1318 }
1319 
1320 namespace {
1321 
1322 // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
1323 class StridedSliceConstantMaskFolder final
1324     : public OpRewritePattern<StridedSliceOp> {
1325 public:
1326   using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
1327 
matchAndRewrite(StridedSliceOp stridedSliceOp,PatternRewriter & rewriter) const1328   PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
1329                                      PatternRewriter &rewriter) const override {
1330     // Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
1331     auto defOp = stridedSliceOp.vector().getDefiningOp();
1332     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
1333     if (!constantMaskOp)
1334       return matchFailure();
1335     // Return if 'stridedSliceOp' has non-unit strides.
1336     if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
1337           return attr.cast<IntegerAttr>().getInt() != 1;
1338         }))
1339       return matchFailure();
1340     // Gather constant mask dimension sizes.
1341     SmallVector<int64_t, 4> maskDimSizes;
1342     populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
1343     // Gather strided slice offsets and sizes.
1344     SmallVector<int64_t, 4> sliceOffsets;
1345     populateFromInt64AttrArray(stridedSliceOp.offsets(), sliceOffsets);
1346     SmallVector<int64_t, 4> sliceSizes;
1347     populateFromInt64AttrArray(stridedSliceOp.sizes(), sliceSizes);
1348 
1349     // Compute slice of vector mask region.
1350     SmallVector<int64_t, 4> sliceMaskDimSizes;
1351     assert(sliceOffsets.size() == maskDimSizes.size());
1352     for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
1353       int64_t maskDimSize = std::get<0>(it);
1354       int64_t sliceOffset = std::get<1>(it);
1355       int64_t sliceSize = std::get<2>(it);
1356       int64_t sliceMaskDimSize = std::max(
1357           static_cast<int64_t>(0),
1358           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
1359       sliceMaskDimSizes.push_back(sliceMaskDimSize);
1360     }
1361     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
1362     // region is a conjunction of mask dim intervals).
1363     if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
1364       sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
1365 
1366     // Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
1367     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
1368         stridedSliceOp, stridedSliceOp.getResult().getType(),
1369         vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
1370     return matchSuccess();
1371   }
1372 };
1373 
1374 } // end anonymous namespace
1375 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1376 void StridedSliceOp::getCanonicalizationPatterns(
1377     OwningRewritePatternList &results, MLIRContext *context) {
1378   // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
1379   results.insert<StridedSliceConstantMaskFolder>(context);
1380 }
1381 
1382 //===----------------------------------------------------------------------===//
1383 // TransferReadOp
1384 //===----------------------------------------------------------------------===//
1385 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)1386 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
1387                                           EmitFun emitOpError) {
1388   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
1389   for (auto expr : permutationMap.getResults()) {
1390     auto dim = expr.dyn_cast<AffineDimExpr>();
1391     auto zero = expr.dyn_cast<AffineConstantExpr>();
1392     if (zero) {
1393       if (zero.getValue() != 0) {
1394         return emitOpError(
1395             "requires a projected permutation_map (at most one dim or the zero "
1396             "constant can appear in each result)");
1397       }
1398       continue;
1399     }
1400     if (!dim) {
1401       return emitOpError("requires a projected permutation_map (at most one "
1402                          "dim or the zero constant can appear in each result)");
1403     }
1404     if (seen[dim.getPosition()]) {
1405       return emitOpError(
1406           "requires a permutation_map that is a permutation (found one dim "
1407           "used more than once)");
1408     }
1409     seen[dim.getPosition()] = true;
1410   }
1411   return success();
1412 }
1413 
verifyTransferOp(Operation * op,MemRefType memrefType,VectorType vectorType,AffineMap permutationMap)1414 static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
1415                                       VectorType vectorType,
1416                                       AffineMap permutationMap) {
1417   auto memrefElementType = memrefType.getElementType();
1418   if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
1419     // Memref has vector element type.
1420 
1421     // Check that 'memrefVectorElementType' and vector element types match.
1422     if (memrefVectorElementType.getElementType() != vectorType.getElementType())
1423       return op->emitOpError(
1424           "requires memref and vector types of the same elemental type");
1425 
1426     // Check that memref vector type is a suffix of 'vectorType.
1427     unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1428     unsigned resultVecRank = vectorType.getRank();
1429     if (memrefVecEltRank > resultVecRank)
1430       return op->emitOpError(
1431           "requires memref vector element and vector result ranks to match.");
1432     // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h.
1433     unsigned rankOffset = resultVecRank - memrefVecEltRank;
1434     auto memrefVecEltShape = memrefVectorElementType.getShape();
1435     auto resultVecShape = vectorType.getShape();
1436     for (unsigned i = 0; i < memrefVecEltRank; ++i)
1437       if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
1438         return op->emitOpError(
1439             "requires memref vector element shape to match suffix of "
1440             "vector result shape.");
1441     // Check that permutation map results match 'rankOffset' of vector type.
1442     if (permutationMap.getNumResults() != rankOffset)
1443       return op->emitOpError("requires a permutation_map with result dims of "
1444                              "the same rank as the vector type");
1445   } else {
1446     // Memref has scalar element type.
1447 
1448     // Check that memref and vector element types match.
1449     if (memrefType.getElementType() != vectorType.getElementType())
1450       return op->emitOpError(
1451           "requires memref and vector types of the same elemental type");
1452 
1453     // Check that permutation map results match rank of vector type.
1454     if (permutationMap.getNumResults() != vectorType.getRank())
1455       return op->emitOpError("requires a permutation_map with result dims of "
1456                              "the same rank as the vector type");
1457   }
1458 
1459   if (permutationMap.getNumSymbols() != 0)
1460     return op->emitOpError("requires permutation_map without symbols");
1461   if (permutationMap.getNumInputs() != memrefType.getRank())
1462     return op->emitOpError("requires a permutation_map with input dims of the "
1463                            "same rank as the memref type");
1464   return success();
1465 }
1466 
print(OpAsmPrinter & p,TransferReadOp op)1467 static void print(OpAsmPrinter &p, TransferReadOp op) {
1468   p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
1469     << "], " << op.padding() << " ";
1470   p.printOptionalAttrDict(op.getAttrs());
1471   p << " : " << op.getMemRefType() << ", " << op.getVectorType();
1472 }
1473 
parseTransferReadOp(OpAsmParser & parser,OperationState & result)1474 static ParseResult parseTransferReadOp(OpAsmParser &parser,
1475                                        OperationState &result) {
1476   llvm::SMLoc typesLoc;
1477   OpAsmParser::OperandType memrefInfo;
1478   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
1479   OpAsmParser::OperandType paddingInfo;
1480   SmallVector<Type, 2> types;
1481   // Parsing with support for optional paddingValue.
1482   if (parser.parseOperand(memrefInfo) ||
1483       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1484       parser.parseComma() || parser.parseOperand(paddingInfo) ||
1485       parser.parseOptionalAttrDict(result.attributes) ||
1486       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1487     return failure();
1488   if (types.size() != 2)
1489     return parser.emitError(typesLoc, "two types required");
1490   auto indexType = parser.getBuilder().getIndexType();
1491   MemRefType memRefType = types[0].dyn_cast<MemRefType>();
1492   if (!memRefType)
1493     return parser.emitError(typesLoc, "memref type required"), failure();
1494   Type vectorType = types[1];
1495   return failure(
1496       parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
1497       parser.resolveOperands(indexInfo, indexType, result.operands) ||
1498       parser.resolveOperand(paddingInfo, memRefType.getElementType(),
1499                             result.operands) ||
1500       parser.addTypeToList(vectorType, result.types));
1501 }
1502 
verify(TransferReadOp op)1503 static LogicalResult verify(TransferReadOp op) {
1504   // Consistency of elemental types in memref and vector.
1505   MemRefType memrefType = op.getMemRefType();
1506   VectorType vectorType = op.getVectorType();
1507   auto paddingType = op.padding().getType();
1508   auto permutationMap = op.permutation_map();
1509   auto memrefElementType = memrefType.getElementType();
1510 
1511   if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
1512     return op.emitOpError("requires ") << memrefType.getRank() << " indices";
1513 
1514   if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
1515                               permutationMap)))
1516     return failure();
1517 
1518   if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
1519     // Memref has vector element type.
1520     // Check that 'memrefVectorElementType' and 'paddingType' types match.
1521     if (memrefVectorElementType != paddingType)
1522       return op.emitOpError(
1523           "requires memref element type and padding type to match.");
1524 
1525   } else {
1526     // Check that 'paddingType' is valid to store in a vector type.
1527     if (!VectorType::isValidElementType(paddingType))
1528       return op.emitOpError("requires valid padding vector elemental type");
1529 
1530     // Check that padding type and vector element types match.
1531     if (paddingType != vectorType.getElementType())
1532       return op.emitOpError(
1533           "requires formal padding and vector of the same elemental type");
1534   }
1535 
1536   return verifyPermutationMap(permutationMap,
1537                               [&op](Twine t) { return op.emitOpError(t); });
1538 }
1539 
1540 //===----------------------------------------------------------------------===//
1541 // TransferWriteOp
1542 //===----------------------------------------------------------------------===//
print(OpAsmPrinter & p,TransferWriteOp op)1543 static void print(OpAsmPrinter &p, TransferWriteOp op) {
1544   p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
1545     << op.indices() << "]";
1546   p.printOptionalAttrDict(op.getAttrs());
1547   p << " : " << op.getVectorType() << ", " << op.getMemRefType();
1548 }
1549 
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)1550 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
1551                                         OperationState &result) {
1552   llvm::SMLoc typesLoc;
1553   OpAsmParser::OperandType storeValueInfo;
1554   OpAsmParser::OperandType memRefInfo;
1555   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1556   SmallVector<Type, 2> types;
1557   if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
1558       parser.parseOperand(memRefInfo) ||
1559       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1560       parser.parseOptionalAttrDict(result.attributes) ||
1561       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1562     return failure();
1563   if (types.size() != 2)
1564     return parser.emitError(typesLoc, "two types required");
1565   auto indexType = parser.getBuilder().getIndexType();
1566   Type vectorType = types[0], memRefType = types[1];
1567   return failure(
1568       parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
1569       parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
1570       parser.resolveOperands(indexInfo, indexType, result.operands));
1571 }
1572 
verify(TransferWriteOp op)1573 static LogicalResult verify(TransferWriteOp op) {
1574   // Consistency of elemental types in memref and vector.
1575   MemRefType memrefType = op.getMemRefType();
1576   VectorType vectorType = op.getVectorType();
1577   auto permutationMap = op.permutation_map();
1578 
1579   if (llvm::size(op.indices()) != memrefType.getRank())
1580     return op.emitOpError("requires ") << memrefType.getRank() << " indices";
1581 
1582   if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
1583                               permutationMap)))
1584     return failure();
1585 
1586   return verifyPermutationMap(permutationMap,
1587                               [&op](Twine t) { return op.emitOpError(t); });
1588 }
1589 
1590 //===----------------------------------------------------------------------===//
1591 // TypeCastOp
1592 //===----------------------------------------------------------------------===//
1593 
inferVectorTypeCastResultType(MemRefType t)1594 static MemRefType inferVectorTypeCastResultType(MemRefType t) {
1595   return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
1596 }
1597 
build(Builder * builder,OperationState & result,Value source)1598 void TypeCastOp::build(Builder *builder, OperationState &result, Value source) {
1599   result.addOperands(source);
1600   result.addTypes(
1601       inferVectorTypeCastResultType(source.getType().cast<MemRefType>()));
1602 }
1603 
print(OpAsmPrinter & p,TypeCastOp op)1604 static void print(OpAsmPrinter &p, TypeCastOp op) {
1605   auto type = op.getOperand().getType().cast<MemRefType>();
1606   p << op.getOperationName() << ' ' << op.memref() << " : " << type << " to "
1607     << inferVectorTypeCastResultType(type);
1608 }
1609 
verify(TypeCastOp op)1610 static LogicalResult verify(TypeCastOp op) {
1611   auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
1612   if (op.getResultMemRefType() != resultType)
1613     return op.emitOpError("expects result type to be: ") << resultType;
1614   return success();
1615 }
1616 
1617 //===----------------------------------------------------------------------===//
1618 // TupleOp
1619 //===----------------------------------------------------------------------===//
1620 
parseTupleOp(OpAsmParser & parser,OperationState & result)1621 static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
1622   SmallVector<OpAsmParser::OperandType, 4> operandInfos;
1623   SmallVector<Type, 4> types;
1624   auto loc = parser.getCurrentLocation();
1625   auto *ctx = parser.getBuilder().getContext();
1626   return failure(
1627       parser.parseOperandList(operandInfos) ||
1628       parser.parseOptionalAttrDict(result.attributes) ||
1629       parser.parseColonTypeList(types) ||
1630       parser.resolveOperands(operandInfos, types, loc, result.operands) ||
1631       parser.addTypeToList(TupleType::get(types, ctx), result.types));
1632 }
1633 
print(OpAsmPrinter & p,TupleOp op)1634 static void print(OpAsmPrinter &p, TupleOp op) {
1635   p << op.getOperationName() << ' ';
1636   p.printOperands(op.getOperands());
1637   p.printOptionalAttrDict(op.getAttrs());
1638   p << " : ";
1639   interleaveComma(op.getOperation()->getOperandTypes(), p);
1640 }
1641 
verify(TupleOp op)1642 static LogicalResult verify(TupleOp op) { return success(); }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // TupleGetOp
1646 //===----------------------------------------------------------------------===//
1647 
parseTupleGetOp(OpAsmParser & parser,OperationState & result)1648 static ParseResult parseTupleGetOp(OpAsmParser &parser,
1649                                    OperationState &result) {
1650   OpAsmParser::OperandType operandInfo;
1651   IntegerAttr indexAttr;
1652   StringRef indexAttrName = TupleGetOp::getIndexAttrName();
1653   Type indexType = parser.getBuilder().getIndexType();
1654   TupleType tupleType;
1655   if (parser.parseOperand(operandInfo) || parser.parseComma() ||
1656       parser.parseAttribute(indexAttr, indexType, indexAttrName,
1657                             result.attributes) ||
1658       parser.parseOptionalAttrDict(result.attributes) ||
1659       parser.parseColonType(tupleType) ||
1660       parser.resolveOperand(operandInfo, tupleType, result.operands))
1661     return failure();
1662   if (indexAttr.getInt() < 0 ||
1663       indexAttr.getInt() >= static_cast<int64_t>(tupleType.size()))
1664     return failure();
1665   parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types);
1666   return success();
1667 }
1668 
print(OpAsmPrinter & p,TupleGetOp op)1669 static void print(OpAsmPrinter &p, TupleGetOp op) {
1670   p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index();
1671   p.printOptionalAttrDict(op.getAttrs(),
1672                           /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
1673   p << " : " << op.getOperand().getType();
1674 }
1675 
verify(TupleGetOp op)1676 static LogicalResult verify(TupleGetOp op) {
1677   auto tupleType = op.getOperand().getType().cast<TupleType>();
1678   if (op.getIndex() < 0 ||
1679       op.getIndex() >= static_cast<int64_t>(tupleType.size()))
1680     return op.emitOpError("tuple get index out of range");
1681   return success();
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // ConstantMaskOp
1686 //===----------------------------------------------------------------------===//
1687 
parseConstantMaskOp(OpAsmParser & parser,OperationState & result)1688 static ParseResult parseConstantMaskOp(OpAsmParser &parser,
1689                                        OperationState &result) {
1690   Type resultType;
1691   ArrayAttr maskDimSizesAttr;
1692   StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
1693   return failure(
1694       parser.parseOptionalAttrDict(result.attributes) ||
1695       parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
1696       parser.parseColonType(resultType) ||
1697       parser.addTypeToList(resultType, result.types));
1698 }
1699 
print(OpAsmPrinter & p,ConstantMaskOp op)1700 static void print(OpAsmPrinter &p, ConstantMaskOp op) {
1701   p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
1702     << op.getResult().getType();
1703 }
1704 
verify(ConstantMaskOp & op)1705 static LogicalResult verify(ConstantMaskOp &op) {
1706   // Verify that array attr size matches the rank of the vector result.
1707   auto resultType = op.getResult().getType().cast<VectorType>();
1708   if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
1709     return op.emitOpError(
1710         "must specify array attr of size equal vector result rank");
1711   // Verify that each array attr element is in bounds of corresponding vector
1712   // result dimension size.
1713   auto resultShape = resultType.getShape();
1714   SmallVector<int64_t, 4> maskDimSizes;
1715   for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
1716     int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
1717     if (attrValue < 0 || attrValue > resultShape[it.index()])
1718       return op.emitOpError(
1719           "array attr of size out of bounds of vector result dimension size");
1720     maskDimSizes.push_back(attrValue);
1721   }
1722   // Verify that if one mask dim size is zero, they all should be zero (because
1723   // the mask region is a conjunction of each mask dimension interval).
1724   bool any_zeros = llvm::is_contained(maskDimSizes, 0);
1725   bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
1726   if (any_zeros && !all_zeros)
1727     return op.emitOpError("expected all mask dim sizes to be zeros, "
1728                           "as a result of conjunction with zero mask dim");
1729   return success();
1730 }
1731 
1732 //===----------------------------------------------------------------------===//
1733 // CreateMaskOp
1734 //===----------------------------------------------------------------------===//
1735 
parseCreateMaskOp(OpAsmParser & parser,OperationState & result)1736 static ParseResult parseCreateMaskOp(OpAsmParser &parser,
1737                                      OperationState &result) {
1738   auto indexType = parser.getBuilder().getIndexType();
1739   Type resultType;
1740   SmallVector<OpAsmParser::OperandType, 4> operandInfo;
1741   return failure(
1742       parser.parseOperandList(operandInfo) ||
1743       parser.parseOptionalAttrDict(result.attributes) ||
1744       parser.parseColonType(resultType) ||
1745       parser.resolveOperands(operandInfo, indexType, result.operands) ||
1746       parser.addTypeToList(resultType, result.types));
1747 }
1748 
print(OpAsmPrinter & p,CreateMaskOp op)1749 static void print(OpAsmPrinter &p, CreateMaskOp op) {
1750   p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
1751 }
1752 
verify(CreateMaskOp op)1753 static LogicalResult verify(CreateMaskOp op) {
1754   // Verify that an operand was specified for each result vector each dimension.
1755   if (op.getNumOperands() !=
1756       op.getResult().getType().cast<VectorType>().getRank())
1757     return op.emitOpError(
1758         "must specify an operand for each result vector dimension");
1759   return success();
1760 }
1761 
1762 //===----------------------------------------------------------------------===//
1763 // PrintOp
1764 //===----------------------------------------------------------------------===//
1765 
parsePrintOp(OpAsmParser & parser,OperationState & result)1766 static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
1767   OpAsmParser::OperandType source;
1768   Type sourceType;
1769   return failure(parser.parseOperand(source) ||
1770                  parser.parseColonType(sourceType) ||
1771                  parser.resolveOperand(source, sourceType, result.operands));
1772 }
1773 
print(OpAsmPrinter & p,PrintOp op)1774 static void print(OpAsmPrinter &p, PrintOp op) {
1775   p << op.getOperationName() << ' ' << op.source() << " : "
1776     << op.getPrintType();
1777 }
1778 
1779 namespace {
1780 
1781 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
1782 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
1783 public:
1784   using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
1785 
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const1786   PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
1787                                      PatternRewriter &rewriter) const override {
1788     // Return if any of 'createMaskOp' operands are not defined by a constant.
1789     auto is_not_def_by_constant = [](Value operand) {
1790       return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1791     };
1792     if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
1793       return matchFailure();
1794     // Gather constant mask dimension sizes.
1795     SmallVector<int64_t, 4> maskDimSizes;
1796     for (auto operand : createMaskOp.operands()) {
1797       auto defOp = operand.getDefiningOp();
1798       maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
1799     }
1800     // Replace 'createMaskOp' with ConstantMaskOp.
1801     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
1802         createMaskOp, createMaskOp.getResult().getType(),
1803         vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
1804     return matchSuccess();
1805   }
1806 };
1807 
1808 } // end anonymous namespace
1809 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1810 void CreateMaskOp::getCanonicalizationPatterns(
1811     OwningRewritePatternList &results, MLIRContext *context) {
1812   results.insert<CreateMaskFolder>(context);
1813 }
1814 
populateVectorToVectorCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)1815 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
1816     OwningRewritePatternList &patterns, MLIRContext *context) {
1817   patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
1818 }
1819 
1820 namespace mlir {
1821 namespace vector {
1822 
1823 #define GET_OP_CLASSES
1824 #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
1825 
1826 } // namespace vector
1827 } // namespace mlir
1828