1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Dialect/Affine/EDSC/Builders.h"
16 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/Dialect/Vector/VectorUtils.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Location.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OperationSupport.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/IR/Types.h"
38 #include "mlir/Interfaces/VectorInterfaces.h"
39 
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43 
44 #define DEBUG_TYPE "vector-to-vector"
45 
46 using namespace mlir;
47 using llvm::dbgs;
48 
49 // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)50 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
51   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
52     int64_t idx = map.getDimPosition(i);
53     if (idx == index)
54       return i;
55   }
56   return None;
57 }
58 
59 // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)60 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
61                                             int64_t index) {
62   SmallVector<Attribute, 4> results;
63   for (auto it : llvm::enumerate(iteratorTypes)) {
64     int64_t idx = it.index();
65     if (idx == index)
66       continue;
67     results.push_back(it.value());
68   }
69   return results;
70 }
71 
72 // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)73 static AffineMap adjustMap(AffineMap map, int64_t index,
74                            PatternRewriter &rewriter) {
75   auto *ctx = rewriter.getContext();
76   SmallVector<AffineExpr, 4> results;
77   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
78     int64_t idx = map.getDimPosition(i);
79     if (idx == index)
80       continue;
81     // Re-insert remaining indices, but renamed when occurring
82     // after the removed index.
83     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
84     results.push_back(targetExpr);
85   }
86   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
87 }
88 
89 // Helper to drop dimension from vector type.
adjustType(VectorType tp,int64_t index)90 static Type adjustType(VectorType tp, int64_t index) {
91   int64_t rank = tp.getRank();
92   Type eltType = tp.getElementType();
93   if (rank == 1) {
94     assert(index == 0 && "index for scalar result out of bounds");
95     return eltType;
96   }
97   SmallVector<int64_t, 4> adjustedShape;
98   for (int64_t i = 0; i < rank; ++i) {
99     // Omit dimension at the given index.
100     if (i == index)
101       continue;
102     // Otherwise, add dimension back.
103     adjustedShape.push_back(tp.getDimSize(i));
104   }
105   return VectorType::get(adjustedShape, eltType);
106 }
107 
108 // Helper method to possibly drop a dimension in a load.
109 // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)110 static Value reshapeLoad(Location loc, Value val, VectorType type,
111                          int64_t index, int64_t pos,
112                          PatternRewriter &rewriter) {
113   if (index == -1)
114     return val;
115   Type lowType = adjustType(type, 0);
116   // At extraction dimension?
117   if (index == 0) {
118     auto posAttr = rewriter.getI64ArrayAttr(pos);
119     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
120   }
121   // Unroll leading dimensions.
122   VectorType vType = lowType.cast<VectorType>();
123   VectorType resType = adjustType(type, index).cast<VectorType>();
124   Value result =
125       rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
126   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
127     auto posAttr = rewriter.getI64ArrayAttr(d);
128     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
129     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
130     result =
131         rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
132   }
133   return result;
134 }
135 
136 // Helper method to possibly drop a dimension in a store.
137 // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)138 static Value reshapeStore(Location loc, Value val, Value result,
139                           VectorType type, int64_t index, int64_t pos,
140                           PatternRewriter &rewriter) {
141   // Unmodified?
142   if (index == -1)
143     return val;
144   // At insertion dimension?
145   if (index == 0) {
146     auto posAttr = rewriter.getI64ArrayAttr(pos);
147     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
148   }
149   // Unroll leading dimensions.
150   Type lowType = adjustType(type, 0);
151   VectorType vType = lowType.cast<VectorType>();
152   Type insType = adjustType(vType, 0);
153   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
154     auto posAttr = rewriter.getI64ArrayAttr(d);
155     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
156     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
157     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
158     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
159   }
160   return result;
161 }
162 
163 // Clones `op` into a new operations that takes `operands` and returns
164 // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)165 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
166                                               Operation *op,
167                                               ArrayRef<Value> operands,
168                                               ArrayRef<Type> resultTypes) {
169   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
170                      op->getAttrs());
171   return builder.createOperation(res);
172 }
173 
174 // Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
175 // for each index 'i' in inputElements with a valid mapping in 'indexMap'.
getMappedElements(const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> inputElements,SmallVectorImpl<int64_t> & resultElements)176 static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
177                               ArrayRef<int64_t> inputElements,
178                               SmallVectorImpl<int64_t> &resultElements) {
179   assert(indexMap.size() == resultElements.size());
180   assert(inputElements.size() >= resultElements.size());
181   for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
182     auto it = indexMap.find(i);
183     if (it != indexMap.end())
184       resultElements[it->second] = inputElements[i];
185   }
186 }
187 
188 // Returns a tuple type with vector element types for each resulting slice
189 // of 'vectorType' unrolled by 'sizes' and 'strides'.
190 // TODO: Move this to a utility function and share it with
191 // Extract/InsertSlicesOp verification.
generateExtractSlicesOpResultType(VectorType vectorType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,OpBuilder & builder)192 static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
193                                                    ArrayRef<int64_t> sizes,
194                                                    ArrayRef<int64_t> strides,
195                                                    OpBuilder &builder) {
196   assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
197   assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
198   assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
199 
200   // Compute shape ratio of 'shape' and 'sizes'.
201   auto shape = vectorType.getShape();
202   auto maybeDimSliceCounts = shapeRatio(shape, sizes);
203   assert(maybeDimSliceCounts.hasValue());
204   auto sliceDimCounts = *maybeDimSliceCounts;
205 
206   // Compute strides w.r.t number of slices in each dimension.
207   auto sliceStrides = computeStrides(sliceDimCounts);
208   int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts);
209   SmallVector<Type, 4> vectorTypes(sliceCount);
210   for (unsigned i = 0; i < sliceCount; ++i) {
211     auto vectorOffsets = delinearize(sliceStrides, i);
212     auto elementOffsets =
213         computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
214     auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
215     // Create Vector type and add to 'vectorTypes[i]'.
216     vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
217   }
218   return builder.getTupleType(vectorTypes);
219 }
220 
221 // UnrolledVectorState aggregates per-operand/result vector state required for
222 // unrolling.
223 struct UnrolledVectorState {
224   SmallVector<int64_t, 4> unrolledShape;
225   SmallVector<int64_t, 4> unrollFactors;
226   SmallVector<int64_t, 8> basis;
227   int64_t numInstances;
228   Value slicesTuple;
229 };
230 
231 // Populates 'state' with unrolled shape, unroll factors, basis and
232 // num unrolled instances for 'vectorType'.
initUnrolledVectorState(VectorType vectorType,Value initValue,const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> targetShape,UnrolledVectorState & state,OpBuilder & builder)233 static void initUnrolledVectorState(VectorType vectorType, Value initValue,
234                                     const DenseMap<int64_t, int64_t> &indexMap,
235                                     ArrayRef<int64_t> targetShape,
236                                     UnrolledVectorState &state,
237                                     OpBuilder &builder) {
238   // Compute unrolled shape of 'vectorType'.
239   state.unrolledShape.resize(vectorType.getRank());
240   getMappedElements(indexMap, targetShape, state.unrolledShape);
241   // Compute unroll factors for unrolled shape.
242   auto maybeUnrollFactors =
243       shapeRatio(vectorType.getShape(), state.unrolledShape);
244   assert(maybeUnrollFactors.hasValue());
245   state.unrollFactors = *maybeUnrollFactors;
246   // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
247   state.basis = computeStrides(state.unrollFactors);
248   state.numInstances = computeMaxLinearIndex(state.unrollFactors);
249   state.slicesTuple = nullptr;
250   if (initValue != nullptr) {
251     // Create ExtractSlicesOp.
252     SmallVector<int64_t, 4> sizes(state.unrolledShape);
253     SmallVector<int64_t, 4> strides(state.unrollFactors.size(), 1);
254     auto tupleType =
255         generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
256     state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
257         initValue.getLoc(), tupleType, initValue, sizes, strides);
258   }
259 }
260 
261 // Computes and returns the linear index of the unrolled vector at
262 // 'vectorOffsets' within the vector represented by 'state'.
263 static int64_t
getUnrolledVectorLinearIndex(UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,DenseMap<int64_t,int64_t> & indexMap)264 getUnrolledVectorLinearIndex(UnrolledVectorState &state,
265                              ArrayRef<int64_t> vectorOffsets,
266                              DenseMap<int64_t, int64_t> &indexMap) {
267   // Compute vector offsets.
268   SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
269   getMappedElements(indexMap, vectorOffsets, sliceOffsets);
270   // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
271   return linearize(sliceOffsets, state.basis);
272 }
273 
274 // Returns an unrolled vector at 'vectorOffsets' within the vector
275 // represented by 'state'. The vector is created from a slice of 'initValue'
276 // if not present in 'cache'.
getOrCreateUnrolledVectorSlice(Location loc,UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,ArrayRef<int64_t> offsets,DenseMap<int64_t,int64_t> & indexMap,Value initValue,SmallVectorImpl<Value> & cache,OpBuilder & builder)277 static Value getOrCreateUnrolledVectorSlice(
278     Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
279     ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
280     Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
281   // Compute slice offsets.
282   SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
283   getMappedElements(indexMap, offsets, sliceOffsets);
284   // TODO: Support non-1 strides.
285   SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
286   // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
287   int64_t sliceLinearIndex =
288       getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
289   assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
290   auto valueSlice = cache[sliceLinearIndex];
291   if (valueSlice == nullptr) {
292     // Return tuple element at 'sliceLinearIndex'.
293     auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
294     auto initValueType = initValue.getType().cast<VectorType>();
295     auto vectorType =
296         VectorType::get(state.unrolledShape, initValueType.getElementType());
297     // Initialize 'cache' with slice from 'initValue'.
298     valueSlice = builder.create<vector::TupleGetOp>(
299         loc, vectorType, state.slicesTuple, tupleIndex);
300     // Store value back to 'cache'.
301     cache[sliceLinearIndex] = valueSlice;
302   }
303   return valueSlice;
304 }
305 
306 // VectorState aggregates per-operand/result vector state required for
307 // creating slices of vector operands, and clones of the operation being
308 // unrolled.
309 struct VectorState {
310   // The type of this vector.
311   VectorType type;
312   // Map from iteration space index to vector dimension index.
313   DenseMap<int64_t, int64_t> indexMap;
314   // Index of this value in operation's operand list (-1 if not an operand).
315   int64_t operandIndex = -1;
316   // Accumulator iterator flag.
317   bool isAcc = false;
318 };
319 
320 //
321 // unrollSingleResultStructuredOp
322 //
323 // Returns a value representing the result of structured operation 'op'
324 // with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
325 // A list of VectorState objects must be specified in 'vectors', where
326 // each VectorState in the list represents a vector operand or vector result
327 // (if the operation does not have an accumulator operand).
328 // The VectorState at index 'resultIndex' in the list must be the state
329 // associated with the operations single result (i.e. either its accumulator
330 // operand or vector result value).
331 //
332 // Example:
333 //
334 //  // Before unrolling
335 //
336 //   operand0                operand1                operand2
337 //       \                      |                      /
338 //        -------------------- opA --------------------
339 //
340 //  // After unrolling by 2
341 //
342 //   operand0                operand1                operand2
343 //   /      \                /      \                /      \
344 // slice00  slice01       slice10  slice11        slice20  slice21
345 //   \         |            |          |            /          |
346 //    -------------------- opA0 --------------------           |
347 //             |            |          |                       |
348 //              \           |          |                      /
349 //               -------------------- opA1 -------------------
350 //                          |          |
351 //                           \        /
352 //                           insertslice
353 //                                |
354 
355 // TODO: Add the following canonicalization/simplification patterns:
356 // *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards
357 //    InsertStridedSlice operand to StridedSlice.
358 // *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks
359 //    if there are duplicate identical StridedSlice ops from SourceOp, and
360 //    rewrites itself to use the first duplicate. This transformation should
361 //    cause users of identifical StridedSlice ops to reuse the same StridedSlice
362 //    operation, and leave the duplicate StridedSlice ops with no users
363 //    (removable with DCE).
364 
365 // TODO: Generalize this to support structured ops beyond
366 // vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
unrollSingleResultStructuredOp(Operation * op,ArrayRef<int64_t> iterationBounds,std::vector<VectorState> & vectors,unsigned resultIndex,ArrayRef<int64_t> targetShape,OpBuilder & builder)367 static Value unrollSingleResultStructuredOp(Operation *op,
368                                             ArrayRef<int64_t> iterationBounds,
369                                             std::vector<VectorState> &vectors,
370                                             unsigned resultIndex,
371                                             ArrayRef<int64_t> targetShape,
372                                             OpBuilder &builder) {
373   auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
374   if (!shapedType || !shapedType.hasStaticShape())
375     assert(false && "Expected a statically shaped result type");
376 
377   // Compute unroll factors for 'iterationBounds' based on 'targetShape'
378   auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
379   if (!maybeUnrollFactors.hasValue())
380     assert(false && "Failed to compute unroll factors for target shape");
381   auto unrollFactors = *maybeUnrollFactors;
382 
383   // Compute unrolled vector state for each vector in 'vectors'.
384   unsigned numVectors = vectors.size();
385   SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
386   for (unsigned i = 0; i < numVectors; ++i) {
387     int64_t operandIndex = vectors[i].operandIndex;
388     auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
389     initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
390                             targetShape, unrolledVectorState[i], builder);
391   }
392   // Compute number of total unrolled instances.
393   auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
394   auto sliceStrides = computeStrides(unrollFactors);
395 
396   auto &resultValueState = unrolledVectorState[resultIndex];
397   auto unrolledResultType = VectorType::get(resultValueState.unrolledShape,
398                                             shapedType.getElementType());
399 
400   // Initialize caches for intermediate vector results.
401   std::vector<SmallVector<Value, 4>> caches(numVectors);
402   for (unsigned i = 0; i < numVectors; ++i)
403     caches[i].resize(unrolledVectorState[i].numInstances);
404 
405   // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
406   for (unsigned i = 0; i < numUnrolledInstances; ++i) {
407     auto vectorOffsets = delinearize(sliceStrides, i);
408     auto elementOffsets =
409         computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
410     // Get cached slice (or create slice) for each operand at 'offsets'.
411     SmallVector<Value, 3> operands;
412     operands.resize(op->getNumOperands());
413     for (unsigned i = 0; i < numVectors; ++i) {
414       int64_t operandIndex = vectors[i].operandIndex;
415       if (operandIndex < 0)
416         continue; // Output
417       auto operand = op->getOperand(operandIndex);
418       operands[operandIndex] = getOrCreateUnrolledVectorSlice(
419           op->getLoc(), unrolledVectorState[i], vectorOffsets, elementOffsets,
420           vectors[i].indexMap, operand, caches[i], builder);
421     }
422     // Create op on sliced vector arguments.
423     auto resultVector =
424         cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
425                                     unrolledResultType)
426             ->getResult(0);
427 
428     // Compute linear result index.
429     int64_t linearIndex = getUnrolledVectorLinearIndex(
430         resultValueState, vectorOffsets, vectors[resultIndex].indexMap);
431     // Update result cache at 'linearIndex'.
432     caches[resultIndex][linearIndex] = resultVector;
433   }
434 
435   // Create TupleOp of unrolled result vectors.
436   SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
437   SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
438   for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
439     vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>();
440     vectorTupleValues[i] = caches[resultIndex][i];
441   }
442   TupleType tupleType = builder.getTupleType(vectorTupleTypes);
443   Value tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
444                                                   vectorTupleValues);
445 
446   // Create InsertSlicesOp(Tuple(result_vectors)).
447   auto resultVectorType = op->getResult(0).getType().cast<VectorType>();
448   SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
449   SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
450 
451   Value insertSlicesOp = builder.create<vector::InsertSlicesOp>(
452       op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
453       builder.getI64ArrayAttr(strides));
454   return insertSlicesOp;
455 }
456 
getVectorContractionOpUnrollState(vector::ContractionOp contractionOp,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)457 static void getVectorContractionOpUnrollState(
458     vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
459     std::vector<VectorState> &vectors, unsigned &resultIndex) {
460   // Get map from iteration space index to lhs/rhs/result shape index.
461   std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
462   contractionOp.getIterationIndexMap(iterationIndexMapList);
463   unsigned numIterators = iterationIndexMapList.size();
464   vectors.resize(numIterators);
465   unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
466   for (unsigned i = 0; i < numIterators; ++i) {
467     vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>();
468     vectors[i].indexMap = iterationIndexMapList[i];
469     vectors[i].operandIndex = i;
470     vectors[i].isAcc = i == accOperandIndex ? true : false;
471   }
472 
473   if (llvm::size(contractionOp.masks()) == 2) {
474     // Add vectors for lhs/rhs vector mask arguments. Masks have the
475     // same vector shape lhs/rhs args, so copy their index maps.
476     vectors.push_back({contractionOp.getLHSVectorMaskType(),
477                        vectors[0].indexMap, accOperandIndex + 1, false});
478     vectors.push_back({contractionOp.getRHSVectorMaskType(),
479                        vectors[1].indexMap, accOperandIndex + 2, false});
480   }
481   // TODO: Use linalg style 'args_in'/'args_out' to partition
482   // 'vectors' instead of 'resultIndex'.
483   resultIndex = accOperandIndex;
484 }
485 
getVectorElementwiseOpUnrollState(Operation * op,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)486 static void getVectorElementwiseOpUnrollState(Operation *op,
487                                               ArrayRef<int64_t> targetShape,
488                                               std::vector<VectorState> &vectors,
489                                               unsigned &resultIndex) {
490   // Verify that operation and operands all have the same vector shape.
491   auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
492   assert(resultType && "Expected op with vector result type");
493   auto resultShape = resultType.getShape();
494   // Verify that all operands have the same vector type as result.
495   assert(llvm::all_of(op->getOperandTypes(), [=](Type type) {
496     return type.cast<VectorType>().getShape() == resultShape;
497   }));
498 
499   // Create trivial elementwise identity index map based on 'resultShape'.
500   DenseMap<int64_t, int64_t> indexMap;
501   indexMap.reserve(resultShape.size());
502   for (unsigned i = 0; i < resultShape.size(); ++i)
503     indexMap[i] = i;
504 
505   // Create VectorState each operand and single result.
506   unsigned numVectors = op->getNumOperands() + op->getNumResults();
507   vectors.resize(numVectors);
508   for (auto it : llvm::enumerate(op->getOperandTypes()))
509     vectors[it.index()] = {it.value().cast<VectorType>(), indexMap,
510                            static_cast<int64_t>(it.index()), false};
511   vectors[numVectors - 1] = {resultType, indexMap, -1, false};
512   resultIndex = numVectors - 1;
513 }
514 
515 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
516 /// calls 'fn' with linear index and indices for each slice.
generateTransferOpSlices(Type shapedElementType,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<Value> indices,OpBuilder & builder,function_ref<void (unsigned,ArrayRef<Value>)> fn)517 static void generateTransferOpSlices(
518     Type shapedElementType, VectorType vectorType, TupleType tupleType,
519     ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
520     OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
521   // Compute strides w.r.t. to slice counts in each dimension.
522   auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
523   assert(maybeDimSliceCounts.hasValue());
524   auto sliceDimCounts = *maybeDimSliceCounts;
525   auto sliceStrides = computeStrides(sliceDimCounts);
526 
527   int64_t numSlices = tupleType.size();
528   unsigned numSliceIndices = indices.size();
529   // Compute 'indexOffset' at which to update 'indices', which is equal
530   // to the memref rank (indices.size) minus the effective 'vectorRank'.
531   // The effective 'vectorRank', is equal to the rank of the vector type
532   // minus the rank of the memref vector element type (if it has one).
533   //
534   // For example:
535   //
536   //   Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector
537   //   transfer_read/write ops which read/write vectors of type
538   //   'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective
539   //   vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
540   //
541   unsigned vectorRank = vectorType.getRank();
542   if (auto sourceVectorElementType = shapedElementType.dyn_cast<VectorType>()) {
543     assert(vectorRank >= sourceVectorElementType.getRank());
544     vectorRank -= sourceVectorElementType.getRank();
545   }
546   unsigned indexOffset = numSliceIndices - vectorRank;
547 
548   auto *ctx = builder.getContext();
549   for (unsigned i = 0; i < numSlices; ++i) {
550     auto vectorOffsets = delinearize(sliceStrides, i);
551     auto elementOffsets =
552         computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
553     // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
554     SmallVector<Value, 4> sliceIndices(numSliceIndices);
555     for (unsigned j = 0; j < numSliceIndices; ++j) {
556       if (j < indexOffset) {
557         sliceIndices[j] = indices[j];
558       } else {
559         auto expr = getAffineDimExpr(0, ctx) +
560                     getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
561         auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
562         sliceIndices[j] = builder.create<AffineApplyOp>(
563             indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
564       }
565     }
566     // Call 'fn' to generate slice 'i' at 'sliceIndices'.
567     fn(i, sliceIndices);
568   }
569 }
570 
571 /// Returns true if 'map' is a suffix of an identity affine map, false
572 /// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
isIdentitySuffix(AffineMap map)573 static bool isIdentitySuffix(AffineMap map) {
574   if (map.getNumDims() < map.getNumResults())
575     return false;
576   ArrayRef<AffineExpr> results = map.getResults();
577   Optional<int> lastPos;
578   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
579     auto expr = results[i].dyn_cast<AffineDimExpr>();
580     if (!expr)
581       return false;
582     int currPos = static_cast<int>(expr.getPosition());
583     if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
584       return false;
585     lastPos = currPos;
586   }
587   return true;
588 }
589 
590 /// Unroll transfer_read ops to the given shape and create an aggregate with all
591 /// the chunks.
unrollTransferReadOp(vector::TransferReadOp readOp,ArrayRef<int64_t> targetShape,OpBuilder & builder)592 static Value unrollTransferReadOp(vector::TransferReadOp readOp,
593                                   ArrayRef<int64_t> targetShape,
594                                   OpBuilder &builder) {
595   if (!isIdentitySuffix(readOp.permutation_map()))
596     return nullptr;
597   auto sourceVectorType = readOp.getVectorType();
598   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
599 
600   Location loc = readOp.getLoc();
601   auto shapedElementType =
602       readOp.source().getType().cast<ShapedType>().getElementType();
603   auto tupleType = generateExtractSlicesOpResultType(
604       sourceVectorType, targetShape, strides, builder);
605   int64_t numSlices = tupleType.size();
606 
607   SmallVector<Value, 4> vectorTupleValues(numSlices);
608   SmallVector<Value, 4> indices(readOp.indices().begin(),
609                                 readOp.indices().end());
610   auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
611     // Get VectorType for slice 'i'.
612     auto sliceVectorType = tupleType.getType(index);
613     // Create split TransferReadOp for 'sliceUser'.
614     // `masked` attribute propagates conservatively: if the coarse op didn't
615     // need masking, the fine op doesn't either.
616     vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
617         loc, sliceVectorType, readOp.source(), sliceIndices,
618         readOp.permutation_map(), readOp.padding(),
619         readOp.masked() ? *readOp.masked() : ArrayAttr());
620   };
621   generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
622                            targetShape, strides, indices, builder, createSlice);
623 
624   // Create tuple of splice transfer read operations.
625   Value tupleOp =
626       builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
627   // Replace 'readOp' with result 'insertSlicesResult'.
628   Value newVec = builder.create<vector::InsertSlicesOp>(
629       loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
630       builder.getI64ArrayAttr(strides));
631   return newVec;
632 }
633 
634 // Entry point for unrolling declarative pattern rewrite for transfer_write op.
635 LogicalResult
unrollTransferWriteOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape,SmallVector<Value,1> & result)636 mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
637                                     ArrayRef<int64_t> targetShape,
638                                     SmallVector<Value, 1> &result) {
639   auto writeOp = cast<vector::TransferWriteOp>(op);
640   if (!isIdentitySuffix(writeOp.permutation_map()))
641     return failure();
642   VectorType sourceVectorType = writeOp.getVectorType();
643   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
644   TupleType tupleType = generateExtractSlicesOpResultType(
645       sourceVectorType, targetShape, strides, builder);
646   Location loc = writeOp.getLoc();
647   Value tuple = builder.create<vector::ExtractSlicesOp>(
648       loc, tupleType, writeOp.vector(), targetShape, strides);
649   auto shapedElementType =
650       writeOp.source().getType().cast<ShapedType>().getElementType();
651   SmallVector<Value, 4> indices(writeOp.indices().begin(),
652                                 writeOp.indices().end());
653   // If the TransferWrite returns a tensor, keep track of the last tensor
654   // created.
655   Value resultTensor;
656   auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
657     auto element = builder.create<vector::TupleGetOp>(
658         loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
659     Operation *write = builder.create<vector::TransferWriteOp>(
660         loc, element.getResult(),
661         resultTensor ? resultTensor : writeOp.source(), sliceIndices,
662         writeOp.permutation_map(),
663         writeOp.masked() ? *writeOp.masked() : ArrayAttr());
664     if (!write->getResults().empty())
665       resultTensor = write->getResult(0);
666   };
667   generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
668                            targetShape, strides, indices, builder, createSlice);
669   if (resultTensor)
670     result.push_back(resultTensor);
671   return success();
672 }
673 
674 // Entry point for unrolling declarative pattern rewrites.
675 SmallVector<Value, 1>
unrollSingleResultVectorOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape)676 mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
677                                          ArrayRef<int64_t> targetShape) {
678   assert(op->getNumResults() == 1 && "Expected single result operation");
679 
680   // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
681   SmallVector<int64_t, 6> iterationBounds;
682   auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
683   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
684   assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
685 
686   std::vector<VectorState> vectors;
687   unsigned resultIndex;
688 
689   if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
690     return SmallVector<Value, 1>{
691         unrollTransferReadOp(readOp, targetShape, builder)};
692 
693   if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
694     // Populate state for vector ContractionOp.
695     getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
696                                       resultIndex);
697   } else {
698     // Populate state for vector elementwise op.
699     getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
700   }
701 
702   // Unroll 'op' with 'iterationBounds' to 'targetShape'.
703   return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
704       op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
705 }
706 
707 namespace {
708 
709 // Splits vector TransferReadOp into smaller TransferReadOps based on slicing
710 // scheme of its unique ExtractSlicesOp user.
711 struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
712   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
713 
matchAndRewrite__anon85c463fe0511::SplitTransferReadOp714   LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
715                                 PatternRewriter &rewriter) const override {
716     // TODO: Support splitting TransferReadOp with non-identity
717     // permutation maps. Repurpose code from MaterializeVectors transformation.
718     if (!isIdentitySuffix(xferReadOp.permutation_map()))
719       return failure();
720     // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
721     Value xferReadResult = xferReadOp.getResult();
722     auto extractSlicesOp =
723         dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
724     if (!xferReadResult.hasOneUse() || !extractSlicesOp)
725       return failure();
726 
727     // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
728     SmallVector<int64_t, 4> sizes;
729     extractSlicesOp.getSizes(sizes);
730     SmallVector<int64_t, 4> strides;
731     extractSlicesOp.getStrides(strides);
732     assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
733 
734     Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
735     if (!newVec)
736       return failure();
737     rewriter.replaceOp(xferReadOp, newVec);
738     return success();
739   }
740 };
741 
742 // Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
743 struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
744   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
745 
matchAndRewrite__anon85c463fe0511::SplitTransferWriteOp746   LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
747                                 PatternRewriter &rewriter) const override {
748     // TODO: Support splitting TransferWriteOp with non-identity
749     // permutation maps. Repurpose code from MaterializeVectors transformation.
750     if (!isIdentitySuffix(xferWriteOp.permutation_map()))
751       return failure();
752     // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
753     auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
754     auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
755     if (!insertSlicesOp)
756       return failure();
757 
758     // Get TupleOp operand of 'insertSlicesOp'.
759     auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
760         insertSlicesOp.vectors().getDefiningOp());
761     if (!tupleOp)
762       return failure();
763 
764     // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
765     auto sourceTupleType = insertSlicesOp.getSourceTupleType();
766     auto resultVectorType = insertSlicesOp.getResultVectorType();
767     SmallVector<int64_t, 4> sizes;
768     insertSlicesOp.getSizes(sizes);
769     SmallVector<int64_t, 4> strides;
770     insertSlicesOp.getStrides(strides);
771 
772     Location loc = xferWriteOp.getLoc();
773     auto shapedElementType =
774         xferWriteOp.source().getType().cast<ShapedType>().getElementType();
775     SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
776                                   xferWriteOp.indices().end());
777     Value resultTensor;
778     auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
779       // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
780       // `masked` attribute propagates conservatively: if the coarse op didn't
781       // need masking, the fine op doesn't either.
782       Operation *write = rewriter.create<vector::TransferWriteOp>(
783           loc, tupleOp.getOperand(index),
784           resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
785           xferWriteOp.permutation_map(),
786           xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
787       if (!write->getResults().empty())
788         resultTensor = write->getResult(0);
789     };
790     generateTransferOpSlices(shapedElementType, resultVectorType,
791                              sourceTupleType, sizes, strides, indices, rewriter,
792                              createSlice);
793 
794     // Erase old 'xferWriteOp'.
795     if (resultTensor)
796       rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
797     else
798       rewriter.eraseOp(xferWriteOp);
799     return success();
800   }
801 };
802 
803 /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
804 /// on vector types.
805 struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
806   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
807 
matchAndRewrite__anon85c463fe0511::ShapeCastOpDecomposer808   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
809                                 PatternRewriter &rewriter) const override {
810     // Check if 'shapeCastOp' has tuple source/result type.
811     auto sourceTupleType =
812         shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
813     auto resultTupleType =
814         shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
815     if (!sourceTupleType || !resultTupleType)
816       return failure();
817     assert(sourceTupleType.size() == resultTupleType.size());
818 
819     // Create single-vector ShapeCastOp for each source tuple element.
820     Location loc = shapeCastOp.getLoc();
821     SmallVector<Value, 8> resultElements;
822     resultElements.reserve(resultTupleType.size());
823     for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
824       auto sourceElement = rewriter.create<vector::TupleGetOp>(
825           loc, sourceTupleType.getType(i), shapeCastOp.source(),
826           rewriter.getI64IntegerAttr(i));
827       resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
828           loc, resultTupleType.getType(i), sourceElement));
829     }
830 
831     // Replace 'shapeCastOp' with tuple of 'resultElements'.
832     rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
833                                                  resultElements);
834     return success();
835   }
836 };
837 
838 /// Returns the producer Value of the same type as 'consumerValue', by tracking
839 /// the tuple index and offsets of the consumer vector value through the
840 /// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
841 /// and ShapeCastOp) from consumer to producer. Each operation in the chain is
842 /// structured, and so the tuple index and offsets can be mapped from result to
843 /// input, while visiting each operation in the chain.
844 /// Returns nullptr on failure.
getProducerValue(Value consumerValue)845 static Value getProducerValue(Value consumerValue) {
846   auto consumerVectorType = consumerValue.getType().cast<VectorType>();
847   // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
848   int64_t tupleIndex = -1;
849   SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
850   auto *op = consumerValue.getDefiningOp();
851   while (op != nullptr) {
852     if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
853       assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
854 
855       // Update 'tupleIndex' and next defining 'op' to visit.
856       tupleIndex = tupleGetOp.getIndex();
857       op = tupleGetOp.vectors().getDefiningOp();
858     } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
859       assert(tupleIndex >= 0);
860 
861       // Compute slice strides for 'extractSlicesOp'.
862       SmallVector<int64_t, 4> sizes;
863       extractSlicesOp.getSizes(sizes);
864       auto sliceStrides = computeStrides(
865           extractSlicesOp.getSourceVectorType().getShape(), sizes);
866 
867       // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
868       // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
869       auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
870       auto elementOffsets =
871           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
872 
873       // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
874       // to the 'extractSlicesOp' input vector type.
875       assert(offsets.size() == elementOffsets.size());
876       for (unsigned i = 0, e = offsets.size(); i < e; ++i)
877         offsets[i] += elementOffsets[i];
878 
879       // Clear 'tupleIndex' and update next defining 'op' to visit.
880       tupleIndex = -1;
881       op = extractSlicesOp.vector().getDefiningOp();
882     } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
883       assert(tupleIndex == -1);
884 
885       // Compute slice strides for 'insertSlicesOp'.
886       SmallVector<int64_t, 4> sizes;
887       insertSlicesOp.getSizes(sizes);
888       auto sliceStrides = computeStrides(
889           insertSlicesOp.getResultVectorType().getShape(), sizes);
890 
891       // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
892       // of 'insertSlicesOp' result vector type at 'offsets'.
893       SmallVector<int64_t, 4> vectorOffsets(offsets.size());
894       assert(offsets.size() == sizes.size());
895       for (unsigned i = 0, e = offsets.size(); i < e; ++i)
896         vectorOffsets[i] = offsets[i] / sizes[i];
897 
898       // Compute the source tuple element index.
899       tupleIndex = linearize(vectorOffsets, sliceStrides);
900 
901       // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
902       // relative to input tuple element vector type at 'tupleIndex'.
903       auto elementOffsets =
904           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
905       assert(offsets.size() == elementOffsets.size());
906       for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
907         offsets[i] -= elementOffsets[i];
908         assert(offsets[i] >= 0);
909       }
910 
911       // Update next defining 'op' to visit.
912       op = insertSlicesOp.vectors().getDefiningOp();
913     } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
914       assert(tupleIndex >= 0);
915 
916       // Return tuple element 'value' at 'tupleIndex' if it matches type.
917       auto value = tupleOp.getOperand(tupleIndex);
918       if (value.getType() == consumerVectorType)
919         return value;
920 
921       // Update 'tupleIndex' and next defining 'op' to visit.
922       tupleIndex = -1;
923       op = value.getDefiningOp();
924     } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
925       if (shapeCastOp.source().getType().isa<TupleType>())
926         return nullptr;
927       assert(tupleIndex == -1);
928       auto sourceVectorType = shapeCastOp.getSourceVectorType();
929       auto sourceVectorShape = sourceVectorType.getShape();
930       unsigned sourceVectorRank = sourceVectorType.getRank();
931       auto resultVectorType = shapeCastOp.getResultVectorType();
932       auto resultVectorShape = resultVectorType.getShape();
933       unsigned resultVectorRank = resultVectorType.getRank();
934 
935       int i = sourceVectorRank - 1;
936       int j = resultVectorRank - 1;
937 
938       // Check that source/result vector shape prefixes match while updating
939       // 'newOffsets'.
940       SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
941       for (auto it : llvm::zip(llvm::reverse(sourceVectorShape),
942                                llvm::reverse(resultVectorShape))) {
943         if (std::get<0>(it) != std::get<1>(it))
944           return nullptr;
945         newOffsets[i--] = offsets[j--];
946       }
947 
948       // Check that remaining prefix of source/result vector shapes are all 1s.
949       // Currently we only support producer/consumer tracking through trivial
950       // shape cast ops. Examples:
951       //   %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
952       //   %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
953       assert(i == -1 || j == -1);
954       if (i >= 0 &&
955           !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
956                        [](int64_t v) { return v == 1; }))
957         return nullptr;
958       if (j >= 0 &&
959           !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
960                        [](int64_t v) { return v == 1; }))
961         return nullptr;
962 
963       offsets.swap(newOffsets);
964       op = shapeCastOp.source().getDefiningOp();
965     } else {
966       // Check if 'op' produces a Value with the same type as 'consumerValue'.
967       if (op->getNumResults() == 1 &&
968           op->getResult(0).getType() == consumerVectorType)
969         return op->getResult(0);
970       return nullptr;
971     }
972   }
973   return nullptr;
974 }
975 
976 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
977 //
978 // Example:
979 //
980 //  The following MLIR with cancelling ShapeCastOps:
981 //
982 //   %0 = source : vector<5x4x2xf32>
983 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
984 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
985 //   %3 = user %2 : vector<5x4x2xf32>
986 //
987 //  Should canonicalize to the following:
988 //
989 //   %0 = source : vector<5x4x2xf32>
990 //   %1 = user %0 : vector<5x4x2xf32>
991 //
992 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
993   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
994 
matchAndRewrite__anon85c463fe0511::ShapeCastOpFolder995   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
996                                 PatternRewriter &rewriter) const override {
997     // Check if we can replace 'shapeCastOp' result with its producer.
998     if (auto producer = getProducerValue(shapeCastOp.getResult())) {
999       rewriter.replaceOp(shapeCastOp, producer);
1000       return success();
1001     }
1002 
1003     // Check if 'shapeCastOp' has vector source/result type.
1004     auto sourceVectorType =
1005         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
1006     auto resultVectorType =
1007         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
1008     if (!sourceVectorType || !resultVectorType)
1009       return failure();
1010 
1011     // Check if shape cast op source operand is also a shape cast op.
1012     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
1013         shapeCastOp.source().getDefiningOp());
1014     if (!sourceShapeCastOp)
1015       return failure();
1016     auto operandSourceVectorType =
1017         sourceShapeCastOp.source().getType().cast<VectorType>();
1018     auto operandResultVectorType = sourceShapeCastOp.getType();
1019 
1020     // Check if shape cast operations invert each other.
1021     if (operandSourceVectorType != resultVectorType ||
1022         operandResultVectorType != sourceVectorType)
1023       return failure();
1024 
1025     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
1026     return success();
1027   }
1028 };
1029 
1030 // Patter rewrite which forward tuple elements to their users.
1031 // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
1032 //   -> User(Producer)
1033 struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
1034   using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
1035 
matchAndRewrite__anon85c463fe0511::TupleGetFolderOp1036   LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
1037                                 PatternRewriter &rewriter) const override {
1038     if (auto producer = getProducerValue(tupleGetOp.getResult())) {
1039       rewriter.replaceOp(tupleGetOp, producer);
1040       return success();
1041     }
1042     return failure();
1043   }
1044 };
1045 
1046 /// Progressive lowering of ExtractSlicesOp to tuple of ExtractStridedSliceOp.
1047 /// One:
1048 ///   %x = vector.extract_slices %0
1049 /// is replaced by:
1050 ///   %a = vector.strided_slice %0
1051 ///   %b = vector.strided_slice %0
1052 ///   ..
1053 ///   %x = vector.tuple %a, %b, ..
1054 class ExtractSlicesOpLowering
1055     : public OpRewritePattern<vector::ExtractSlicesOp> {
1056 public:
1057   using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
1058 
matchAndRewrite(vector::ExtractSlicesOp op,PatternRewriter & rewriter) const1059   LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
1060                                 PatternRewriter &rewriter) const override {
1061     auto loc = op.getLoc();
1062 
1063     VectorType vectorType = op.getSourceVectorType();
1064     auto shape = vectorType.getShape();
1065 
1066     SmallVector<int64_t, 4> sizes;
1067     op.getSizes(sizes);
1068     SmallVector<int64_t, 4> strides;
1069     op.getStrides(strides); // all-ones at the moment
1070 
1071     // For each element in the tuple, generate the proper strided slice.
1072     TupleType tupleType = op.getResultTupleType();
1073     int64_t tupleSize = tupleType.size();
1074     SmallVector<Value, 4> tupleValues(tupleSize);
1075     auto sliceStrides = computeStrides(shape, sizes);
1076     for (int64_t i = 0; i < tupleSize; ++i) {
1077       auto vectorOffsets = delinearize(sliceStrides, i);
1078       auto elementOffsets =
1079           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1080       auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
1081       // Insert in tuple.
1082       tupleValues[i] = rewriter.create<vector::ExtractStridedSliceOp>(
1083           loc, op.vector(), elementOffsets, sliceSizes, strides);
1084     }
1085 
1086     rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
1087     return success();
1088   }
1089 };
1090 
1091 /// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
1092 /// One:
1093 ///   %x = vector.insert_slices %0
1094 /// is replaced by:
1095 ///   %r0 = zero-result
1096 ///   %t1 = vector.tuple_get %0, 0
1097 ///   %r1 = vector.insert_strided_slice %r0, %t1
1098 ///   %t2 = vector.tuple_get %0, 1
1099 ///   %r2 = vector.insert_strided_slice %r1, %t2
1100 ///   ..
1101 ///   %x  = ..
1102 class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
1103 public:
1104   using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
1105 
matchAndRewrite(vector::InsertSlicesOp op,PatternRewriter & rewriter) const1106   LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
1107                                 PatternRewriter &rewriter) const override {
1108     auto loc = op.getLoc();
1109 
1110     VectorType vectorType = op.getResultVectorType();
1111     auto shape = vectorType.getShape();
1112 
1113     SmallVector<int64_t, 4> sizes;
1114     op.getSizes(sizes);
1115     SmallVector<int64_t, 4> strides;
1116     op.getStrides(strides); // all-ones at the moment
1117 
1118     // Prepare result.
1119     Value result = rewriter.create<ConstantOp>(
1120         loc, vectorType, rewriter.getZeroAttr(vectorType));
1121 
1122     // For each element in the tuple, extract the proper strided slice.
1123     TupleType tupleType = op.getSourceTupleType();
1124     int64_t tupleSize = tupleType.size();
1125     auto sliceStrides = computeStrides(shape, sizes);
1126     for (int64_t i = 0; i < tupleSize; ++i) {
1127       auto vectorOffsets = delinearize(sliceStrides, i);
1128       auto elementOffsets =
1129           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1130       // Extract from tuple into the result.
1131       auto index = rewriter.getI64IntegerAttr(i);
1132       auto tupleGet = rewriter.create<vector::TupleGetOp>(
1133           loc, tupleType.getType(i), op.getOperand(), index);
1134       result = rewriter.create<vector::InsertStridedSliceOp>(
1135           loc, tupleGet, result, elementOffsets, strides);
1136     }
1137 
1138     rewriter.replaceOp(op, result);
1139     return success();
1140   }
1141 };
1142 
1143 /// Progressive lowering of BroadcastOp.
1144 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
1145 public:
1146   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
1147 
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const1148   LogicalResult matchAndRewrite(vector::BroadcastOp op,
1149                                 PatternRewriter &rewriter) const override {
1150     auto loc = op.getLoc();
1151     VectorType dstType = op.getVectorType();
1152     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
1153     Type eltType = dstType.getElementType();
1154 
1155     // Determine rank of source and destination.
1156     int64_t srcRank = srcType ? srcType.getRank() : 0;
1157     int64_t dstRank = dstType.getRank();
1158 
1159     // Duplicate this rank.
1160     // For example:
1161     //   %x = broadcast %y  : k-D to n-D, k < n
1162     // becomes:
1163     //   %b = broadcast %y  : k-D to (n-1)-D
1164     //   %x = [%b,%b,%b,%b] : n-D
1165     // becomes:
1166     //   %b = [%y,%y]       : (n-1)-D
1167     //   %x = [%b,%b,%b,%b] : n-D
1168     if (srcRank < dstRank) {
1169       // Scalar to any vector can use splat.
1170       if (srcRank == 0) {
1171         rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
1172         return success();
1173       }
1174       // Duplication.
1175       VectorType resType =
1176           VectorType::get(dstType.getShape().drop_front(), eltType);
1177       Value bcst =
1178           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
1179       Value result = rewriter.create<ConstantOp>(loc, dstType,
1180                                                  rewriter.getZeroAttr(dstType));
1181       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1182         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1183       rewriter.replaceOp(op, result);
1184       return success();
1185     }
1186 
1187     // Find non-matching dimension, if any.
1188     assert(srcRank == dstRank);
1189     int64_t m = -1;
1190     for (int64_t r = 0; r < dstRank; r++)
1191       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
1192         m = r;
1193         break;
1194       }
1195 
1196     // All trailing dimensions are the same. Simply pass through.
1197     if (m == -1) {
1198       rewriter.replaceOp(op, op.source());
1199       return success();
1200     }
1201 
1202     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
1203     if (srcRank == 1) {
1204       assert(m == 0);
1205       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1206       rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
1207       return success();
1208     }
1209 
1210     // Any non-matching dimension forces a stretch along this rank.
1211     // For example:
1212     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
1213     // becomes:
1214     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
1215     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
1216     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
1217     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
1218     //   %x = [%a,%b,%c,%d]
1219     // becomes:
1220     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
1221     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
1222     //   %a = [%u, %v]
1223     //   ..
1224     //   %x = [%a,%b,%c,%d]
1225     VectorType resType =
1226         VectorType::get(dstType.getShape().drop_front(), eltType);
1227     Value result = rewriter.create<ConstantOp>(loc, dstType,
1228                                                rewriter.getZeroAttr(dstType));
1229     if (m == 0) {
1230       // Stetch at start.
1231       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1232       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1233       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1234         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1235     } else {
1236       // Stetch not at start.
1237       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
1238         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
1239         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1240         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1241       }
1242     }
1243     rewriter.replaceOp(op, result);
1244     return success();
1245   }
1246 };
1247 
1248 /// Progressive lowering of TransposeOp.
1249 /// One:
1250 ///   %x = vector.transpose %y, [1, 0]
1251 /// is replaced by:
1252 ///   %z = constant dense<0.000000e+00>
1253 ///   %0 = vector.extract %y[0, 0]
1254 ///   %1 = vector.insert %0, %z [0, 0]
1255 ///   ..
1256 ///   %x = vector.insert .., .. [.., ..]
1257 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
1258 public:
1259   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
1260 
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,MLIRContext * context)1261   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
1262                       MLIRContext *context)
1263       : OpRewritePattern<vector::TransposeOp>(context),
1264         vectorTransformsOptions(vectorTransformsOptions) {}
1265 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const1266   LogicalResult matchAndRewrite(vector::TransposeOp op,
1267                                 PatternRewriter &rewriter) const override {
1268     auto loc = op.getLoc();
1269 
1270     VectorType resType = op.getResultType();
1271 
1272     // Set up convenience transposition table.
1273     SmallVector<int64_t, 4> transp;
1274     for (auto attr : op.transp())
1275       transp.push_back(attr.cast<IntegerAttr>().getInt());
1276 
1277     // Handle a true 2-D matrix transpose differently when requested.
1278     if (vectorTransformsOptions.vectorTransposeLowering ==
1279             vector::VectorTransposeLowering::Flat &&
1280         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
1281       Type flattenedType =
1282           VectorType::get(resType.getNumElements(), resType.getElementType());
1283       auto matrix =
1284           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
1285       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
1286       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
1287       Value trans = rewriter.create<vector::FlatTransposeOp>(
1288           loc, flattenedType, matrix, rows, columns);
1289       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
1290       return success();
1291     }
1292 
1293     // Generate fully unrolled extract/insert ops.
1294     Value result = rewriter.create<ConstantOp>(loc, resType,
1295                                                rewriter.getZeroAttr(resType));
1296     SmallVector<int64_t, 4> lhs(transp.size(), 0);
1297     SmallVector<int64_t, 4> rhs(transp.size(), 0);
1298     rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
1299                                          op.vector(), result, rewriter));
1300     return success();
1301   }
1302 
1303 private:
1304   // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
1305   // operation when al ranks are exhausted.
expandIndices(Location loc,VectorType resType,int64_t pos,SmallVector<int64_t,4> & transp,SmallVector<int64_t,4> & lhs,SmallVector<int64_t,4> & rhs,Value input,Value result,PatternRewriter & rewriter) const1306   Value expandIndices(Location loc, VectorType resType, int64_t pos,
1307                       SmallVector<int64_t, 4> &transp,
1308                       SmallVector<int64_t, 4> &lhs,
1309                       SmallVector<int64_t, 4> &rhs, Value input, Value result,
1310                       PatternRewriter &rewriter) const {
1311     if (pos >= resType.getRank()) {
1312       auto ridx = rewriter.getI64ArrayAttr(rhs);
1313       auto lidx = rewriter.getI64ArrayAttr(lhs);
1314       Type eltType = resType.getElementType();
1315       Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
1316       return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
1317     }
1318     for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
1319       lhs[pos] = d;
1320       rhs[transp[pos]] = d;
1321       result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
1322                              result, rewriter);
1323     }
1324     return result;
1325   }
1326 
1327   /// Options to control the vector patterns.
1328   vector::VectorTransformsOptions vectorTransformsOptions;
1329 };
1330 
1331 /// Progressive lowering of OuterProductOp.
1332 /// One:
1333 ///   %x = vector.outerproduct %lhs, %rhs, %acc
1334 /// is replaced by:
1335 ///   %z = zero-result
1336 ///   %0 = vector.extract %lhs[0]
1337 ///   %1 = vector.broadcast %0
1338 ///   %2 = vector.extract %acc[0]
1339 ///   %3 = vector.fma %1, %rhs, %2
1340 ///   %4 = vector.insert %3, %z[0]
1341 ///   ..
1342 ///   %x = vector.insert %.., %..[N-1]
1343 ///
1344 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1345 public:
1346   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
1347 
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const1348   LogicalResult matchAndRewrite(vector::OuterProductOp op,
1349                                 PatternRewriter &rewriter) const override {
1350     auto loc = op.getLoc();
1351 
1352     VectorType lhsType = op.getOperandVectorTypeLHS();
1353     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
1354     VectorType resType = op.getVectorType();
1355     Type eltType = resType.getElementType();
1356     bool isInt = eltType.isa<IntegerType>();
1357     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
1358 
1359     if (!rhsType) {
1360       // Special case: AXPY operation.
1361       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
1362       rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
1363       return success();
1364     }
1365 
1366     Value result = rewriter.create<ConstantOp>(loc, resType,
1367                                                rewriter.getZeroAttr(resType));
1368     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1369       auto pos = rewriter.getI64ArrayAttr(d);
1370       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
1371       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1372       Value r = nullptr;
1373       if (acc)
1374         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
1375       Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
1376       result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
1377     }
1378     rewriter.replaceOp(op, result);
1379     return success();
1380   }
1381 
1382 private:
genMult(Location loc,Value x,Value y,Value acc,bool isInt,PatternRewriter & rewriter)1383   static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
1384                        PatternRewriter &rewriter) {
1385     if (acc) {
1386       if (isInt)
1387         return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
1388                                        acc);
1389       return rewriter.create<vector::FMAOp>(loc, x, y, acc);
1390     }
1391     if (isInt)
1392       return rewriter.create<MulIOp>(loc, x, y);
1393     return rewriter.create<MulFOp>(loc, x, y);
1394   }
1395 };
1396 
1397 /// Progressive lowering of ConstantMaskOp.
1398 /// One:
1399 ///   %x = vector.constant_mask [a,b]
1400 /// is replaced by:
1401 ///   %z = zero-result
1402 ///   %l = vector.constant_mask [b]
1403 ///   %4 = vector.insert %l, %z[0]
1404 ///   ..
1405 ///   %x = vector.insert %l, %..[a-1]
1406 /// until a one-dimensional vector is reached. All these operations
1407 /// will be folded at LLVM IR level.
1408 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
1409 public:
1410   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
1411 
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const1412   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
1413                                 PatternRewriter &rewriter) const override {
1414     auto loc = op.getLoc();
1415     auto dstType = op.getType();
1416     auto eltType = dstType.getElementType();
1417     auto dimSizes = op.mask_dim_sizes();
1418     int64_t rank = dimSizes.size();
1419     int64_t trueDim = std::min(dstType.getDimSize(0),
1420                                dimSizes[0].cast<IntegerAttr>().getInt());
1421 
1422     if (rank == 1) {
1423       // Express constant 1-D case in explicit vector form:
1424       //   [T,..,T,F,..,F].
1425       SmallVector<bool, 4> values(dstType.getDimSize(0));
1426       for (int64_t d = 0; d < trueDim; d++)
1427         values[d] = true;
1428       rewriter.replaceOpWithNewOp<ConstantOp>(
1429           op, dstType, rewriter.getBoolVectorAttr(values));
1430       return success();
1431     }
1432 
1433     VectorType lowType =
1434         VectorType::get(dstType.getShape().drop_front(), eltType);
1435     SmallVector<int64_t, 4> newDimSizes;
1436     for (int64_t r = 1; r < rank; r++)
1437       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
1438     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
1439         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
1440     Value result = rewriter.create<ConstantOp>(loc, dstType,
1441                                                rewriter.getZeroAttr(dstType));
1442     for (int64_t d = 0; d < trueDim; d++) {
1443       auto pos = rewriter.getI64ArrayAttr(d);
1444       result =
1445           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
1446     }
1447     rewriter.replaceOp(op, result);
1448     return success();
1449   }
1450 };
1451 
1452 /// Progressive lowering of CreateMaskOp.
1453 /// One:
1454 ///   %x = vector.create_mask %a, ... : vector<dx...>
1455 /// is replaced by:
1456 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
1457 ///   %0 = cmpi "slt", %ci, %a       |
1458 ///   %1 = select %0, %l, %zeroes    |
1459 ///   %r = vector.insert %1, %pr [i] | d-times
1460 ///   %x = ....
1461 /// until a one-dimensional vector is reached.
1462 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
1463 public:
1464   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
1465 
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const1466   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1467                                 PatternRewriter &rewriter) const override {
1468     auto loc = op.getLoc();
1469     auto dstType = op.getResult().getType().cast<VectorType>();
1470     auto eltType = dstType.getElementType();
1471     int64_t dim = dstType.getDimSize(0);
1472     int64_t rank = dstType.getRank();
1473     Value idx = op.getOperand(0);
1474 
1475     if (rank == 1)
1476       return failure(); // leave for lowering
1477 
1478     VectorType lowType =
1479         VectorType::get(dstType.getShape().drop_front(), eltType);
1480     Value trueVal = rewriter.create<vector::CreateMaskOp>(
1481         loc, lowType, op.getOperands().drop_front());
1482     Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
1483                                                  rewriter.getZeroAttr(lowType));
1484     Value result = rewriter.create<ConstantOp>(loc, dstType,
1485                                                rewriter.getZeroAttr(dstType));
1486     for (int64_t d = 0; d < dim; d++) {
1487       Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
1488       Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
1489       Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
1490       auto pos = rewriter.getI64ArrayAttr(d);
1491       result =
1492           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
1493     }
1494     rewriter.replaceOp(op, result);
1495     return success();
1496   }
1497 };
1498 
1499 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
1500 /// vectors progressively on the way to target llvm.matrix intrinsics.
1501 /// This iterates over the most major dimension of the 2-D vector and performs
1502 /// rewrites into:
1503 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
1504 class ShapeCastOp2DDownCastRewritePattern
1505     : public OpRewritePattern<vector::ShapeCastOp> {
1506 public:
1507   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1508 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1509   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1510                                 PatternRewriter &rewriter) const override {
1511     auto sourceVectorType = op.getSourceVectorType();
1512     auto resultVectorType = op.getResultVectorType();
1513     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
1514       return failure();
1515 
1516     auto loc = op.getLoc();
1517     Value desc = rewriter.create<ConstantOp>(
1518         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1519     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
1520     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
1521       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
1522       desc = rewriter.create<vector::InsertStridedSliceOp>(
1523           loc, vec, desc,
1524           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
1525     }
1526     rewriter.replaceOp(op, desc);
1527     return success();
1528   }
1529 };
1530 
1531 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
1532 /// vectors progressively on the way from targeting llvm.matrix intrinsics.
1533 /// This iterates over the most major dimension of the 2-D vector and performs
1534 /// rewrites into:
1535 ///   vector.strided_slice from 1-D + vector.insert into 2-D
1536 class ShapeCastOp2DUpCastRewritePattern
1537     : public OpRewritePattern<vector::ShapeCastOp> {
1538 public:
1539   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1540 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1541   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1542                                 PatternRewriter &rewriter) const override {
1543     auto sourceVectorType = op.getSourceVectorType();
1544     auto resultVectorType = op.getResultVectorType();
1545     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
1546       return failure();
1547 
1548     auto loc = op.getLoc();
1549     Value desc = rewriter.create<ConstantOp>(
1550         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1551     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
1552     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
1553       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
1554           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
1555           /*sizes=*/mostMinorVectorSize,
1556           /*strides=*/1);
1557       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
1558     }
1559     rewriter.replaceOp(op, desc);
1560     return success();
1561   }
1562 };
1563 
1564 // We typically should not lower general shape cast operations into data
1565 // movement instructions, since the assumption is that these casts are
1566 // optimized away during progressive lowering. For completeness, however,
1567 // we fall back to a reference implementation that moves all elements
1568 // into the right place if we get here.
1569 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
1570 public:
1571   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1572 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1573   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1574                                 PatternRewriter &rewriter) const override {
1575     Location loc = op.getLoc();
1576     auto sourceVectorType = op.getSourceVectorType();
1577     auto resultVectorType = op.getResultVectorType();
1578     // Intended 2D/1D lowerings with better implementations.
1579     int64_t srcRank = sourceVectorType.getRank();
1580     int64_t resRank = resultVectorType.getRank();
1581     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
1582       return failure();
1583     // Compute number of elements involved in the reshape.
1584     int64_t numElts = 1;
1585     for (int64_t r = 0; r < srcRank; r++)
1586       numElts *= sourceVectorType.getDimSize(r);
1587     // Replace with data movement operations:
1588     //    x[0,0,0] = y[0,0]
1589     //    x[0,0,1] = y[0,1]
1590     //    x[0,1,0] = y[0,2]
1591     // etc., incrementing the two index vectors "row-major"
1592     // within the source and result shape.
1593     SmallVector<int64_t, 4> srcIdx(srcRank);
1594     SmallVector<int64_t, 4> resIdx(resRank);
1595     Value result = rewriter.create<ConstantOp>(
1596         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1597     for (int64_t i = 0; i < numElts; i++) {
1598       if (i != 0) {
1599         incIdx(srcIdx, sourceVectorType, srcRank - 1);
1600         incIdx(resIdx, resultVectorType, resRank - 1);
1601       }
1602       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
1603       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
1604     }
1605     rewriter.replaceOp(op, result);
1606     return success();
1607   }
1608 
1609 private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)1610   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
1611     assert(0 <= r && r < tp.getRank());
1612     if (++idx[r] == tp.getDimSize(r)) {
1613       idx[r] = 0;
1614       incIdx(idx, tp, r - 1);
1615     }
1616   }
1617 };
1618 
1619 } // namespace
1620 
1621 namespace mlir {
1622 
1623 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1624 /// semantics to:
1625 /// ```
1626 ///    %flattened_a = vector.shape_cast %a
1627 ///    %flattened_b = vector.shape_cast %b
1628 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1629 ///    %d = vector.shape_cast %%flattened_d
1630 ///    %e = add %c, %d
1631 /// ```
1632 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1633 //
1634 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
1635 /// the vector.contract op is a row-major matrix multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1636 LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
1637     vector::ContractionOp op, PatternRewriter &rewriter) const {
1638   // TODO: implement masks
1639   if (llvm::size(op.masks()) != 0)
1640     return failure();
1641   if (vectorTransformsOptions.vectorContractLowering !=
1642       vector::VectorContractLowering::Matmul)
1643     return failure();
1644   if (failed(filter(op)))
1645     return failure();
1646 
1647   auto iteratorTypes = op.iterator_types().getValue();
1648   if (!isParallelIterator(iteratorTypes[0]) ||
1649       !isParallelIterator(iteratorTypes[1]) ||
1650       !isReductionIterator(iteratorTypes[2]))
1651     return failure();
1652 
1653   if (!isRowMajorMatmul(op.indexing_maps()))
1654     return failure();
1655 
1656   Type elementType = op.getLhsType().getElementType();
1657   if (!elementType.isIntOrFloat())
1658     return failure();
1659 
1660   VectorType lhsType = op.getLhsType();
1661   VectorType rhsType = op.getRhsType();
1662   int64_t lhsRows = lhsType.getDimSize(0);
1663   int64_t lhsColumns = lhsType.getDimSize(1);
1664   int64_t rhsColumns = rhsType.getDimSize(1);
1665 
1666   Type flattenedLHSType =
1667       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1668   Type flattenedRHSType =
1669       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1670   auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
1671                                                   op.lhs());
1672   auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
1673                                                   op.rhs());
1674 
1675   Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
1676                                                 lhsColumns, rhsColumns);
1677   mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
1678                                              mul);
1679   if (elementType.isa<IntegerType>())
1680     rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
1681   else
1682     rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
1683 
1684   return success();
1685 }
1686 
1687 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1688 /// semantics to a reduction_size-unrolled sequence:
1689 /// ```
1690 ///    %at = vector.transpose %a, [1, 0]
1691 ///    %bRow0 = vector.extract %b[0]
1692 ///    %atRow0 = vector.extract %at[0]
1693 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1694 ///    ...
1695 ///    %bRowK = vector.extract %b[K]
1696 ///    %atRowK = vector.extract %at[K]
1697 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1698 /// ```
1699 ///
1700 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1701 /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1702 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1703     vector::ContractionOp op, PatternRewriter &rewriter) const {
1704   // TODO: implement masks
1705   if (llvm::size(op.masks()) != 0)
1706     return failure();
1707 
1708   if (vectorTransformsOptions.vectorContractLowering !=
1709       vector::VectorContractLowering::OuterProduct)
1710     return failure();
1711 
1712   if (failed(filter(op)))
1713     return failure();
1714 
1715   Location loc = op.getLoc();
1716   int64_t reductionSize = 0;
1717   VectorType lhsType = op.getLhsType();
1718   Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
1719 
1720   // Set up the parallel/reduction structure in right form.
1721   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1722   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1723   AffineExpr m, n, k;
1724   bindDims(rewriter.getContext(), m, n, k);
1725   static constexpr std::array<int64_t, 2> perm = {1, 0};
1726   auto iteratorTypes = op.iterator_types().getValue();
1727   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1728   if (isParallelIterator(iteratorTypes[0]) &&
1729       isParallelIterator(iteratorTypes[1]) &&
1730       isReductionIterator(iteratorTypes[2])) {
1731     //
1732     // Two outer parallel, one inner reduction (matmat flavor).
1733     //
1734     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1735       // This is the classical row-major matmul. Just permute the lhs.
1736       reductionSize = lhsType.getDimSize(1);
1737       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1738     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1739       // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1740       reductionSize = lhsType.getDimSize(1);
1741       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1742       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1743     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1744       // No need to permute anything.
1745       reductionSize = lhsType.getDimSize(0);
1746     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1747       // Just permute the rhs.
1748       reductionSize = lhsType.getDimSize(0);
1749       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1750     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1751       // This is the classical row-major matmul. Just permute the lhs.
1752       reductionSize = lhsType.getDimSize(1);
1753       Value tmp = rhs;
1754       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1755       lhs = tmp;
1756     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1757       // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1758       reductionSize = lhsType.getDimSize(1);
1759       Value tmp = rhs;
1760       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1761       lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1762     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1763       // No need to permute anything, but still swap lhs and rhs.
1764       reductionSize = lhsType.getDimSize(0);
1765       std::swap(lhs, rhs);
1766     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1767       // Just permute the rhs.
1768       reductionSize = lhsType.getDimSize(0);
1769       Value tmp = lhs;
1770       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1771       rhs = tmp;
1772     } else {
1773       return failure();
1774     }
1775   } else if (isParallelIterator(iteratorTypes[0]) &&
1776              isReductionIterator(iteratorTypes[1])) {
1777     //
1778     // One outer parallel, one inner reduction (matvec flavor)
1779     //
1780     if (maps == infer({{m, n}, {n}, {m}})) {
1781       // Case mat-vec: transpose.
1782       reductionSize = lhsType.getDimSize(1);
1783       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1784     } else if (maps == infer({{n, m}, {n}, {m}})) {
1785       // Case mat-trans-vec: ready to go.
1786       reductionSize = lhsType.getDimSize(0);
1787     } else if (maps == infer({{n}, {m, n}, {m}})) {
1788       // Case vec-mat: swap and transpose.
1789       reductionSize = lhsType.getDimSize(0);
1790       std::swap(lhs, rhs);
1791       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1792     } else if (maps == infer({{n}, {n, m}, {m}})) {
1793       // Case vec-mat-trans: swap and ready to go.
1794       reductionSize = lhsType.getDimSize(0);
1795       std::swap(lhs, rhs);
1796     } else {
1797       return failure();
1798     }
1799   } else {
1800     return failure();
1801   }
1802   assert(reductionSize > 0);
1803 
1804   // Unroll outer-products along reduction.
1805   for (int64_t k = 0; k < reductionSize; ++k) {
1806     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
1807     Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
1808     res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
1809   }
1810   rewriter.replaceOp(op, res);
1811   return success();
1812 }
1813 
1814 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1815 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1816                                             PatternRewriter &rewriter) const {
1817   // TODO: implement masks
1818   if (llvm::size(op.masks()) != 0)
1819     return failure();
1820 
1821   if (failed(filter(op)))
1822     return failure();
1823 
1824   if (vectorTransformsOptions.vectorContractLowering !=
1825       vector::VectorContractLowering::Dot)
1826     return failure();
1827 
1828   auto iteratorTypes = op.iterator_types().getValue();
1829   static constexpr std::array<int64_t, 2> perm = {1, 0};
1830   Location loc = op.getLoc();
1831   Value lhs = op.lhs(), rhs = op.rhs();
1832 
1833   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1834   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1835   AffineExpr m, n, k;
1836   bindDims(rewriter.getContext(), m, n, k);
1837   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1838   //
1839   // In the following we wish to make the reduction dimension innermost so we
1840   // can load vectors and just fmul + reduce into a scalar.
1841   //
1842   if (isParallelIterator(iteratorTypes[0]) &&
1843       isParallelIterator(iteratorTypes[1]) &&
1844       isReductionIterator(iteratorTypes[2])) {
1845     //
1846     // Two outer parallel, one inner reduction (matmat flavor).
1847     //
1848     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1849       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1850     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1851       // No need to permute anything.
1852     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1853       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1854       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1855     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1856       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1857     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1858       // This is the classical row-major matmul. Just permute the lhs.
1859       Value tmp = lhs;
1860       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1861       rhs = tmp;
1862     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1863       std::swap(lhs, rhs);
1864     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1865       Value tmp = lhs;
1866       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1867       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1868     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1869       Value tmp = rhs;
1870       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1871       lhs = tmp;
1872     } else {
1873       return failure();
1874     }
1875   } else if (isParallelIterator(iteratorTypes[0]) &&
1876              isReductionIterator(iteratorTypes[1])) {
1877     //
1878     // One outer parallel, one inner reduction (matvec flavor)
1879     //
1880     if (maps == infer({{m, n}, {n}, {m}})) {
1881       // No need to permute anything.
1882     } else if (maps == infer({{n, m}, {n}, {m}})) {
1883       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1884     } else if (maps == infer({{n}, {m, n}, {m}})) {
1885       std::swap(lhs, rhs);
1886     } else if (maps == infer({{n}, {n, m}, {m}})) {
1887       std::swap(lhs, rhs);
1888       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1889     } else {
1890       return failure();
1891     }
1892   } else {
1893     return failure();
1894   }
1895 
1896   VectorType dstType = op.getResultType().cast<VectorType>();
1897   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1898          "Expected dst type of rank 1 or 2");
1899 
1900   unsigned rank = dstType.getRank();
1901   unsigned dstRows = dstType.getShape()[0];
1902   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1903 
1904   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1905   Value res =
1906       rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
1907   for (unsigned r = 0; r < dstRows; ++r) {
1908     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1909     for (unsigned c = 0; c < dstColumns; ++c) {
1910       Value b = rank == 1
1911                     ? rhs
1912                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1913       Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
1914       Value reduced = rewriter.create<vector::ReductionOp>(
1915           op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
1916           m, ValueRange{});
1917 
1918       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1919                                               : SmallVector<int64_t, 2>{r, c};
1920       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1921     }
1922   }
1923   if (auto acc = op.acc())
1924     res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
1925   rewriter.replaceOp(op, res);
1926   return success();
1927 }
1928 
1929 /// Progressive lowering of ContractionOp.
1930 /// One:
1931 ///   %x = vector.contract with at least one free/batch dimension
1932 /// is replaced by:
1933 ///   %a = vector.contract with one less free/batch dimension
1934 ///   %b = vector.contract with one less free/batch dimension
1935 ///   ..
1936 ///   %x = combine %a %b ..
1937 /// until a pure contraction is reached (no free/batch dimensions),
1938 /// which is replaced by a dot-product.
1939 ///
1940 /// This only kicks in when either VectorTransformsOptions is set
1941 /// to DOT or when other contraction patterns fail.
1942 //
1943 // TODO: break down into transpose/reshape/cast ops
1944 //               when they become available to avoid code dup
1945 // TODO: investigate lowering order impact on performance
1946 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1947 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1948                                        PatternRewriter &rewriter) const {
1949   // TODO: implement masks.
1950   if (llvm::size(op.masks()) != 0)
1951     return failure();
1952 
1953   if (failed(filter(op)))
1954     return failure();
1955 
1956   // TODO: support mixed mode contract lowering.
1957   if (op.getLhsType().getElementType() !=
1958           getElementTypeOrSelf(op.getAccType()) ||
1959       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1960     return failure();
1961 
1962   // TODO: implement benefits, cost models.
1963   MLIRContext *ctx = op.getContext();
1964   ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
1965   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1966     return success();
1967   ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
1968   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1969     return success();
1970   ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
1971   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1972     return success();
1973 
1974   // Find first batch dimension in LHS/RHS, and lower when found.
1975   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1976   if (!batchDimMap.empty()) {
1977     int64_t lhsIndex = batchDimMap[0].first;
1978     int64_t rhsIndex = batchDimMap[0].second;
1979     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1980     return success();
1981   }
1982 
1983   // Collect contracting dimensions.
1984   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1985       op.getContractingDimMap();
1986   DenseSet<int64_t> lhsContractingDimSet;
1987   DenseSet<int64_t> rhsContractingDimSet;
1988   for (auto &dimPair : contractingDimMap) {
1989     lhsContractingDimSet.insert(dimPair.first);
1990     rhsContractingDimSet.insert(dimPair.second);
1991   }
1992 
1993   // Find first free dimension in LHS, and lower when found.
1994   VectorType lhsType = op.getLhsType();
1995   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1996     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1997       rewriter.replaceOp(
1998           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1999       return success();
2000     }
2001   }
2002 
2003   // Find first free dimension in RHS, and lower when found.
2004   VectorType rhsType = op.getRhsType();
2005   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
2006     if (rhsContractingDimSet.count(rhsIndex) == 0) {
2007       rewriter.replaceOp(
2008           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
2009       return success();
2010     }
2011   }
2012 
2013   // Lower the first remaining reduction dimension.
2014   if (!contractingDimMap.empty()) {
2015     rewriter.replaceOp(op, lowerReduction(op, rewriter));
2016     return success();
2017   }
2018 
2019   return failure();
2020 }
2021 
2022 // Lower one parallel dimension.
2023 // TODO: consider reusing existing contract unrolling
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const2024 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
2025                                            int64_t lhsIndex, int64_t rhsIndex,
2026                                            PatternRewriter &rewriter) const {
2027   VectorType lhsType = op.getLhsType();
2028   VectorType rhsType = op.getRhsType();
2029   VectorType resType = op.getResultType().cast<VectorType>();
2030   // Find the iterator type index and result index.
2031   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2032   int64_t iterIndex = -1;
2033   int64_t dimSize = -1;
2034   if (lhsIndex >= 0) {
2035     iterIndex = iMap[0].getDimPosition(lhsIndex);
2036     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
2037            "parallel index should be free in LHS or batch in LHS/RHS");
2038     dimSize = lhsType.getDimSize(lhsIndex);
2039   } else {
2040     assert(rhsIndex >= 0 && "missing parallel index");
2041     iterIndex = iMap[1].getDimPosition(rhsIndex);
2042     dimSize = rhsType.getDimSize(rhsIndex);
2043   }
2044   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
2045   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
2046   assert(lookup.hasValue() && "parallel index not listed in reduction");
2047   int64_t resIndex = lookup.getValue();
2048   // Construct new iterator types and affine map array attribute.
2049   std::array<AffineMap, 3> lowIndexingMaps = {
2050       adjustMap(iMap[0], iterIndex, rewriter),
2051       adjustMap(iMap[1], iterIndex, rewriter),
2052       adjustMap(iMap[2], iterIndex, rewriter)};
2053   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2054   auto lowIter =
2055       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2056   // Unroll into a series of lower dimensional vector.contract ops.
2057   Location loc = op.getLoc();
2058   Value result =
2059       rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
2060   for (int64_t d = 0; d < dimSize; ++d) {
2061     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2062     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2063     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
2064     Value lowContract = rewriter.create<vector::ContractionOp>(
2065         loc, lhs, rhs, acc, lowAffine, lowIter);
2066     result =
2067         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
2068   }
2069   return result;
2070 }
2071 
2072 // Lower one reduction dimension.
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const2073 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
2074                                             PatternRewriter &rewriter) const {
2075   auto loc = op.getLoc();
2076   VectorType lhsType = op.getLhsType();
2077   VectorType rhsType = op.getRhsType();
2078   Type resType = op.getResultType();
2079   assert(!resType.isa<VectorType>());
2080   // Use iterator index 0.
2081   int64_t iterIndex = 0;
2082   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2083   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
2084   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
2085   assert(lookupLhs.hasValue() && "missing LHS parallel index");
2086   assert(lookupRhs.hasValue() && "missing RHS parallel index");
2087   int64_t lhsIndex = lookupLhs.getValue();
2088   int64_t rhsIndex = lookupRhs.getValue();
2089   int64_t dimSize = lhsType.getDimSize(lhsIndex);
2090   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
2091   // Base case.
2092   if (lhsType.getRank() == 1) {
2093     assert(rhsType.getRank() == 1 && "corrupt contraction");
2094     Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
2095     StringAttr kind = rewriter.getStringAttr("add");
2096     return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
2097                                                 op.acc());
2098   }
2099   // Construct new iterator types and affine map array attribute.
2100   std::array<AffineMap, 3> lowIndexingMaps = {
2101       adjustMap(iMap[0], iterIndex, rewriter),
2102       adjustMap(iMap[1], iterIndex, rewriter),
2103       adjustMap(iMap[2], iterIndex, rewriter)};
2104   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2105   auto lowIter =
2106       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2107   // Unroll into a series of lower dimensional vector.contract ops.
2108   // By feeding the initial accumulator into the first contraction,
2109   // and the result of each contraction into the next, eventually
2110   // the sum of all reductions is computed.
2111   Value result = op.acc();
2112   for (int64_t d = 0; d < dimSize; ++d) {
2113     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2114     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2115     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
2116                                                     lowAffine, lowIter);
2117   }
2118   return result;
2119 }
2120 
2121 } // namespace mlir
2122 
extractConstantIndex(Value v)2123 static Optional<int64_t> extractConstantIndex(Value v) {
2124   if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
2125     return cstOp.getValue();
2126   if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
2127     if (affineApplyOp.getAffineMap().isSingleConstant())
2128       return affineApplyOp.getAffineMap().getSingleConstantResult();
2129   return None;
2130 }
2131 
2132 // Missing foldings of scf.if make it necessary to perform poor man's folding
2133 // eagerly, especially in the case of unrolling. In the future, this should go
2134 // away once scf.if folds properly.
createScopedFoldedSLE(Value v,Value ub)2135 static Value createScopedFoldedSLE(Value v, Value ub) {
2136   using namespace edsc::op;
2137   auto maybeCstV = extractConstantIndex(v);
2138   auto maybeCstUb = extractConstantIndex(ub);
2139   if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
2140     return Value();
2141   return sle(v, ub);
2142 }
2143 
2144 // Operates under a scoped context to build the condition to ensure that a
2145 // particular VectorTransferOpInterface is unmasked.
createScopedInBoundsCond(VectorTransferOpInterface xferOp)2146 static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
2147   assert(xferOp.permutation_map().isMinorIdentity() &&
2148          "Expected minor identity map");
2149   Value inBoundsCond;
2150   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2151     // Zip over the resulting vector shape and memref indices.
2152     // If the dimension is known to be unmasked, it does not participate in the
2153     // construction of `inBoundsCond`.
2154     if (!xferOp.isMaskedDim(resultIdx))
2155       return;
2156     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
2157     using namespace edsc::op;
2158     using namespace edsc::intrinsics;
2159     // Fold or create the check that `index + vector_size` <= `memref_size`.
2160     Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
2161     Value cond =
2162         createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx));
2163     if (!cond)
2164       return;
2165     // Conjunction over all dims for which we are in-bounds.
2166     inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
2167   });
2168   return inBoundsCond;
2169 }
2170 
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)2171 LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
2172     VectorTransferOpInterface xferOp) {
2173   // TODO: expand support to these 2 cases.
2174   if (!xferOp.permutation_map().isMinorIdentity())
2175     return failure();
2176   // Must have some masked dimension to be a candidate for splitting.
2177   if (!xferOp.hasMaskedDim())
2178     return failure();
2179   // Don't split transfer operations directly under IfOp, this avoids applying
2180   // the pattern recursively.
2181   // TODO: improve the filtering condition to make it more applicable.
2182   if (isa<scf::IfOp>(xferOp->getParentOp()))
2183     return failure();
2184   return success();
2185 }
2186 
2187 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
2188 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
2189 /// return null; otherwise:
2190 ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
2191 ///   2. else return a new MemRefType obtained by iterating over the shape and
2192 ///   strides and:
2193 ///     a. keeping the ones that are static and equal across `aT` and `bT`.
2194 ///     b. using a dynamic shape and/or stride for the dimensions that don't
2195 ///        agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)2196 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
2197   if (MemRefCastOp::areCastCompatible(aT, bT))
2198     return aT;
2199   if (aT.getRank() != bT.getRank())
2200     return MemRefType();
2201   int64_t aOffset, bOffset;
2202   SmallVector<int64_t, 4> aStrides, bStrides;
2203   if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
2204       failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
2205       aStrides.size() != bStrides.size())
2206     return MemRefType();
2207 
2208   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
2209   int64_t resOffset;
2210   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
2211       resStrides(bT.getRank(), 0);
2212   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
2213     resShape[idx] =
2214         (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
2215     resStrides[idx] = (aStrides[idx] == bStrides[idx])
2216                           ? aStrides[idx]
2217                           : MemRefType::kDynamicStrideOrOffset;
2218   }
2219   resOffset =
2220       (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
2221   return MemRefType::get(
2222       resShape, aT.getElementType(),
2223       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
2224 }
2225 
2226 /// Operates under a scoped context to build the intersection between the
2227 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
2228 // TODO: view intersection/union/differences should be a proper std op.
createScopedSubViewIntersection(VectorTransferOpInterface xferOp,Value alloc)2229 static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
2230                                              Value alloc) {
2231   using namespace edsc::intrinsics;
2232   int64_t memrefRank = xferOp.getShapedType().getRank();
2233   // TODO: relax this precondition, will require rank-reducing subviews.
2234   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
2235          "Expected memref rank to match the alloc rank");
2236   ValueRange leadingIndices =
2237       xferOp.indices().take_front(xferOp.getLeadingShapedRank());
2238   SmallVector<OpFoldResult, 4> sizes;
2239   sizes.append(leadingIndices.begin(), leadingIndices.end());
2240   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2241     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2242     Value dimMemRef = std_dim(xferOp.source(), indicesIdx);
2243     Value dimAlloc = std_dim(alloc, resultIdx);
2244     Value index = xferOp.indices()[indicesIdx];
2245     AffineExpr i, j, k;
2246     bindDims(xferOp.getContext(), i, j, k);
2247     SmallVector<AffineMap, 4> maps =
2248         AffineMap::inferFromExprList(MapList{{i - j, k}});
2249     // affine_min(%dimMemRef - %index, %dimAlloc)
2250     Value affineMin = affine_min(index.getType(), maps[0],
2251                                  ValueRange{dimMemRef, index, dimAlloc});
2252     sizes.push_back(affineMin);
2253   });
2254 
2255   SmallVector<OpFoldResult, 4> indices = llvm::to_vector<4>(llvm::map_range(
2256       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
2257   return std_sub_view(
2258       xferOp.source(), indices, sizes,
2259       SmallVector<OpFoldResult>(memrefRank, OpBuilder(xferOp).getIndexAttr(1)));
2260 }
2261 
2262 /// Given an `xferOp` for which:
2263 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2264 ///   2. a memref of single vector `alloc` has been allocated.
2265 /// Produce IR resembling:
2266 /// ```
2267 ///    %1:3 = scf.if (%inBounds) {
2268 ///      memref_cast %A: memref<A...> to compatibleMemRefType
2269 ///      scf.yield %view, ... : compatibleMemRefType, index, index
2270 ///    } else {
2271 ///      %2 = linalg.fill(%alloc, %pad)
2272 ///      %3 = subview %view [...][...][...]
2273 ///      linalg.copy(%3, %alloc)
2274 ///      memref_cast %alloc: memref<B...> to compatibleMemRefType
2275 ///      scf.yield %4, ... : compatibleMemRefType, index, index
2276 ///   }
2277 /// ```
2278 /// Return the produced scf::IfOp.
createScopedFullPartialLinalgCopy(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2279 static scf::IfOp createScopedFullPartialLinalgCopy(
2280     vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2281     MemRefType compatibleMemRefType, Value alloc) {
2282   using namespace edsc;
2283   using namespace edsc::intrinsics;
2284   scf::IfOp fullPartialIfOp;
2285   Value zero = std_constant_index(0);
2286   Value memref = xferOp.source();
2287   conditionBuilder(
2288       returnTypes, inBoundsCond,
2289       [&]() -> scf::ValueVector {
2290         Value res = memref;
2291         if (compatibleMemRefType != xferOp.getShapedType())
2292           res = std_memref_cast(memref, compatibleMemRefType);
2293         scf::ValueVector viewAndIndices{res};
2294         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2295                               xferOp.indices().end());
2296         return viewAndIndices;
2297       },
2298       [&]() -> scf::ValueVector {
2299         linalg_fill(alloc, xferOp.padding());
2300         // Take partial subview of memref which guarantees no dimension
2301         // overflows.
2302         Value memRefSubView = createScopedSubViewIntersection(
2303             cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
2304         linalg_copy(memRefSubView, alloc);
2305         Value casted = std_memref_cast(alloc, compatibleMemRefType);
2306         scf::ValueVector viewAndIndices{casted};
2307         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2308                               zero);
2309         return viewAndIndices;
2310       },
2311       &fullPartialIfOp);
2312   return fullPartialIfOp;
2313 }
2314 
2315 /// Given an `xferOp` for which:
2316 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2317 ///   2. a memref of single vector `alloc` has been allocated.
2318 /// Produce IR resembling:
2319 /// ```
2320 ///    %1:3 = scf.if (%inBounds) {
2321 ///      memref_cast %A: memref<A...> to compatibleMemRefType
2322 ///      scf.yield %view, ... : compatibleMemRefType, index, index
2323 ///    } else {
2324 ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
2325 ///      %3 = vector.type_cast %extra_alloc :
2326 ///        memref<...> to memref<vector<...>>
2327 ///      store %2, %3[] : memref<vector<...>>
2328 ///      %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
2329 ///      scf.yield %4, ... : compatibleMemRefType, index, index
2330 ///   }
2331 /// ```
2332 /// Return the produced scf::IfOp.
createScopedFullPartialVectorTransferRead(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2333 static scf::IfOp createScopedFullPartialVectorTransferRead(
2334     vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2335     MemRefType compatibleMemRefType, Value alloc) {
2336   using namespace edsc;
2337   using namespace edsc::intrinsics;
2338   scf::IfOp fullPartialIfOp;
2339   Value zero = std_constant_index(0);
2340   Value memref = xferOp.source();
2341   conditionBuilder(
2342       returnTypes, inBoundsCond,
2343       [&]() -> scf::ValueVector {
2344         Value res = memref;
2345         if (compatibleMemRefType != xferOp.getShapedType())
2346           res = std_memref_cast(memref, compatibleMemRefType);
2347         scf::ValueVector viewAndIndices{res};
2348         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2349                               xferOp.indices().end());
2350         return viewAndIndices;
2351       },
2352       [&]() -> scf::ValueVector {
2353         Operation *newXfer =
2354             ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
2355         Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
2356         std_store(vector, vector_type_cast(
2357                               MemRefType::get({}, vector.getType()), alloc));
2358 
2359         Value casted = std_memref_cast(alloc, compatibleMemRefType);
2360         scf::ValueVector viewAndIndices{casted};
2361         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2362                               zero);
2363 
2364         return viewAndIndices;
2365       },
2366       &fullPartialIfOp);
2367   return fullPartialIfOp;
2368 }
2369 
2370 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
2371 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
2372 /// newly created conditional upon function return.
2373 /// To accomodate for the fact that the original vector.transfer indexing may be
2374 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
2375 /// scf.if op returns a view and values of type index.
2376 /// At this time, only vector.transfer_read case is implemented.
2377 ///
2378 /// Example (a 2-D vector.transfer_read):
2379 /// ```
2380 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
2381 /// ```
2382 /// is transformed into:
2383 /// ```
2384 ///    %1:3 = scf.if (%inBounds) {
2385 ///      // fastpath, direct cast
2386 ///      memref_cast %A: memref<A...> to compatibleMemRefType
2387 ///      scf.yield %view : compatibleMemRefType, index, index
2388 ///    } else {
2389 ///      // slowpath, masked vector.transfer or linalg.copy.
2390 ///      memref_cast %alloc: memref<B...> to compatibleMemRefType
2391 ///      scf.yield %4 : compatibleMemRefType, index, index
2392 //     }
2393 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
2394 /// ```
2395 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
2396 ///
2397 /// Preconditions:
2398 ///  1. `xferOp.permutation_map()` must be a minor identity map
2399 ///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
2400 ///  must be equal. This will be relaxed in the future but requires
2401 ///  rank-reducing subviews.
splitFullAndPartialTransfer(OpBuilder & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)2402 LogicalResult mlir::vector::splitFullAndPartialTransfer(
2403     OpBuilder &b, VectorTransferOpInterface xferOp,
2404     VectorTransformsOptions options, scf::IfOp *ifOp) {
2405   using namespace edsc;
2406   using namespace edsc::intrinsics;
2407 
2408   if (options.vectorTransferSplit == VectorTransferSplit::None)
2409     return failure();
2410 
2411   SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
2412   auto unmaskedAttr = b.getBoolArrayAttr(bools);
2413   if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
2414     xferOp->setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2415     return success();
2416   }
2417 
2418   assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
2419          "Expected splitFullAndPartialTransferPrecondition to hold");
2420   auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
2421 
2422   // TODO: add support for write case.
2423   if (!xferReadOp)
2424     return failure();
2425 
2426   OpBuilder::InsertionGuard guard(b);
2427   if (Operation *sourceOp = xferOp.source().getDefiningOp())
2428     b.setInsertionPointAfter(sourceOp);
2429   else
2430     b.setInsertionPoint(xferOp);
2431   ScopedContext scope(b, xferOp.getLoc());
2432   Value inBoundsCond = createScopedInBoundsCond(
2433       cast<VectorTransferOpInterface>(xferOp.getOperation()));
2434   if (!inBoundsCond)
2435     return failure();
2436 
2437   // Top of the function `alloc` for transient storage.
2438   Value alloc;
2439   {
2440     FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
2441     OpBuilder::InsertionGuard guard(b);
2442     b.setInsertionPointToStart(&funcOp.getRegion().front());
2443     auto shape = xferOp.getVectorType().getShape();
2444     Type elementType = xferOp.getVectorType().getElementType();
2445     alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
2446                        b.getI64IntegerAttr(32));
2447   }
2448 
2449   MemRefType compatibleMemRefType =
2450       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
2451                                   alloc.getType().cast<MemRefType>());
2452 
2453   // Read case: full fill + partial copy -> unmasked vector.xfer_read.
2454   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
2455                                    b.getIndexType());
2456   returnTypes[0] = compatibleMemRefType;
2457   scf::IfOp fullPartialIfOp =
2458       options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
2459           ? createScopedFullPartialVectorTransferRead(
2460                 xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
2461                 alloc)
2462           : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
2463                                               inBoundsCond,
2464                                               compatibleMemRefType, alloc);
2465   if (ifOp)
2466     *ifOp = fullPartialIfOp;
2467 
2468   // Unmask the existing read op, it always reads from a full buffer.
2469   for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
2470     xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
2471   xferOp->setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2472 
2473   return success();
2474 }
2475 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2476 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
2477     Operation *op, PatternRewriter &rewriter) const {
2478   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
2479   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
2480       failed(filter(xferOp)))
2481     return failure();
2482   rewriter.startRootUpdate(xferOp);
2483   if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
2484     rewriter.finalizeRootUpdate(xferOp);
2485     return success();
2486   }
2487   rewriter.cancelRootUpdate(xferOp);
2488   return failure();
2489 }
2490 
matchAndRewrite(ExtractMapOp extract,PatternRewriter & rewriter) const2491 LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
2492     ExtractMapOp extract, PatternRewriter &rewriter) const {
2493   Operation *definedOp = extract.vector().getDefiningOp();
2494   if (!definedOp || definedOp->getNumResults() != 1)
2495     return failure();
2496   // TODO: Create an interfaceOp for elementwise operations.
2497   if (!isa<AddFOp>(definedOp))
2498     return failure();
2499   Location loc = extract.getLoc();
2500   SmallVector<Value, 4> extractOperands;
2501   for (OpOperand &operand : definedOp->getOpOperands())
2502     extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2503         loc, extract.getResultType(), operand.get(), extract.ids()));
2504   Operation *newOp = cloneOpWithOperandsAndTypes(
2505       rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
2506   rewriter.replaceOp(extract, newOp->getResult(0));
2507   return success();
2508 }
2509 
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)2510 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
2511     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
2512     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2513   OpBuilder::InsertionGuard guard(builder);
2514   builder.setInsertionPointAfter(op);
2515   Location loc = op->getLoc();
2516   if (op->getNumResults() != 1)
2517     return {};
2518   Value result = op->getResult(0);
2519   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2520   if (!type || map.getNumResults() != multiplicity.size())
2521     return {};
2522   // For each dimension being distributed check that the size is a multiple of
2523   // the multiplicity. To handle more sizes we would need to support masking.
2524   unsigned multiplictyCount = 0;
2525   for (auto exp : map.getResults()) {
2526     auto affinExp = exp.dyn_cast<AffineDimExpr>();
2527     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2528         type.getDimSize(affinExp.getPosition()) %
2529                 multiplicity[multiplictyCount++] !=
2530             0)
2531       return {};
2532   }
2533   DistributeOps ops;
2534   ops.extract =
2535       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2536   ops.insert =
2537       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2538   return ops;
2539 }
2540 
2541 struct TransferReadExtractPattern
2542     : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPatternTransferReadExtractPattern2543   TransferReadExtractPattern(MLIRContext *context)
2544       : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewriteTransferReadExtractPattern2545   LogicalResult matchAndRewrite(vector::TransferReadOp read,
2546                                 PatternRewriter &rewriter) const override {
2547     if (!read.getResult().hasOneUse())
2548       return failure();
2549     auto extract =
2550         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
2551     if (!extract)
2552       return failure();
2553     edsc::ScopedContext scope(rewriter, read.getLoc());
2554     using mlir::edsc::op::operator+;
2555     using mlir::edsc::op::operator*;
2556     using namespace mlir::edsc::intrinsics;
2557     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
2558     AffineMap map = extract.map();
2559     unsigned idCount = 0;
2560     for (auto expr : map.getResults()) {
2561       unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2562       indices[pos] =
2563           indices[pos] +
2564           extract.ids()[idCount++] *
2565               std_constant_index(extract.getResultType().getDimSize(pos));
2566     }
2567     Value newRead = vector_transfer_read(extract.getType(), read.source(),
2568                                          indices, read.permutation_map(),
2569                                          read.padding(), read.maskedAttr());
2570     Value dest = rewriter.create<ConstantOp>(
2571         read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType()));
2572     newRead = rewriter.create<vector::InsertMapOp>(read.getLoc(), newRead, dest,
2573                                                    extract.ids());
2574     rewriter.replaceOp(read, newRead);
2575     return success();
2576   }
2577 };
2578 
2579 struct TransferWriteInsertPattern
2580     : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPatternTransferWriteInsertPattern2581   TransferWriteInsertPattern(MLIRContext *context)
2582       : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewriteTransferWriteInsertPattern2583   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2584                                 PatternRewriter &rewriter) const override {
2585     auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
2586     if (!insert)
2587       return failure();
2588     edsc::ScopedContext scope(rewriter, write.getLoc());
2589     using mlir::edsc::op::operator+;
2590     using mlir::edsc::op::operator*;
2591     using namespace mlir::edsc::intrinsics;
2592     SmallVector<Value, 4> indices(write.indices().begin(),
2593                                   write.indices().end());
2594     AffineMap map = insert.map();
2595     unsigned idCount = 0;
2596     for (auto expr : map.getResults()) {
2597       unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2598       indices[pos] =
2599           indices[pos] +
2600           insert.ids()[idCount++] *
2601               std_constant_index(insert.getSourceVectorType().getDimSize(pos));
2602     }
2603     vector_transfer_write(insert.vector(), write.source(), indices,
2604                           write.permutation_map(), write.maskedAttr());
2605     rewriter.eraseOp(write);
2606     return success();
2607   }
2608 };
2609 
2610 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
2611 // TODO: Add this as DRR pattern.
populateVectorToVectorTransformationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2612 void mlir::vector::populateVectorToVectorTransformationPatterns(
2613     OwningRewritePatternList &patterns, MLIRContext *context) {
2614   // clang-format off
2615   patterns.insert<ShapeCastOpDecomposer,
2616                   ShapeCastOpFolder,
2617                   SplitTransferReadOp,
2618                   SplitTransferWriteOp,
2619                   TupleGetFolderOp,
2620                   TransferReadExtractPattern,
2621                   TransferWriteInsertPattern>(context);
2622   // clang-format on
2623 }
2624 
populateVectorSlicesLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2625 void mlir::vector::populateVectorSlicesLoweringPatterns(
2626     OwningRewritePatternList &patterns, MLIRContext *context) {
2627   patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
2628 }
2629 
populateVectorContractLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context,VectorTransformsOptions parameters)2630 void mlir::vector::populateVectorContractLoweringPatterns(
2631     OwningRewritePatternList &patterns, MLIRContext *context,
2632     VectorTransformsOptions parameters) {
2633   // clang-format off
2634   patterns.insert<BroadcastOpLowering,
2635                   CreateMaskOpLowering,
2636                   ConstantMaskOpLowering,
2637                   OuterProductOpLowering,
2638                   ShapeCastOp2DDownCastRewritePattern,
2639                   ShapeCastOp2DUpCastRewritePattern,
2640                   ShapeCastOpRewritePattern>(context);
2641   patterns.insert<TransposeOpLowering,
2642                   ContractionOpLowering,
2643                   ContractionOpToMatmulOpLowering,
2644                   ContractionOpToOuterProductOpLowering>(parameters, context);
2645   // clang-format on
2646 }
2647