1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <type_traits>
14
15 #include "mlir/Dialect/Affine/EDSC/Builders.h"
16 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/Dialect/Vector/VectorUtils.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Location.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OperationSupport.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/IR/Types.h"
38 #include "mlir/Interfaces/VectorInterfaces.h"
39
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43
44 #define DEBUG_TYPE "vector-to-vector"
45
46 using namespace mlir;
47 using llvm::dbgs;
48
49 // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)50 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
51 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
52 int64_t idx = map.getDimPosition(i);
53 if (idx == index)
54 return i;
55 }
56 return None;
57 }
58
59 // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)60 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
61 int64_t index) {
62 SmallVector<Attribute, 4> results;
63 for (auto it : llvm::enumerate(iteratorTypes)) {
64 int64_t idx = it.index();
65 if (idx == index)
66 continue;
67 results.push_back(it.value());
68 }
69 return results;
70 }
71
72 // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)73 static AffineMap adjustMap(AffineMap map, int64_t index,
74 PatternRewriter &rewriter) {
75 auto *ctx = rewriter.getContext();
76 SmallVector<AffineExpr, 4> results;
77 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
78 int64_t idx = map.getDimPosition(i);
79 if (idx == index)
80 continue;
81 // Re-insert remaining indices, but renamed when occurring
82 // after the removed index.
83 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
84 results.push_back(targetExpr);
85 }
86 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
87 }
88
89 // Helper to drop dimension from vector type.
adjustType(VectorType tp,int64_t index)90 static Type adjustType(VectorType tp, int64_t index) {
91 int64_t rank = tp.getRank();
92 Type eltType = tp.getElementType();
93 if (rank == 1) {
94 assert(index == 0 && "index for scalar result out of bounds");
95 return eltType;
96 }
97 SmallVector<int64_t, 4> adjustedShape;
98 for (int64_t i = 0; i < rank; ++i) {
99 // Omit dimension at the given index.
100 if (i == index)
101 continue;
102 // Otherwise, add dimension back.
103 adjustedShape.push_back(tp.getDimSize(i));
104 }
105 return VectorType::get(adjustedShape, eltType);
106 }
107
108 // Helper method to possibly drop a dimension in a load.
109 // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)110 static Value reshapeLoad(Location loc, Value val, VectorType type,
111 int64_t index, int64_t pos,
112 PatternRewriter &rewriter) {
113 if (index == -1)
114 return val;
115 Type lowType = adjustType(type, 0);
116 // At extraction dimension?
117 if (index == 0) {
118 auto posAttr = rewriter.getI64ArrayAttr(pos);
119 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
120 }
121 // Unroll leading dimensions.
122 VectorType vType = lowType.cast<VectorType>();
123 VectorType resType = adjustType(type, index).cast<VectorType>();
124 Value result =
125 rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
126 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
127 auto posAttr = rewriter.getI64ArrayAttr(d);
128 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
129 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
130 result =
131 rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
132 }
133 return result;
134 }
135
136 // Helper method to possibly drop a dimension in a store.
137 // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)138 static Value reshapeStore(Location loc, Value val, Value result,
139 VectorType type, int64_t index, int64_t pos,
140 PatternRewriter &rewriter) {
141 // Unmodified?
142 if (index == -1)
143 return val;
144 // At insertion dimension?
145 if (index == 0) {
146 auto posAttr = rewriter.getI64ArrayAttr(pos);
147 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
148 }
149 // Unroll leading dimensions.
150 Type lowType = adjustType(type, 0);
151 VectorType vType = lowType.cast<VectorType>();
152 Type insType = adjustType(vType, 0);
153 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
154 auto posAttr = rewriter.getI64ArrayAttr(d);
155 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
156 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
157 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
158 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
159 }
160 return result;
161 }
162
163 // Clones `op` into a new operations that takes `operands` and returns
164 // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)165 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
166 Operation *op,
167 ArrayRef<Value> operands,
168 ArrayRef<Type> resultTypes) {
169 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
170 op->getAttrs());
171 return builder.createOperation(res);
172 }
173
174 // Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
175 // for each index 'i' in inputElements with a valid mapping in 'indexMap'.
getMappedElements(const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> inputElements,SmallVectorImpl<int64_t> & resultElements)176 static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
177 ArrayRef<int64_t> inputElements,
178 SmallVectorImpl<int64_t> &resultElements) {
179 assert(indexMap.size() == resultElements.size());
180 assert(inputElements.size() >= resultElements.size());
181 for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
182 auto it = indexMap.find(i);
183 if (it != indexMap.end())
184 resultElements[it->second] = inputElements[i];
185 }
186 }
187
188 // Returns a tuple type with vector element types for each resulting slice
189 // of 'vectorType' unrolled by 'sizes' and 'strides'.
190 // TODO: Move this to a utility function and share it with
191 // Extract/InsertSlicesOp verification.
generateExtractSlicesOpResultType(VectorType vectorType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,OpBuilder & builder)192 static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
193 ArrayRef<int64_t> sizes,
194 ArrayRef<int64_t> strides,
195 OpBuilder &builder) {
196 assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
197 assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
198 assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
199
200 // Compute shape ratio of 'shape' and 'sizes'.
201 auto shape = vectorType.getShape();
202 auto maybeDimSliceCounts = shapeRatio(shape, sizes);
203 assert(maybeDimSliceCounts.hasValue());
204 auto sliceDimCounts = *maybeDimSliceCounts;
205
206 // Compute strides w.r.t number of slices in each dimension.
207 auto sliceStrides = computeStrides(sliceDimCounts);
208 int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts);
209 SmallVector<Type, 4> vectorTypes(sliceCount);
210 for (unsigned i = 0; i < sliceCount; ++i) {
211 auto vectorOffsets = delinearize(sliceStrides, i);
212 auto elementOffsets =
213 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
214 auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
215 // Create Vector type and add to 'vectorTypes[i]'.
216 vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
217 }
218 return builder.getTupleType(vectorTypes);
219 }
220
221 // UnrolledVectorState aggregates per-operand/result vector state required for
222 // unrolling.
223 struct UnrolledVectorState {
224 SmallVector<int64_t, 4> unrolledShape;
225 SmallVector<int64_t, 4> unrollFactors;
226 SmallVector<int64_t, 8> basis;
227 int64_t numInstances;
228 Value slicesTuple;
229 };
230
231 // Populates 'state' with unrolled shape, unroll factors, basis and
232 // num unrolled instances for 'vectorType'.
initUnrolledVectorState(VectorType vectorType,Value initValue,const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> targetShape,UnrolledVectorState & state,OpBuilder & builder)233 static void initUnrolledVectorState(VectorType vectorType, Value initValue,
234 const DenseMap<int64_t, int64_t> &indexMap,
235 ArrayRef<int64_t> targetShape,
236 UnrolledVectorState &state,
237 OpBuilder &builder) {
238 // Compute unrolled shape of 'vectorType'.
239 state.unrolledShape.resize(vectorType.getRank());
240 getMappedElements(indexMap, targetShape, state.unrolledShape);
241 // Compute unroll factors for unrolled shape.
242 auto maybeUnrollFactors =
243 shapeRatio(vectorType.getShape(), state.unrolledShape);
244 assert(maybeUnrollFactors.hasValue());
245 state.unrollFactors = *maybeUnrollFactors;
246 // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
247 state.basis = computeStrides(state.unrollFactors);
248 state.numInstances = computeMaxLinearIndex(state.unrollFactors);
249 state.slicesTuple = nullptr;
250 if (initValue != nullptr) {
251 // Create ExtractSlicesOp.
252 SmallVector<int64_t, 4> sizes(state.unrolledShape);
253 SmallVector<int64_t, 4> strides(state.unrollFactors.size(), 1);
254 auto tupleType =
255 generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
256 state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
257 initValue.getLoc(), tupleType, initValue, sizes, strides);
258 }
259 }
260
261 // Computes and returns the linear index of the unrolled vector at
262 // 'vectorOffsets' within the vector represented by 'state'.
263 static int64_t
getUnrolledVectorLinearIndex(UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,DenseMap<int64_t,int64_t> & indexMap)264 getUnrolledVectorLinearIndex(UnrolledVectorState &state,
265 ArrayRef<int64_t> vectorOffsets,
266 DenseMap<int64_t, int64_t> &indexMap) {
267 // Compute vector offsets.
268 SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
269 getMappedElements(indexMap, vectorOffsets, sliceOffsets);
270 // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
271 return linearize(sliceOffsets, state.basis);
272 }
273
274 // Returns an unrolled vector at 'vectorOffsets' within the vector
275 // represented by 'state'. The vector is created from a slice of 'initValue'
276 // if not present in 'cache'.
getOrCreateUnrolledVectorSlice(Location loc,UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,ArrayRef<int64_t> offsets,DenseMap<int64_t,int64_t> & indexMap,Value initValue,SmallVectorImpl<Value> & cache,OpBuilder & builder)277 static Value getOrCreateUnrolledVectorSlice(
278 Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
279 ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
280 Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
281 // Compute slice offsets.
282 SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
283 getMappedElements(indexMap, offsets, sliceOffsets);
284 // TODO: Support non-1 strides.
285 SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
286 // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
287 int64_t sliceLinearIndex =
288 getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
289 assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
290 auto valueSlice = cache[sliceLinearIndex];
291 if (valueSlice == nullptr) {
292 // Return tuple element at 'sliceLinearIndex'.
293 auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
294 auto initValueType = initValue.getType().cast<VectorType>();
295 auto vectorType =
296 VectorType::get(state.unrolledShape, initValueType.getElementType());
297 // Initialize 'cache' with slice from 'initValue'.
298 valueSlice = builder.create<vector::TupleGetOp>(
299 loc, vectorType, state.slicesTuple, tupleIndex);
300 // Store value back to 'cache'.
301 cache[sliceLinearIndex] = valueSlice;
302 }
303 return valueSlice;
304 }
305
306 // VectorState aggregates per-operand/result vector state required for
307 // creating slices of vector operands, and clones of the operation being
308 // unrolled.
309 struct VectorState {
310 // The type of this vector.
311 VectorType type;
312 // Map from iteration space index to vector dimension index.
313 DenseMap<int64_t, int64_t> indexMap;
314 // Index of this value in operation's operand list (-1 if not an operand).
315 int64_t operandIndex = -1;
316 // Accumulator iterator flag.
317 bool isAcc = false;
318 };
319
320 //
321 // unrollSingleResultStructuredOp
322 //
323 // Returns a value representing the result of structured operation 'op'
324 // with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
325 // A list of VectorState objects must be specified in 'vectors', where
326 // each VectorState in the list represents a vector operand or vector result
327 // (if the operation does not have an accumulator operand).
328 // The VectorState at index 'resultIndex' in the list must be the state
329 // associated with the operations single result (i.e. either its accumulator
330 // operand or vector result value).
331 //
332 // Example:
333 //
334 // // Before unrolling
335 //
336 // operand0 operand1 operand2
337 // \ | /
338 // -------------------- opA --------------------
339 //
340 // // After unrolling by 2
341 //
342 // operand0 operand1 operand2
343 // / \ / \ / \
344 // slice00 slice01 slice10 slice11 slice20 slice21
345 // \ | | | / |
346 // -------------------- opA0 -------------------- |
347 // | | | |
348 // \ | | /
349 // -------------------- opA1 -------------------
350 // | |
351 // \ /
352 // insertslice
353 // |
354
355 // TODO: Add the following canonicalization/simplification patterns:
356 // *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards
357 // InsertStridedSlice operand to StridedSlice.
358 // *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks
359 // if there are duplicate identical StridedSlice ops from SourceOp, and
360 // rewrites itself to use the first duplicate. This transformation should
361 // cause users of identifical StridedSlice ops to reuse the same StridedSlice
362 // operation, and leave the duplicate StridedSlice ops with no users
363 // (removable with DCE).
364
365 // TODO: Generalize this to support structured ops beyond
366 // vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
unrollSingleResultStructuredOp(Operation * op,ArrayRef<int64_t> iterationBounds,std::vector<VectorState> & vectors,unsigned resultIndex,ArrayRef<int64_t> targetShape,OpBuilder & builder)367 static Value unrollSingleResultStructuredOp(Operation *op,
368 ArrayRef<int64_t> iterationBounds,
369 std::vector<VectorState> &vectors,
370 unsigned resultIndex,
371 ArrayRef<int64_t> targetShape,
372 OpBuilder &builder) {
373 auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
374 if (!shapedType || !shapedType.hasStaticShape())
375 assert(false && "Expected a statically shaped result type");
376
377 // Compute unroll factors for 'iterationBounds' based on 'targetShape'
378 auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
379 if (!maybeUnrollFactors.hasValue())
380 assert(false && "Failed to compute unroll factors for target shape");
381 auto unrollFactors = *maybeUnrollFactors;
382
383 // Compute unrolled vector state for each vector in 'vectors'.
384 unsigned numVectors = vectors.size();
385 SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
386 for (unsigned i = 0; i < numVectors; ++i) {
387 int64_t operandIndex = vectors[i].operandIndex;
388 auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
389 initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
390 targetShape, unrolledVectorState[i], builder);
391 }
392 // Compute number of total unrolled instances.
393 auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
394 auto sliceStrides = computeStrides(unrollFactors);
395
396 auto &resultValueState = unrolledVectorState[resultIndex];
397 auto unrolledResultType = VectorType::get(resultValueState.unrolledShape,
398 shapedType.getElementType());
399
400 // Initialize caches for intermediate vector results.
401 std::vector<SmallVector<Value, 4>> caches(numVectors);
402 for (unsigned i = 0; i < numVectors; ++i)
403 caches[i].resize(unrolledVectorState[i].numInstances);
404
405 // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
406 for (unsigned i = 0; i < numUnrolledInstances; ++i) {
407 auto vectorOffsets = delinearize(sliceStrides, i);
408 auto elementOffsets =
409 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
410 // Get cached slice (or create slice) for each operand at 'offsets'.
411 SmallVector<Value, 3> operands;
412 operands.resize(op->getNumOperands());
413 for (unsigned i = 0; i < numVectors; ++i) {
414 int64_t operandIndex = vectors[i].operandIndex;
415 if (operandIndex < 0)
416 continue; // Output
417 auto operand = op->getOperand(operandIndex);
418 operands[operandIndex] = getOrCreateUnrolledVectorSlice(
419 op->getLoc(), unrolledVectorState[i], vectorOffsets, elementOffsets,
420 vectors[i].indexMap, operand, caches[i], builder);
421 }
422 // Create op on sliced vector arguments.
423 auto resultVector =
424 cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
425 unrolledResultType)
426 ->getResult(0);
427
428 // Compute linear result index.
429 int64_t linearIndex = getUnrolledVectorLinearIndex(
430 resultValueState, vectorOffsets, vectors[resultIndex].indexMap);
431 // Update result cache at 'linearIndex'.
432 caches[resultIndex][linearIndex] = resultVector;
433 }
434
435 // Create TupleOp of unrolled result vectors.
436 SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
437 SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
438 for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
439 vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>();
440 vectorTupleValues[i] = caches[resultIndex][i];
441 }
442 TupleType tupleType = builder.getTupleType(vectorTupleTypes);
443 Value tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
444 vectorTupleValues);
445
446 // Create InsertSlicesOp(Tuple(result_vectors)).
447 auto resultVectorType = op->getResult(0).getType().cast<VectorType>();
448 SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
449 SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
450
451 Value insertSlicesOp = builder.create<vector::InsertSlicesOp>(
452 op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
453 builder.getI64ArrayAttr(strides));
454 return insertSlicesOp;
455 }
456
getVectorContractionOpUnrollState(vector::ContractionOp contractionOp,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)457 static void getVectorContractionOpUnrollState(
458 vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
459 std::vector<VectorState> &vectors, unsigned &resultIndex) {
460 // Get map from iteration space index to lhs/rhs/result shape index.
461 std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
462 contractionOp.getIterationIndexMap(iterationIndexMapList);
463 unsigned numIterators = iterationIndexMapList.size();
464 vectors.resize(numIterators);
465 unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
466 for (unsigned i = 0; i < numIterators; ++i) {
467 vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>();
468 vectors[i].indexMap = iterationIndexMapList[i];
469 vectors[i].operandIndex = i;
470 vectors[i].isAcc = i == accOperandIndex ? true : false;
471 }
472
473 if (llvm::size(contractionOp.masks()) == 2) {
474 // Add vectors for lhs/rhs vector mask arguments. Masks have the
475 // same vector shape lhs/rhs args, so copy their index maps.
476 vectors.push_back({contractionOp.getLHSVectorMaskType(),
477 vectors[0].indexMap, accOperandIndex + 1, false});
478 vectors.push_back({contractionOp.getRHSVectorMaskType(),
479 vectors[1].indexMap, accOperandIndex + 2, false});
480 }
481 // TODO: Use linalg style 'args_in'/'args_out' to partition
482 // 'vectors' instead of 'resultIndex'.
483 resultIndex = accOperandIndex;
484 }
485
getVectorElementwiseOpUnrollState(Operation * op,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)486 static void getVectorElementwiseOpUnrollState(Operation *op,
487 ArrayRef<int64_t> targetShape,
488 std::vector<VectorState> &vectors,
489 unsigned &resultIndex) {
490 // Verify that operation and operands all have the same vector shape.
491 auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
492 assert(resultType && "Expected op with vector result type");
493 auto resultShape = resultType.getShape();
494 // Verify that all operands have the same vector type as result.
495 assert(llvm::all_of(op->getOperandTypes(), [=](Type type) {
496 return type.cast<VectorType>().getShape() == resultShape;
497 }));
498
499 // Create trivial elementwise identity index map based on 'resultShape'.
500 DenseMap<int64_t, int64_t> indexMap;
501 indexMap.reserve(resultShape.size());
502 for (unsigned i = 0; i < resultShape.size(); ++i)
503 indexMap[i] = i;
504
505 // Create VectorState each operand and single result.
506 unsigned numVectors = op->getNumOperands() + op->getNumResults();
507 vectors.resize(numVectors);
508 for (auto it : llvm::enumerate(op->getOperandTypes()))
509 vectors[it.index()] = {it.value().cast<VectorType>(), indexMap,
510 static_cast<int64_t>(it.index()), false};
511 vectors[numVectors - 1] = {resultType, indexMap, -1, false};
512 resultIndex = numVectors - 1;
513 }
514
515 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
516 /// calls 'fn' with linear index and indices for each slice.
generateTransferOpSlices(Type shapedElementType,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<Value> indices,OpBuilder & builder,function_ref<void (unsigned,ArrayRef<Value>)> fn)517 static void generateTransferOpSlices(
518 Type shapedElementType, VectorType vectorType, TupleType tupleType,
519 ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
520 OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
521 // Compute strides w.r.t. to slice counts in each dimension.
522 auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
523 assert(maybeDimSliceCounts.hasValue());
524 auto sliceDimCounts = *maybeDimSliceCounts;
525 auto sliceStrides = computeStrides(sliceDimCounts);
526
527 int64_t numSlices = tupleType.size();
528 unsigned numSliceIndices = indices.size();
529 // Compute 'indexOffset' at which to update 'indices', which is equal
530 // to the memref rank (indices.size) minus the effective 'vectorRank'.
531 // The effective 'vectorRank', is equal to the rank of the vector type
532 // minus the rank of the memref vector element type (if it has one).
533 //
534 // For example:
535 //
536 // Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector
537 // transfer_read/write ops which read/write vectors of type
538 // 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective
539 // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
540 //
541 unsigned vectorRank = vectorType.getRank();
542 if (auto sourceVectorElementType = shapedElementType.dyn_cast<VectorType>()) {
543 assert(vectorRank >= sourceVectorElementType.getRank());
544 vectorRank -= sourceVectorElementType.getRank();
545 }
546 unsigned indexOffset = numSliceIndices - vectorRank;
547
548 auto *ctx = builder.getContext();
549 for (unsigned i = 0; i < numSlices; ++i) {
550 auto vectorOffsets = delinearize(sliceStrides, i);
551 auto elementOffsets =
552 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
553 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
554 SmallVector<Value, 4> sliceIndices(numSliceIndices);
555 for (unsigned j = 0; j < numSliceIndices; ++j) {
556 if (j < indexOffset) {
557 sliceIndices[j] = indices[j];
558 } else {
559 auto expr = getAffineDimExpr(0, ctx) +
560 getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
561 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
562 sliceIndices[j] = builder.create<AffineApplyOp>(
563 indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
564 }
565 }
566 // Call 'fn' to generate slice 'i' at 'sliceIndices'.
567 fn(i, sliceIndices);
568 }
569 }
570
571 /// Returns true if 'map' is a suffix of an identity affine map, false
572 /// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
isIdentitySuffix(AffineMap map)573 static bool isIdentitySuffix(AffineMap map) {
574 if (map.getNumDims() < map.getNumResults())
575 return false;
576 ArrayRef<AffineExpr> results = map.getResults();
577 Optional<int> lastPos;
578 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
579 auto expr = results[i].dyn_cast<AffineDimExpr>();
580 if (!expr)
581 return false;
582 int currPos = static_cast<int>(expr.getPosition());
583 if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
584 return false;
585 lastPos = currPos;
586 }
587 return true;
588 }
589
590 /// Unroll transfer_read ops to the given shape and create an aggregate with all
591 /// the chunks.
unrollTransferReadOp(vector::TransferReadOp readOp,ArrayRef<int64_t> targetShape,OpBuilder & builder)592 static Value unrollTransferReadOp(vector::TransferReadOp readOp,
593 ArrayRef<int64_t> targetShape,
594 OpBuilder &builder) {
595 if (!isIdentitySuffix(readOp.permutation_map()))
596 return nullptr;
597 auto sourceVectorType = readOp.getVectorType();
598 SmallVector<int64_t, 4> strides(targetShape.size(), 1);
599
600 Location loc = readOp.getLoc();
601 auto shapedElementType =
602 readOp.source().getType().cast<ShapedType>().getElementType();
603 auto tupleType = generateExtractSlicesOpResultType(
604 sourceVectorType, targetShape, strides, builder);
605 int64_t numSlices = tupleType.size();
606
607 SmallVector<Value, 4> vectorTupleValues(numSlices);
608 SmallVector<Value, 4> indices(readOp.indices().begin(),
609 readOp.indices().end());
610 auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
611 // Get VectorType for slice 'i'.
612 auto sliceVectorType = tupleType.getType(index);
613 // Create split TransferReadOp for 'sliceUser'.
614 // `masked` attribute propagates conservatively: if the coarse op didn't
615 // need masking, the fine op doesn't either.
616 vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
617 loc, sliceVectorType, readOp.source(), sliceIndices,
618 readOp.permutation_map(), readOp.padding(),
619 readOp.masked() ? *readOp.masked() : ArrayAttr());
620 };
621 generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
622 targetShape, strides, indices, builder, createSlice);
623
624 // Create tuple of splice transfer read operations.
625 Value tupleOp =
626 builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
627 // Replace 'readOp' with result 'insertSlicesResult'.
628 Value newVec = builder.create<vector::InsertSlicesOp>(
629 loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
630 builder.getI64ArrayAttr(strides));
631 return newVec;
632 }
633
634 // Entry point for unrolling declarative pattern rewrite for transfer_write op.
635 LogicalResult
unrollTransferWriteOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape,SmallVector<Value,1> & result)636 mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
637 ArrayRef<int64_t> targetShape,
638 SmallVector<Value, 1> &result) {
639 auto writeOp = cast<vector::TransferWriteOp>(op);
640 if (!isIdentitySuffix(writeOp.permutation_map()))
641 return failure();
642 VectorType sourceVectorType = writeOp.getVectorType();
643 SmallVector<int64_t, 4> strides(targetShape.size(), 1);
644 TupleType tupleType = generateExtractSlicesOpResultType(
645 sourceVectorType, targetShape, strides, builder);
646 Location loc = writeOp.getLoc();
647 Value tuple = builder.create<vector::ExtractSlicesOp>(
648 loc, tupleType, writeOp.vector(), targetShape, strides);
649 auto shapedElementType =
650 writeOp.source().getType().cast<ShapedType>().getElementType();
651 SmallVector<Value, 4> indices(writeOp.indices().begin(),
652 writeOp.indices().end());
653 // If the TransferWrite returns a tensor, keep track of the last tensor
654 // created.
655 Value resultTensor;
656 auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
657 auto element = builder.create<vector::TupleGetOp>(
658 loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
659 Operation *write = builder.create<vector::TransferWriteOp>(
660 loc, element.getResult(),
661 resultTensor ? resultTensor : writeOp.source(), sliceIndices,
662 writeOp.permutation_map(),
663 writeOp.masked() ? *writeOp.masked() : ArrayAttr());
664 if (!write->getResults().empty())
665 resultTensor = write->getResult(0);
666 };
667 generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
668 targetShape, strides, indices, builder, createSlice);
669 if (resultTensor)
670 result.push_back(resultTensor);
671 return success();
672 }
673
674 // Entry point for unrolling declarative pattern rewrites.
675 SmallVector<Value, 1>
unrollSingleResultVectorOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape)676 mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
677 ArrayRef<int64_t> targetShape) {
678 assert(op->getNumResults() == 1 && "Expected single result operation");
679
680 // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
681 SmallVector<int64_t, 6> iterationBounds;
682 auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
683 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
684 assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
685
686 std::vector<VectorState> vectors;
687 unsigned resultIndex;
688
689 if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
690 return SmallVector<Value, 1>{
691 unrollTransferReadOp(readOp, targetShape, builder)};
692
693 if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
694 // Populate state for vector ContractionOp.
695 getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
696 resultIndex);
697 } else {
698 // Populate state for vector elementwise op.
699 getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
700 }
701
702 // Unroll 'op' with 'iterationBounds' to 'targetShape'.
703 return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
704 op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
705 }
706
707 namespace {
708
709 // Splits vector TransferReadOp into smaller TransferReadOps based on slicing
710 // scheme of its unique ExtractSlicesOp user.
711 struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
712 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
713
matchAndRewrite__anon85c463fe0511::SplitTransferReadOp714 LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
715 PatternRewriter &rewriter) const override {
716 // TODO: Support splitting TransferReadOp with non-identity
717 // permutation maps. Repurpose code from MaterializeVectors transformation.
718 if (!isIdentitySuffix(xferReadOp.permutation_map()))
719 return failure();
720 // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
721 Value xferReadResult = xferReadOp.getResult();
722 auto extractSlicesOp =
723 dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
724 if (!xferReadResult.hasOneUse() || !extractSlicesOp)
725 return failure();
726
727 // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
728 SmallVector<int64_t, 4> sizes;
729 extractSlicesOp.getSizes(sizes);
730 SmallVector<int64_t, 4> strides;
731 extractSlicesOp.getStrides(strides);
732 assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
733
734 Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
735 if (!newVec)
736 return failure();
737 rewriter.replaceOp(xferReadOp, newVec);
738 return success();
739 }
740 };
741
742 // Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
743 struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
744 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
745
matchAndRewrite__anon85c463fe0511::SplitTransferWriteOp746 LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
747 PatternRewriter &rewriter) const override {
748 // TODO: Support splitting TransferWriteOp with non-identity
749 // permutation maps. Repurpose code from MaterializeVectors transformation.
750 if (!isIdentitySuffix(xferWriteOp.permutation_map()))
751 return failure();
752 // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
753 auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
754 auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
755 if (!insertSlicesOp)
756 return failure();
757
758 // Get TupleOp operand of 'insertSlicesOp'.
759 auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
760 insertSlicesOp.vectors().getDefiningOp());
761 if (!tupleOp)
762 return failure();
763
764 // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
765 auto sourceTupleType = insertSlicesOp.getSourceTupleType();
766 auto resultVectorType = insertSlicesOp.getResultVectorType();
767 SmallVector<int64_t, 4> sizes;
768 insertSlicesOp.getSizes(sizes);
769 SmallVector<int64_t, 4> strides;
770 insertSlicesOp.getStrides(strides);
771
772 Location loc = xferWriteOp.getLoc();
773 auto shapedElementType =
774 xferWriteOp.source().getType().cast<ShapedType>().getElementType();
775 SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
776 xferWriteOp.indices().end());
777 Value resultTensor;
778 auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
779 // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
780 // `masked` attribute propagates conservatively: if the coarse op didn't
781 // need masking, the fine op doesn't either.
782 Operation *write = rewriter.create<vector::TransferWriteOp>(
783 loc, tupleOp.getOperand(index),
784 resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
785 xferWriteOp.permutation_map(),
786 xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
787 if (!write->getResults().empty())
788 resultTensor = write->getResult(0);
789 };
790 generateTransferOpSlices(shapedElementType, resultVectorType,
791 sourceTupleType, sizes, strides, indices, rewriter,
792 createSlice);
793
794 // Erase old 'xferWriteOp'.
795 if (resultTensor)
796 rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
797 else
798 rewriter.eraseOp(xferWriteOp);
799 return success();
800 }
801 };
802
803 /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
804 /// on vector types.
805 struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
806 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
807
matchAndRewrite__anon85c463fe0511::ShapeCastOpDecomposer808 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
809 PatternRewriter &rewriter) const override {
810 // Check if 'shapeCastOp' has tuple source/result type.
811 auto sourceTupleType =
812 shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
813 auto resultTupleType =
814 shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
815 if (!sourceTupleType || !resultTupleType)
816 return failure();
817 assert(sourceTupleType.size() == resultTupleType.size());
818
819 // Create single-vector ShapeCastOp for each source tuple element.
820 Location loc = shapeCastOp.getLoc();
821 SmallVector<Value, 8> resultElements;
822 resultElements.reserve(resultTupleType.size());
823 for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
824 auto sourceElement = rewriter.create<vector::TupleGetOp>(
825 loc, sourceTupleType.getType(i), shapeCastOp.source(),
826 rewriter.getI64IntegerAttr(i));
827 resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
828 loc, resultTupleType.getType(i), sourceElement));
829 }
830
831 // Replace 'shapeCastOp' with tuple of 'resultElements'.
832 rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
833 resultElements);
834 return success();
835 }
836 };
837
838 /// Returns the producer Value of the same type as 'consumerValue', by tracking
839 /// the tuple index and offsets of the consumer vector value through the
840 /// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
841 /// and ShapeCastOp) from consumer to producer. Each operation in the chain is
842 /// structured, and so the tuple index and offsets can be mapped from result to
843 /// input, while visiting each operation in the chain.
844 /// Returns nullptr on failure.
getProducerValue(Value consumerValue)845 static Value getProducerValue(Value consumerValue) {
846 auto consumerVectorType = consumerValue.getType().cast<VectorType>();
847 // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
848 int64_t tupleIndex = -1;
849 SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
850 auto *op = consumerValue.getDefiningOp();
851 while (op != nullptr) {
852 if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
853 assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
854
855 // Update 'tupleIndex' and next defining 'op' to visit.
856 tupleIndex = tupleGetOp.getIndex();
857 op = tupleGetOp.vectors().getDefiningOp();
858 } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
859 assert(tupleIndex >= 0);
860
861 // Compute slice strides for 'extractSlicesOp'.
862 SmallVector<int64_t, 4> sizes;
863 extractSlicesOp.getSizes(sizes);
864 auto sliceStrides = computeStrides(
865 extractSlicesOp.getSourceVectorType().getShape(), sizes);
866
867 // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
868 // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
869 auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
870 auto elementOffsets =
871 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
872
873 // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
874 // to the 'extractSlicesOp' input vector type.
875 assert(offsets.size() == elementOffsets.size());
876 for (unsigned i = 0, e = offsets.size(); i < e; ++i)
877 offsets[i] += elementOffsets[i];
878
879 // Clear 'tupleIndex' and update next defining 'op' to visit.
880 tupleIndex = -1;
881 op = extractSlicesOp.vector().getDefiningOp();
882 } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
883 assert(tupleIndex == -1);
884
885 // Compute slice strides for 'insertSlicesOp'.
886 SmallVector<int64_t, 4> sizes;
887 insertSlicesOp.getSizes(sizes);
888 auto sliceStrides = computeStrides(
889 insertSlicesOp.getResultVectorType().getShape(), sizes);
890
891 // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
892 // of 'insertSlicesOp' result vector type at 'offsets'.
893 SmallVector<int64_t, 4> vectorOffsets(offsets.size());
894 assert(offsets.size() == sizes.size());
895 for (unsigned i = 0, e = offsets.size(); i < e; ++i)
896 vectorOffsets[i] = offsets[i] / sizes[i];
897
898 // Compute the source tuple element index.
899 tupleIndex = linearize(vectorOffsets, sliceStrides);
900
901 // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
902 // relative to input tuple element vector type at 'tupleIndex'.
903 auto elementOffsets =
904 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
905 assert(offsets.size() == elementOffsets.size());
906 for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
907 offsets[i] -= elementOffsets[i];
908 assert(offsets[i] >= 0);
909 }
910
911 // Update next defining 'op' to visit.
912 op = insertSlicesOp.vectors().getDefiningOp();
913 } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
914 assert(tupleIndex >= 0);
915
916 // Return tuple element 'value' at 'tupleIndex' if it matches type.
917 auto value = tupleOp.getOperand(tupleIndex);
918 if (value.getType() == consumerVectorType)
919 return value;
920
921 // Update 'tupleIndex' and next defining 'op' to visit.
922 tupleIndex = -1;
923 op = value.getDefiningOp();
924 } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
925 if (shapeCastOp.source().getType().isa<TupleType>())
926 return nullptr;
927 assert(tupleIndex == -1);
928 auto sourceVectorType = shapeCastOp.getSourceVectorType();
929 auto sourceVectorShape = sourceVectorType.getShape();
930 unsigned sourceVectorRank = sourceVectorType.getRank();
931 auto resultVectorType = shapeCastOp.getResultVectorType();
932 auto resultVectorShape = resultVectorType.getShape();
933 unsigned resultVectorRank = resultVectorType.getRank();
934
935 int i = sourceVectorRank - 1;
936 int j = resultVectorRank - 1;
937
938 // Check that source/result vector shape prefixes match while updating
939 // 'newOffsets'.
940 SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
941 for (auto it : llvm::zip(llvm::reverse(sourceVectorShape),
942 llvm::reverse(resultVectorShape))) {
943 if (std::get<0>(it) != std::get<1>(it))
944 return nullptr;
945 newOffsets[i--] = offsets[j--];
946 }
947
948 // Check that remaining prefix of source/result vector shapes are all 1s.
949 // Currently we only support producer/consumer tracking through trivial
950 // shape cast ops. Examples:
951 // %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
952 // %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
953 assert(i == -1 || j == -1);
954 if (i >= 0 &&
955 !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
956 [](int64_t v) { return v == 1; }))
957 return nullptr;
958 if (j >= 0 &&
959 !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
960 [](int64_t v) { return v == 1; }))
961 return nullptr;
962
963 offsets.swap(newOffsets);
964 op = shapeCastOp.source().getDefiningOp();
965 } else {
966 // Check if 'op' produces a Value with the same type as 'consumerValue'.
967 if (op->getNumResults() == 1 &&
968 op->getResult(0).getType() == consumerVectorType)
969 return op->getResult(0);
970 return nullptr;
971 }
972 }
973 return nullptr;
974 }
975
976 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
977 //
978 // Example:
979 //
980 // The following MLIR with cancelling ShapeCastOps:
981 //
982 // %0 = source : vector<5x4x2xf32>
983 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
984 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
985 // %3 = user %2 : vector<5x4x2xf32>
986 //
987 // Should canonicalize to the following:
988 //
989 // %0 = source : vector<5x4x2xf32>
990 // %1 = user %0 : vector<5x4x2xf32>
991 //
992 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
993 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
994
matchAndRewrite__anon85c463fe0511::ShapeCastOpFolder995 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
996 PatternRewriter &rewriter) const override {
997 // Check if we can replace 'shapeCastOp' result with its producer.
998 if (auto producer = getProducerValue(shapeCastOp.getResult())) {
999 rewriter.replaceOp(shapeCastOp, producer);
1000 return success();
1001 }
1002
1003 // Check if 'shapeCastOp' has vector source/result type.
1004 auto sourceVectorType =
1005 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
1006 auto resultVectorType =
1007 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
1008 if (!sourceVectorType || !resultVectorType)
1009 return failure();
1010
1011 // Check if shape cast op source operand is also a shape cast op.
1012 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
1013 shapeCastOp.source().getDefiningOp());
1014 if (!sourceShapeCastOp)
1015 return failure();
1016 auto operandSourceVectorType =
1017 sourceShapeCastOp.source().getType().cast<VectorType>();
1018 auto operandResultVectorType = sourceShapeCastOp.getType();
1019
1020 // Check if shape cast operations invert each other.
1021 if (operandSourceVectorType != resultVectorType ||
1022 operandResultVectorType != sourceVectorType)
1023 return failure();
1024
1025 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
1026 return success();
1027 }
1028 };
1029
1030 // Patter rewrite which forward tuple elements to their users.
1031 // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
1032 // -> User(Producer)
1033 struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
1034 using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
1035
matchAndRewrite__anon85c463fe0511::TupleGetFolderOp1036 LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
1037 PatternRewriter &rewriter) const override {
1038 if (auto producer = getProducerValue(tupleGetOp.getResult())) {
1039 rewriter.replaceOp(tupleGetOp, producer);
1040 return success();
1041 }
1042 return failure();
1043 }
1044 };
1045
1046 /// Progressive lowering of ExtractSlicesOp to tuple of ExtractStridedSliceOp.
1047 /// One:
1048 /// %x = vector.extract_slices %0
1049 /// is replaced by:
1050 /// %a = vector.strided_slice %0
1051 /// %b = vector.strided_slice %0
1052 /// ..
1053 /// %x = vector.tuple %a, %b, ..
1054 class ExtractSlicesOpLowering
1055 : public OpRewritePattern<vector::ExtractSlicesOp> {
1056 public:
1057 using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
1058
matchAndRewrite(vector::ExtractSlicesOp op,PatternRewriter & rewriter) const1059 LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
1060 PatternRewriter &rewriter) const override {
1061 auto loc = op.getLoc();
1062
1063 VectorType vectorType = op.getSourceVectorType();
1064 auto shape = vectorType.getShape();
1065
1066 SmallVector<int64_t, 4> sizes;
1067 op.getSizes(sizes);
1068 SmallVector<int64_t, 4> strides;
1069 op.getStrides(strides); // all-ones at the moment
1070
1071 // For each element in the tuple, generate the proper strided slice.
1072 TupleType tupleType = op.getResultTupleType();
1073 int64_t tupleSize = tupleType.size();
1074 SmallVector<Value, 4> tupleValues(tupleSize);
1075 auto sliceStrides = computeStrides(shape, sizes);
1076 for (int64_t i = 0; i < tupleSize; ++i) {
1077 auto vectorOffsets = delinearize(sliceStrides, i);
1078 auto elementOffsets =
1079 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1080 auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
1081 // Insert in tuple.
1082 tupleValues[i] = rewriter.create<vector::ExtractStridedSliceOp>(
1083 loc, op.vector(), elementOffsets, sliceSizes, strides);
1084 }
1085
1086 rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
1087 return success();
1088 }
1089 };
1090
1091 /// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
1092 /// One:
1093 /// %x = vector.insert_slices %0
1094 /// is replaced by:
1095 /// %r0 = zero-result
1096 /// %t1 = vector.tuple_get %0, 0
1097 /// %r1 = vector.insert_strided_slice %r0, %t1
1098 /// %t2 = vector.tuple_get %0, 1
1099 /// %r2 = vector.insert_strided_slice %r1, %t2
1100 /// ..
1101 /// %x = ..
1102 class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
1103 public:
1104 using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
1105
matchAndRewrite(vector::InsertSlicesOp op,PatternRewriter & rewriter) const1106 LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
1107 PatternRewriter &rewriter) const override {
1108 auto loc = op.getLoc();
1109
1110 VectorType vectorType = op.getResultVectorType();
1111 auto shape = vectorType.getShape();
1112
1113 SmallVector<int64_t, 4> sizes;
1114 op.getSizes(sizes);
1115 SmallVector<int64_t, 4> strides;
1116 op.getStrides(strides); // all-ones at the moment
1117
1118 // Prepare result.
1119 Value result = rewriter.create<ConstantOp>(
1120 loc, vectorType, rewriter.getZeroAttr(vectorType));
1121
1122 // For each element in the tuple, extract the proper strided slice.
1123 TupleType tupleType = op.getSourceTupleType();
1124 int64_t tupleSize = tupleType.size();
1125 auto sliceStrides = computeStrides(shape, sizes);
1126 for (int64_t i = 0; i < tupleSize; ++i) {
1127 auto vectorOffsets = delinearize(sliceStrides, i);
1128 auto elementOffsets =
1129 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1130 // Extract from tuple into the result.
1131 auto index = rewriter.getI64IntegerAttr(i);
1132 auto tupleGet = rewriter.create<vector::TupleGetOp>(
1133 loc, tupleType.getType(i), op.getOperand(), index);
1134 result = rewriter.create<vector::InsertStridedSliceOp>(
1135 loc, tupleGet, result, elementOffsets, strides);
1136 }
1137
1138 rewriter.replaceOp(op, result);
1139 return success();
1140 }
1141 };
1142
1143 /// Progressive lowering of BroadcastOp.
1144 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
1145 public:
1146 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
1147
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const1148 LogicalResult matchAndRewrite(vector::BroadcastOp op,
1149 PatternRewriter &rewriter) const override {
1150 auto loc = op.getLoc();
1151 VectorType dstType = op.getVectorType();
1152 VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
1153 Type eltType = dstType.getElementType();
1154
1155 // Determine rank of source and destination.
1156 int64_t srcRank = srcType ? srcType.getRank() : 0;
1157 int64_t dstRank = dstType.getRank();
1158
1159 // Duplicate this rank.
1160 // For example:
1161 // %x = broadcast %y : k-D to n-D, k < n
1162 // becomes:
1163 // %b = broadcast %y : k-D to (n-1)-D
1164 // %x = [%b,%b,%b,%b] : n-D
1165 // becomes:
1166 // %b = [%y,%y] : (n-1)-D
1167 // %x = [%b,%b,%b,%b] : n-D
1168 if (srcRank < dstRank) {
1169 // Scalar to any vector can use splat.
1170 if (srcRank == 0) {
1171 rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
1172 return success();
1173 }
1174 // Duplication.
1175 VectorType resType =
1176 VectorType::get(dstType.getShape().drop_front(), eltType);
1177 Value bcst =
1178 rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
1179 Value result = rewriter.create<ConstantOp>(loc, dstType,
1180 rewriter.getZeroAttr(dstType));
1181 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1182 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1183 rewriter.replaceOp(op, result);
1184 return success();
1185 }
1186
1187 // Find non-matching dimension, if any.
1188 assert(srcRank == dstRank);
1189 int64_t m = -1;
1190 for (int64_t r = 0; r < dstRank; r++)
1191 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
1192 m = r;
1193 break;
1194 }
1195
1196 // All trailing dimensions are the same. Simply pass through.
1197 if (m == -1) {
1198 rewriter.replaceOp(op, op.source());
1199 return success();
1200 }
1201
1202 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
1203 if (srcRank == 1) {
1204 assert(m == 0);
1205 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1206 rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
1207 return success();
1208 }
1209
1210 // Any non-matching dimension forces a stretch along this rank.
1211 // For example:
1212 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
1213 // becomes:
1214 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
1215 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
1216 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
1217 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
1218 // %x = [%a,%b,%c,%d]
1219 // becomes:
1220 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
1221 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
1222 // %a = [%u, %v]
1223 // ..
1224 // %x = [%a,%b,%c,%d]
1225 VectorType resType =
1226 VectorType::get(dstType.getShape().drop_front(), eltType);
1227 Value result = rewriter.create<ConstantOp>(loc, dstType,
1228 rewriter.getZeroAttr(dstType));
1229 if (m == 0) {
1230 // Stetch at start.
1231 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1232 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1233 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1234 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1235 } else {
1236 // Stetch not at start.
1237 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
1238 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
1239 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1240 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1241 }
1242 }
1243 rewriter.replaceOp(op, result);
1244 return success();
1245 }
1246 };
1247
1248 /// Progressive lowering of TransposeOp.
1249 /// One:
1250 /// %x = vector.transpose %y, [1, 0]
1251 /// is replaced by:
1252 /// %z = constant dense<0.000000e+00>
1253 /// %0 = vector.extract %y[0, 0]
1254 /// %1 = vector.insert %0, %z [0, 0]
1255 /// ..
1256 /// %x = vector.insert .., .. [.., ..]
1257 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
1258 public:
1259 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
1260
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,MLIRContext * context)1261 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
1262 MLIRContext *context)
1263 : OpRewritePattern<vector::TransposeOp>(context),
1264 vectorTransformsOptions(vectorTransformsOptions) {}
1265
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const1266 LogicalResult matchAndRewrite(vector::TransposeOp op,
1267 PatternRewriter &rewriter) const override {
1268 auto loc = op.getLoc();
1269
1270 VectorType resType = op.getResultType();
1271
1272 // Set up convenience transposition table.
1273 SmallVector<int64_t, 4> transp;
1274 for (auto attr : op.transp())
1275 transp.push_back(attr.cast<IntegerAttr>().getInt());
1276
1277 // Handle a true 2-D matrix transpose differently when requested.
1278 if (vectorTransformsOptions.vectorTransposeLowering ==
1279 vector::VectorTransposeLowering::Flat &&
1280 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
1281 Type flattenedType =
1282 VectorType::get(resType.getNumElements(), resType.getElementType());
1283 auto matrix =
1284 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
1285 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
1286 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
1287 Value trans = rewriter.create<vector::FlatTransposeOp>(
1288 loc, flattenedType, matrix, rows, columns);
1289 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
1290 return success();
1291 }
1292
1293 // Generate fully unrolled extract/insert ops.
1294 Value result = rewriter.create<ConstantOp>(loc, resType,
1295 rewriter.getZeroAttr(resType));
1296 SmallVector<int64_t, 4> lhs(transp.size(), 0);
1297 SmallVector<int64_t, 4> rhs(transp.size(), 0);
1298 rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
1299 op.vector(), result, rewriter));
1300 return success();
1301 }
1302
1303 private:
1304 // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
1305 // operation when al ranks are exhausted.
expandIndices(Location loc,VectorType resType,int64_t pos,SmallVector<int64_t,4> & transp,SmallVector<int64_t,4> & lhs,SmallVector<int64_t,4> & rhs,Value input,Value result,PatternRewriter & rewriter) const1306 Value expandIndices(Location loc, VectorType resType, int64_t pos,
1307 SmallVector<int64_t, 4> &transp,
1308 SmallVector<int64_t, 4> &lhs,
1309 SmallVector<int64_t, 4> &rhs, Value input, Value result,
1310 PatternRewriter &rewriter) const {
1311 if (pos >= resType.getRank()) {
1312 auto ridx = rewriter.getI64ArrayAttr(rhs);
1313 auto lidx = rewriter.getI64ArrayAttr(lhs);
1314 Type eltType = resType.getElementType();
1315 Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
1316 return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
1317 }
1318 for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
1319 lhs[pos] = d;
1320 rhs[transp[pos]] = d;
1321 result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
1322 result, rewriter);
1323 }
1324 return result;
1325 }
1326
1327 /// Options to control the vector patterns.
1328 vector::VectorTransformsOptions vectorTransformsOptions;
1329 };
1330
1331 /// Progressive lowering of OuterProductOp.
1332 /// One:
1333 /// %x = vector.outerproduct %lhs, %rhs, %acc
1334 /// is replaced by:
1335 /// %z = zero-result
1336 /// %0 = vector.extract %lhs[0]
1337 /// %1 = vector.broadcast %0
1338 /// %2 = vector.extract %acc[0]
1339 /// %3 = vector.fma %1, %rhs, %2
1340 /// %4 = vector.insert %3, %z[0]
1341 /// ..
1342 /// %x = vector.insert %.., %..[N-1]
1343 ///
1344 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1345 public:
1346 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
1347
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const1348 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1349 PatternRewriter &rewriter) const override {
1350 auto loc = op.getLoc();
1351
1352 VectorType lhsType = op.getOperandVectorTypeLHS();
1353 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
1354 VectorType resType = op.getVectorType();
1355 Type eltType = resType.getElementType();
1356 bool isInt = eltType.isa<IntegerType>();
1357 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
1358
1359 if (!rhsType) {
1360 // Special case: AXPY operation.
1361 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
1362 rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
1363 return success();
1364 }
1365
1366 Value result = rewriter.create<ConstantOp>(loc, resType,
1367 rewriter.getZeroAttr(resType));
1368 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1369 auto pos = rewriter.getI64ArrayAttr(d);
1370 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
1371 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1372 Value r = nullptr;
1373 if (acc)
1374 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
1375 Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
1376 result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
1377 }
1378 rewriter.replaceOp(op, result);
1379 return success();
1380 }
1381
1382 private:
genMult(Location loc,Value x,Value y,Value acc,bool isInt,PatternRewriter & rewriter)1383 static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
1384 PatternRewriter &rewriter) {
1385 if (acc) {
1386 if (isInt)
1387 return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
1388 acc);
1389 return rewriter.create<vector::FMAOp>(loc, x, y, acc);
1390 }
1391 if (isInt)
1392 return rewriter.create<MulIOp>(loc, x, y);
1393 return rewriter.create<MulFOp>(loc, x, y);
1394 }
1395 };
1396
1397 /// Progressive lowering of ConstantMaskOp.
1398 /// One:
1399 /// %x = vector.constant_mask [a,b]
1400 /// is replaced by:
1401 /// %z = zero-result
1402 /// %l = vector.constant_mask [b]
1403 /// %4 = vector.insert %l, %z[0]
1404 /// ..
1405 /// %x = vector.insert %l, %..[a-1]
1406 /// until a one-dimensional vector is reached. All these operations
1407 /// will be folded at LLVM IR level.
1408 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
1409 public:
1410 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
1411
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const1412 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
1413 PatternRewriter &rewriter) const override {
1414 auto loc = op.getLoc();
1415 auto dstType = op.getType();
1416 auto eltType = dstType.getElementType();
1417 auto dimSizes = op.mask_dim_sizes();
1418 int64_t rank = dimSizes.size();
1419 int64_t trueDim = std::min(dstType.getDimSize(0),
1420 dimSizes[0].cast<IntegerAttr>().getInt());
1421
1422 if (rank == 1) {
1423 // Express constant 1-D case in explicit vector form:
1424 // [T,..,T,F,..,F].
1425 SmallVector<bool, 4> values(dstType.getDimSize(0));
1426 for (int64_t d = 0; d < trueDim; d++)
1427 values[d] = true;
1428 rewriter.replaceOpWithNewOp<ConstantOp>(
1429 op, dstType, rewriter.getBoolVectorAttr(values));
1430 return success();
1431 }
1432
1433 VectorType lowType =
1434 VectorType::get(dstType.getShape().drop_front(), eltType);
1435 SmallVector<int64_t, 4> newDimSizes;
1436 for (int64_t r = 1; r < rank; r++)
1437 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
1438 Value trueVal = rewriter.create<vector::ConstantMaskOp>(
1439 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
1440 Value result = rewriter.create<ConstantOp>(loc, dstType,
1441 rewriter.getZeroAttr(dstType));
1442 for (int64_t d = 0; d < trueDim; d++) {
1443 auto pos = rewriter.getI64ArrayAttr(d);
1444 result =
1445 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
1446 }
1447 rewriter.replaceOp(op, result);
1448 return success();
1449 }
1450 };
1451
1452 /// Progressive lowering of CreateMaskOp.
1453 /// One:
1454 /// %x = vector.create_mask %a, ... : vector<dx...>
1455 /// is replaced by:
1456 /// %l = vector.create_mask ... : vector<...> ; one lower rank
1457 /// %0 = cmpi "slt", %ci, %a |
1458 /// %1 = select %0, %l, %zeroes |
1459 /// %r = vector.insert %1, %pr [i] | d-times
1460 /// %x = ....
1461 /// until a one-dimensional vector is reached.
1462 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
1463 public:
1464 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
1465
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const1466 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1467 PatternRewriter &rewriter) const override {
1468 auto loc = op.getLoc();
1469 auto dstType = op.getResult().getType().cast<VectorType>();
1470 auto eltType = dstType.getElementType();
1471 int64_t dim = dstType.getDimSize(0);
1472 int64_t rank = dstType.getRank();
1473 Value idx = op.getOperand(0);
1474
1475 if (rank == 1)
1476 return failure(); // leave for lowering
1477
1478 VectorType lowType =
1479 VectorType::get(dstType.getShape().drop_front(), eltType);
1480 Value trueVal = rewriter.create<vector::CreateMaskOp>(
1481 loc, lowType, op.getOperands().drop_front());
1482 Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
1483 rewriter.getZeroAttr(lowType));
1484 Value result = rewriter.create<ConstantOp>(loc, dstType,
1485 rewriter.getZeroAttr(dstType));
1486 for (int64_t d = 0; d < dim; d++) {
1487 Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
1488 Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
1489 Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
1490 auto pos = rewriter.getI64ArrayAttr(d);
1491 result =
1492 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
1493 }
1494 rewriter.replaceOp(op, result);
1495 return success();
1496 }
1497 };
1498
1499 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
1500 /// vectors progressively on the way to target llvm.matrix intrinsics.
1501 /// This iterates over the most major dimension of the 2-D vector and performs
1502 /// rewrites into:
1503 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
1504 class ShapeCastOp2DDownCastRewritePattern
1505 : public OpRewritePattern<vector::ShapeCastOp> {
1506 public:
1507 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1508
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1509 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1510 PatternRewriter &rewriter) const override {
1511 auto sourceVectorType = op.getSourceVectorType();
1512 auto resultVectorType = op.getResultVectorType();
1513 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
1514 return failure();
1515
1516 auto loc = op.getLoc();
1517 Value desc = rewriter.create<ConstantOp>(
1518 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1519 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
1520 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
1521 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
1522 desc = rewriter.create<vector::InsertStridedSliceOp>(
1523 loc, vec, desc,
1524 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
1525 }
1526 rewriter.replaceOp(op, desc);
1527 return success();
1528 }
1529 };
1530
1531 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
1532 /// vectors progressively on the way from targeting llvm.matrix intrinsics.
1533 /// This iterates over the most major dimension of the 2-D vector and performs
1534 /// rewrites into:
1535 /// vector.strided_slice from 1-D + vector.insert into 2-D
1536 class ShapeCastOp2DUpCastRewritePattern
1537 : public OpRewritePattern<vector::ShapeCastOp> {
1538 public:
1539 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1540
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1541 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1542 PatternRewriter &rewriter) const override {
1543 auto sourceVectorType = op.getSourceVectorType();
1544 auto resultVectorType = op.getResultVectorType();
1545 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
1546 return failure();
1547
1548 auto loc = op.getLoc();
1549 Value desc = rewriter.create<ConstantOp>(
1550 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1551 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
1552 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
1553 Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
1554 loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
1555 /*sizes=*/mostMinorVectorSize,
1556 /*strides=*/1);
1557 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
1558 }
1559 rewriter.replaceOp(op, desc);
1560 return success();
1561 }
1562 };
1563
1564 // We typically should not lower general shape cast operations into data
1565 // movement instructions, since the assumption is that these casts are
1566 // optimized away during progressive lowering. For completeness, however,
1567 // we fall back to a reference implementation that moves all elements
1568 // into the right place if we get here.
1569 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
1570 public:
1571 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1572
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1573 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1574 PatternRewriter &rewriter) const override {
1575 Location loc = op.getLoc();
1576 auto sourceVectorType = op.getSourceVectorType();
1577 auto resultVectorType = op.getResultVectorType();
1578 // Intended 2D/1D lowerings with better implementations.
1579 int64_t srcRank = sourceVectorType.getRank();
1580 int64_t resRank = resultVectorType.getRank();
1581 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
1582 return failure();
1583 // Compute number of elements involved in the reshape.
1584 int64_t numElts = 1;
1585 for (int64_t r = 0; r < srcRank; r++)
1586 numElts *= sourceVectorType.getDimSize(r);
1587 // Replace with data movement operations:
1588 // x[0,0,0] = y[0,0]
1589 // x[0,0,1] = y[0,1]
1590 // x[0,1,0] = y[0,2]
1591 // etc., incrementing the two index vectors "row-major"
1592 // within the source and result shape.
1593 SmallVector<int64_t, 4> srcIdx(srcRank);
1594 SmallVector<int64_t, 4> resIdx(resRank);
1595 Value result = rewriter.create<ConstantOp>(
1596 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1597 for (int64_t i = 0; i < numElts; i++) {
1598 if (i != 0) {
1599 incIdx(srcIdx, sourceVectorType, srcRank - 1);
1600 incIdx(resIdx, resultVectorType, resRank - 1);
1601 }
1602 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
1603 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
1604 }
1605 rewriter.replaceOp(op, result);
1606 return success();
1607 }
1608
1609 private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)1610 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
1611 assert(0 <= r && r < tp.getRank());
1612 if (++idx[r] == tp.getDimSize(r)) {
1613 idx[r] = 0;
1614 incIdx(idx, tp, r - 1);
1615 }
1616 }
1617 };
1618
1619 } // namespace
1620
1621 namespace mlir {
1622
1623 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1624 /// semantics to:
1625 /// ```
1626 /// %flattened_a = vector.shape_cast %a
1627 /// %flattened_b = vector.shape_cast %b
1628 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1629 /// %d = vector.shape_cast %%flattened_d
1630 /// %e = add %c, %d
1631 /// ```
1632 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1633 //
1634 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
1635 /// the vector.contract op is a row-major matrix multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1636 LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
1637 vector::ContractionOp op, PatternRewriter &rewriter) const {
1638 // TODO: implement masks
1639 if (llvm::size(op.masks()) != 0)
1640 return failure();
1641 if (vectorTransformsOptions.vectorContractLowering !=
1642 vector::VectorContractLowering::Matmul)
1643 return failure();
1644 if (failed(filter(op)))
1645 return failure();
1646
1647 auto iteratorTypes = op.iterator_types().getValue();
1648 if (!isParallelIterator(iteratorTypes[0]) ||
1649 !isParallelIterator(iteratorTypes[1]) ||
1650 !isReductionIterator(iteratorTypes[2]))
1651 return failure();
1652
1653 if (!isRowMajorMatmul(op.indexing_maps()))
1654 return failure();
1655
1656 Type elementType = op.getLhsType().getElementType();
1657 if (!elementType.isIntOrFloat())
1658 return failure();
1659
1660 VectorType lhsType = op.getLhsType();
1661 VectorType rhsType = op.getRhsType();
1662 int64_t lhsRows = lhsType.getDimSize(0);
1663 int64_t lhsColumns = lhsType.getDimSize(1);
1664 int64_t rhsColumns = rhsType.getDimSize(1);
1665
1666 Type flattenedLHSType =
1667 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1668 Type flattenedRHSType =
1669 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1670 auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
1671 op.lhs());
1672 auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
1673 op.rhs());
1674
1675 Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
1676 lhsColumns, rhsColumns);
1677 mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
1678 mul);
1679 if (elementType.isa<IntegerType>())
1680 rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
1681 else
1682 rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
1683
1684 return success();
1685 }
1686
1687 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1688 /// semantics to a reduction_size-unrolled sequence:
1689 /// ```
1690 /// %at = vector.transpose %a, [1, 0]
1691 /// %bRow0 = vector.extract %b[0]
1692 /// %atRow0 = vector.extract %at[0]
1693 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
1694 /// ...
1695 /// %bRowK = vector.extract %b[K]
1696 /// %atRowK = vector.extract %at[K]
1697 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1698 /// ```
1699 ///
1700 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1701 /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1702 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1703 vector::ContractionOp op, PatternRewriter &rewriter) const {
1704 // TODO: implement masks
1705 if (llvm::size(op.masks()) != 0)
1706 return failure();
1707
1708 if (vectorTransformsOptions.vectorContractLowering !=
1709 vector::VectorContractLowering::OuterProduct)
1710 return failure();
1711
1712 if (failed(filter(op)))
1713 return failure();
1714
1715 Location loc = op.getLoc();
1716 int64_t reductionSize = 0;
1717 VectorType lhsType = op.getLhsType();
1718 Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
1719
1720 // Set up the parallel/reduction structure in right form.
1721 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1722 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1723 AffineExpr m, n, k;
1724 bindDims(rewriter.getContext(), m, n, k);
1725 static constexpr std::array<int64_t, 2> perm = {1, 0};
1726 auto iteratorTypes = op.iterator_types().getValue();
1727 SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1728 if (isParallelIterator(iteratorTypes[0]) &&
1729 isParallelIterator(iteratorTypes[1]) &&
1730 isReductionIterator(iteratorTypes[2])) {
1731 //
1732 // Two outer parallel, one inner reduction (matmat flavor).
1733 //
1734 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1735 // This is the classical row-major matmul. Just permute the lhs.
1736 reductionSize = lhsType.getDimSize(1);
1737 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1738 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1739 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1740 reductionSize = lhsType.getDimSize(1);
1741 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1742 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1743 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1744 // No need to permute anything.
1745 reductionSize = lhsType.getDimSize(0);
1746 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1747 // Just permute the rhs.
1748 reductionSize = lhsType.getDimSize(0);
1749 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1750 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1751 // This is the classical row-major matmul. Just permute the lhs.
1752 reductionSize = lhsType.getDimSize(1);
1753 Value tmp = rhs;
1754 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1755 lhs = tmp;
1756 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1757 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1758 reductionSize = lhsType.getDimSize(1);
1759 Value tmp = rhs;
1760 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1761 lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1762 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1763 // No need to permute anything, but still swap lhs and rhs.
1764 reductionSize = lhsType.getDimSize(0);
1765 std::swap(lhs, rhs);
1766 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1767 // Just permute the rhs.
1768 reductionSize = lhsType.getDimSize(0);
1769 Value tmp = lhs;
1770 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1771 rhs = tmp;
1772 } else {
1773 return failure();
1774 }
1775 } else if (isParallelIterator(iteratorTypes[0]) &&
1776 isReductionIterator(iteratorTypes[1])) {
1777 //
1778 // One outer parallel, one inner reduction (matvec flavor)
1779 //
1780 if (maps == infer({{m, n}, {n}, {m}})) {
1781 // Case mat-vec: transpose.
1782 reductionSize = lhsType.getDimSize(1);
1783 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1784 } else if (maps == infer({{n, m}, {n}, {m}})) {
1785 // Case mat-trans-vec: ready to go.
1786 reductionSize = lhsType.getDimSize(0);
1787 } else if (maps == infer({{n}, {m, n}, {m}})) {
1788 // Case vec-mat: swap and transpose.
1789 reductionSize = lhsType.getDimSize(0);
1790 std::swap(lhs, rhs);
1791 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1792 } else if (maps == infer({{n}, {n, m}, {m}})) {
1793 // Case vec-mat-trans: swap and ready to go.
1794 reductionSize = lhsType.getDimSize(0);
1795 std::swap(lhs, rhs);
1796 } else {
1797 return failure();
1798 }
1799 } else {
1800 return failure();
1801 }
1802 assert(reductionSize > 0);
1803
1804 // Unroll outer-products along reduction.
1805 for (int64_t k = 0; k < reductionSize; ++k) {
1806 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
1807 Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
1808 res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
1809 }
1810 rewriter.replaceOp(op, res);
1811 return success();
1812 }
1813
1814 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1815 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1816 PatternRewriter &rewriter) const {
1817 // TODO: implement masks
1818 if (llvm::size(op.masks()) != 0)
1819 return failure();
1820
1821 if (failed(filter(op)))
1822 return failure();
1823
1824 if (vectorTransformsOptions.vectorContractLowering !=
1825 vector::VectorContractLowering::Dot)
1826 return failure();
1827
1828 auto iteratorTypes = op.iterator_types().getValue();
1829 static constexpr std::array<int64_t, 2> perm = {1, 0};
1830 Location loc = op.getLoc();
1831 Value lhs = op.lhs(), rhs = op.rhs();
1832
1833 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1834 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1835 AffineExpr m, n, k;
1836 bindDims(rewriter.getContext(), m, n, k);
1837 SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1838 //
1839 // In the following we wish to make the reduction dimension innermost so we
1840 // can load vectors and just fmul + reduce into a scalar.
1841 //
1842 if (isParallelIterator(iteratorTypes[0]) &&
1843 isParallelIterator(iteratorTypes[1]) &&
1844 isReductionIterator(iteratorTypes[2])) {
1845 //
1846 // Two outer parallel, one inner reduction (matmat flavor).
1847 //
1848 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1849 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1850 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1851 // No need to permute anything.
1852 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1853 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1854 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1855 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1856 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1857 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1858 // This is the classical row-major matmul. Just permute the lhs.
1859 Value tmp = lhs;
1860 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1861 rhs = tmp;
1862 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1863 std::swap(lhs, rhs);
1864 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1865 Value tmp = lhs;
1866 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1867 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1868 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1869 Value tmp = rhs;
1870 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1871 lhs = tmp;
1872 } else {
1873 return failure();
1874 }
1875 } else if (isParallelIterator(iteratorTypes[0]) &&
1876 isReductionIterator(iteratorTypes[1])) {
1877 //
1878 // One outer parallel, one inner reduction (matvec flavor)
1879 //
1880 if (maps == infer({{m, n}, {n}, {m}})) {
1881 // No need to permute anything.
1882 } else if (maps == infer({{n, m}, {n}, {m}})) {
1883 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1884 } else if (maps == infer({{n}, {m, n}, {m}})) {
1885 std::swap(lhs, rhs);
1886 } else if (maps == infer({{n}, {n, m}, {m}})) {
1887 std::swap(lhs, rhs);
1888 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1889 } else {
1890 return failure();
1891 }
1892 } else {
1893 return failure();
1894 }
1895
1896 VectorType dstType = op.getResultType().cast<VectorType>();
1897 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1898 "Expected dst type of rank 1 or 2");
1899
1900 unsigned rank = dstType.getRank();
1901 unsigned dstRows = dstType.getShape()[0];
1902 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1903
1904 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1905 Value res =
1906 rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
1907 for (unsigned r = 0; r < dstRows; ++r) {
1908 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1909 for (unsigned c = 0; c < dstColumns; ++c) {
1910 Value b = rank == 1
1911 ? rhs
1912 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1913 Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
1914 Value reduced = rewriter.create<vector::ReductionOp>(
1915 op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
1916 m, ValueRange{});
1917
1918 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1919 : SmallVector<int64_t, 2>{r, c};
1920 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1921 }
1922 }
1923 if (auto acc = op.acc())
1924 res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
1925 rewriter.replaceOp(op, res);
1926 return success();
1927 }
1928
1929 /// Progressive lowering of ContractionOp.
1930 /// One:
1931 /// %x = vector.contract with at least one free/batch dimension
1932 /// is replaced by:
1933 /// %a = vector.contract with one less free/batch dimension
1934 /// %b = vector.contract with one less free/batch dimension
1935 /// ..
1936 /// %x = combine %a %b ..
1937 /// until a pure contraction is reached (no free/batch dimensions),
1938 /// which is replaced by a dot-product.
1939 ///
1940 /// This only kicks in when either VectorTransformsOptions is set
1941 /// to DOT or when other contraction patterns fail.
1942 //
1943 // TODO: break down into transpose/reshape/cast ops
1944 // when they become available to avoid code dup
1945 // TODO: investigate lowering order impact on performance
1946 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1947 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1948 PatternRewriter &rewriter) const {
1949 // TODO: implement masks.
1950 if (llvm::size(op.masks()) != 0)
1951 return failure();
1952
1953 if (failed(filter(op)))
1954 return failure();
1955
1956 // TODO: support mixed mode contract lowering.
1957 if (op.getLhsType().getElementType() !=
1958 getElementTypeOrSelf(op.getAccType()) ||
1959 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1960 return failure();
1961
1962 // TODO: implement benefits, cost models.
1963 MLIRContext *ctx = op.getContext();
1964 ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
1965 if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1966 return success();
1967 ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
1968 if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1969 return success();
1970 ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
1971 if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1972 return success();
1973
1974 // Find first batch dimension in LHS/RHS, and lower when found.
1975 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1976 if (!batchDimMap.empty()) {
1977 int64_t lhsIndex = batchDimMap[0].first;
1978 int64_t rhsIndex = batchDimMap[0].second;
1979 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1980 return success();
1981 }
1982
1983 // Collect contracting dimensions.
1984 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1985 op.getContractingDimMap();
1986 DenseSet<int64_t> lhsContractingDimSet;
1987 DenseSet<int64_t> rhsContractingDimSet;
1988 for (auto &dimPair : contractingDimMap) {
1989 lhsContractingDimSet.insert(dimPair.first);
1990 rhsContractingDimSet.insert(dimPair.second);
1991 }
1992
1993 // Find first free dimension in LHS, and lower when found.
1994 VectorType lhsType = op.getLhsType();
1995 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1996 if (lhsContractingDimSet.count(lhsIndex) == 0) {
1997 rewriter.replaceOp(
1998 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1999 return success();
2000 }
2001 }
2002
2003 // Find first free dimension in RHS, and lower when found.
2004 VectorType rhsType = op.getRhsType();
2005 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
2006 if (rhsContractingDimSet.count(rhsIndex) == 0) {
2007 rewriter.replaceOp(
2008 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
2009 return success();
2010 }
2011 }
2012
2013 // Lower the first remaining reduction dimension.
2014 if (!contractingDimMap.empty()) {
2015 rewriter.replaceOp(op, lowerReduction(op, rewriter));
2016 return success();
2017 }
2018
2019 return failure();
2020 }
2021
2022 // Lower one parallel dimension.
2023 // TODO: consider reusing existing contract unrolling
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const2024 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
2025 int64_t lhsIndex, int64_t rhsIndex,
2026 PatternRewriter &rewriter) const {
2027 VectorType lhsType = op.getLhsType();
2028 VectorType rhsType = op.getRhsType();
2029 VectorType resType = op.getResultType().cast<VectorType>();
2030 // Find the iterator type index and result index.
2031 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2032 int64_t iterIndex = -1;
2033 int64_t dimSize = -1;
2034 if (lhsIndex >= 0) {
2035 iterIndex = iMap[0].getDimPosition(lhsIndex);
2036 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
2037 "parallel index should be free in LHS or batch in LHS/RHS");
2038 dimSize = lhsType.getDimSize(lhsIndex);
2039 } else {
2040 assert(rhsIndex >= 0 && "missing parallel index");
2041 iterIndex = iMap[1].getDimPosition(rhsIndex);
2042 dimSize = rhsType.getDimSize(rhsIndex);
2043 }
2044 assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
2045 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
2046 assert(lookup.hasValue() && "parallel index not listed in reduction");
2047 int64_t resIndex = lookup.getValue();
2048 // Construct new iterator types and affine map array attribute.
2049 std::array<AffineMap, 3> lowIndexingMaps = {
2050 adjustMap(iMap[0], iterIndex, rewriter),
2051 adjustMap(iMap[1], iterIndex, rewriter),
2052 adjustMap(iMap[2], iterIndex, rewriter)};
2053 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2054 auto lowIter =
2055 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2056 // Unroll into a series of lower dimensional vector.contract ops.
2057 Location loc = op.getLoc();
2058 Value result =
2059 rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
2060 for (int64_t d = 0; d < dimSize; ++d) {
2061 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2062 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2063 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
2064 Value lowContract = rewriter.create<vector::ContractionOp>(
2065 loc, lhs, rhs, acc, lowAffine, lowIter);
2066 result =
2067 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
2068 }
2069 return result;
2070 }
2071
2072 // Lower one reduction dimension.
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const2073 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
2074 PatternRewriter &rewriter) const {
2075 auto loc = op.getLoc();
2076 VectorType lhsType = op.getLhsType();
2077 VectorType rhsType = op.getRhsType();
2078 Type resType = op.getResultType();
2079 assert(!resType.isa<VectorType>());
2080 // Use iterator index 0.
2081 int64_t iterIndex = 0;
2082 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2083 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
2084 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
2085 assert(lookupLhs.hasValue() && "missing LHS parallel index");
2086 assert(lookupRhs.hasValue() && "missing RHS parallel index");
2087 int64_t lhsIndex = lookupLhs.getValue();
2088 int64_t rhsIndex = lookupRhs.getValue();
2089 int64_t dimSize = lhsType.getDimSize(lhsIndex);
2090 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
2091 // Base case.
2092 if (lhsType.getRank() == 1) {
2093 assert(rhsType.getRank() == 1 && "corrupt contraction");
2094 Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
2095 StringAttr kind = rewriter.getStringAttr("add");
2096 return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
2097 op.acc());
2098 }
2099 // Construct new iterator types and affine map array attribute.
2100 std::array<AffineMap, 3> lowIndexingMaps = {
2101 adjustMap(iMap[0], iterIndex, rewriter),
2102 adjustMap(iMap[1], iterIndex, rewriter),
2103 adjustMap(iMap[2], iterIndex, rewriter)};
2104 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2105 auto lowIter =
2106 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2107 // Unroll into a series of lower dimensional vector.contract ops.
2108 // By feeding the initial accumulator into the first contraction,
2109 // and the result of each contraction into the next, eventually
2110 // the sum of all reductions is computed.
2111 Value result = op.acc();
2112 for (int64_t d = 0; d < dimSize; ++d) {
2113 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2114 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2115 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
2116 lowAffine, lowIter);
2117 }
2118 return result;
2119 }
2120
2121 } // namespace mlir
2122
extractConstantIndex(Value v)2123 static Optional<int64_t> extractConstantIndex(Value v) {
2124 if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
2125 return cstOp.getValue();
2126 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
2127 if (affineApplyOp.getAffineMap().isSingleConstant())
2128 return affineApplyOp.getAffineMap().getSingleConstantResult();
2129 return None;
2130 }
2131
2132 // Missing foldings of scf.if make it necessary to perform poor man's folding
2133 // eagerly, especially in the case of unrolling. In the future, this should go
2134 // away once scf.if folds properly.
createScopedFoldedSLE(Value v,Value ub)2135 static Value createScopedFoldedSLE(Value v, Value ub) {
2136 using namespace edsc::op;
2137 auto maybeCstV = extractConstantIndex(v);
2138 auto maybeCstUb = extractConstantIndex(ub);
2139 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
2140 return Value();
2141 return sle(v, ub);
2142 }
2143
2144 // Operates under a scoped context to build the condition to ensure that a
2145 // particular VectorTransferOpInterface is unmasked.
createScopedInBoundsCond(VectorTransferOpInterface xferOp)2146 static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
2147 assert(xferOp.permutation_map().isMinorIdentity() &&
2148 "Expected minor identity map");
2149 Value inBoundsCond;
2150 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2151 // Zip over the resulting vector shape and memref indices.
2152 // If the dimension is known to be unmasked, it does not participate in the
2153 // construction of `inBoundsCond`.
2154 if (!xferOp.isMaskedDim(resultIdx))
2155 return;
2156 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
2157 using namespace edsc::op;
2158 using namespace edsc::intrinsics;
2159 // Fold or create the check that `index + vector_size` <= `memref_size`.
2160 Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
2161 Value cond =
2162 createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx));
2163 if (!cond)
2164 return;
2165 // Conjunction over all dims for which we are in-bounds.
2166 inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
2167 });
2168 return inBoundsCond;
2169 }
2170
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)2171 LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
2172 VectorTransferOpInterface xferOp) {
2173 // TODO: expand support to these 2 cases.
2174 if (!xferOp.permutation_map().isMinorIdentity())
2175 return failure();
2176 // Must have some masked dimension to be a candidate for splitting.
2177 if (!xferOp.hasMaskedDim())
2178 return failure();
2179 // Don't split transfer operations directly under IfOp, this avoids applying
2180 // the pattern recursively.
2181 // TODO: improve the filtering condition to make it more applicable.
2182 if (isa<scf::IfOp>(xferOp->getParentOp()))
2183 return failure();
2184 return success();
2185 }
2186
2187 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
2188 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
2189 /// return null; otherwise:
2190 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
2191 /// 2. else return a new MemRefType obtained by iterating over the shape and
2192 /// strides and:
2193 /// a. keeping the ones that are static and equal across `aT` and `bT`.
2194 /// b. using a dynamic shape and/or stride for the dimensions that don't
2195 /// agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)2196 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
2197 if (MemRefCastOp::areCastCompatible(aT, bT))
2198 return aT;
2199 if (aT.getRank() != bT.getRank())
2200 return MemRefType();
2201 int64_t aOffset, bOffset;
2202 SmallVector<int64_t, 4> aStrides, bStrides;
2203 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
2204 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
2205 aStrides.size() != bStrides.size())
2206 return MemRefType();
2207
2208 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
2209 int64_t resOffset;
2210 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
2211 resStrides(bT.getRank(), 0);
2212 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
2213 resShape[idx] =
2214 (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
2215 resStrides[idx] = (aStrides[idx] == bStrides[idx])
2216 ? aStrides[idx]
2217 : MemRefType::kDynamicStrideOrOffset;
2218 }
2219 resOffset =
2220 (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
2221 return MemRefType::get(
2222 resShape, aT.getElementType(),
2223 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
2224 }
2225
2226 /// Operates under a scoped context to build the intersection between the
2227 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
2228 // TODO: view intersection/union/differences should be a proper std op.
createScopedSubViewIntersection(VectorTransferOpInterface xferOp,Value alloc)2229 static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
2230 Value alloc) {
2231 using namespace edsc::intrinsics;
2232 int64_t memrefRank = xferOp.getShapedType().getRank();
2233 // TODO: relax this precondition, will require rank-reducing subviews.
2234 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
2235 "Expected memref rank to match the alloc rank");
2236 ValueRange leadingIndices =
2237 xferOp.indices().take_front(xferOp.getLeadingShapedRank());
2238 SmallVector<OpFoldResult, 4> sizes;
2239 sizes.append(leadingIndices.begin(), leadingIndices.end());
2240 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2241 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2242 Value dimMemRef = std_dim(xferOp.source(), indicesIdx);
2243 Value dimAlloc = std_dim(alloc, resultIdx);
2244 Value index = xferOp.indices()[indicesIdx];
2245 AffineExpr i, j, k;
2246 bindDims(xferOp.getContext(), i, j, k);
2247 SmallVector<AffineMap, 4> maps =
2248 AffineMap::inferFromExprList(MapList{{i - j, k}});
2249 // affine_min(%dimMemRef - %index, %dimAlloc)
2250 Value affineMin = affine_min(index.getType(), maps[0],
2251 ValueRange{dimMemRef, index, dimAlloc});
2252 sizes.push_back(affineMin);
2253 });
2254
2255 SmallVector<OpFoldResult, 4> indices = llvm::to_vector<4>(llvm::map_range(
2256 xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
2257 return std_sub_view(
2258 xferOp.source(), indices, sizes,
2259 SmallVector<OpFoldResult>(memrefRank, OpBuilder(xferOp).getIndexAttr(1)));
2260 }
2261
2262 /// Given an `xferOp` for which:
2263 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2264 /// 2. a memref of single vector `alloc` has been allocated.
2265 /// Produce IR resembling:
2266 /// ```
2267 /// %1:3 = scf.if (%inBounds) {
2268 /// memref_cast %A: memref<A...> to compatibleMemRefType
2269 /// scf.yield %view, ... : compatibleMemRefType, index, index
2270 /// } else {
2271 /// %2 = linalg.fill(%alloc, %pad)
2272 /// %3 = subview %view [...][...][...]
2273 /// linalg.copy(%3, %alloc)
2274 /// memref_cast %alloc: memref<B...> to compatibleMemRefType
2275 /// scf.yield %4, ... : compatibleMemRefType, index, index
2276 /// }
2277 /// ```
2278 /// Return the produced scf::IfOp.
createScopedFullPartialLinalgCopy(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2279 static scf::IfOp createScopedFullPartialLinalgCopy(
2280 vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2281 MemRefType compatibleMemRefType, Value alloc) {
2282 using namespace edsc;
2283 using namespace edsc::intrinsics;
2284 scf::IfOp fullPartialIfOp;
2285 Value zero = std_constant_index(0);
2286 Value memref = xferOp.source();
2287 conditionBuilder(
2288 returnTypes, inBoundsCond,
2289 [&]() -> scf::ValueVector {
2290 Value res = memref;
2291 if (compatibleMemRefType != xferOp.getShapedType())
2292 res = std_memref_cast(memref, compatibleMemRefType);
2293 scf::ValueVector viewAndIndices{res};
2294 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2295 xferOp.indices().end());
2296 return viewAndIndices;
2297 },
2298 [&]() -> scf::ValueVector {
2299 linalg_fill(alloc, xferOp.padding());
2300 // Take partial subview of memref which guarantees no dimension
2301 // overflows.
2302 Value memRefSubView = createScopedSubViewIntersection(
2303 cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
2304 linalg_copy(memRefSubView, alloc);
2305 Value casted = std_memref_cast(alloc, compatibleMemRefType);
2306 scf::ValueVector viewAndIndices{casted};
2307 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2308 zero);
2309 return viewAndIndices;
2310 },
2311 &fullPartialIfOp);
2312 return fullPartialIfOp;
2313 }
2314
2315 /// Given an `xferOp` for which:
2316 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2317 /// 2. a memref of single vector `alloc` has been allocated.
2318 /// Produce IR resembling:
2319 /// ```
2320 /// %1:3 = scf.if (%inBounds) {
2321 /// memref_cast %A: memref<A...> to compatibleMemRefType
2322 /// scf.yield %view, ... : compatibleMemRefType, index, index
2323 /// } else {
2324 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
2325 /// %3 = vector.type_cast %extra_alloc :
2326 /// memref<...> to memref<vector<...>>
2327 /// store %2, %3[] : memref<vector<...>>
2328 /// %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
2329 /// scf.yield %4, ... : compatibleMemRefType, index, index
2330 /// }
2331 /// ```
2332 /// Return the produced scf::IfOp.
createScopedFullPartialVectorTransferRead(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2333 static scf::IfOp createScopedFullPartialVectorTransferRead(
2334 vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2335 MemRefType compatibleMemRefType, Value alloc) {
2336 using namespace edsc;
2337 using namespace edsc::intrinsics;
2338 scf::IfOp fullPartialIfOp;
2339 Value zero = std_constant_index(0);
2340 Value memref = xferOp.source();
2341 conditionBuilder(
2342 returnTypes, inBoundsCond,
2343 [&]() -> scf::ValueVector {
2344 Value res = memref;
2345 if (compatibleMemRefType != xferOp.getShapedType())
2346 res = std_memref_cast(memref, compatibleMemRefType);
2347 scf::ValueVector viewAndIndices{res};
2348 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2349 xferOp.indices().end());
2350 return viewAndIndices;
2351 },
2352 [&]() -> scf::ValueVector {
2353 Operation *newXfer =
2354 ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
2355 Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
2356 std_store(vector, vector_type_cast(
2357 MemRefType::get({}, vector.getType()), alloc));
2358
2359 Value casted = std_memref_cast(alloc, compatibleMemRefType);
2360 scf::ValueVector viewAndIndices{casted};
2361 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2362 zero);
2363
2364 return viewAndIndices;
2365 },
2366 &fullPartialIfOp);
2367 return fullPartialIfOp;
2368 }
2369
2370 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
2371 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
2372 /// newly created conditional upon function return.
2373 /// To accomodate for the fact that the original vector.transfer indexing may be
2374 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
2375 /// scf.if op returns a view and values of type index.
2376 /// At this time, only vector.transfer_read case is implemented.
2377 ///
2378 /// Example (a 2-D vector.transfer_read):
2379 /// ```
2380 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
2381 /// ```
2382 /// is transformed into:
2383 /// ```
2384 /// %1:3 = scf.if (%inBounds) {
2385 /// // fastpath, direct cast
2386 /// memref_cast %A: memref<A...> to compatibleMemRefType
2387 /// scf.yield %view : compatibleMemRefType, index, index
2388 /// } else {
2389 /// // slowpath, masked vector.transfer or linalg.copy.
2390 /// memref_cast %alloc: memref<B...> to compatibleMemRefType
2391 /// scf.yield %4 : compatibleMemRefType, index, index
2392 // }
2393 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
2394 /// ```
2395 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
2396 ///
2397 /// Preconditions:
2398 /// 1. `xferOp.permutation_map()` must be a minor identity map
2399 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
2400 /// must be equal. This will be relaxed in the future but requires
2401 /// rank-reducing subviews.
splitFullAndPartialTransfer(OpBuilder & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)2402 LogicalResult mlir::vector::splitFullAndPartialTransfer(
2403 OpBuilder &b, VectorTransferOpInterface xferOp,
2404 VectorTransformsOptions options, scf::IfOp *ifOp) {
2405 using namespace edsc;
2406 using namespace edsc::intrinsics;
2407
2408 if (options.vectorTransferSplit == VectorTransferSplit::None)
2409 return failure();
2410
2411 SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
2412 auto unmaskedAttr = b.getBoolArrayAttr(bools);
2413 if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
2414 xferOp->setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2415 return success();
2416 }
2417
2418 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
2419 "Expected splitFullAndPartialTransferPrecondition to hold");
2420 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
2421
2422 // TODO: add support for write case.
2423 if (!xferReadOp)
2424 return failure();
2425
2426 OpBuilder::InsertionGuard guard(b);
2427 if (Operation *sourceOp = xferOp.source().getDefiningOp())
2428 b.setInsertionPointAfter(sourceOp);
2429 else
2430 b.setInsertionPoint(xferOp);
2431 ScopedContext scope(b, xferOp.getLoc());
2432 Value inBoundsCond = createScopedInBoundsCond(
2433 cast<VectorTransferOpInterface>(xferOp.getOperation()));
2434 if (!inBoundsCond)
2435 return failure();
2436
2437 // Top of the function `alloc` for transient storage.
2438 Value alloc;
2439 {
2440 FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
2441 OpBuilder::InsertionGuard guard(b);
2442 b.setInsertionPointToStart(&funcOp.getRegion().front());
2443 auto shape = xferOp.getVectorType().getShape();
2444 Type elementType = xferOp.getVectorType().getElementType();
2445 alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
2446 b.getI64IntegerAttr(32));
2447 }
2448
2449 MemRefType compatibleMemRefType =
2450 getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
2451 alloc.getType().cast<MemRefType>());
2452
2453 // Read case: full fill + partial copy -> unmasked vector.xfer_read.
2454 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
2455 b.getIndexType());
2456 returnTypes[0] = compatibleMemRefType;
2457 scf::IfOp fullPartialIfOp =
2458 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
2459 ? createScopedFullPartialVectorTransferRead(
2460 xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
2461 alloc)
2462 : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
2463 inBoundsCond,
2464 compatibleMemRefType, alloc);
2465 if (ifOp)
2466 *ifOp = fullPartialIfOp;
2467
2468 // Unmask the existing read op, it always reads from a full buffer.
2469 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
2470 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
2471 xferOp->setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2472
2473 return success();
2474 }
2475
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2476 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
2477 Operation *op, PatternRewriter &rewriter) const {
2478 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
2479 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
2480 failed(filter(xferOp)))
2481 return failure();
2482 rewriter.startRootUpdate(xferOp);
2483 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
2484 rewriter.finalizeRootUpdate(xferOp);
2485 return success();
2486 }
2487 rewriter.cancelRootUpdate(xferOp);
2488 return failure();
2489 }
2490
matchAndRewrite(ExtractMapOp extract,PatternRewriter & rewriter) const2491 LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
2492 ExtractMapOp extract, PatternRewriter &rewriter) const {
2493 Operation *definedOp = extract.vector().getDefiningOp();
2494 if (!definedOp || definedOp->getNumResults() != 1)
2495 return failure();
2496 // TODO: Create an interfaceOp for elementwise operations.
2497 if (!isa<AddFOp>(definedOp))
2498 return failure();
2499 Location loc = extract.getLoc();
2500 SmallVector<Value, 4> extractOperands;
2501 for (OpOperand &operand : definedOp->getOpOperands())
2502 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2503 loc, extract.getResultType(), operand.get(), extract.ids()));
2504 Operation *newOp = cloneOpWithOperandsAndTypes(
2505 rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
2506 rewriter.replaceOp(extract, newOp->getResult(0));
2507 return success();
2508 }
2509
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)2510 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
2511 OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
2512 ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2513 OpBuilder::InsertionGuard guard(builder);
2514 builder.setInsertionPointAfter(op);
2515 Location loc = op->getLoc();
2516 if (op->getNumResults() != 1)
2517 return {};
2518 Value result = op->getResult(0);
2519 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2520 if (!type || map.getNumResults() != multiplicity.size())
2521 return {};
2522 // For each dimension being distributed check that the size is a multiple of
2523 // the multiplicity. To handle more sizes we would need to support masking.
2524 unsigned multiplictyCount = 0;
2525 for (auto exp : map.getResults()) {
2526 auto affinExp = exp.dyn_cast<AffineDimExpr>();
2527 if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2528 type.getDimSize(affinExp.getPosition()) %
2529 multiplicity[multiplictyCount++] !=
2530 0)
2531 return {};
2532 }
2533 DistributeOps ops;
2534 ops.extract =
2535 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2536 ops.insert =
2537 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2538 return ops;
2539 }
2540
2541 struct TransferReadExtractPattern
2542 : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPatternTransferReadExtractPattern2543 TransferReadExtractPattern(MLIRContext *context)
2544 : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewriteTransferReadExtractPattern2545 LogicalResult matchAndRewrite(vector::TransferReadOp read,
2546 PatternRewriter &rewriter) const override {
2547 if (!read.getResult().hasOneUse())
2548 return failure();
2549 auto extract =
2550 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
2551 if (!extract)
2552 return failure();
2553 edsc::ScopedContext scope(rewriter, read.getLoc());
2554 using mlir::edsc::op::operator+;
2555 using mlir::edsc::op::operator*;
2556 using namespace mlir::edsc::intrinsics;
2557 SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
2558 AffineMap map = extract.map();
2559 unsigned idCount = 0;
2560 for (auto expr : map.getResults()) {
2561 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2562 indices[pos] =
2563 indices[pos] +
2564 extract.ids()[idCount++] *
2565 std_constant_index(extract.getResultType().getDimSize(pos));
2566 }
2567 Value newRead = vector_transfer_read(extract.getType(), read.source(),
2568 indices, read.permutation_map(),
2569 read.padding(), read.maskedAttr());
2570 Value dest = rewriter.create<ConstantOp>(
2571 read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType()));
2572 newRead = rewriter.create<vector::InsertMapOp>(read.getLoc(), newRead, dest,
2573 extract.ids());
2574 rewriter.replaceOp(read, newRead);
2575 return success();
2576 }
2577 };
2578
2579 struct TransferWriteInsertPattern
2580 : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPatternTransferWriteInsertPattern2581 TransferWriteInsertPattern(MLIRContext *context)
2582 : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewriteTransferWriteInsertPattern2583 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2584 PatternRewriter &rewriter) const override {
2585 auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
2586 if (!insert)
2587 return failure();
2588 edsc::ScopedContext scope(rewriter, write.getLoc());
2589 using mlir::edsc::op::operator+;
2590 using mlir::edsc::op::operator*;
2591 using namespace mlir::edsc::intrinsics;
2592 SmallVector<Value, 4> indices(write.indices().begin(),
2593 write.indices().end());
2594 AffineMap map = insert.map();
2595 unsigned idCount = 0;
2596 for (auto expr : map.getResults()) {
2597 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2598 indices[pos] =
2599 indices[pos] +
2600 insert.ids()[idCount++] *
2601 std_constant_index(insert.getSourceVectorType().getDimSize(pos));
2602 }
2603 vector_transfer_write(insert.vector(), write.source(), indices,
2604 write.permutation_map(), write.maskedAttr());
2605 rewriter.eraseOp(write);
2606 return success();
2607 }
2608 };
2609
2610 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
2611 // TODO: Add this as DRR pattern.
populateVectorToVectorTransformationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2612 void mlir::vector::populateVectorToVectorTransformationPatterns(
2613 OwningRewritePatternList &patterns, MLIRContext *context) {
2614 // clang-format off
2615 patterns.insert<ShapeCastOpDecomposer,
2616 ShapeCastOpFolder,
2617 SplitTransferReadOp,
2618 SplitTransferWriteOp,
2619 TupleGetFolderOp,
2620 TransferReadExtractPattern,
2621 TransferWriteInsertPattern>(context);
2622 // clang-format on
2623 }
2624
populateVectorSlicesLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2625 void mlir::vector::populateVectorSlicesLoweringPatterns(
2626 OwningRewritePatternList &patterns, MLIRContext *context) {
2627 patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
2628 }
2629
populateVectorContractLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context,VectorTransformsOptions parameters)2630 void mlir::vector::populateVectorContractLoweringPatterns(
2631 OwningRewritePatternList &patterns, MLIRContext *context,
2632 VectorTransformsOptions parameters) {
2633 // clang-format off
2634 patterns.insert<BroadcastOpLowering,
2635 CreateMaskOpLowering,
2636 ConstantMaskOpLowering,
2637 OuterProductOpLowering,
2638 ShapeCastOp2DDownCastRewritePattern,
2639 ShapeCastOp2DUpCastRewritePattern,
2640 ShapeCastOpRewritePattern>(context);
2641 patterns.insert<TransposeOpLowering,
2642 ContractionOpLowering,
2643 ContractionOpToMatmulOpLowering,
2644 ContractionOpToOuterProductOpLowering>(parameters, context);
2645 // clang-format on
2646 }
2647