1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Dialect/Vector/VectorUtils.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/Passes.h"
29 
30 using namespace mlir;
31 using vector::TransferReadOp;
32 using vector::TransferWriteOp;
33 
34 namespace {
35 
36 /// Attribute name used for labeling transfer ops during progressive lowering.
37 static const char kPassLabel[] = "__vector_to_scf_lowering__";
38 
39 /// Patterns that inherit from this struct have access to
40 /// VectorTransferToSCFOptions.
41 template <typename OpTy>
42 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorToSCFPattern__anon32638c370111::VectorToSCFPattern43   explicit VectorToSCFPattern(MLIRContext *context,
44                               VectorTransferToSCFOptions opt)
45       : OpRewritePattern<OpTy>(context), options(opt) {}
46 
47   VectorTransferToSCFOptions options;
48 };
49 
50 /// Given a vector transfer op, calculate which dimension of the `source`
51 /// memref should be unpacked in the next application of TransferOpConversion.
52 /// A return value of None indicates a broadcast.
53 template <typename OpTy>
unpackedDim(OpTy xferOp)54 static Optional<int64_t> unpackedDim(OpTy xferOp) {
55   auto map = xferOp.permutation_map();
56   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
57     return expr.getPosition();
58   }
59   assert(xferOp.isBroadcastDim(0) &&
60          "Expected AffineDimExpr or AffineConstantExpr");
61   return None;
62 }
63 
64 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
65 /// map is identical to the current permutation map, but the first result is
66 /// omitted.
67 template <typename OpTy>
unpackedPermutationMap(OpBuilder & b,OpTy xferOp)68 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
69   auto map = xferOp.permutation_map();
70   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
71                         b.getContext());
72 }
73 
74 /// Calculate the indices for the new vector transfer op.
75 ///
76 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
77 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
78 ///                                 ^^^^^^
79 ///              `iv` is the iteration variable of the (new) surrounding loop.
80 template <typename OpTy>
getXferIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & indices)81 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
82                            SmallVector<Value, 8> &indices) {
83   typename OpTy::Adaptor adaptor(xferOp);
84   // Corresponding memref dim of the vector dim that is unpacked.
85   auto dim = unpackedDim(xferOp);
86   auto prevIndices = adaptor.indices();
87   indices.append(prevIndices.begin(), prevIndices.end());
88 
89   Location loc = xferOp.getLoc();
90   bool isBroadcast = !dim.hasValue();
91   if (!isBroadcast) {
92     AffineExpr d0, d1;
93     bindDims(xferOp.getContext(), d0, d1);
94     Value offset = adaptor.indices()[dim.getValue()];
95     indices[dim.getValue()] =
96         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
97   }
98 }
99 
maybeYieldValue(OpBuilder & b,Location loc,bool hasRetVal,Value value)100 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
101                             Value value) {
102   if (hasRetVal) {
103     assert(value && "Expected non-empty value");
104     b.create<scf::YieldOp>(loc, value);
105   } else {
106     b.create<scf::YieldOp>(loc);
107   }
108 }
109 
110 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
111 /// is set to true. No such check is generated under following circumstances:
112 /// * xferOp does not have a mask.
113 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
114 ///   computed and attached to the new transfer op in the pattern.)
115 /// * The to-be-unpacked dim of xferOp is a broadcast.
116 template <typename OpTy>
generateMaskCheck(OpBuilder & b,OpTy xferOp,Value iv)117 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
118   if (!xferOp.mask())
119     return Value();
120   if (xferOp.getMaskType().getRank() != 1)
121     return Value();
122   if (xferOp.isBroadcastDim(0))
123     return Value();
124 
125   Location loc = xferOp.getLoc();
126   Value ivI32 =
127       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
128   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
129 }
130 
131 /// Helper function TransferOpConversion and TransferOp1dConversion.
132 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
133 /// specified dimension `dim` with the loop iteration variable `iv`.
134 /// E.g., when unpacking dimension 0 from:
135 /// ```
136 /// %vec = vector.transfer_read %A[%a, %b] %cst
137 ///     : vector<5x4xf32>, memref<?x?xf32>
138 /// ```
139 /// An if check similar to this will be generated inside the loop:
140 /// ```
141 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
142 /// if (%a + iv < %d) {
143 ///   (in-bounds case)
144 /// } else {
145 ///   (out-of-bounds case)
146 /// }
147 /// ```
148 ///
149 /// If the transfer is 1D and has a mask, this function generates a more complex
150 /// check also accounts for potentially masked out elements.
151 ///
152 /// This function variant returns the value returned by `inBoundsCase` or
153 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
154 /// `resultTypes`.
155 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,TypeRange resultTypes,function_ref<Value (OpBuilder &,Location)> inBoundsCase,function_ref<Value (OpBuilder &,Location)> outOfBoundsCase=nullptr)156 static Value generateInBoundsCheck(
157     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
158     TypeRange resultTypes,
159     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
160     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
161   bool hasRetVal = !resultTypes.empty();
162   Value cond; // Condition to be built...
163 
164   // Condition check 1: Access in-bounds?
165   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
166   Location loc = xferOp.getLoc();
167   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
168   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
169     Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
170     AffineExpr d0, d1;
171     bindDims(xferOp.getContext(), d0, d1);
172     Value base = xferOp.indices()[dim.getValue()];
173     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
174     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
175   }
176 
177   // Condition check 2: Masked in?
178   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
179     if (cond)
180       cond = lb.create<AndOp>(cond, maskCond);
181     else
182       cond = maskCond;
183   }
184 
185   // If the condition is non-empty, generate an SCF::IfOp.
186   if (cond) {
187     auto check = lb.create<scf::IfOp>(
188         resultTypes, cond,
189         /*thenBuilder=*/
190         [&](OpBuilder &b, Location loc) {
191           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
192         },
193         /*elseBuilder=*/
194         [&](OpBuilder &b, Location loc) {
195           if (outOfBoundsCase) {
196             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
197           } else {
198             b.create<scf::YieldOp>(loc);
199           }
200         });
201 
202     return hasRetVal ? check.getResult(0) : Value();
203   }
204 
205   // Condition is empty, no need for an SCF::IfOp.
206   return inBoundsCase(b, loc);
207 }
208 
209 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
210 /// a return value. Consequently, this function does not have a return value.
211 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,function_ref<void (OpBuilder &,Location)> inBoundsCase,function_ref<void (OpBuilder &,Location)> outOfBoundsCase=nullptr)212 static void generateInBoundsCheck(
213     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
214     function_ref<void(OpBuilder &, Location)> inBoundsCase,
215     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
216   generateInBoundsCheck(
217       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
218       /*inBoundsCase=*/
219       [&](OpBuilder &b, Location loc) {
220         inBoundsCase(b, loc);
221         return Value();
222       },
223       /*outOfBoundsCase=*/
224       [&](OpBuilder &b, Location loc) {
225         if (outOfBoundsCase)
226           outOfBoundsCase(b, loc);
227         return Value();
228       });
229 }
230 
231 /// Given an ArrayAttr, return a copy where the first element is dropped.
dropFirstElem(OpBuilder & b,ArrayAttr attr)232 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
233   if (!attr)
234     return attr;
235   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
236 }
237 
238 /// Add the pass label to a vector transfer op if its rank is not the target
239 /// rank.
240 template <typename OpTy>
maybeApplyPassLabel(OpBuilder & b,OpTy newXferOp,unsigned targetRank)241 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
242                                 unsigned targetRank) {
243   if (newXferOp.getVectorType().getRank() > targetRank)
244     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
245 }
246 
247 /// Return true if this transfer op operates on a source tensor.
248 template <typename OpTy>
isTensorOp(OpTy xferOp)249 static bool isTensorOp(OpTy xferOp) {
250   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
251     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
252       // TransferWriteOps on tensors have a result.
253       assert(xferOp->getNumResults() > 0);
254     }
255     return true;
256   }
257   return false;
258 }
259 
260 namespace lowering_n_d {
261 
262 /// Helper data structure for data and mask buffers.
263 struct BufferAllocs {
264   Value dataBuffer;
265   Value maskBuffer;
266 };
267 
268 /// Allocate temporary buffers for data (vector) and mask (if present).
269 /// TODO: Parallelism and threadlocal considerations.
270 template <typename OpTy>
allocBuffers(OpBuilder & b,OpTy xferOp)271 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
272   Location loc = xferOp.getLoc();
273   OpBuilder::InsertionGuard guard(b);
274   Operation *scope =
275       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
276   assert(scope && "Expected op to be inside automatic allocation scope");
277   b.setInsertionPointToStart(&scope->getRegion(0).front());
278 
279   BufferAllocs result;
280   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
281   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
282 
283   if (xferOp.mask()) {
284     auto maskType = MemRefType::get({}, xferOp.mask().getType());
285     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
286     b.setInsertionPoint(xferOp);
287     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
288     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
289   }
290 
291   return result;
292 }
293 
294 /// Given a MemRefType with VectorType element type, unpack one dimension from
295 /// the VectorType into the MemRefType.
296 ///
297 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
unpackOneDim(MemRefType type)298 static MemRefType unpackOneDim(MemRefType type) {
299   auto vectorType = type.getElementType().dyn_cast<VectorType>();
300   auto memrefShape = type.getShape();
301   SmallVector<int64_t, 8> newMemrefShape;
302   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
303   newMemrefShape.push_back(vectorType.getDimSize(0));
304   return MemRefType::get(newMemrefShape,
305                          VectorType::get(vectorType.getShape().drop_front(),
306                                          vectorType.getElementType()));
307 }
308 
309 /// Given a transfer op, find the memref from which the mask is loaded. This
310 /// is similar to Strategy<TransferWriteOp>::getBuffer.
311 template <typename OpTy>
getMaskBuffer(OpTy xferOp)312 static Value getMaskBuffer(OpTy xferOp) {
313   assert(xferOp.mask() && "Expected that transfer op has mask");
314   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
315   assert(loadOp && "Expected transfer op mask produced by LoadOp");
316   return loadOp.getMemRef();
317 }
318 
319 /// Codegen strategy, depending on the operation.
320 template <typename OpTy>
321 struct Strategy;
322 
323 /// Code strategy for vector TransferReadOp.
324 template <>
325 struct Strategy<TransferReadOp> {
326   /// Find the StoreOp that is used for writing the current TransferReadOp's
327   /// result to the temporary buffer allocation.
getStoreOp__anon32638c370111::lowering_n_d::Strategy328   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
329     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
330     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
331     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
332     return storeOp;
333   }
334 
335   /// Find the temporary buffer allocation. All labeled TransferReadOps are
336   /// used like this, where %buf is either the buffer allocation or a type cast
337   /// of the buffer allocation:
338   /// ```
339   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
340   /// memref.store %vec, %buf[...] ...
341   /// ```
getBuffer__anon32638c370111::lowering_n_d::Strategy342   static Value getBuffer(TransferReadOp xferOp) {
343     return getStoreOp(xferOp).getMemRef();
344   }
345 
346   /// Retrieve the indices of the current StoreOp that stores into the buffer.
getBufferIndices__anon32638c370111::lowering_n_d::Strategy347   static void getBufferIndices(TransferReadOp xferOp,
348                                SmallVector<Value, 8> &indices) {
349     auto storeOp = getStoreOp(xferOp);
350     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
351     indices.append(prevIndices.begin(), prevIndices.end());
352   }
353 
354   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
355   /// accesses on the to-be-unpacked dimension.
356   ///
357   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
358   ///    variable `iv`.
359   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
360   ///
361   /// E.g.:
362   /// ```
363   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
364   ///     : memref<?x?x?xf32>, vector<4x3xf32>
365   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
366   /// ```
367   /// Is rewritten to:
368   /// ```
369   /// %casted = vector.type_cast %buf
370   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
371   /// for %j = 0 to 4 {
372   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
373   ///       : memref<?x?x?xf32>, vector<3xf32>
374   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
375   /// }
376   /// ```
377   ///
378   /// Note: The loop and type cast are generated in TransferOpConversion.
379   ///       The original TransferReadOp and store op are deleted in `cleanup`.
380   /// Note: The `mask` operand is set in TransferOpConversion.
rewriteOp__anon32638c370111::lowering_n_d::Strategy381   static TransferReadOp rewriteOp(OpBuilder &b,
382                                   VectorTransferToSCFOptions options,
383                                   TransferReadOp xferOp, Value buffer, Value iv,
384                                   ValueRange /*loopState*/) {
385     SmallVector<Value, 8> storeIndices;
386     getBufferIndices(xferOp, storeIndices);
387     storeIndices.push_back(iv);
388 
389     SmallVector<Value, 8> xferIndices;
390     getXferIndices(b, xferOp, iv, xferIndices);
391 
392     Location loc = xferOp.getLoc();
393     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
394     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
395     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
396     auto newXferOp = b.create<vector::TransferReadOp>(
397         loc, vecType, xferOp.source(), xferIndices,
398         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
399         Value(), inBoundsAttr);
400 
401     maybeApplyPassLabel(b, newXferOp, options.targetRank);
402 
403     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
404     return newXferOp;
405   }
406 
407   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
408   /// padding value to the temporary buffer.
handleOutOfBoundsDim__anon32638c370111::lowering_n_d::Strategy409   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
410                                     Value buffer, Value iv,
411                                     ValueRange /*loopState*/) {
412     SmallVector<Value, 8> storeIndices;
413     getBufferIndices(xferOp, storeIndices);
414     storeIndices.push_back(iv);
415 
416     Location loc = xferOp.getLoc();
417     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
418     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
419     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
420     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
421 
422     return Value();
423   }
424 
425   /// Cleanup after rewriting the op.
cleanup__anon32638c370111::lowering_n_d::Strategy426   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
427                       scf::ForOp /*forOp*/) {
428     rewriter.eraseOp(getStoreOp(xferOp));
429     rewriter.eraseOp(xferOp);
430   }
431 
432   /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon32638c370111::lowering_n_d::Strategy433   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
434 };
435 
436 /// Codegen strategy for vector TransferWriteOp.
437 template <>
438 struct Strategy<TransferWriteOp> {
439   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
440   /// used like this, where %buf is either the buffer allocation or a type cast
441   /// of the buffer allocation:
442   /// ```
443   /// %vec = memref.load %buf[...] ...
444   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
445   /// ```
getBuffer__anon32638c370111::lowering_n_d::Strategy446   static Value getBuffer(TransferWriteOp xferOp) {
447     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
448     assert(loadOp && "Expected transfer op vector produced by LoadOp");
449     return loadOp.getMemRef();
450   }
451 
452   /// Retrieve the indices of the current LoadOp that loads from the buffer.
getBufferIndices__anon32638c370111::lowering_n_d::Strategy453   static void getBufferIndices(TransferWriteOp xferOp,
454                                SmallVector<Value, 8> &indices) {
455     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
456     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
457     indices.append(prevIndices.begin(), prevIndices.end());
458   }
459 
460   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
461   /// accesses on the to-be-unpacked dimension.
462   ///
463   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
464   ///    using the loop iteration variable `iv`.
465   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
466   ///    to memory.
467   ///
468   /// Note: For more details, see comments on Strategy<TransferReadOp>.
rewriteOp__anon32638c370111::lowering_n_d::Strategy469   static TransferWriteOp rewriteOp(OpBuilder &b,
470                                    VectorTransferToSCFOptions options,
471                                    TransferWriteOp xferOp, Value buffer,
472                                    Value iv, ValueRange loopState) {
473     SmallVector<Value, 8> loadIndices;
474     getBufferIndices(xferOp, loadIndices);
475     loadIndices.push_back(iv);
476 
477     SmallVector<Value, 8> xferIndices;
478     getXferIndices(b, xferOp, iv, xferIndices);
479 
480     Location loc = xferOp.getLoc();
481     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
482     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
483     auto source = loopState.empty() ? xferOp.source() : loopState[0];
484     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
485     auto newXferOp = b.create<vector::TransferWriteOp>(
486         loc, type, vec, source, xferIndices,
487         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
488         inBoundsAttr);
489 
490     maybeApplyPassLabel(b, newXferOp, options.targetRank);
491 
492     return newXferOp;
493   }
494 
495   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
handleOutOfBoundsDim__anon32638c370111::lowering_n_d::Strategy496   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
497                                     Value buffer, Value iv,
498                                     ValueRange loopState) {
499     return isTensorOp(xferOp) ? loopState[0] : Value();
500   }
501 
502   /// Cleanup after rewriting the op.
cleanup__anon32638c370111::lowering_n_d::Strategy503   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
504                       scf::ForOp forOp) {
505     if (isTensorOp(xferOp)) {
506       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
507       rewriter.replaceOp(xferOp, forOp->getResult(0));
508     } else {
509       rewriter.eraseOp(xferOp);
510     }
511   }
512 
513   /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon32638c370111::lowering_n_d::Strategy514   static Value initialLoopState(TransferWriteOp xferOp) {
515     return isTensorOp(xferOp) ? xferOp.source() : Value();
516   }
517 };
518 
519 template <typename OpTy>
checkPrepareXferOp(OpTy xferOp,VectorTransferToSCFOptions options)520 LogicalResult checkPrepareXferOp(OpTy xferOp,
521                                  VectorTransferToSCFOptions options) {
522   if (xferOp->hasAttr(kPassLabel))
523     return failure();
524   if (xferOp.getVectorType().getRank() <= options.targetRank)
525     return failure();
526   if (isTensorOp(xferOp) && !options.lowerTensors)
527     return failure();
528   // Transfer ops that modify the element type are not supported atm.
529   if (xferOp.getVectorType().getElementType() !=
530       xferOp.getShapedType().getElementType())
531     return failure();
532   return success();
533 }
534 
535 /// Prepare a TransferReadOp for progressive lowering.
536 ///
537 /// 1. Allocate a temporary buffer.
538 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
539 /// 3. Store the result of the TransferReadOp into the temporary buffer.
540 /// 4. Load the result from the temporary buffer and replace all uses of the
541 ///    original TransferReadOp with this load.
542 ///
543 /// E.g.:
544 /// ```
545 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
546 ///     : vector<5x4xf32>, memref<?x?x?xf32>
547 /// ```
548 /// is rewritten to:
549 /// ```
550 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
551 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
552 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
553 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
554 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
555 /// ```
556 ///
557 /// Note: A second temporary buffer may be allocated for the `mask` operand.
558 struct PrepareTransferReadConversion
559     : public VectorToSCFPattern<TransferReadOp> {
560   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
561 
matchAndRewrite__anon32638c370111::lowering_n_d::PrepareTransferReadConversion562   LogicalResult matchAndRewrite(TransferReadOp xferOp,
563                                 PatternRewriter &rewriter) const override {
564     if (checkPrepareXferOp(xferOp, options).failed())
565       return failure();
566 
567     auto buffers = allocBuffers(rewriter, xferOp);
568     auto *newXfer = rewriter.clone(*xferOp.getOperation());
569     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
570     if (xferOp.mask()) {
571       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
572           buffers.maskBuffer);
573     }
574 
575     Location loc = xferOp.getLoc();
576     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
577                                      buffers.dataBuffer);
578     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
579 
580     return success();
581   }
582 };
583 
584 /// Prepare a TransferWriteOp for progressive lowering.
585 ///
586 /// 1. Allocate a temporary buffer.
587 /// 2. Store the vector into the buffer.
588 /// 3. Load the vector from the buffer again.
589 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
590 ///    marking it eligible for progressive lowering via TransferOpConversion.
591 ///
592 /// E.g.:
593 /// ```
594 /// vector.transfer_write %vec, %A[%a, %b, %c]
595 ///     : vector<5x4xf32>, memref<?x?x?xf32>
596 /// ```
597 /// is rewritten to:
598 /// ```
599 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
600 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
601 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
602 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
603 ///     : vector<5x4xf32>, memref<?x?x?xf32>
604 /// ```
605 ///
606 /// Note: A second temporary buffer may be allocated for the `mask` operand.
607 struct PrepareTransferWriteConversion
608     : public VectorToSCFPattern<TransferWriteOp> {
609   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
610 
matchAndRewrite__anon32638c370111::lowering_n_d::PrepareTransferWriteConversion611   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
612                                 PatternRewriter &rewriter) const override {
613     if (checkPrepareXferOp(xferOp, options).failed())
614       return failure();
615 
616     Location loc = xferOp.getLoc();
617     auto buffers = allocBuffers(rewriter, xferOp);
618     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
619     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
620     rewriter.updateRootInPlace(xferOp, [&]() {
621       xferOp.vectorMutable().assign(loadedVec);
622       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
623     });
624 
625     if (xferOp.mask()) {
626       rewriter.updateRootInPlace(
627           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
628     }
629 
630     return success();
631   }
632 };
633 
634 /// Progressive lowering of vector transfer ops: Unpack one dimension.
635 ///
636 /// 1. Unpack one dimension from the current buffer type and cast the buffer
637 ///    to that new type. E.g.:
638 ///    ```
639 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
640 ///    vector.transfer_write %vec ...
641 ///    ```
642 ///    The following cast is generated:
643 ///    ```
644 ///    %casted = vector.type_cast %0
645 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
646 ///    ```
647 /// 2. Generate a for loop and rewrite the transfer op according to the
648 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
649 ///    out-of-bounds, generate an if-check and handle both cases separately.
650 /// 3. Clean up according to the corresponding Strategy<OpTy>.
651 ///
652 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
653 /// source (as opposed to a memref source), then each iteration of the generated
654 /// scf.for loop yields the new tensor value. E.g.:
655 /// ```
656 /// %result = scf.for i = 0 to 5 {
657 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
658 ///   %1 = vector.transfer_write %0, %source[...]
659 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
660 ///   scf.yield %1 : tensor<5x4x3xf32>
661 /// }
662 /// ```
663 template <typename OpTy>
664 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
665   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
666 
initialize__anon32638c370111::lowering_n_d::TransferOpConversion667   void initialize() {
668     // This pattern recursively unpacks one dimension at a time. The recursion
669     // bounded as the rank is strictly decreasing.
670     this->setHasBoundedRewriteRecursion();
671   }
672 
matchAndRewrite__anon32638c370111::lowering_n_d::TransferOpConversion673   LogicalResult matchAndRewrite(OpTy xferOp,
674                                 PatternRewriter &rewriter) const override {
675     if (!xferOp->hasAttr(kPassLabel))
676       return failure();
677 
678     // Find and cast data buffer. How the buffer can be found depends on OpTy.
679     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
680     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
681     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
682     auto castedDataType = unpackOneDim(dataBufferType);
683     auto castedDataBuffer =
684         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
685 
686     // If the xferOp has a mask: Find and cast mask buffer.
687     Value castedMaskBuffer;
688     if (xferOp.mask()) {
689       auto maskBuffer = getMaskBuffer(xferOp);
690       auto maskBufferType =
691           maskBuffer.getType().template dyn_cast<MemRefType>();
692       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
693         // Do not unpack a dimension of the mask, if:
694         // * To-be-unpacked transfer op dimension is a broadcast.
695         // * Mask is 1D, i.e., the mask cannot be further unpacked.
696         //   (That means that all remaining dimensions of the transfer op must
697         //   be broadcasted.)
698         castedMaskBuffer = maskBuffer;
699       } else {
700         auto castedMaskType = unpackOneDim(maskBufferType);
701         castedMaskBuffer =
702             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
703       }
704     }
705 
706     // Loop bounds and step.
707     auto lb = locB.create<ConstantIndexOp>(0);
708     auto ub = locB.create<ConstantIndexOp>(
709         castedDataType.getDimSize(castedDataType.getRank() - 1));
710     auto step = locB.create<ConstantIndexOp>(1);
711     // TransferWriteOps that operate on tensors return the modified tensor and
712     // require a loop state.
713     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
714 
715     // Generate for loop.
716     auto result = locB.create<scf::ForOp>(
717         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
718         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
719           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
720 
721           auto result = generateInBoundsCheck(
722               b, xferOp, iv, unpackedDim(xferOp),
723               stateType ? TypeRange(stateType) : TypeRange(),
724               /*inBoundsCase=*/
725               [&](OpBuilder &b, Location loc) {
726                 // Create new transfer op.
727                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
728                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
729 
730                 // If old transfer op has a mask: Set mask on new transfer op.
731                 // Special case: If the mask of the old transfer op is 1D and
732                 // the
733                 //               unpacked dim is not a broadcast, no mask is
734                 //               needed on the new transfer op.
735                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
736                                       xferOp.getMaskType().getRank() > 1)) {
737                   OpBuilder::InsertionGuard guard(b);
738                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
739 
740                   SmallVector<Value, 8> loadIndices;
741                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
742                   // In case of broadcast: Use same indices to load from memref
743                   // as before.
744                   if (!xferOp.isBroadcastDim(0))
745                     loadIndices.push_back(iv);
746 
747                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
748                                                        loadIndices);
749                   rewriter.updateRootInPlace(
750                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
751                 }
752 
753                 return loopState.empty() ? Value() : newXfer->getResult(0);
754               },
755               /*outOfBoundsCase=*/
756               [&](OpBuilder &b, Location /*loc*/) {
757                 return Strategy<OpTy>::handleOutOfBoundsDim(
758                     b, xferOp, castedDataBuffer, iv, loopState);
759               });
760 
761           maybeYieldValue(b, loc, !loopState.empty(), result);
762         });
763 
764     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
765     return success();
766   }
767 };
768 
769 } // namespace lowering_n_d
770 
771 namespace lowering_n_d_unrolled {
772 
773 /// If the original transfer op has a mask, compute the mask of the new transfer
774 /// op (for the current iteration `i`) and assign it.
775 template <typename OpTy>
maybeAssignMask(OpBuilder & b,OpTy xferOp,OpTy newXferOp,int64_t i)776 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
777                             int64_t i) {
778   if (!xferOp.mask())
779     return;
780 
781   if (xferOp.isBroadcastDim(0)) {
782     // To-be-unpacked dimension is a broadcast, which does not have a
783     // corresponding mask dimension. Mask attribute remains unchanged.
784     newXferOp.maskMutable().assign(xferOp.mask());
785     return;
786   }
787 
788   if (xferOp.getMaskType().getRank() > 1) {
789     // Unpack one dimension of the mask.
790     OpBuilder::InsertionGuard guard(b);
791     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
792 
793     llvm::SmallVector<int64_t, 1> indices({i});
794     Location loc = xferOp.getLoc();
795     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
796     newXferOp.maskMutable().assign(newMask);
797   }
798 
799   // If we end up here: The mask of the old transfer op is 1D and the unpacked
800   // dim is not a broadcast, so no mask is needed on the new transfer op.
801   // `generateInBoundsCheck` will have evaluated the mask already.
802 }
803 
804 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
805 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
806 /// memref buffer is allocated and the SCF loop is fully unrolled.
807 ///
808 /// ```
809 /// E.g.:
810 /// ```
811 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
812 ///     : memref<?x?x?xf32>, vector<5x4xf32>
813 /// ```
814 /// is rewritten to IR such as (simplified):
815 /// ```
816 /// %v_init = splat %padding : vector<5x4xf32>
817 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
818 ///     : memref<?x?x?xf32>, vector<4xf32>
819 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
820 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
821 ///     : memref<?x?x?xf32>, vector<4xf32>
822 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
823 /// ...
824 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
825 ///     : memref<?x?x?xf32>, vector<4xf32>
826 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
827 /// ```
828 ///
829 /// Note: As an optimization, if the result of the original TransferReadOp
830 /// was directly inserted into another vector, no new %v_init vector is created.
831 /// Instead, the new TransferReadOp results are inserted into that vector.
832 struct UnrollTransferReadConversion
833     : public VectorToSCFPattern<TransferReadOp> {
834   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
835 
initialize__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion836   void initialize() {
837     // This pattern recursively unpacks one dimension at a time. The recursion
838     // bounded as the rank is strictly decreasing.
839     setHasBoundedRewriteRecursion();
840   }
841 
842   /// Return the vector into which the newly created TransferReadOp results
843   /// are inserted.
getResultVector__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion844   Value getResultVector(TransferReadOp xferOp,
845                         PatternRewriter &rewriter) const {
846     if (auto insertOp = getInsertOp(xferOp))
847       return insertOp.dest();
848     Location loc = xferOp.getLoc();
849     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
850                                     xferOp.padding());
851   }
852 
853   /// If the result of the TransferReadOp has exactly one user, which is a
854   /// vector::InsertOp, return that operation.
getInsertOp__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion855   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
856     if (xferOp->hasOneUse()) {
857       Operation *xferOpUser = *xferOp->getUsers().begin();
858       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
859         return insertOp;
860     }
861 
862     return vector::InsertOp();
863   }
864 
865   /// If the result of the TransferReadOp has exactly one user, which is a
866   /// vector::InsertOp, return that operation's indices.
getInsertionIndices__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion867   void getInsertionIndices(TransferReadOp xferOp,
868                            SmallVector<int64_t, 8> &indices) const {
869     if (auto insertOp = getInsertOp(xferOp)) {
870       llvm::for_each(insertOp.position(), [&](Attribute attr) {
871         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
872       });
873     }
874   }
875 
876   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
877   /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion878   LogicalResult matchAndRewrite(TransferReadOp xferOp,
879                                 PatternRewriter &rewriter) const override {
880     if (xferOp.getVectorType().getRank() <= options.targetRank)
881       return failure();
882     if (isTensorOp(xferOp) && !options.lowerTensors)
883       return failure();
884     // Transfer ops that modify the element type are not supported atm.
885     if (xferOp.getVectorType().getElementType() !=
886         xferOp.getShapedType().getElementType())
887       return failure();
888 
889     auto insertOp = getInsertOp(xferOp);
890     auto vec = getResultVector(xferOp, rewriter);
891     auto vecType = vec.getType().dyn_cast<VectorType>();
892     auto xferVecType = xferOp.getVectorType();
893     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
894                                           xferVecType.getElementType());
895     int64_t dimSize = xferVecType.getShape()[0];
896 
897     // Generate fully unrolled loop of transfer ops.
898     Location loc = xferOp.getLoc();
899     for (int64_t i = 0; i < dimSize; ++i) {
900       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
901 
902       vec = generateInBoundsCheck(
903           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
904           /*inBoundsCase=*/
905           [&](OpBuilder &b, Location loc) {
906             // Indices for the new transfer op.
907             SmallVector<Value, 8> xferIndices;
908             getXferIndices(b, xferOp, iv, xferIndices);
909 
910             // Indices for the new vector.insert op.
911             SmallVector<int64_t, 8> insertionIndices;
912             getInsertionIndices(xferOp, insertionIndices);
913             insertionIndices.push_back(i);
914 
915             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
916             auto newXferOp = b.create<vector::TransferReadOp>(
917                 loc, newXferVecType, xferOp.source(), xferIndices,
918                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
919                 xferOp.padding(), Value(), inBoundsAttr);
920             maybeAssignMask(b, xferOp, newXferOp, i);
921             return b.create<vector::InsertOp>(loc, newXferOp, vec,
922                                               insertionIndices);
923           },
924           /*outOfBoundsCase=*/
925           [&](OpBuilder &b, Location loc) {
926             // Loop through original (unmodified) vector.
927             return vec;
928           });
929     }
930 
931     if (insertOp) {
932       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
933       rewriter.replaceOp(insertOp, vec);
934       rewriter.eraseOp(xferOp);
935     } else {
936       rewriter.replaceOp(xferOp, vec);
937     }
938 
939     return success();
940   }
941 };
942 
943 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
944 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
945 /// memref buffer is allocated and the SCF loop is fully unrolled.
946 ///
947 /// ```
948 /// E.g.:
949 /// ```
950 /// vector.transfer_write %vec, %A[%a, %b, %c]
951 ///     : vector<5x4xf32>, memref<?x?x?xf32>
952 /// ```
953 /// is rewritten to IR such as (simplified):
954 /// ```
955 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
956 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
957 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
958 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
959 /// ...
960 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
961 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
962 /// ```
963 ///
964 /// Note: As an optimization, if the vector of the original TransferWriteOp
965 /// was directly extracted from another vector via an ExtractOp `a`, extract
966 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
967 /// doing so, `a` may become dead, and the number of ExtractOps generated during
968 /// recursive application of this pattern will be minimal.
969 struct UnrollTransferWriteConversion
970     : public VectorToSCFPattern<TransferWriteOp> {
971   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
972 
initialize__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion973   void initialize() {
974     // This pattern recursively unpacks one dimension at a time. The recursion
975     // bounded as the rank is strictly decreasing.
976     setHasBoundedRewriteRecursion();
977   }
978 
979   /// Return the vector from which newly generated ExtracOps will extract.
getDataVector__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion980   Value getDataVector(TransferWriteOp xferOp) const {
981     if (auto extractOp = getExtractOp(xferOp))
982       return extractOp.vector();
983     return xferOp.vector();
984   }
985 
986   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
getExtractOp__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion987   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
988     if (auto *op = xferOp.vector().getDefiningOp())
989       return dyn_cast<vector::ExtractOp>(op);
990     return vector::ExtractOp();
991   }
992 
993   /// If the input of the given TransferWriteOp is an ExtractOp, return its
994   /// indices.
getExtractionIndices__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion995   void getExtractionIndices(TransferWriteOp xferOp,
996                             SmallVector<int64_t, 8> &indices) const {
997     if (auto extractOp = getExtractOp(xferOp)) {
998       llvm::for_each(extractOp.position(), [&](Attribute attr) {
999         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1000       });
1001     }
1002   }
1003 
1004   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1005   /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion1006   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1007                                 PatternRewriter &rewriter) const override {
1008     if (xferOp.getVectorType().getRank() <= options.targetRank)
1009       return failure();
1010     if (isTensorOp(xferOp) && !options.lowerTensors)
1011       return failure();
1012     // Transfer ops that modify the element type are not supported atm.
1013     if (xferOp.getVectorType().getElementType() !=
1014         xferOp.getShapedType().getElementType())
1015       return failure();
1016 
1017     auto vec = getDataVector(xferOp);
1018     auto xferVecType = xferOp.getVectorType();
1019     int64_t dimSize = xferVecType.getShape()[0];
1020     auto source = xferOp.source(); // memref or tensor to be written to.
1021     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1022 
1023     // Generate fully unrolled loop of transfer ops.
1024     Location loc = xferOp.getLoc();
1025     for (int64_t i = 0; i < dimSize; ++i) {
1026       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
1027 
1028       auto updatedSource = generateInBoundsCheck(
1029           rewriter, xferOp, iv, unpackedDim(xferOp),
1030           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1031           /*inBoundsCase=*/
1032           [&](OpBuilder &b, Location loc) {
1033             // Indices for the new transfer op.
1034             SmallVector<Value, 8> xferIndices;
1035             getXferIndices(b, xferOp, iv, xferIndices);
1036 
1037             // Indices for the new vector.extract op.
1038             SmallVector<int64_t, 8> extractionIndices;
1039             getExtractionIndices(xferOp, extractionIndices);
1040             extractionIndices.push_back(i);
1041 
1042             auto extracted =
1043                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1044             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1045             auto newXferOp = b.create<vector::TransferWriteOp>(
1046                 loc, sourceType, extracted, source, xferIndices,
1047                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1048                 inBoundsAttr);
1049 
1050             maybeAssignMask(b, xferOp, newXferOp, i);
1051 
1052             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1053           },
1054           /*outOfBoundsCase=*/
1055           [&](OpBuilder &b, Location loc) {
1056             return isTensorOp(xferOp) ? source : Value();
1057           });
1058 
1059       if (isTensorOp(xferOp))
1060         source = updatedSource;
1061     }
1062 
1063     if (isTensorOp(xferOp))
1064       rewriter.replaceOp(xferOp, source);
1065     else
1066       rewriter.eraseOp(xferOp);
1067 
1068     return success();
1069   }
1070 };
1071 
1072 } // namespace lowering_n_d_unrolled
1073 
1074 namespace lowering_1_d {
1075 
1076 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1077 /// part of TransferOp1dConversion. Return the memref dimension on which
1078 /// the transfer is operating. A return value of None indicates a broadcast.
1079 template <typename OpTy>
1080 static Optional<int64_t>
get1dMemrefIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & memrefIndices)1081 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1082                    SmallVector<Value, 8> &memrefIndices) {
1083   auto indices = xferOp.indices();
1084   auto map = xferOp.permutation_map();
1085 
1086   memrefIndices.append(indices.begin(), indices.end());
1087   assert(map.getNumResults() == 1 &&
1088          "Expected 1 permutation map result for 1D transfer");
1089   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1090     Location loc = xferOp.getLoc();
1091     auto dim = expr.getPosition();
1092     AffineExpr d0, d1;
1093     bindDims(xferOp.getContext(), d0, d1);
1094     Value offset = memrefIndices[dim];
1095     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1096     return dim;
1097   }
1098 
1099   assert(xferOp.isBroadcastDim(0) &&
1100          "Expected AffineDimExpr or AffineConstantExpr");
1101   return None;
1102 }
1103 
1104 /// Codegen strategy for TransferOp1dConversion, depending on the
1105 /// operation.
1106 template <typename OpTy>
1107 struct Strategy1d;
1108 
1109 /// Codegen strategy for TransferReadOp.
1110 template <>
1111 struct Strategy1d<TransferReadOp> {
generateForLoopBody__anon32638c370111::lowering_1_d::Strategy1d1112   static void generateForLoopBody(OpBuilder &b, Location loc,
1113                                   TransferReadOp xferOp, Value iv,
1114                                   ValueRange loopState) {
1115     SmallVector<Value, 8> indices;
1116     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1117     Value ivI32 =
1118         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1119     auto vec = loopState[0];
1120 
1121     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1122     // padding value).
1123     auto nextVec = generateInBoundsCheck(
1124         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1125         /*inBoundsCase=*/
1126         [&](OpBuilder &b, Location loc) {
1127           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1128           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
1129         },
1130         /*outOfBoundsCase=*/
1131         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1132     b.create<scf::YieldOp>(loc, nextVec);
1133   }
1134 
initialLoopState__anon32638c370111::lowering_1_d::Strategy1d1135   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1136     // Inititalize vector with padding value.
1137     Location loc = xferOp.getLoc();
1138     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1139   }
1140 };
1141 
1142 /// Codegen strategy for TransferWriteOp.
1143 template <>
1144 struct Strategy1d<TransferWriteOp> {
generateForLoopBody__anon32638c370111::lowering_1_d::Strategy1d1145   static void generateForLoopBody(OpBuilder &b, Location loc,
1146                                   TransferWriteOp xferOp, Value iv,
1147                                   ValueRange /*loopState*/) {
1148     SmallVector<Value, 8> indices;
1149     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1150     Value ivI32 =
1151         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1152 
1153     // Nothing to do in case of out-of-bounds access.
1154     generateInBoundsCheck(
1155         b, xferOp, iv, dim,
1156         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1157           auto val =
1158               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1159           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1160         });
1161     b.create<scf::YieldOp>(loc);
1162   }
1163 
initialLoopState__anon32638c370111::lowering_1_d::Strategy1d1164   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1165     return Value();
1166   }
1167 };
1168 
1169 /// Return true if the last dimension of the MemRefType has unit stride.
isLastMemrefDimUnitStride(MemRefType type)1170 static bool isLastMemrefDimUnitStride(MemRefType type) {
1171   int64_t offset;
1172   SmallVector<int64_t, 4> strides;
1173   auto successStrides = getStridesAndOffset(type, strides, offset);
1174   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1175 }
1176 
1177 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1178 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1179 /// vector load/stores due to non-unit strides or broadcasts:
1180 ///
1181 /// * Transfer dimension is not the last memref dimension
1182 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1183 /// * Memref has a layout map with non-unit stride on the last dimension
1184 ///
1185 /// This pattern generates IR as follows:
1186 ///
1187 /// 1. Generate a for loop iterating over each vector element.
1188 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1189 ///    depending on OpTy.
1190 ///
1191 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1192 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1193 ///       to ConvertVectorToLLVM.
1194 ///
1195 /// E.g.:
1196 /// ```
1197 /// vector.transfer_write %vec, %A[%a, %b]
1198 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1199 ///    : vector<9xf32>, memref<?x?xf32>
1200 /// ```
1201 /// Is rewritten to approximately the following pseudo-IR:
1202 /// ```
1203 /// for i = 0 to 9 {
1204 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1205 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1206 /// }
1207 /// ```
1208 template <typename OpTy>
1209 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1210   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1211 
matchAndRewrite__anon32638c370111::lowering_1_d::TransferOp1dConversion1212   LogicalResult matchAndRewrite(OpTy xferOp,
1213                                 PatternRewriter &rewriter) const override {
1214     auto map = xferOp.permutation_map();
1215     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1216 
1217     if (!memRefType)
1218       return failure();
1219     if (xferOp.getVectorType().getRank() != 1)
1220       return failure();
1221     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1222       return failure(); // Handled by ConvertVectorToLLVM
1223 
1224     // Loop bounds, step, state...
1225     Location loc = xferOp.getLoc();
1226     auto vecType = xferOp.getVectorType();
1227     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
1228     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
1229     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
1230     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1231 
1232     // Generate for loop.
1233     rewriter.replaceOpWithNewOp<scf::ForOp>(
1234         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1235         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1236           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1237         });
1238 
1239     return success();
1240   }
1241 };
1242 
1243 } // namespace lowering_1_d
1244 } // namespace
1245 
1246 namespace mlir {
1247 
populateVectorToSCFConversionPatterns(RewritePatternSet & patterns,const VectorTransferToSCFOptions & options)1248 void populateVectorToSCFConversionPatterns(
1249     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1250   if (options.unroll) {
1251     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1252                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1253         patterns.getContext(), options);
1254   } else {
1255     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1256                  lowering_n_d::PrepareTransferWriteConversion,
1257                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1258                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1259         patterns.getContext(), options);
1260   }
1261 
1262   if (options.targetRank == 1) {
1263     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1264                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1265         patterns.getContext(), options);
1266   }
1267 }
1268 
1269 } // namespace mlir
1270 
1271 namespace {
1272 
1273 struct ConvertVectorToSCFPass
1274     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1275   ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass__anon32638c371611::ConvertVectorToSCFPass1276   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1277     this->fullUnroll = options.unroll;
1278     this->targetRank = options.targetRank;
1279     this->lowerPermutationMaps = options.lowerPermutationMaps;
1280     this->lowerTensors = options.lowerTensors;
1281   }
1282 
runOnFunction__anon32638c371611::ConvertVectorToSCFPass1283   void runOnFunction() override {
1284     VectorTransferToSCFOptions options;
1285     options.unroll = fullUnroll;
1286     options.targetRank = targetRank;
1287     options.lowerPermutationMaps = lowerPermutationMaps;
1288     options.lowerTensors = lowerTensors;
1289 
1290     // Lower permutation maps first.
1291     if (lowerPermutationMaps) {
1292       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1293       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1294           lowerTransferPatterns);
1295       (void)applyPatternsAndFoldGreedily(getFunction(),
1296                                          std::move(lowerTransferPatterns));
1297     }
1298 
1299     RewritePatternSet patterns(getFunction().getContext());
1300     populateVectorToSCFConversionPatterns(patterns, options);
1301     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1302   }
1303 };
1304 
1305 } // namespace
1306 
1307 std::unique_ptr<Pass>
createConvertVectorToSCFPass(const VectorTransferToSCFOptions & options)1308 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1309   return std::make_unique<ConvertVectorToSCFPass>(options);
1310 }
1311