1 //===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/VectorOps/VectorOps.h"
15 #include "mlir/Dialect/StandardOps/Ops.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Support/Functional.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Support/MathExtras.h"
26 #include "mlir/Support/STLExtras.h"
27 #include "llvm/ADT/StringSet.h"
28
29 using namespace mlir;
30 using namespace mlir::vector;
31
32 //===----------------------------------------------------------------------===//
33 // VectorOpsDialect
34 //===----------------------------------------------------------------------===//
35
VectorOpsDialect(MLIRContext * context)36 VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
37 : Dialect(getDialectNamespace(), context) {
38 addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
41 >();
42 }
43
44 /// Materialize a single constant operation from a given attribute value with
45 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)46 Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
47 Attribute value, Type type,
48 Location loc) {
49 return builder.create<ConstantOp>(loc, type, value);
50 }
51
getVectorSubscriptType(Builder & builder)52 IntegerType vector::getVectorSubscriptType(Builder &builder) {
53 return builder.getIntegerType(64);
54 }
55
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)56 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
57 ArrayRef<int64_t> values) {
58 return builder.getI64ArrayAttr(values);
59 }
60
61 //===----------------------------------------------------------------------===//
62 // ContractionOp
63 //===----------------------------------------------------------------------===//
64
build(Builder * builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)65 void vector::ContractionOp::build(Builder *builder, OperationState &result,
66 Value lhs, Value rhs, Value acc,
67 ArrayAttr indexingMaps,
68 ArrayAttr iteratorTypes) {
69 result.addOperands({lhs, rhs, acc});
70 result.addTypes(acc.getType());
71 result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
72 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
73 }
74
parseContractionOp(OpAsmParser & parser,OperationState & result)75 static ParseResult parseContractionOp(OpAsmParser &parser,
76 OperationState &result) {
77 OpAsmParser::OperandType lhsInfo;
78 OpAsmParser::OperandType rhsInfo;
79 OpAsmParser::OperandType accInfo;
80 SmallVector<OpAsmParser::OperandType, 2> masksInfo;
81 SmallVector<Type, 2> types;
82 Type resultVectorType;
83 auto loc = parser.getCurrentLocation();
84 DictionaryAttr dictAttr;
85 // TODO(andydavis, ntv) Unify linalg op attribute parsing.
86 if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
87 parser.parseOperand(lhsInfo) || parser.parseComma() ||
88 parser.parseOperand(rhsInfo) || parser.parseComma() ||
89 parser.parseOperand(accInfo) ||
90 parser.parseTrailingOperandList(masksInfo) ||
91 parser.parseOptionalAttrDict(result.attributes) ||
92 parser.parseColonTypeList(types) ||
93 parser.parseKeywordType("into", resultVectorType) ||
94 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
95 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
96 parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
97 parser.addTypeToList(resultVectorType, result.types))
98 return failure();
99 result.attributes.assign(dictAttr.getValue().begin(),
100 dictAttr.getValue().end());
101 if (masksInfo.empty())
102 return success();
103 if (masksInfo.size() != 2)
104 return parser.emitError(parser.getNameLoc(),
105 "expected zero or exactly 2 vector mask operands");
106 auto lhsType = types[0].cast<VectorType>();
107 auto rhsType = types[1].cast<VectorType>();
108 auto maskElementType = parser.getBuilder().getI1Type();
109 SmallVector<Type, 2> maskTypes;
110 maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
111 maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
112 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
113 return failure();
114 return success();
115 }
116
print(OpAsmPrinter & p,ContractionOp op)117 static void print(OpAsmPrinter &p, ContractionOp op) {
118 // TODO(andydavis, ntv) Unify printing code with linalg ops.
119 auto attrNames = op.getTraitAttrNames();
120 llvm::StringSet<> traitAttrsSet;
121 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
122 SmallVector<NamedAttribute, 8> attrs;
123 for (auto attr : op.getAttrs())
124 if (traitAttrsSet.count(attr.first.strref()) > 0)
125 attrs.push_back(attr);
126
127 auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
128 p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
129 p << op.rhs() << ", " << op.acc();
130 if (op.masks().size() == 2)
131 p << ", " << op.masks();
132
133 p.printOptionalAttrDict(op.getAttrs(), attrNames);
134 p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
135 << op.getResultType();
136 }
137
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)138 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
139 const std::vector<std::pair<int64_t, int64_t>> &map) {
140 for (auto &dimPair : map) {
141 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
142 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
143 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
144 return false;
145 }
146 return true;
147 }
148
verifyOutputShape(VectorType lhsType,VectorType rhsType,VectorType accType,VectorType resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)149 static bool verifyOutputShape(
150 VectorType lhsType, VectorType rhsType, VectorType accType,
151 VectorType resType,
152 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
153 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
154 DenseSet<int64_t> lhsContractingDimSet;
155 DenseSet<int64_t> rhsContractingDimSet;
156 for (auto &dimPair : contractingDimMap) {
157 lhsContractingDimSet.insert(dimPair.first);
158 rhsContractingDimSet.insert(dimPair.second);
159 }
160 DenseSet<int64_t> rhsBatchDimSet;
161 for (auto &dimPair : batchDimMap)
162 rhsBatchDimSet.insert(dimPair.second);
163
164 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
165 SmallVector<int64_t, 4> expectedResultDims;
166 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
167 if (lhsContractingDimSet.count(i) > 0)
168 continue;
169 expectedResultDims.push_back(lhsType.getDimSize(i));
170 }
171
172 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
173 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
174 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
175 continue;
176 expectedResultDims.push_back(rhsType.getDimSize(i));
177 }
178
179 // Verify dimension from 'resType' against 'expectedResultDims'.
180 if (resType.getShape().size() != expectedResultDims.size() ||
181 accType.getShape().size() != expectedResultDims.size())
182 return false;
183 for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
184 if (resType.getDimSize(i) != expectedResultDims[i] ||
185 accType.getDimSize(i) != expectedResultDims[i])
186 return false;
187 }
188 return true;
189 }
190
verify(ContractionOp op)191 static LogicalResult verify(ContractionOp op) {
192 auto lhsType = op.getLhsType();
193 auto rhsType = op.getRhsType();
194 auto accType = op.getAccType();
195 auto resType = op.getResultType();
196
197 // Verify that an indexing map was specified for each vector operand.
198 if (op.indexing_maps().size() != 3)
199 return op.emitOpError("expected an indexing map for each vector operand");
200
201 // Verify that each index map has 'numIterators' inputs, no symbols, and
202 // that the number of map outputs equals the rank of its associated
203 // vector operand.
204 unsigned numIterators = op.iterator_types().getValue().size();
205 for (auto it : llvm::enumerate(op.indexing_maps())) {
206 auto index = it.index();
207 auto map = it.value().cast<AffineMapAttr>().getValue();
208 if (map.getNumSymbols() != 0)
209 return op.emitOpError("expected indexing map ")
210 << index << " to have no symbols";
211 if (map.getNumDims() != numIterators)
212 return op.emitOpError("expected indexing map ")
213 << index << " to have " << numIterators << " number of inputs";
214 auto operandType = op.getOperand(index).getType().cast<VectorType>();
215 unsigned rank = operandType.getShape().size();
216 if (map.getNumResults() != rank)
217 return op.emitOpError("expected indexing map ")
218 << index << " to have " << rank << " number of outputs";
219 if (!map.isProjectedPermutation())
220 return op.emitOpError("expected indexing map ")
221 << index << " to be a projected permutation of its inputs";
222 }
223
224 auto contractingDimMap = op.getContractingDimMap();
225 auto batchDimMap = op.getBatchDimMap();
226
227 // Verify at least one contracting dimension pair was specified.
228 if (contractingDimMap.empty())
229 return op.emitOpError("expected at least one contracting dimension pair");
230
231 // Verify contracting dimension map was properly constructed.
232 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
233 return op.emitOpError("invalid contracting dimension map");
234
235 // Verify batch dimension map was properly constructed.
236 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
237 return op.emitOpError("invalid batch dimension map");
238
239 // Verify 'accType' and 'resType' shape.
240 if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
241 batchDimMap))
242 return op.emitOpError("invalid accumulator/result vector shape");
243
244 // Verify that either two vector masks are set or none are set.
245 auto lhsMaskType = op.getLHSVectorMaskType();
246 auto rhsMaskType = op.getRHSVectorMaskType();
247 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
248 return op.emitOpError("invalid number of vector masks specified");
249 if (lhsMaskType && rhsMaskType) {
250 // Verify mask rank == argument rank.
251 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
252 rhsMaskType.getShape().size() != rhsType.getShape().size())
253 return op.emitOpError("invalid vector mask rank");
254 }
255 return success();
256 }
257
getTraitAttrNames()258 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
259 static constexpr StringLiteral names[2] = {getIndexingMapsAttrName(),
260 getIteratorTypesAttrName()};
261 ArrayRef<StringLiteral> res{names};
262 return ArrayRef<StringRef>{res.begin(), res.end()};
263 }
264
getResultIndex(AffineMap map,AffineExpr targetExpr)265 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
266 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
267 if (targetExpr == map.getResult(i))
268 return i;
269 return -1;
270 }
271
272 static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)273 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
274 StringRef targetIteratorTypeName, MLIRContext *context) {
275 std::vector<std::pair<int64_t, int64_t>> dimMap;
276 for (auto it : llvm::enumerate(iteratorTypes)) {
277 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
278 if (iteratorTypeName != targetIteratorTypeName)
279 continue;
280 // Search lhs/rhs map results for 'targetExpr'.
281 auto targetExpr = getAffineDimExpr(it.index(), context);
282 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
283 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
284 if (lhsDim >= 0 && rhsDim >= 0)
285 dimMap.push_back({lhsDim, rhsDim});
286 }
287 return dimMap;
288 }
289
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)290 void ContractionOp::getIterationBounds(
291 SmallVectorImpl<int64_t> &iterationBounds) {
292 auto lhsShape = getLhsType().getShape();
293 auto resShape = getResultType().getShape();
294 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
295 SmallVector<int64_t, 2> iterationShape;
296 for (auto it : llvm::enumerate(iterator_types())) {
297 // Search lhs/rhs map results for 'targetExpr'.
298 auto targetExpr = getAffineDimExpr(it.index(), getContext());
299 auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
300 if (iteratorTypeName == getReductionIteratorTypeName()) {
301 // Get reduction dim size from lhs shape (same size in rhsShape).
302 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
303 assert(lhsDimIndex >= 0);
304 iterationBounds.push_back(lhsShape[lhsDimIndex]);
305 continue;
306 }
307 // Get parallel dimension size from result shape.
308 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
309 assert(resDimIndex >= 0);
310 iterationBounds.push_back(resShape[resDimIndex]);
311 }
312 }
313
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)314 void ContractionOp::getIterationIndexMap(
315 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
316 unsigned numMaps = indexing_maps().getValue().size();
317 iterationIndexMap.resize(numMaps);
318 for (auto it : llvm::enumerate(indexing_maps())) {
319 auto index = it.index();
320 auto map = it.value().cast<AffineMapAttr>().getValue();
321 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
322 auto dim = map.getResult(i).cast<AffineDimExpr>();
323 iterationIndexMap[index][dim.getPosition()] = i;
324 }
325 }
326 }
327
getContractingDimMap()328 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
329 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
330 return getDimMap(indexingMaps, iterator_types(),
331 getReductionIteratorTypeName(), getContext());
332 }
333
getBatchDimMap()334 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
335 SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
336 return getDimMap(indexingMaps, iterator_types(),
337 getParallelIteratorTypeName(), getContext());
338 }
339
getIndexingMaps()340 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
341 SmallVector<AffineMap, 4> res;
342 auto mapAttrs = indexing_maps().getValue();
343 res.reserve(mapAttrs.size());
344 for (auto mapAttr : mapAttrs)
345 res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
346 return res;
347 }
348
349 //===----------------------------------------------------------------------===//
350 // ExtractElementOp
351 //===----------------------------------------------------------------------===//
352
print(OpAsmPrinter & p,vector::ExtractElementOp op)353 static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
354 p << op.getOperationName() << " " << op.vector() << "[" << op.position()
355 << " : " << op.position().getType() << "]";
356 p.printOptionalAttrDict(op.getAttrs());
357 p << " : " << op.vector().getType();
358 }
359
parseExtractElementOp(OpAsmParser & parser,OperationState & result)360 static ParseResult parseExtractElementOp(OpAsmParser &parser,
361 OperationState &result) {
362 OpAsmParser::OperandType vector, position;
363 Type positionType;
364 VectorType vectorType;
365 if (parser.parseOperand(vector) || parser.parseLSquare() ||
366 parser.parseOperand(position) || parser.parseColonType(positionType) ||
367 parser.parseRSquare() ||
368 parser.parseOptionalAttrDict(result.attributes) ||
369 parser.parseColonType(vectorType))
370 return failure();
371 Type resultType = vectorType.getElementType();
372 return failure(
373 parser.resolveOperand(vector, vectorType, result.operands) ||
374 parser.resolveOperand(position, positionType, result.operands) ||
375 parser.addTypeToList(resultType, result.types));
376 }
377
verify(vector::ExtractElementOp op)378 static LogicalResult verify(vector::ExtractElementOp op) {
379 VectorType vectorType = op.getVectorType();
380 if (vectorType.getRank() != 1)
381 return op.emitOpError("expected 1-D vector");
382 return success();
383 }
384
385 //===----------------------------------------------------------------------===//
386 // ExtractOp
387 //===----------------------------------------------------------------------===//
388
inferExtractOpResultType(VectorType vectorType,ArrayAttr position)389 static Type inferExtractOpResultType(VectorType vectorType,
390 ArrayAttr position) {
391 if (static_cast<int64_t>(position.size()) == vectorType.getRank())
392 return vectorType.getElementType();
393 return VectorType::get(vectorType.getShape().drop_front(position.size()),
394 vectorType.getElementType());
395 }
396
build(Builder * builder,OperationState & result,Value source,ArrayRef<int64_t> position)397 void vector::ExtractOp::build(Builder *builder, OperationState &result,
398 Value source, ArrayRef<int64_t> position) {
399 result.addOperands(source);
400 auto positionAttr = getVectorSubscriptAttr(*builder, position);
401 result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
402 positionAttr));
403 result.addAttribute(getPositionAttrName(), positionAttr);
404 }
405
print(OpAsmPrinter & p,vector::ExtractOp op)406 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
407 p << op.getOperationName() << " " << op.vector() << op.position();
408 p.printOptionalAttrDict(op.getAttrs(), {"position"});
409 p << " : " << op.vector().getType();
410 }
411
parseExtractOp(OpAsmParser & parser,OperationState & result)412 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
413 llvm::SMLoc attributeLoc, typeLoc;
414 SmallVector<NamedAttribute, 4> attrs;
415 OpAsmParser::OperandType vector;
416 Type type;
417 Attribute attr;
418 if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
419 parser.parseAttribute(attr, "position", attrs) ||
420 parser.parseOptionalAttrDict(attrs) ||
421 parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
422 return failure();
423
424 auto vectorType = type.dyn_cast<VectorType>();
425 if (!vectorType)
426 return parser.emitError(typeLoc, "expected vector type");
427
428 auto positionAttr = attr.dyn_cast<ArrayAttr>();
429 if (!positionAttr ||
430 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
431 return parser.emitError(
432 attributeLoc,
433 "expected position attribute of rank smaller than vector rank");
434
435 Type resType = inferExtractOpResultType(vectorType, positionAttr);
436 result.attributes = attrs;
437 return failure(parser.resolveOperand(vector, type, result.operands) ||
438 parser.addTypeToList(resType, result.types));
439 }
440
verify(vector::ExtractOp op)441 static LogicalResult verify(vector::ExtractOp op) {
442 auto positionAttr = op.position().getValue();
443 if (positionAttr.empty())
444 return op.emitOpError("expected non-empty position attribute");
445 if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
446 return op.emitOpError(
447 "expected position attribute of rank smaller than vector rank");
448 for (auto en : llvm::enumerate(positionAttr)) {
449 auto attr = en.value().dyn_cast<IntegerAttr>();
450 if (!attr || attr.getInt() < 0 ||
451 attr.getInt() >= op.getVectorType().getDimSize(en.index()))
452 return op.emitOpError("expected position attribute #")
453 << (en.index() + 1)
454 << " to be a non-negative integer smaller than the corresponding "
455 "vector dimension";
456 }
457 return success();
458 }
459
460 //===----------------------------------------------------------------------===//
461 // ExtractSlicesOp
462 //===----------------------------------------------------------------------===//
463
build(Builder * builder,OperationState & result,TupleType tupleType,Value vector,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)464 void ExtractSlicesOp::build(Builder *builder, OperationState &result,
465 TupleType tupleType, Value vector,
466 ArrayRef<int64_t> sizes,
467 ArrayRef<int64_t> strides) {
468 result.addOperands(vector);
469 auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
470 auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
471 result.addTypes(tupleType);
472 result.addAttribute(getSizesAttrName(), sizesAttr);
473 result.addAttribute(getStridesAttrName(), stridesAttr);
474 }
475
parseExtractSlicesOp(OpAsmParser & parser,OperationState & result)476 static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
477 OperationState &result) {
478 OpAsmParser::OperandType operandInfo;
479 ArrayAttr sizesAttr;
480 StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
481 ArrayAttr stridesAttr;
482 StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
483 VectorType vectorType;
484 TupleType resultTupleType;
485 return failure(
486 parser.parseOperand(operandInfo) || parser.parseComma() ||
487 parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
488 parser.parseComma() ||
489 parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
490 parser.parseOptionalAttrDict(result.attributes) ||
491 parser.parseColonType(vectorType) ||
492 parser.parseKeywordType("into", resultTupleType) ||
493 parser.resolveOperand(operandInfo, vectorType, result.operands) ||
494 parser.addTypeToList(resultTupleType, result.types));
495 }
496
print(OpAsmPrinter & p,ExtractSlicesOp op)497 static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
498 p << op.getOperationName() << ' ' << op.vector() << ", ";
499 p << op.sizes() << ", " << op.strides();
500 p.printOptionalAttrDict(
501 op.getAttrs(),
502 /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
503 ExtractSlicesOp::getStridesAttrName()});
504 p << " : " << op.vector().getType();
505 p << " into " << op.getResultTupleType();
506 }
507
508 static LogicalResult
isValidExtractOrInsertSlicesType(Operation * op,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)509 isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
510 TupleType tupleType, ArrayRef<int64_t> sizes,
511 ArrayRef<int64_t> strides) {
512 // Check for non-unit strides.
513 // TODO(b/144845578) Support non-1 strides.
514 if (llvm::any_of(strides, [](int64_t s) { return s != 1; }))
515 return op->emitError("requires unit strides");
516 // Check that 'vectorType' rank matches rank of tuple element vectors.
517 unsigned rank = vectorType.getRank();
518 auto is_vector_type_of_rank = [&](Type t) {
519 return t.isa<VectorType>() && t.cast<VectorType>().getRank() == rank;
520 };
521 if (!llvm::all_of(tupleType.getTypes(), is_vector_type_of_rank))
522 return op->emitError("requires vector tuple elements of rank ") << rank;
523 // Check that 'sizes' and 'strides' are of size == 'rank'.
524 if (sizes.size() != rank || strides.size() != rank)
525 return op->emitError("requires sizes and strides of rank ") << rank;
526
527 // Compute the number of slices in each dimension.
528 // TODO(andydavis) Move this into a slice generation helper function.
529 auto shape = vectorType.getShape();
530 SmallVector<int64_t, 4> dimSliceCounts(rank);
531 for (unsigned i = 0; i < rank; ++i)
532 dimSliceCounts[i] = ceilDiv(shape[i], sizes[i]);
533 // Compute the strides between slices in each dimension.
534 SmallVector<int64_t, 4> sliceStrides(rank);
535 sliceStrides[rank - 1] = 1;
536 for (int i = rank - 2; i >= 0; --i)
537 sliceStrides[i] = sliceStrides[i + 1] * dimSliceCounts[i + 1];
538
539 // Generate each slice shape based on 'sizes', 'strides' and 'vectorType',
540 // and verify that the same matches the corresponding tuple element 'i'.
541 for (int64_t i = 0, e = tupleType.size(); i < e; ++i) {
542 // De-linearize w.r.t. 'sliceStrides'.
543 SmallVector<int64_t, 4> vectorOffsets(rank);
544 int64_t linearIndex = i;
545 for (unsigned j = 0; j < rank; ++j) {
546 vectorOffsets[j] = linearIndex / sliceStrides[j];
547 linearIndex %= sliceStrides[j];
548 }
549 // Convert from unrolled vector-space offsets to element-space offsets.
550 auto offsets = mlir::functional::zipMap(
551 [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
552 // Initialize 'sliceSizes' to target 'sizes'
553 SmallVector<int64_t, 4> sliceSizes(sizes.begin(), sizes.end());
554 for (unsigned j = 0; j < rank; ++j) {
555 // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles.
556 sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]);
557 }
558 // Create slice VectorType type.
559 auto sliceVectorType =
560 VectorType::get(sliceSizes, vectorType.getElementType());
561 // Verify that 'sliceVectorType' matches tupleType.getTypes(i)
562 if (sliceVectorType != tupleType.getType(i))
563 return op->emitError("invalid tuple element type ") << sliceVectorType;
564 }
565 return success();
566 }
567
verify(ExtractSlicesOp op)568 static LogicalResult verify(ExtractSlicesOp op) {
569 SmallVector<int64_t, 4> sizes;
570 op.getSizes(sizes);
571 SmallVector<int64_t, 4> strides;
572 op.getStrides(strides);
573 return isValidExtractOrInsertSlicesType(
574 op.getOperation(), op.getSourceVectorType(), op.getResultTupleType(),
575 sizes, strides);
576 }
577
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)578 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
579 SmallVectorImpl<int64_t> &results) {
580 for (auto attr : arrayAttr)
581 results.push_back(attr.cast<IntegerAttr>().getInt());
582 }
583
getSizes(SmallVectorImpl<int64_t> & results)584 void ExtractSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
585 populateFromInt64AttrArray(sizes(), results);
586 }
587
getStrides(SmallVectorImpl<int64_t> & results)588 void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
589 populateFromInt64AttrArray(strides(), results);
590 }
591
592 //===----------------------------------------------------------------------===//
593 // BroadcastOp
594 //===----------------------------------------------------------------------===//
595
print(OpAsmPrinter & p,BroadcastOp op)596 static void print(OpAsmPrinter &p, BroadcastOp op) {
597 p << op.getOperationName() << " " << op.source() << " : "
598 << op.getSourceType() << " to " << op.getVectorType();
599 }
600
verify(BroadcastOp op)601 static LogicalResult verify(BroadcastOp op) {
602 VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
603 VectorType dstVectorType = op.getVectorType();
604 // Scalar to vector broadcast is always valid. A vector
605 // to vector broadcast needs some additional checking.
606 if (srcVectorType) {
607 int64_t srcRank = srcVectorType.getRank();
608 int64_t dstRank = dstVectorType.getRank();
609 if (srcRank > dstRank)
610 return op.emitOpError("source rank higher than destination rank");
611 // Source has an exact match or singleton value for all trailing dimensions
612 // (all leading dimensions are simply duplicated).
613 int64_t lead = dstRank - srcRank;
614 for (int64_t r = 0; r < srcRank; ++r) {
615 int64_t srcDim = srcVectorType.getDimSize(r);
616 int64_t dstDim = dstVectorType.getDimSize(lead + r);
617 if (srcDim != 1 && srcDim != dstDim)
618 return op.emitOpError("dimension mismatch (")
619 << srcDim << " vs. " << dstDim << ")";
620 }
621 }
622 return success();
623 }
624
parseBroadcastOp(OpAsmParser & parser,OperationState & result)625 static ParseResult parseBroadcastOp(OpAsmParser &parser,
626 OperationState &result) {
627 OpAsmParser::OperandType source;
628 Type sourceType;
629 VectorType vectorType;
630 return failure(parser.parseOperand(source) ||
631 parser.parseColonType(sourceType) ||
632 parser.parseKeywordType("to", vectorType) ||
633 parser.resolveOperand(source, sourceType, result.operands) ||
634 parser.addTypeToList(vectorType, result.types));
635 }
636
637 //===----------------------------------------------------------------------===//
638 // ShuffleOp
639 //===----------------------------------------------------------------------===//
640
build(Builder * builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)641 void ShuffleOp::build(Builder *builder, OperationState &result, Value v1,
642 Value v2, ArrayRef<int64_t> mask) {
643 result.addOperands({v1, v2});
644 auto maskAttr = getVectorSubscriptAttr(*builder, mask);
645 result.addTypes(v1.getType());
646 result.addAttribute(getMaskAttrName(), maskAttr);
647 }
648
print(OpAsmPrinter & p,ShuffleOp op)649 static void print(OpAsmPrinter &p, ShuffleOp op) {
650 p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
651 << op.mask();
652 p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
653 p << " : " << op.v1().getType() << ", " << op.v2().getType();
654 }
655
verify(ShuffleOp op)656 static LogicalResult verify(ShuffleOp op) {
657 VectorType resultType = op.getVectorType();
658 VectorType v1Type = op.getV1VectorType();
659 VectorType v2Type = op.getV2VectorType();
660 // Verify ranks.
661 int64_t resRank = resultType.getRank();
662 int64_t v1Rank = v1Type.getRank();
663 int64_t v2Rank = v2Type.getRank();
664 if (resRank != v1Rank || v1Rank != v2Rank)
665 return op.emitOpError("rank mismatch");
666 // Verify all but leading dimension sizes.
667 for (int64_t r = 1; r < v1Rank; ++r) {
668 int64_t resDim = resultType.getDimSize(r);
669 int64_t v1Dim = v1Type.getDimSize(r);
670 int64_t v2Dim = v2Type.getDimSize(r);
671 if (resDim != v1Dim || v1Dim != v2Dim)
672 return op.emitOpError("dimension mismatch");
673 }
674 // Verify mask length.
675 auto maskAttr = op.mask().getValue();
676 int64_t maskLength = maskAttr.size();
677 if (maskLength != resultType.getDimSize(0))
678 return op.emitOpError("mask length mismatch");
679 // Verify all indices.
680 int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
681 for (auto en : llvm::enumerate(maskAttr)) {
682 auto attr = en.value().dyn_cast<IntegerAttr>();
683 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
684 return op.emitOpError("mask index #")
685 << (en.index() + 1) << " out of range";
686 }
687 return success();
688 }
689
parseShuffleOp(OpAsmParser & parser,OperationState & result)690 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
691 OpAsmParser::OperandType v1, v2;
692 Attribute attr;
693 VectorType v1Type, v2Type;
694 if (parser.parseOperand(v1) || parser.parseComma() ||
695 parser.parseOperand(v2) ||
696 parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
697 result.attributes) ||
698 parser.parseOptionalAttrDict(result.attributes) ||
699 parser.parseColonType(v1Type) || parser.parseComma() ||
700 parser.parseType(v2Type) ||
701 parser.resolveOperand(v1, v1Type, result.operands) ||
702 parser.resolveOperand(v2, v2Type, result.operands))
703 return failure();
704 // Construct resulting type: leading dimension matches mask length,
705 // all trailing dimensions match the operands.
706 auto maskAttr = attr.dyn_cast<ArrayAttr>();
707 if (!maskAttr)
708 return parser.emitError(parser.getNameLoc(), "missing mask attribute");
709 int64_t maskLength = maskAttr.size();
710 if (maskLength <= 0)
711 return parser.emitError(parser.getNameLoc(), "invalid mask length");
712 int64_t v1Rank = v1Type.getRank();
713 SmallVector<int64_t, 4> shape;
714 shape.reserve(v1Rank);
715 shape.push_back(maskLength);
716 for (int64_t r = 1; r < v1Rank; ++r)
717 shape.push_back(v1Type.getDimSize(r));
718 VectorType resType = VectorType::get(shape, v1Type.getElementType());
719 parser.addTypeToList(resType, result.types);
720 return success();
721 }
722
723 //===----------------------------------------------------------------------===//
724 // InsertElementOp
725 //===----------------------------------------------------------------------===//
726
print(OpAsmPrinter & p,InsertElementOp op)727 static void print(OpAsmPrinter &p, InsertElementOp op) {
728 p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
729 << op.position() << " : " << op.position().getType() << "]";
730 p.printOptionalAttrDict(op.getAttrs());
731 p << " : " << op.dest().getType();
732 }
733
parseInsertElementOp(OpAsmParser & parser,OperationState & result)734 static ParseResult parseInsertElementOp(OpAsmParser &parser,
735 OperationState &result) {
736 OpAsmParser::OperandType source, dest, position;
737 Type positionType;
738 VectorType destType;
739 if (parser.parseOperand(source) || parser.parseComma() ||
740 parser.parseOperand(dest) || parser.parseLSquare() ||
741 parser.parseOperand(position) || parser.parseColonType(positionType) ||
742 parser.parseRSquare() ||
743 parser.parseOptionalAttrDict(result.attributes) ||
744 parser.parseColonType(destType))
745 return failure();
746 Type sourceType = destType.getElementType();
747 return failure(
748 parser.resolveOperand(source, sourceType, result.operands) ||
749 parser.resolveOperand(dest, destType, result.operands) ||
750 parser.resolveOperand(position, positionType, result.operands) ||
751 parser.addTypeToList(destType, result.types));
752 }
753
verify(InsertElementOp op)754 static LogicalResult verify(InsertElementOp op) {
755 auto dstVectorType = op.getDestVectorType();
756 if (dstVectorType.getRank() != 1)
757 return op.emitOpError("expected 1-D vector");
758 return success();
759 }
760
761 //===----------------------------------------------------------------------===//
762 // InsertOp
763 //===----------------------------------------------------------------------===//
764
build(Builder * builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)765 void InsertOp::build(Builder *builder, OperationState &result, Value source,
766 Value dest, ArrayRef<int64_t> position) {
767 result.addOperands({source, dest});
768 auto positionAttr = getVectorSubscriptAttr(*builder, position);
769 result.addTypes(dest.getType());
770 result.addAttribute(getPositionAttrName(), positionAttr);
771 }
772
print(OpAsmPrinter & p,InsertOp op)773 static void print(OpAsmPrinter &p, InsertOp op) {
774 p << op.getOperationName() << " " << op.source() << ", " << op.dest()
775 << op.position();
776 p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
777 p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
778 }
779
parseInsertOp(OpAsmParser & parser,OperationState & result)780 static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
781 SmallVector<NamedAttribute, 4> attrs;
782 OpAsmParser::OperandType source, dest;
783 Type sourceType;
784 VectorType destType;
785 Attribute attr;
786 return failure(parser.parseOperand(source) || parser.parseComma() ||
787 parser.parseOperand(dest) ||
788 parser.parseAttribute(attr, InsertOp::getPositionAttrName(),
789 result.attributes) ||
790 parser.parseOptionalAttrDict(attrs) ||
791 parser.parseColonType(sourceType) ||
792 parser.parseKeywordType("into", destType) ||
793 parser.resolveOperand(source, sourceType, result.operands) ||
794 parser.resolveOperand(dest, destType, result.operands) ||
795 parser.addTypeToList(destType, result.types));
796 }
797
verify(InsertOp op)798 static LogicalResult verify(InsertOp op) {
799 auto positionAttr = op.position().getValue();
800 if (positionAttr.empty())
801 return op.emitOpError("expected non-empty position attribute");
802 auto destVectorType = op.getDestVectorType();
803 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
804 return op.emitOpError(
805 "expected position attribute of rank smaller than dest vector rank");
806 auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
807 if (srcVectorType &&
808 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
809 static_cast<unsigned>(destVectorType.getRank())))
810 return op.emitOpError("expected position attribute rank + source rank to "
811 "match dest vector rank");
812 else if (!srcVectorType && (positionAttr.size() !=
813 static_cast<unsigned>(destVectorType.getRank())))
814 return op.emitOpError(
815 "expected position attribute rank to match the dest vector rank");
816 for (auto en : llvm::enumerate(positionAttr)) {
817 auto attr = en.value().dyn_cast<IntegerAttr>();
818 if (!attr || attr.getInt() < 0 ||
819 attr.getInt() >= destVectorType.getDimSize(en.index()))
820 return op.emitOpError("expected position attribute #")
821 << (en.index() + 1)
822 << " to be a non-negative integer smaller than the corresponding "
823 "dest vector dimension";
824 }
825 return success();
826 }
827
828 //===----------------------------------------------------------------------===//
829 // InsertSlicesOp
830 //===----------------------------------------------------------------------===//
831
parseInsertSlicesOp(OpAsmParser & parser,OperationState & result)832 static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
833 OperationState &result) {
834 OpAsmParser::OperandType operandInfo;
835 ArrayAttr sizesAttr;
836 StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
837 ArrayAttr stridesAttr;
838 StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
839 TupleType tupleType;
840 VectorType resultVectorType;
841 return failure(
842 parser.parseOperand(operandInfo) || parser.parseComma() ||
843 parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
844 parser.parseComma() ||
845 parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
846 parser.parseOptionalAttrDict(result.attributes) ||
847 parser.parseColonType(tupleType) ||
848 parser.parseKeywordType("into", resultVectorType) ||
849 parser.resolveOperand(operandInfo, tupleType, result.operands) ||
850 parser.addTypeToList(resultVectorType, result.types));
851 }
852
print(OpAsmPrinter & p,InsertSlicesOp op)853 static void print(OpAsmPrinter &p, InsertSlicesOp op) {
854 p << op.getOperationName() << ' ' << op.vectors() << ", ";
855 p << op.sizes() << ", " << op.strides();
856 p.printOptionalAttrDict(
857 op.getAttrs(),
858 /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
859 InsertSlicesOp::getStridesAttrName()});
860 p << " : " << op.vectors().getType();
861 p << " into " << op.getResultVectorType();
862 }
863
verify(InsertSlicesOp op)864 static LogicalResult verify(InsertSlicesOp op) {
865 SmallVector<int64_t, 4> sizes;
866 op.getSizes(sizes);
867 SmallVector<int64_t, 4> strides;
868 op.getStrides(strides);
869 return isValidExtractOrInsertSlicesType(
870 op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
871 sizes, strides);
872 }
873
getSizes(SmallVectorImpl<int64_t> & results)874 void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
875 populateFromInt64AttrArray(sizes(), results);
876 }
877
getStrides(SmallVectorImpl<int64_t> & results)878 void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
879 populateFromInt64AttrArray(strides(), results);
880 }
881
882 //===----------------------------------------------------------------------===//
883 // InsertStridedSliceOp
884 //===----------------------------------------------------------------------===//
885
build(Builder * builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)886 void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
887 Value source, Value dest,
888 ArrayRef<int64_t> offsets,
889 ArrayRef<int64_t> strides) {
890 result.addOperands({source, dest});
891 auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
892 auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
893 result.addTypes(dest.getType());
894 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
895 result.addAttribute(getStridesAttrName(), stridesAttr);
896 }
897
print(OpAsmPrinter & p,InsertStridedSliceOp op)898 static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
899 p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " ";
900 p.printOptionalAttrDict(op.getAttrs());
901 p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
902 }
903
parseInsertStridedSliceOp(OpAsmParser & parser,OperationState & result)904 static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
905 OperationState &result) {
906 OpAsmParser::OperandType source, dest;
907 VectorType sourceVectorType, destVectorType;
908 return failure(
909 parser.parseOperand(source) || parser.parseComma() ||
910 parser.parseOperand(dest) ||
911 parser.parseOptionalAttrDict(result.attributes) ||
912 parser.parseColonType(sourceVectorType) ||
913 parser.parseKeywordType("into", destVectorType) ||
914 parser.resolveOperand(source, sourceVectorType, result.operands) ||
915 parser.resolveOperand(dest, destVectorType, result.operands) ||
916 parser.addTypeToList(destVectorType, result.types));
917 }
918
919 // TODO(ntv) Should be moved to Tablegen Confined attributes.
920 template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)921 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
922 ArrayAttr arrayAttr,
923 ArrayRef<int64_t> shape,
924 StringRef attrName) {
925 if (arrayAttr.size() > shape.size())
926 return op.emitOpError("expected ")
927 << attrName << " attribute of rank smaller than vector rank";
928 return success();
929 }
930
931 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
932 // interval. If `halfOpen` is true then the admissible interval is [min, max).
933 // Otherwise, the admissible interval is [min, max].
934 template <typename OpType>
935 static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)936 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
937 int64_t max, StringRef attrName,
938 bool halfOpen = true) {
939 for (auto attr : arrayAttr) {
940 auto val = attr.cast<IntegerAttr>().getInt();
941 auto upper = max;
942 if (!halfOpen)
943 upper += 1;
944 if (val < min || val >= upper)
945 return op.emitOpError("expected ") << attrName << " to be confined to ["
946 << min << ", " << upper << ")";
947 }
948 return success();
949 }
950
951 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
952 // interval. If `halfOpen` is true then the admissible interval is [min, max).
953 // Otherwise, the admissible interval is [min, max].
954 template <typename OpType>
955 static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)956 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
957 ArrayRef<int64_t> shape, StringRef attrName,
958 bool halfOpen = true, int64_t min = 0) {
959 assert(arrayAttr.size() <= shape.size());
960 unsigned index = 0;
961 for (auto it : llvm::zip(arrayAttr, shape)) {
962 auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
963 auto max = std::get<1>(it);
964 if (!halfOpen)
965 max += 1;
966 if (val < min || val >= max)
967 return op.emitOpError("expected ")
968 << attrName << " dimension " << index << " to be confined to ["
969 << min << ", " << max << ")";
970 ++index;
971 }
972 return success();
973 }
974
975 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
976 // interval. If `halfOpen` is true then the admissible interval is [min, max).
977 // Otherwise, the admissible interval is [min, max].
978 template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)979 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
980 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
981 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
982 bool halfOpen = true, int64_t min = 1) {
983 assert(arrayAttr1.size() <= shape.size());
984 assert(arrayAttr2.size() <= shape.size());
985 unsigned index = 0;
986 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
987 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
988 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
989 auto max = std::get<2>(it);
990 if (!halfOpen)
991 max += 1;
992 if (val1 + val2 < 0 || val1 + val2 >= max)
993 return op.emitOpError("expected sum(")
994 << attrName1 << ", " << attrName2 << ") dimension " << index
995 << " to be confined to [" << min << ", " << max << ")";
996 ++index;
997 }
998 return success();
999 }
1000
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)1001 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1002 MLIRContext *context) {
1003 auto attrs = functional::map(
1004 [context](int64_t v) -> Attribute {
1005 return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
1006 },
1007 values);
1008 return ArrayAttr::get(attrs, context);
1009 }
1010
verify(InsertStridedSliceOp op)1011 static LogicalResult verify(InsertStridedSliceOp op) {
1012 auto sourceVectorType = op.getSourceVectorType();
1013 auto destVectorType = op.getDestVectorType();
1014 auto offsets = op.offsets();
1015 auto strides = op.strides();
1016 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1017 return op.emitOpError(
1018 "expected offsets of same size as destination vector rank");
1019 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1020 return op.emitOpError(
1021 "expected strides of same size as source vector rank");
1022 if (sourceVectorType.getRank() > destVectorType.getRank())
1023 return op.emitOpError(
1024 "expected source rank to be smaller than destination rank");
1025
1026 auto sourceShape = sourceVectorType.getShape();
1027 auto destShape = destVectorType.getShape();
1028 SmallVector<int64_t, 4> sourceShapeAsDestShape(
1029 destShape.size() - sourceShape.size(), 0);
1030 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
1031 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
1032 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
1033 if (failed(
1034 isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
1035 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1036 /*halfOpen=*/false)) ||
1037 failed(isSumOfIntegerArrayAttrConfinedToShape(
1038 op, offsets,
1039 makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
1040 offName, "source vector shape",
1041 /*halfOpen=*/false, /*min=*/1)))
1042 return failure();
1043
1044 return success();
1045 }
1046
1047 //===----------------------------------------------------------------------===//
1048 // OuterProductOp
1049 //===----------------------------------------------------------------------===//
1050
print(OpAsmPrinter & p,OuterProductOp op)1051 static void print(OpAsmPrinter &p, OuterProductOp op) {
1052 p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
1053 if (!op.acc().empty())
1054 p << ", " << op.acc();
1055 p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
1056 }
1057
parseOuterProductOp(OpAsmParser & parser,OperationState & result)1058 static ParseResult parseOuterProductOp(OpAsmParser &parser,
1059 OperationState &result) {
1060 SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
1061 Type tLHS, tRHS;
1062 if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
1063 parser.parseComma() || parser.parseType(tRHS))
1064 return failure();
1065 if (operandsInfo.size() < 2)
1066 return parser.emitError(parser.getNameLoc(),
1067 "expected at least 2 operands");
1068 VectorType vLHS = tLHS.dyn_cast<VectorType>();
1069 VectorType vRHS = tRHS.dyn_cast<VectorType>();
1070 if (!vLHS || !vRHS)
1071 return parser.emitError(parser.getNameLoc(), "expected 2 vector types");
1072 VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
1073 vLHS.getElementType());
1074 return failure(
1075 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
1076 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
1077 (operandsInfo.size() > 2 &&
1078 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
1079 parser.addTypeToList(resType, result.types));
1080 }
1081
verify(OuterProductOp op)1082 static LogicalResult verify(OuterProductOp op) {
1083 VectorType vLHS = op.getOperandVectorTypeLHS(),
1084 vRHS = op.getOperandVectorTypeRHS(),
1085 vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
1086 if (vLHS.getRank() != 1)
1087 return op.emitOpError("expected 1-d vector for operand #1");
1088 if (vRHS.getRank() != 1)
1089 return op.emitOpError("expected 1-d vector for operand #2");
1090 if (vRES.getRank() != 2)
1091 return op.emitOpError("expected 2-d vector result");
1092 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
1093 return op.emitOpError("expected #1 operand dim to match result dim #1");
1094 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
1095 return op.emitOpError("expected #2 operand dim to match result dim #2");
1096 if (vACC && vACC != vRES)
1097 return op.emitOpError("expected operand #3 of same type as result type");
1098 return success();
1099 }
1100
1101 //===----------------------------------------------------------------------===//
1102 // ReshapeOp
1103 //===----------------------------------------------------------------------===//
1104
print(OpAsmPrinter & p,ReshapeOp op)1105 static void print(OpAsmPrinter &p, ReshapeOp op) {
1106 p << op.getOperationName() << " " << op.vector() << ", [" << op.input_shape()
1107 << "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
1108 SmallVector<StringRef, 2> elidedAttrs = {
1109 ReshapeOp::getOperandSegmentSizeAttr(),
1110 ReshapeOp::getFixedVectorSizesAttrName()};
1111 p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
1112 p << " : " << op.getInputVectorType() << " to " << op.getOutputVectorType();
1113 }
1114
1115 // TODO(b/146516564) Consider passing number of inner vector dimensions that
1116 // are fixed, instead of their values in 'fixesVectorSizes' array attr.
1117 //
1118 // operation ::= ssa-id `=` `vector.reshape` ssa-use, `[` ssa-use-list `]`,
1119 // `[` ssa-use-list `]`, `[` array-attribute `]`
1120 // `:` vector-type 'to' vector-type
1121 //
parseReshapeOp(OpAsmParser & parser,OperationState & result)1122 static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
1123 OpAsmParser::OperandType inputInfo;
1124 SmallVector<OpAsmParser::OperandType, 4> inputShapeInfo;
1125 SmallVector<OpAsmParser::OperandType, 4> outputShapeInfo;
1126 ArrayAttr fixedVectorSizesAttr;
1127 StringRef attrName = ReshapeOp::getFixedVectorSizesAttrName();
1128 auto indexType = parser.getBuilder().getIndexType();
1129 if (parser.parseOperand(inputInfo) || parser.parseComma() ||
1130 parser.parseOperandList(inputShapeInfo, OpAsmParser::Delimiter::Square) ||
1131 parser.parseComma() ||
1132 parser.parseOperandList(outputShapeInfo,
1133 OpAsmParser::Delimiter::Square) ||
1134 parser.parseComma()) {
1135 return failure();
1136 }
1137
1138 auto builder = parser.getBuilder();
1139 result.addAttribute(
1140 ReshapeOp::getOperandSegmentSizeAttr(),
1141 builder.getI32VectorAttr({1, static_cast<int32_t>(inputShapeInfo.size()),
1142 static_cast<int32_t>(outputShapeInfo.size())}));
1143 Type inputType;
1144 Type outputType;
1145 return failure(
1146 parser.parseAttribute(fixedVectorSizesAttr, attrName,
1147 result.attributes) ||
1148 parser.parseOptionalAttrDict(result.attributes) ||
1149 parser.parseColonType(inputType) ||
1150 parser.resolveOperand(inputInfo, inputType, result.operands) ||
1151 parser.resolveOperands(inputShapeInfo, indexType, result.operands) ||
1152 parser.resolveOperands(outputShapeInfo, indexType, result.operands) ||
1153 parser.parseKeywordType("to", outputType) ||
1154 parser.addTypeToList(outputType, result.types));
1155 }
1156
verify(ReshapeOp op)1157 static LogicalResult verify(ReshapeOp op) {
1158 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
1159 auto inputVectorType = op.getInputVectorType();
1160 auto outputVectorType = op.getOutputVectorType();
1161 int64_t inputShapeRank = op.getNumInputShapeSizes();
1162 int64_t outputShapeRank = op.getNumOutputShapeSizes();
1163 SmallVector<int64_t, 4> fixedVectorSizes;
1164 op.getFixedVectorSizes(fixedVectorSizes);
1165 int64_t numFixedVectorSizes = fixedVectorSizes.size();
1166
1167 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
1168 return op.emitError("invalid input shape for vector type ")
1169 << inputVectorType;
1170
1171 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
1172 return op.emitError("invalid output shape for vector type ")
1173 << outputVectorType;
1174
1175 // Verify that the 'fixedVectorSizes' match a input/output vector shape
1176 // suffix.
1177 unsigned inputVectorRank = inputVectorType.getRank();
1178 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1179 unsigned index = inputVectorRank - numFixedVectorSizes - i;
1180 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
1181 return op.emitError("fixed vector size must match input vector for dim ")
1182 << i;
1183 }
1184
1185 unsigned outputVectorRank = outputVectorType.getRank();
1186 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
1187 unsigned index = outputVectorRank - numFixedVectorSizes - i;
1188 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
1189 return op.emitError("fixed vector size must match output vector for dim ")
1190 << i;
1191 }
1192
1193 // If all shape operands are produced by constant ops, verify that product
1194 // of dimensions for input/output shape match.
1195 auto isDefByConstant = [](Value operand) {
1196 return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1197 };
1198 if (llvm::all_of(op.input_shape(), isDefByConstant) &&
1199 llvm::all_of(op.output_shape(), isDefByConstant)) {
1200 int64_t numInputElements = 1;
1201 for (auto operand : op.input_shape())
1202 numInputElements *=
1203 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1204 int64_t numOutputElements = 1;
1205 for (auto operand : op.output_shape())
1206 numOutputElements *=
1207 cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
1208 if (numInputElements != numOutputElements)
1209 return op.emitError("product of input and output shape sizes must match");
1210 }
1211 return success();
1212 }
1213
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)1214 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
1215 populateFromInt64AttrArray(fixed_vector_sizes(), results);
1216 }
1217
1218 //===----------------------------------------------------------------------===//
1219 // StridedSliceOp
1220 //===----------------------------------------------------------------------===//
1221
1222 // Inference works as follows:
1223 // 1. Add 'sizes' from prefix of dims in 'offsets'.
1224 // 2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)1225 static Type inferStridedSliceOpResultType(VectorType vectorType,
1226 ArrayAttr offsets, ArrayAttr sizes,
1227 ArrayAttr strides) {
1228 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
1229 SmallVector<int64_t, 4> shape;
1230 shape.reserve(vectorType.getRank());
1231 unsigned idx = 0;
1232 for (unsigned e = offsets.size(); idx < e; ++idx)
1233 shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
1234 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
1235 shape.push_back(vectorType.getShape()[idx]);
1236
1237 return VectorType::get(shape, vectorType.getElementType());
1238 }
1239
build(Builder * builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)1240 void StridedSliceOp::build(Builder *builder, OperationState &result,
1241 Value source, ArrayRef<int64_t> offsets,
1242 ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
1243 result.addOperands(source);
1244 auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
1245 auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
1246 auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
1247 result.addTypes(
1248 inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
1249 offsetsAttr, sizesAttr, stridesAttr));
1250 result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1251 result.addAttribute(getSizesAttrName(), sizesAttr);
1252 result.addAttribute(getStridesAttrName(), stridesAttr);
1253 }
1254
print(OpAsmPrinter & p,StridedSliceOp op)1255 static void print(OpAsmPrinter &p, StridedSliceOp op) {
1256 p << op.getOperationName() << " " << op.vector();
1257 p.printOptionalAttrDict(op.getAttrs());
1258 p << " : " << op.vector().getType() << " to " << op.getResult().getType();
1259 }
1260
parseStridedSliceOp(OpAsmParser & parser,OperationState & result)1261 static ParseResult parseStridedSliceOp(OpAsmParser &parser,
1262 OperationState &result) {
1263 llvm::SMLoc attributeLoc, typeLoc;
1264 OpAsmParser::OperandType vector;
1265 VectorType vectorType, resultVectorType;
1266 return failure(parser.parseOperand(vector) ||
1267 parser.getCurrentLocation(&attributeLoc) ||
1268 parser.parseOptionalAttrDict(result.attributes) ||
1269 parser.getCurrentLocation(&typeLoc) ||
1270 parser.parseColonType(vectorType) ||
1271 parser.parseKeywordType("to", resultVectorType) ||
1272 parser.resolveOperand(vector, vectorType, result.operands) ||
1273 parser.addTypeToList(resultVectorType, result.types));
1274 }
1275
verify(StridedSliceOp op)1276 static LogicalResult verify(StridedSliceOp op) {
1277 auto type = op.getVectorType();
1278 auto offsets = op.offsets();
1279 auto sizes = op.sizes();
1280 auto strides = op.strides();
1281 if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
1282 op.emitOpError(
1283 "expected offsets, sizes and strides attributes of same size");
1284 return failure();
1285 }
1286
1287 auto shape = type.getShape();
1288 auto offName = StridedSliceOp::getOffsetsAttrName();
1289 auto sizesName = StridedSliceOp::getSizesAttrName();
1290 auto stridesName = StridedSliceOp::getStridesAttrName();
1291 if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
1292 failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
1293 failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
1294 stridesName)) ||
1295 failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
1296 failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
1297 /*halfOpen=*/false,
1298 /*min=*/1)) ||
1299 failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
1300 /*halfOpen=*/false)) ||
1301 failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
1302 offName, sizesName,
1303 /*halfOpen=*/false)))
1304 return failure();
1305
1306 auto resultType = inferStridedSliceOpResultType(
1307 op.getVectorType(), op.offsets(), op.sizes(), op.strides());
1308 if (op.getResult().getType() != resultType) {
1309 op.emitOpError("expected result type to be ") << resultType;
1310 return failure();
1311 }
1312
1313 return success();
1314 }
1315
getOffsets(SmallVectorImpl<int64_t> & results)1316 void StridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
1317 populateFromInt64AttrArray(offsets(), results);
1318 }
1319
1320 namespace {
1321
1322 // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
1323 class StridedSliceConstantMaskFolder final
1324 : public OpRewritePattern<StridedSliceOp> {
1325 public:
1326 using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
1327
matchAndRewrite(StridedSliceOp stridedSliceOp,PatternRewriter & rewriter) const1328 PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
1329 PatternRewriter &rewriter) const override {
1330 // Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
1331 auto defOp = stridedSliceOp.vector().getDefiningOp();
1332 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
1333 if (!constantMaskOp)
1334 return matchFailure();
1335 // Return if 'stridedSliceOp' has non-unit strides.
1336 if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
1337 return attr.cast<IntegerAttr>().getInt() != 1;
1338 }))
1339 return matchFailure();
1340 // Gather constant mask dimension sizes.
1341 SmallVector<int64_t, 4> maskDimSizes;
1342 populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
1343 // Gather strided slice offsets and sizes.
1344 SmallVector<int64_t, 4> sliceOffsets;
1345 populateFromInt64AttrArray(stridedSliceOp.offsets(), sliceOffsets);
1346 SmallVector<int64_t, 4> sliceSizes;
1347 populateFromInt64AttrArray(stridedSliceOp.sizes(), sliceSizes);
1348
1349 // Compute slice of vector mask region.
1350 SmallVector<int64_t, 4> sliceMaskDimSizes;
1351 assert(sliceOffsets.size() == maskDimSizes.size());
1352 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
1353 int64_t maskDimSize = std::get<0>(it);
1354 int64_t sliceOffset = std::get<1>(it);
1355 int64_t sliceSize = std::get<2>(it);
1356 int64_t sliceMaskDimSize = std::max(
1357 static_cast<int64_t>(0),
1358 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
1359 sliceMaskDimSizes.push_back(sliceMaskDimSize);
1360 }
1361 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
1362 // region is a conjunction of mask dim intervals).
1363 if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
1364 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
1365
1366 // Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
1367 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
1368 stridedSliceOp, stridedSliceOp.getResult().getType(),
1369 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
1370 return matchSuccess();
1371 }
1372 };
1373
1374 } // end anonymous namespace
1375
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1376 void StridedSliceOp::getCanonicalizationPatterns(
1377 OwningRewritePatternList &results, MLIRContext *context) {
1378 // Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
1379 results.insert<StridedSliceConstantMaskFolder>(context);
1380 }
1381
1382 //===----------------------------------------------------------------------===//
1383 // TransferReadOp
1384 //===----------------------------------------------------------------------===//
1385 template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)1386 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
1387 EmitFun emitOpError) {
1388 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
1389 for (auto expr : permutationMap.getResults()) {
1390 auto dim = expr.dyn_cast<AffineDimExpr>();
1391 auto zero = expr.dyn_cast<AffineConstantExpr>();
1392 if (zero) {
1393 if (zero.getValue() != 0) {
1394 return emitOpError(
1395 "requires a projected permutation_map (at most one dim or the zero "
1396 "constant can appear in each result)");
1397 }
1398 continue;
1399 }
1400 if (!dim) {
1401 return emitOpError("requires a projected permutation_map (at most one "
1402 "dim or the zero constant can appear in each result)");
1403 }
1404 if (seen[dim.getPosition()]) {
1405 return emitOpError(
1406 "requires a permutation_map that is a permutation (found one dim "
1407 "used more than once)");
1408 }
1409 seen[dim.getPosition()] = true;
1410 }
1411 return success();
1412 }
1413
verifyTransferOp(Operation * op,MemRefType memrefType,VectorType vectorType,AffineMap permutationMap)1414 static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
1415 VectorType vectorType,
1416 AffineMap permutationMap) {
1417 auto memrefElementType = memrefType.getElementType();
1418 if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
1419 // Memref has vector element type.
1420
1421 // Check that 'memrefVectorElementType' and vector element types match.
1422 if (memrefVectorElementType.getElementType() != vectorType.getElementType())
1423 return op->emitOpError(
1424 "requires memref and vector types of the same elemental type");
1425
1426 // Check that memref vector type is a suffix of 'vectorType.
1427 unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1428 unsigned resultVecRank = vectorType.getRank();
1429 if (memrefVecEltRank > resultVecRank)
1430 return op->emitOpError(
1431 "requires memref vector element and vector result ranks to match.");
1432 // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h.
1433 unsigned rankOffset = resultVecRank - memrefVecEltRank;
1434 auto memrefVecEltShape = memrefVectorElementType.getShape();
1435 auto resultVecShape = vectorType.getShape();
1436 for (unsigned i = 0; i < memrefVecEltRank; ++i)
1437 if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
1438 return op->emitOpError(
1439 "requires memref vector element shape to match suffix of "
1440 "vector result shape.");
1441 // Check that permutation map results match 'rankOffset' of vector type.
1442 if (permutationMap.getNumResults() != rankOffset)
1443 return op->emitOpError("requires a permutation_map with result dims of "
1444 "the same rank as the vector type");
1445 } else {
1446 // Memref has scalar element type.
1447
1448 // Check that memref and vector element types match.
1449 if (memrefType.getElementType() != vectorType.getElementType())
1450 return op->emitOpError(
1451 "requires memref and vector types of the same elemental type");
1452
1453 // Check that permutation map results match rank of vector type.
1454 if (permutationMap.getNumResults() != vectorType.getRank())
1455 return op->emitOpError("requires a permutation_map with result dims of "
1456 "the same rank as the vector type");
1457 }
1458
1459 if (permutationMap.getNumSymbols() != 0)
1460 return op->emitOpError("requires permutation_map without symbols");
1461 if (permutationMap.getNumInputs() != memrefType.getRank())
1462 return op->emitOpError("requires a permutation_map with input dims of the "
1463 "same rank as the memref type");
1464 return success();
1465 }
1466
print(OpAsmPrinter & p,TransferReadOp op)1467 static void print(OpAsmPrinter &p, TransferReadOp op) {
1468 p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
1469 << "], " << op.padding() << " ";
1470 p.printOptionalAttrDict(op.getAttrs());
1471 p << " : " << op.getMemRefType() << ", " << op.getVectorType();
1472 }
1473
parseTransferReadOp(OpAsmParser & parser,OperationState & result)1474 static ParseResult parseTransferReadOp(OpAsmParser &parser,
1475 OperationState &result) {
1476 llvm::SMLoc typesLoc;
1477 OpAsmParser::OperandType memrefInfo;
1478 SmallVector<OpAsmParser::OperandType, 8> indexInfo;
1479 OpAsmParser::OperandType paddingInfo;
1480 SmallVector<Type, 2> types;
1481 // Parsing with support for optional paddingValue.
1482 if (parser.parseOperand(memrefInfo) ||
1483 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1484 parser.parseComma() || parser.parseOperand(paddingInfo) ||
1485 parser.parseOptionalAttrDict(result.attributes) ||
1486 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1487 return failure();
1488 if (types.size() != 2)
1489 return parser.emitError(typesLoc, "two types required");
1490 auto indexType = parser.getBuilder().getIndexType();
1491 MemRefType memRefType = types[0].dyn_cast<MemRefType>();
1492 if (!memRefType)
1493 return parser.emitError(typesLoc, "memref type required"), failure();
1494 Type vectorType = types[1];
1495 return failure(
1496 parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
1497 parser.resolveOperands(indexInfo, indexType, result.operands) ||
1498 parser.resolveOperand(paddingInfo, memRefType.getElementType(),
1499 result.operands) ||
1500 parser.addTypeToList(vectorType, result.types));
1501 }
1502
verify(TransferReadOp op)1503 static LogicalResult verify(TransferReadOp op) {
1504 // Consistency of elemental types in memref and vector.
1505 MemRefType memrefType = op.getMemRefType();
1506 VectorType vectorType = op.getVectorType();
1507 auto paddingType = op.padding().getType();
1508 auto permutationMap = op.permutation_map();
1509 auto memrefElementType = memrefType.getElementType();
1510
1511 if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
1512 return op.emitOpError("requires ") << memrefType.getRank() << " indices";
1513
1514 if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
1515 permutationMap)))
1516 return failure();
1517
1518 if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
1519 // Memref has vector element type.
1520 // Check that 'memrefVectorElementType' and 'paddingType' types match.
1521 if (memrefVectorElementType != paddingType)
1522 return op.emitOpError(
1523 "requires memref element type and padding type to match.");
1524
1525 } else {
1526 // Check that 'paddingType' is valid to store in a vector type.
1527 if (!VectorType::isValidElementType(paddingType))
1528 return op.emitOpError("requires valid padding vector elemental type");
1529
1530 // Check that padding type and vector element types match.
1531 if (paddingType != vectorType.getElementType())
1532 return op.emitOpError(
1533 "requires formal padding and vector of the same elemental type");
1534 }
1535
1536 return verifyPermutationMap(permutationMap,
1537 [&op](Twine t) { return op.emitOpError(t); });
1538 }
1539
1540 //===----------------------------------------------------------------------===//
1541 // TransferWriteOp
1542 //===----------------------------------------------------------------------===//
print(OpAsmPrinter & p,TransferWriteOp op)1543 static void print(OpAsmPrinter &p, TransferWriteOp op) {
1544 p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
1545 << op.indices() << "]";
1546 p.printOptionalAttrDict(op.getAttrs());
1547 p << " : " << op.getVectorType() << ", " << op.getMemRefType();
1548 }
1549
parseTransferWriteOp(OpAsmParser & parser,OperationState & result)1550 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
1551 OperationState &result) {
1552 llvm::SMLoc typesLoc;
1553 OpAsmParser::OperandType storeValueInfo;
1554 OpAsmParser::OperandType memRefInfo;
1555 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1556 SmallVector<Type, 2> types;
1557 if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
1558 parser.parseOperand(memRefInfo) ||
1559 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1560 parser.parseOptionalAttrDict(result.attributes) ||
1561 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1562 return failure();
1563 if (types.size() != 2)
1564 return parser.emitError(typesLoc, "two types required");
1565 auto indexType = parser.getBuilder().getIndexType();
1566 Type vectorType = types[0], memRefType = types[1];
1567 return failure(
1568 parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
1569 parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
1570 parser.resolveOperands(indexInfo, indexType, result.operands));
1571 }
1572
verify(TransferWriteOp op)1573 static LogicalResult verify(TransferWriteOp op) {
1574 // Consistency of elemental types in memref and vector.
1575 MemRefType memrefType = op.getMemRefType();
1576 VectorType vectorType = op.getVectorType();
1577 auto permutationMap = op.permutation_map();
1578
1579 if (llvm::size(op.indices()) != memrefType.getRank())
1580 return op.emitOpError("requires ") << memrefType.getRank() << " indices";
1581
1582 if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
1583 permutationMap)))
1584 return failure();
1585
1586 return verifyPermutationMap(permutationMap,
1587 [&op](Twine t) { return op.emitOpError(t); });
1588 }
1589
1590 //===----------------------------------------------------------------------===//
1591 // TypeCastOp
1592 //===----------------------------------------------------------------------===//
1593
inferVectorTypeCastResultType(MemRefType t)1594 static MemRefType inferVectorTypeCastResultType(MemRefType t) {
1595 return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
1596 }
1597
build(Builder * builder,OperationState & result,Value source)1598 void TypeCastOp::build(Builder *builder, OperationState &result, Value source) {
1599 result.addOperands(source);
1600 result.addTypes(
1601 inferVectorTypeCastResultType(source.getType().cast<MemRefType>()));
1602 }
1603
print(OpAsmPrinter & p,TypeCastOp op)1604 static void print(OpAsmPrinter &p, TypeCastOp op) {
1605 auto type = op.getOperand().getType().cast<MemRefType>();
1606 p << op.getOperationName() << ' ' << op.memref() << " : " << type << " to "
1607 << inferVectorTypeCastResultType(type);
1608 }
1609
verify(TypeCastOp op)1610 static LogicalResult verify(TypeCastOp op) {
1611 auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
1612 if (op.getResultMemRefType() != resultType)
1613 return op.emitOpError("expects result type to be: ") << resultType;
1614 return success();
1615 }
1616
1617 //===----------------------------------------------------------------------===//
1618 // TupleOp
1619 //===----------------------------------------------------------------------===//
1620
parseTupleOp(OpAsmParser & parser,OperationState & result)1621 static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
1622 SmallVector<OpAsmParser::OperandType, 4> operandInfos;
1623 SmallVector<Type, 4> types;
1624 auto loc = parser.getCurrentLocation();
1625 auto *ctx = parser.getBuilder().getContext();
1626 return failure(
1627 parser.parseOperandList(operandInfos) ||
1628 parser.parseOptionalAttrDict(result.attributes) ||
1629 parser.parseColonTypeList(types) ||
1630 parser.resolveOperands(operandInfos, types, loc, result.operands) ||
1631 parser.addTypeToList(TupleType::get(types, ctx), result.types));
1632 }
1633
print(OpAsmPrinter & p,TupleOp op)1634 static void print(OpAsmPrinter &p, TupleOp op) {
1635 p << op.getOperationName() << ' ';
1636 p.printOperands(op.getOperands());
1637 p.printOptionalAttrDict(op.getAttrs());
1638 p << " : ";
1639 interleaveComma(op.getOperation()->getOperandTypes(), p);
1640 }
1641
verify(TupleOp op)1642 static LogicalResult verify(TupleOp op) { return success(); }
1643
1644 //===----------------------------------------------------------------------===//
1645 // TupleGetOp
1646 //===----------------------------------------------------------------------===//
1647
parseTupleGetOp(OpAsmParser & parser,OperationState & result)1648 static ParseResult parseTupleGetOp(OpAsmParser &parser,
1649 OperationState &result) {
1650 OpAsmParser::OperandType operandInfo;
1651 IntegerAttr indexAttr;
1652 StringRef indexAttrName = TupleGetOp::getIndexAttrName();
1653 Type indexType = parser.getBuilder().getIndexType();
1654 TupleType tupleType;
1655 if (parser.parseOperand(operandInfo) || parser.parseComma() ||
1656 parser.parseAttribute(indexAttr, indexType, indexAttrName,
1657 result.attributes) ||
1658 parser.parseOptionalAttrDict(result.attributes) ||
1659 parser.parseColonType(tupleType) ||
1660 parser.resolveOperand(operandInfo, tupleType, result.operands))
1661 return failure();
1662 if (indexAttr.getInt() < 0 ||
1663 indexAttr.getInt() >= static_cast<int64_t>(tupleType.size()))
1664 return failure();
1665 parser.addTypeToList(tupleType.getType(indexAttr.getInt()), result.types);
1666 return success();
1667 }
1668
print(OpAsmPrinter & p,TupleGetOp op)1669 static void print(OpAsmPrinter &p, TupleGetOp op) {
1670 p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index();
1671 p.printOptionalAttrDict(op.getAttrs(),
1672 /*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
1673 p << " : " << op.getOperand().getType();
1674 }
1675
verify(TupleGetOp op)1676 static LogicalResult verify(TupleGetOp op) {
1677 auto tupleType = op.getOperand().getType().cast<TupleType>();
1678 if (op.getIndex() < 0 ||
1679 op.getIndex() >= static_cast<int64_t>(tupleType.size()))
1680 return op.emitOpError("tuple get index out of range");
1681 return success();
1682 }
1683
1684 //===----------------------------------------------------------------------===//
1685 // ConstantMaskOp
1686 //===----------------------------------------------------------------------===//
1687
parseConstantMaskOp(OpAsmParser & parser,OperationState & result)1688 static ParseResult parseConstantMaskOp(OpAsmParser &parser,
1689 OperationState &result) {
1690 Type resultType;
1691 ArrayAttr maskDimSizesAttr;
1692 StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
1693 return failure(
1694 parser.parseOptionalAttrDict(result.attributes) ||
1695 parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
1696 parser.parseColonType(resultType) ||
1697 parser.addTypeToList(resultType, result.types));
1698 }
1699
print(OpAsmPrinter & p,ConstantMaskOp op)1700 static void print(OpAsmPrinter &p, ConstantMaskOp op) {
1701 p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
1702 << op.getResult().getType();
1703 }
1704
verify(ConstantMaskOp & op)1705 static LogicalResult verify(ConstantMaskOp &op) {
1706 // Verify that array attr size matches the rank of the vector result.
1707 auto resultType = op.getResult().getType().cast<VectorType>();
1708 if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
1709 return op.emitOpError(
1710 "must specify array attr of size equal vector result rank");
1711 // Verify that each array attr element is in bounds of corresponding vector
1712 // result dimension size.
1713 auto resultShape = resultType.getShape();
1714 SmallVector<int64_t, 4> maskDimSizes;
1715 for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
1716 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
1717 if (attrValue < 0 || attrValue > resultShape[it.index()])
1718 return op.emitOpError(
1719 "array attr of size out of bounds of vector result dimension size");
1720 maskDimSizes.push_back(attrValue);
1721 }
1722 // Verify that if one mask dim size is zero, they all should be zero (because
1723 // the mask region is a conjunction of each mask dimension interval).
1724 bool any_zeros = llvm::is_contained(maskDimSizes, 0);
1725 bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
1726 if (any_zeros && !all_zeros)
1727 return op.emitOpError("expected all mask dim sizes to be zeros, "
1728 "as a result of conjunction with zero mask dim");
1729 return success();
1730 }
1731
1732 //===----------------------------------------------------------------------===//
1733 // CreateMaskOp
1734 //===----------------------------------------------------------------------===//
1735
parseCreateMaskOp(OpAsmParser & parser,OperationState & result)1736 static ParseResult parseCreateMaskOp(OpAsmParser &parser,
1737 OperationState &result) {
1738 auto indexType = parser.getBuilder().getIndexType();
1739 Type resultType;
1740 SmallVector<OpAsmParser::OperandType, 4> operandInfo;
1741 return failure(
1742 parser.parseOperandList(operandInfo) ||
1743 parser.parseOptionalAttrDict(result.attributes) ||
1744 parser.parseColonType(resultType) ||
1745 parser.resolveOperands(operandInfo, indexType, result.operands) ||
1746 parser.addTypeToList(resultType, result.types));
1747 }
1748
print(OpAsmPrinter & p,CreateMaskOp op)1749 static void print(OpAsmPrinter &p, CreateMaskOp op) {
1750 p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
1751 }
1752
verify(CreateMaskOp op)1753 static LogicalResult verify(CreateMaskOp op) {
1754 // Verify that an operand was specified for each result vector each dimension.
1755 if (op.getNumOperands() !=
1756 op.getResult().getType().cast<VectorType>().getRank())
1757 return op.emitOpError(
1758 "must specify an operand for each result vector dimension");
1759 return success();
1760 }
1761
1762 //===----------------------------------------------------------------------===//
1763 // PrintOp
1764 //===----------------------------------------------------------------------===//
1765
parsePrintOp(OpAsmParser & parser,OperationState & result)1766 static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
1767 OpAsmParser::OperandType source;
1768 Type sourceType;
1769 return failure(parser.parseOperand(source) ||
1770 parser.parseColonType(sourceType) ||
1771 parser.resolveOperand(source, sourceType, result.operands));
1772 }
1773
print(OpAsmPrinter & p,PrintOp op)1774 static void print(OpAsmPrinter &p, PrintOp op) {
1775 p << op.getOperationName() << ' ' << op.source() << " : "
1776 << op.getPrintType();
1777 }
1778
1779 namespace {
1780
1781 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
1782 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
1783 public:
1784 using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
1785
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const1786 PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
1787 PatternRewriter &rewriter) const override {
1788 // Return if any of 'createMaskOp' operands are not defined by a constant.
1789 auto is_not_def_by_constant = [](Value operand) {
1790 return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
1791 };
1792 if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
1793 return matchFailure();
1794 // Gather constant mask dimension sizes.
1795 SmallVector<int64_t, 4> maskDimSizes;
1796 for (auto operand : createMaskOp.operands()) {
1797 auto defOp = operand.getDefiningOp();
1798 maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
1799 }
1800 // Replace 'createMaskOp' with ConstantMaskOp.
1801 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
1802 createMaskOp, createMaskOp.getResult().getType(),
1803 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
1804 return matchSuccess();
1805 }
1806 };
1807
1808 } // end anonymous namespace
1809
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1810 void CreateMaskOp::getCanonicalizationPatterns(
1811 OwningRewritePatternList &results, MLIRContext *context) {
1812 results.insert<CreateMaskFolder>(context);
1813 }
1814
populateVectorToVectorCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)1815 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
1816 OwningRewritePatternList &patterns, MLIRContext *context) {
1817 patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
1818 }
1819
1820 namespace mlir {
1821 namespace vector {
1822
1823 #define GET_OP_CLASSES
1824 #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
1825
1826 } // namespace vector
1827 } // namespace mlir
1828