1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <type_traits>
14
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Dialect/Vector/VectorUtils.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/Passes.h"
29
30 using namespace mlir;
31 using vector::TransferReadOp;
32 using vector::TransferWriteOp;
33
34 namespace {
35
36 /// Attribute name used for labeling transfer ops during progressive lowering.
37 static const char kPassLabel[] = "__vector_to_scf_lowering__";
38
39 /// Patterns that inherit from this struct have access to
40 /// VectorTransferToSCFOptions.
41 template <typename OpTy>
42 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorToSCFPattern__anon32638c370111::VectorToSCFPattern43 explicit VectorToSCFPattern(MLIRContext *context,
44 VectorTransferToSCFOptions opt)
45 : OpRewritePattern<OpTy>(context), options(opt) {}
46
47 VectorTransferToSCFOptions options;
48 };
49
50 /// Given a vector transfer op, calculate which dimension of the `source`
51 /// memref should be unpacked in the next application of TransferOpConversion.
52 /// A return value of None indicates a broadcast.
53 template <typename OpTy>
unpackedDim(OpTy xferOp)54 static Optional<int64_t> unpackedDim(OpTy xferOp) {
55 auto map = xferOp.permutation_map();
56 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
57 return expr.getPosition();
58 }
59 assert(xferOp.isBroadcastDim(0) &&
60 "Expected AffineDimExpr or AffineConstantExpr");
61 return None;
62 }
63
64 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
65 /// map is identical to the current permutation map, but the first result is
66 /// omitted.
67 template <typename OpTy>
unpackedPermutationMap(OpBuilder & b,OpTy xferOp)68 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
69 auto map = xferOp.permutation_map();
70 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
71 b.getContext());
72 }
73
74 /// Calculate the indices for the new vector transfer op.
75 ///
76 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
77 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
78 /// ^^^^^^
79 /// `iv` is the iteration variable of the (new) surrounding loop.
80 template <typename OpTy>
getXferIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & indices)81 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
82 SmallVector<Value, 8> &indices) {
83 typename OpTy::Adaptor adaptor(xferOp);
84 // Corresponding memref dim of the vector dim that is unpacked.
85 auto dim = unpackedDim(xferOp);
86 auto prevIndices = adaptor.indices();
87 indices.append(prevIndices.begin(), prevIndices.end());
88
89 Location loc = xferOp.getLoc();
90 bool isBroadcast = !dim.hasValue();
91 if (!isBroadcast) {
92 AffineExpr d0, d1;
93 bindDims(xferOp.getContext(), d0, d1);
94 Value offset = adaptor.indices()[dim.getValue()];
95 indices[dim.getValue()] =
96 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
97 }
98 }
99
maybeYieldValue(OpBuilder & b,Location loc,bool hasRetVal,Value value)100 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
101 Value value) {
102 if (hasRetVal) {
103 assert(value && "Expected non-empty value");
104 b.create<scf::YieldOp>(loc, value);
105 } else {
106 b.create<scf::YieldOp>(loc);
107 }
108 }
109
110 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
111 /// is set to true. No such check is generated under following circumstances:
112 /// * xferOp does not have a mask.
113 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
114 /// computed and attached to the new transfer op in the pattern.)
115 /// * The to-be-unpacked dim of xferOp is a broadcast.
116 template <typename OpTy>
generateMaskCheck(OpBuilder & b,OpTy xferOp,Value iv)117 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
118 if (!xferOp.mask())
119 return Value();
120 if (xferOp.getMaskType().getRank() != 1)
121 return Value();
122 if (xferOp.isBroadcastDim(0))
123 return Value();
124
125 Location loc = xferOp.getLoc();
126 Value ivI32 =
127 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
128 return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
129 }
130
131 /// Helper function TransferOpConversion and TransferOp1dConversion.
132 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
133 /// specified dimension `dim` with the loop iteration variable `iv`.
134 /// E.g., when unpacking dimension 0 from:
135 /// ```
136 /// %vec = vector.transfer_read %A[%a, %b] %cst
137 /// : vector<5x4xf32>, memref<?x?xf32>
138 /// ```
139 /// An if check similar to this will be generated inside the loop:
140 /// ```
141 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
142 /// if (%a + iv < %d) {
143 /// (in-bounds case)
144 /// } else {
145 /// (out-of-bounds case)
146 /// }
147 /// ```
148 ///
149 /// If the transfer is 1D and has a mask, this function generates a more complex
150 /// check also accounts for potentially masked out elements.
151 ///
152 /// This function variant returns the value returned by `inBoundsCase` or
153 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
154 /// `resultTypes`.
155 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,TypeRange resultTypes,function_ref<Value (OpBuilder &,Location)> inBoundsCase,function_ref<Value (OpBuilder &,Location)> outOfBoundsCase=nullptr)156 static Value generateInBoundsCheck(
157 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
158 TypeRange resultTypes,
159 function_ref<Value(OpBuilder &, Location)> inBoundsCase,
160 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
161 bool hasRetVal = !resultTypes.empty();
162 Value cond; // Condition to be built...
163
164 // Condition check 1: Access in-bounds?
165 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
166 Location loc = xferOp.getLoc();
167 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
168 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
169 Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
170 AffineExpr d0, d1;
171 bindDims(xferOp.getContext(), d0, d1);
172 Value base = xferOp.indices()[dim.getValue()];
173 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
174 cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
175 }
176
177 // Condition check 2: Masked in?
178 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
179 if (cond)
180 cond = lb.create<AndOp>(cond, maskCond);
181 else
182 cond = maskCond;
183 }
184
185 // If the condition is non-empty, generate an SCF::IfOp.
186 if (cond) {
187 auto check = lb.create<scf::IfOp>(
188 resultTypes, cond,
189 /*thenBuilder=*/
190 [&](OpBuilder &b, Location loc) {
191 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
192 },
193 /*elseBuilder=*/
194 [&](OpBuilder &b, Location loc) {
195 if (outOfBoundsCase) {
196 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
197 } else {
198 b.create<scf::YieldOp>(loc);
199 }
200 });
201
202 return hasRetVal ? check.getResult(0) : Value();
203 }
204
205 // Condition is empty, no need for an SCF::IfOp.
206 return inBoundsCase(b, loc);
207 }
208
209 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
210 /// a return value. Consequently, this function does not have a return value.
211 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,function_ref<void (OpBuilder &,Location)> inBoundsCase,function_ref<void (OpBuilder &,Location)> outOfBoundsCase=nullptr)212 static void generateInBoundsCheck(
213 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
214 function_ref<void(OpBuilder &, Location)> inBoundsCase,
215 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
216 generateInBoundsCheck(
217 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
218 /*inBoundsCase=*/
219 [&](OpBuilder &b, Location loc) {
220 inBoundsCase(b, loc);
221 return Value();
222 },
223 /*outOfBoundsCase=*/
224 [&](OpBuilder &b, Location loc) {
225 if (outOfBoundsCase)
226 outOfBoundsCase(b, loc);
227 return Value();
228 });
229 }
230
231 /// Given an ArrayAttr, return a copy where the first element is dropped.
dropFirstElem(OpBuilder & b,ArrayAttr attr)232 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
233 if (!attr)
234 return attr;
235 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
236 }
237
238 /// Add the pass label to a vector transfer op if its rank is not the target
239 /// rank.
240 template <typename OpTy>
maybeApplyPassLabel(OpBuilder & b,OpTy newXferOp,unsigned targetRank)241 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
242 unsigned targetRank) {
243 if (newXferOp.getVectorType().getRank() > targetRank)
244 newXferOp->setAttr(kPassLabel, b.getUnitAttr());
245 }
246
247 /// Return true if this transfer op operates on a source tensor.
248 template <typename OpTy>
isTensorOp(OpTy xferOp)249 static bool isTensorOp(OpTy xferOp) {
250 if (xferOp.getShapedType().template isa<RankedTensorType>()) {
251 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
252 // TransferWriteOps on tensors have a result.
253 assert(xferOp->getNumResults() > 0);
254 }
255 return true;
256 }
257 return false;
258 }
259
260 namespace lowering_n_d {
261
262 /// Helper data structure for data and mask buffers.
263 struct BufferAllocs {
264 Value dataBuffer;
265 Value maskBuffer;
266 };
267
268 /// Allocate temporary buffers for data (vector) and mask (if present).
269 /// TODO: Parallelism and threadlocal considerations.
270 template <typename OpTy>
allocBuffers(OpBuilder & b,OpTy xferOp)271 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
272 Location loc = xferOp.getLoc();
273 OpBuilder::InsertionGuard guard(b);
274 Operation *scope =
275 xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
276 assert(scope && "Expected op to be inside automatic allocation scope");
277 b.setInsertionPointToStart(&scope->getRegion(0).front());
278
279 BufferAllocs result;
280 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
281 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
282
283 if (xferOp.mask()) {
284 auto maskType = MemRefType::get({}, xferOp.mask().getType());
285 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
286 b.setInsertionPoint(xferOp);
287 b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
288 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
289 }
290
291 return result;
292 }
293
294 /// Given a MemRefType with VectorType element type, unpack one dimension from
295 /// the VectorType into the MemRefType.
296 ///
297 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
unpackOneDim(MemRefType type)298 static MemRefType unpackOneDim(MemRefType type) {
299 auto vectorType = type.getElementType().dyn_cast<VectorType>();
300 auto memrefShape = type.getShape();
301 SmallVector<int64_t, 8> newMemrefShape;
302 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
303 newMemrefShape.push_back(vectorType.getDimSize(0));
304 return MemRefType::get(newMemrefShape,
305 VectorType::get(vectorType.getShape().drop_front(),
306 vectorType.getElementType()));
307 }
308
309 /// Given a transfer op, find the memref from which the mask is loaded. This
310 /// is similar to Strategy<TransferWriteOp>::getBuffer.
311 template <typename OpTy>
getMaskBuffer(OpTy xferOp)312 static Value getMaskBuffer(OpTy xferOp) {
313 assert(xferOp.mask() && "Expected that transfer op has mask");
314 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
315 assert(loadOp && "Expected transfer op mask produced by LoadOp");
316 return loadOp.getMemRef();
317 }
318
319 /// Codegen strategy, depending on the operation.
320 template <typename OpTy>
321 struct Strategy;
322
323 /// Code strategy for vector TransferReadOp.
324 template <>
325 struct Strategy<TransferReadOp> {
326 /// Find the StoreOp that is used for writing the current TransferReadOp's
327 /// result to the temporary buffer allocation.
getStoreOp__anon32638c370111::lowering_n_d::Strategy328 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
329 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
330 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
331 assert(storeOp && "Expected TransferReadOp result used by StoreOp");
332 return storeOp;
333 }
334
335 /// Find the temporary buffer allocation. All labeled TransferReadOps are
336 /// used like this, where %buf is either the buffer allocation or a type cast
337 /// of the buffer allocation:
338 /// ```
339 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
340 /// memref.store %vec, %buf[...] ...
341 /// ```
getBuffer__anon32638c370111::lowering_n_d::Strategy342 static Value getBuffer(TransferReadOp xferOp) {
343 return getStoreOp(xferOp).getMemRef();
344 }
345
346 /// Retrieve the indices of the current StoreOp that stores into the buffer.
getBufferIndices__anon32638c370111::lowering_n_d::Strategy347 static void getBufferIndices(TransferReadOp xferOp,
348 SmallVector<Value, 8> &indices) {
349 auto storeOp = getStoreOp(xferOp);
350 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
351 indices.append(prevIndices.begin(), prevIndices.end());
352 }
353
354 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
355 /// accesses on the to-be-unpacked dimension.
356 ///
357 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
358 /// variable `iv`.
359 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
360 ///
361 /// E.g.:
362 /// ```
363 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
364 /// : memref<?x?x?xf32>, vector<4x3xf32>
365 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
366 /// ```
367 /// Is rewritten to:
368 /// ```
369 /// %casted = vector.type_cast %buf
370 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
371 /// for %j = 0 to 4 {
372 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
373 /// : memref<?x?x?xf32>, vector<3xf32>
374 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
375 /// }
376 /// ```
377 ///
378 /// Note: The loop and type cast are generated in TransferOpConversion.
379 /// The original TransferReadOp and store op are deleted in `cleanup`.
380 /// Note: The `mask` operand is set in TransferOpConversion.
rewriteOp__anon32638c370111::lowering_n_d::Strategy381 static TransferReadOp rewriteOp(OpBuilder &b,
382 VectorTransferToSCFOptions options,
383 TransferReadOp xferOp, Value buffer, Value iv,
384 ValueRange /*loopState*/) {
385 SmallVector<Value, 8> storeIndices;
386 getBufferIndices(xferOp, storeIndices);
387 storeIndices.push_back(iv);
388
389 SmallVector<Value, 8> xferIndices;
390 getXferIndices(b, xferOp, iv, xferIndices);
391
392 Location loc = xferOp.getLoc();
393 auto bufferType = buffer.getType().dyn_cast<ShapedType>();
394 auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
395 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
396 auto newXferOp = b.create<vector::TransferReadOp>(
397 loc, vecType, xferOp.source(), xferIndices,
398 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
399 Value(), inBoundsAttr);
400
401 maybeApplyPassLabel(b, newXferOp, options.targetRank);
402
403 b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
404 return newXferOp;
405 }
406
407 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
408 /// padding value to the temporary buffer.
handleOutOfBoundsDim__anon32638c370111::lowering_n_d::Strategy409 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
410 Value buffer, Value iv,
411 ValueRange /*loopState*/) {
412 SmallVector<Value, 8> storeIndices;
413 getBufferIndices(xferOp, storeIndices);
414 storeIndices.push_back(iv);
415
416 Location loc = xferOp.getLoc();
417 auto bufferType = buffer.getType().dyn_cast<ShapedType>();
418 auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
419 auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
420 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
421
422 return Value();
423 }
424
425 /// Cleanup after rewriting the op.
cleanup__anon32638c370111::lowering_n_d::Strategy426 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
427 scf::ForOp /*forOp*/) {
428 rewriter.eraseOp(getStoreOp(xferOp));
429 rewriter.eraseOp(xferOp);
430 }
431
432 /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon32638c370111::lowering_n_d::Strategy433 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
434 };
435
436 /// Codegen strategy for vector TransferWriteOp.
437 template <>
438 struct Strategy<TransferWriteOp> {
439 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
440 /// used like this, where %buf is either the buffer allocation or a type cast
441 /// of the buffer allocation:
442 /// ```
443 /// %vec = memref.load %buf[...] ...
444 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
445 /// ```
getBuffer__anon32638c370111::lowering_n_d::Strategy446 static Value getBuffer(TransferWriteOp xferOp) {
447 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
448 assert(loadOp && "Expected transfer op vector produced by LoadOp");
449 return loadOp.getMemRef();
450 }
451
452 /// Retrieve the indices of the current LoadOp that loads from the buffer.
getBufferIndices__anon32638c370111::lowering_n_d::Strategy453 static void getBufferIndices(TransferWriteOp xferOp,
454 SmallVector<Value, 8> &indices) {
455 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
456 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
457 indices.append(prevIndices.begin(), prevIndices.end());
458 }
459
460 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
461 /// accesses on the to-be-unpacked dimension.
462 ///
463 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
464 /// using the loop iteration variable `iv`.
465 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
466 /// to memory.
467 ///
468 /// Note: For more details, see comments on Strategy<TransferReadOp>.
rewriteOp__anon32638c370111::lowering_n_d::Strategy469 static TransferWriteOp rewriteOp(OpBuilder &b,
470 VectorTransferToSCFOptions options,
471 TransferWriteOp xferOp, Value buffer,
472 Value iv, ValueRange loopState) {
473 SmallVector<Value, 8> loadIndices;
474 getBufferIndices(xferOp, loadIndices);
475 loadIndices.push_back(iv);
476
477 SmallVector<Value, 8> xferIndices;
478 getXferIndices(b, xferOp, iv, xferIndices);
479
480 Location loc = xferOp.getLoc();
481 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
482 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
483 auto source = loopState.empty() ? xferOp.source() : loopState[0];
484 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
485 auto newXferOp = b.create<vector::TransferWriteOp>(
486 loc, type, vec, source, xferIndices,
487 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
488 inBoundsAttr);
489
490 maybeApplyPassLabel(b, newXferOp, options.targetRank);
491
492 return newXferOp;
493 }
494
495 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
handleOutOfBoundsDim__anon32638c370111::lowering_n_d::Strategy496 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
497 Value buffer, Value iv,
498 ValueRange loopState) {
499 return isTensorOp(xferOp) ? loopState[0] : Value();
500 }
501
502 /// Cleanup after rewriting the op.
cleanup__anon32638c370111::lowering_n_d::Strategy503 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
504 scf::ForOp forOp) {
505 if (isTensorOp(xferOp)) {
506 assert(forOp->getNumResults() == 1 && "Expected one for loop result");
507 rewriter.replaceOp(xferOp, forOp->getResult(0));
508 } else {
509 rewriter.eraseOp(xferOp);
510 }
511 }
512
513 /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon32638c370111::lowering_n_d::Strategy514 static Value initialLoopState(TransferWriteOp xferOp) {
515 return isTensorOp(xferOp) ? xferOp.source() : Value();
516 }
517 };
518
519 template <typename OpTy>
checkPrepareXferOp(OpTy xferOp,VectorTransferToSCFOptions options)520 LogicalResult checkPrepareXferOp(OpTy xferOp,
521 VectorTransferToSCFOptions options) {
522 if (xferOp->hasAttr(kPassLabel))
523 return failure();
524 if (xferOp.getVectorType().getRank() <= options.targetRank)
525 return failure();
526 if (isTensorOp(xferOp) && !options.lowerTensors)
527 return failure();
528 // Transfer ops that modify the element type are not supported atm.
529 if (xferOp.getVectorType().getElementType() !=
530 xferOp.getShapedType().getElementType())
531 return failure();
532 return success();
533 }
534
535 /// Prepare a TransferReadOp for progressive lowering.
536 ///
537 /// 1. Allocate a temporary buffer.
538 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
539 /// 3. Store the result of the TransferReadOp into the temporary buffer.
540 /// 4. Load the result from the temporary buffer and replace all uses of the
541 /// original TransferReadOp with this load.
542 ///
543 /// E.g.:
544 /// ```
545 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
546 /// : vector<5x4xf32>, memref<?x?x?xf32>
547 /// ```
548 /// is rewritten to:
549 /// ```
550 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
551 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
552 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
553 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
554 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
555 /// ```
556 ///
557 /// Note: A second temporary buffer may be allocated for the `mask` operand.
558 struct PrepareTransferReadConversion
559 : public VectorToSCFPattern<TransferReadOp> {
560 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
561
matchAndRewrite__anon32638c370111::lowering_n_d::PrepareTransferReadConversion562 LogicalResult matchAndRewrite(TransferReadOp xferOp,
563 PatternRewriter &rewriter) const override {
564 if (checkPrepareXferOp(xferOp, options).failed())
565 return failure();
566
567 auto buffers = allocBuffers(rewriter, xferOp);
568 auto *newXfer = rewriter.clone(*xferOp.getOperation());
569 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
570 if (xferOp.mask()) {
571 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
572 buffers.maskBuffer);
573 }
574
575 Location loc = xferOp.getLoc();
576 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
577 buffers.dataBuffer);
578 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
579
580 return success();
581 }
582 };
583
584 /// Prepare a TransferWriteOp for progressive lowering.
585 ///
586 /// 1. Allocate a temporary buffer.
587 /// 2. Store the vector into the buffer.
588 /// 3. Load the vector from the buffer again.
589 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
590 /// marking it eligible for progressive lowering via TransferOpConversion.
591 ///
592 /// E.g.:
593 /// ```
594 /// vector.transfer_write %vec, %A[%a, %b, %c]
595 /// : vector<5x4xf32>, memref<?x?x?xf32>
596 /// ```
597 /// is rewritten to:
598 /// ```
599 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
600 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
601 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
602 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
603 /// : vector<5x4xf32>, memref<?x?x?xf32>
604 /// ```
605 ///
606 /// Note: A second temporary buffer may be allocated for the `mask` operand.
607 struct PrepareTransferWriteConversion
608 : public VectorToSCFPattern<TransferWriteOp> {
609 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
610
matchAndRewrite__anon32638c370111::lowering_n_d::PrepareTransferWriteConversion611 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
612 PatternRewriter &rewriter) const override {
613 if (checkPrepareXferOp(xferOp, options).failed())
614 return failure();
615
616 Location loc = xferOp.getLoc();
617 auto buffers = allocBuffers(rewriter, xferOp);
618 rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
619 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
620 rewriter.updateRootInPlace(xferOp, [&]() {
621 xferOp.vectorMutable().assign(loadedVec);
622 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
623 });
624
625 if (xferOp.mask()) {
626 rewriter.updateRootInPlace(
627 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
628 }
629
630 return success();
631 }
632 };
633
634 /// Progressive lowering of vector transfer ops: Unpack one dimension.
635 ///
636 /// 1. Unpack one dimension from the current buffer type and cast the buffer
637 /// to that new type. E.g.:
638 /// ```
639 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
640 /// vector.transfer_write %vec ...
641 /// ```
642 /// The following cast is generated:
643 /// ```
644 /// %casted = vector.type_cast %0
645 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
646 /// ```
647 /// 2. Generate a for loop and rewrite the transfer op according to the
648 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
649 /// out-of-bounds, generate an if-check and handle both cases separately.
650 /// 3. Clean up according to the corresponding Strategy<OpTy>.
651 ///
652 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
653 /// source (as opposed to a memref source), then each iteration of the generated
654 /// scf.for loop yields the new tensor value. E.g.:
655 /// ```
656 /// %result = scf.for i = 0 to 5 {
657 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
658 /// %1 = vector.transfer_write %0, %source[...]
659 /// : vector<4x3xf32>, tensor<5x4x3xf32>
660 /// scf.yield %1 : tensor<5x4x3xf32>
661 /// }
662 /// ```
663 template <typename OpTy>
664 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
665 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
666
initialize__anon32638c370111::lowering_n_d::TransferOpConversion667 void initialize() {
668 // This pattern recursively unpacks one dimension at a time. The recursion
669 // bounded as the rank is strictly decreasing.
670 this->setHasBoundedRewriteRecursion();
671 }
672
matchAndRewrite__anon32638c370111::lowering_n_d::TransferOpConversion673 LogicalResult matchAndRewrite(OpTy xferOp,
674 PatternRewriter &rewriter) const override {
675 if (!xferOp->hasAttr(kPassLabel))
676 return failure();
677
678 // Find and cast data buffer. How the buffer can be found depends on OpTy.
679 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
680 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
681 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
682 auto castedDataType = unpackOneDim(dataBufferType);
683 auto castedDataBuffer =
684 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
685
686 // If the xferOp has a mask: Find and cast mask buffer.
687 Value castedMaskBuffer;
688 if (xferOp.mask()) {
689 auto maskBuffer = getMaskBuffer(xferOp);
690 auto maskBufferType =
691 maskBuffer.getType().template dyn_cast<MemRefType>();
692 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
693 // Do not unpack a dimension of the mask, if:
694 // * To-be-unpacked transfer op dimension is a broadcast.
695 // * Mask is 1D, i.e., the mask cannot be further unpacked.
696 // (That means that all remaining dimensions of the transfer op must
697 // be broadcasted.)
698 castedMaskBuffer = maskBuffer;
699 } else {
700 auto castedMaskType = unpackOneDim(maskBufferType);
701 castedMaskBuffer =
702 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
703 }
704 }
705
706 // Loop bounds and step.
707 auto lb = locB.create<ConstantIndexOp>(0);
708 auto ub = locB.create<ConstantIndexOp>(
709 castedDataType.getDimSize(castedDataType.getRank() - 1));
710 auto step = locB.create<ConstantIndexOp>(1);
711 // TransferWriteOps that operate on tensors return the modified tensor and
712 // require a loop state.
713 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
714
715 // Generate for loop.
716 auto result = locB.create<scf::ForOp>(
717 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
718 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
719 Type stateType = loopState.empty() ? Type() : loopState[0].getType();
720
721 auto result = generateInBoundsCheck(
722 b, xferOp, iv, unpackedDim(xferOp),
723 stateType ? TypeRange(stateType) : TypeRange(),
724 /*inBoundsCase=*/
725 [&](OpBuilder &b, Location loc) {
726 // Create new transfer op.
727 OpTy newXfer = Strategy<OpTy>::rewriteOp(
728 b, this->options, xferOp, castedDataBuffer, iv, loopState);
729
730 // If old transfer op has a mask: Set mask on new transfer op.
731 // Special case: If the mask of the old transfer op is 1D and
732 // the
733 // unpacked dim is not a broadcast, no mask is
734 // needed on the new transfer op.
735 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
736 xferOp.getMaskType().getRank() > 1)) {
737 OpBuilder::InsertionGuard guard(b);
738 b.setInsertionPoint(newXfer); // Insert load before newXfer.
739
740 SmallVector<Value, 8> loadIndices;
741 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
742 // In case of broadcast: Use same indices to load from memref
743 // as before.
744 if (!xferOp.isBroadcastDim(0))
745 loadIndices.push_back(iv);
746
747 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
748 loadIndices);
749 rewriter.updateRootInPlace(
750 newXfer, [&]() { newXfer.maskMutable().assign(mask); });
751 }
752
753 return loopState.empty() ? Value() : newXfer->getResult(0);
754 },
755 /*outOfBoundsCase=*/
756 [&](OpBuilder &b, Location /*loc*/) {
757 return Strategy<OpTy>::handleOutOfBoundsDim(
758 b, xferOp, castedDataBuffer, iv, loopState);
759 });
760
761 maybeYieldValue(b, loc, !loopState.empty(), result);
762 });
763
764 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
765 return success();
766 }
767 };
768
769 } // namespace lowering_n_d
770
771 namespace lowering_n_d_unrolled {
772
773 /// If the original transfer op has a mask, compute the mask of the new transfer
774 /// op (for the current iteration `i`) and assign it.
775 template <typename OpTy>
maybeAssignMask(OpBuilder & b,OpTy xferOp,OpTy newXferOp,int64_t i)776 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
777 int64_t i) {
778 if (!xferOp.mask())
779 return;
780
781 if (xferOp.isBroadcastDim(0)) {
782 // To-be-unpacked dimension is a broadcast, which does not have a
783 // corresponding mask dimension. Mask attribute remains unchanged.
784 newXferOp.maskMutable().assign(xferOp.mask());
785 return;
786 }
787
788 if (xferOp.getMaskType().getRank() > 1) {
789 // Unpack one dimension of the mask.
790 OpBuilder::InsertionGuard guard(b);
791 b.setInsertionPoint(newXferOp); // Insert load before newXfer.
792
793 llvm::SmallVector<int64_t, 1> indices({i});
794 Location loc = xferOp.getLoc();
795 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
796 newXferOp.maskMutable().assign(newMask);
797 }
798
799 // If we end up here: The mask of the old transfer op is 1D and the unpacked
800 // dim is not a broadcast, so no mask is needed on the new transfer op.
801 // `generateInBoundsCheck` will have evaluated the mask already.
802 }
803
804 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
805 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
806 /// memref buffer is allocated and the SCF loop is fully unrolled.
807 ///
808 /// ```
809 /// E.g.:
810 /// ```
811 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
812 /// : memref<?x?x?xf32>, vector<5x4xf32>
813 /// ```
814 /// is rewritten to IR such as (simplified):
815 /// ```
816 /// %v_init = splat %padding : vector<5x4xf32>
817 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
818 /// : memref<?x?x?xf32>, vector<4xf32>
819 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
820 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
821 /// : memref<?x?x?xf32>, vector<4xf32>
822 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
823 /// ...
824 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
825 /// : memref<?x?x?xf32>, vector<4xf32>
826 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
827 /// ```
828 ///
829 /// Note: As an optimization, if the result of the original TransferReadOp
830 /// was directly inserted into another vector, no new %v_init vector is created.
831 /// Instead, the new TransferReadOp results are inserted into that vector.
832 struct UnrollTransferReadConversion
833 : public VectorToSCFPattern<TransferReadOp> {
834 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
835
initialize__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion836 void initialize() {
837 // This pattern recursively unpacks one dimension at a time. The recursion
838 // bounded as the rank is strictly decreasing.
839 setHasBoundedRewriteRecursion();
840 }
841
842 /// Return the vector into which the newly created TransferReadOp results
843 /// are inserted.
getResultVector__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion844 Value getResultVector(TransferReadOp xferOp,
845 PatternRewriter &rewriter) const {
846 if (auto insertOp = getInsertOp(xferOp))
847 return insertOp.dest();
848 Location loc = xferOp.getLoc();
849 return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
850 xferOp.padding());
851 }
852
853 /// If the result of the TransferReadOp has exactly one user, which is a
854 /// vector::InsertOp, return that operation.
getInsertOp__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion855 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
856 if (xferOp->hasOneUse()) {
857 Operation *xferOpUser = *xferOp->getUsers().begin();
858 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
859 return insertOp;
860 }
861
862 return vector::InsertOp();
863 }
864
865 /// If the result of the TransferReadOp has exactly one user, which is a
866 /// vector::InsertOp, return that operation's indices.
getInsertionIndices__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion867 void getInsertionIndices(TransferReadOp xferOp,
868 SmallVector<int64_t, 8> &indices) const {
869 if (auto insertOp = getInsertOp(xferOp)) {
870 llvm::for_each(insertOp.position(), [&](Attribute attr) {
871 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
872 });
873 }
874 }
875
876 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
877 /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon32638c370111::lowering_n_d_unrolled::UnrollTransferReadConversion878 LogicalResult matchAndRewrite(TransferReadOp xferOp,
879 PatternRewriter &rewriter) const override {
880 if (xferOp.getVectorType().getRank() <= options.targetRank)
881 return failure();
882 if (isTensorOp(xferOp) && !options.lowerTensors)
883 return failure();
884 // Transfer ops that modify the element type are not supported atm.
885 if (xferOp.getVectorType().getElementType() !=
886 xferOp.getShapedType().getElementType())
887 return failure();
888
889 auto insertOp = getInsertOp(xferOp);
890 auto vec = getResultVector(xferOp, rewriter);
891 auto vecType = vec.getType().dyn_cast<VectorType>();
892 auto xferVecType = xferOp.getVectorType();
893 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
894 xferVecType.getElementType());
895 int64_t dimSize = xferVecType.getShape()[0];
896
897 // Generate fully unrolled loop of transfer ops.
898 Location loc = xferOp.getLoc();
899 for (int64_t i = 0; i < dimSize; ++i) {
900 Value iv = rewriter.create<ConstantIndexOp>(loc, i);
901
902 vec = generateInBoundsCheck(
903 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
904 /*inBoundsCase=*/
905 [&](OpBuilder &b, Location loc) {
906 // Indices for the new transfer op.
907 SmallVector<Value, 8> xferIndices;
908 getXferIndices(b, xferOp, iv, xferIndices);
909
910 // Indices for the new vector.insert op.
911 SmallVector<int64_t, 8> insertionIndices;
912 getInsertionIndices(xferOp, insertionIndices);
913 insertionIndices.push_back(i);
914
915 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
916 auto newXferOp = b.create<vector::TransferReadOp>(
917 loc, newXferVecType, xferOp.source(), xferIndices,
918 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
919 xferOp.padding(), Value(), inBoundsAttr);
920 maybeAssignMask(b, xferOp, newXferOp, i);
921 return b.create<vector::InsertOp>(loc, newXferOp, vec,
922 insertionIndices);
923 },
924 /*outOfBoundsCase=*/
925 [&](OpBuilder &b, Location loc) {
926 // Loop through original (unmodified) vector.
927 return vec;
928 });
929 }
930
931 if (insertOp) {
932 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
933 rewriter.replaceOp(insertOp, vec);
934 rewriter.eraseOp(xferOp);
935 } else {
936 rewriter.replaceOp(xferOp, vec);
937 }
938
939 return success();
940 }
941 };
942
943 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
944 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
945 /// memref buffer is allocated and the SCF loop is fully unrolled.
946 ///
947 /// ```
948 /// E.g.:
949 /// ```
950 /// vector.transfer_write %vec, %A[%a, %b, %c]
951 /// : vector<5x4xf32>, memref<?x?x?xf32>
952 /// ```
953 /// is rewritten to IR such as (simplified):
954 /// ```
955 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
956 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
957 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
958 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
959 /// ...
960 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
961 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
962 /// ```
963 ///
964 /// Note: As an optimization, if the vector of the original TransferWriteOp
965 /// was directly extracted from another vector via an ExtractOp `a`, extract
966 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
967 /// doing so, `a` may become dead, and the number of ExtractOps generated during
968 /// recursive application of this pattern will be minimal.
969 struct UnrollTransferWriteConversion
970 : public VectorToSCFPattern<TransferWriteOp> {
971 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
972
initialize__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion973 void initialize() {
974 // This pattern recursively unpacks one dimension at a time. The recursion
975 // bounded as the rank is strictly decreasing.
976 setHasBoundedRewriteRecursion();
977 }
978
979 /// Return the vector from which newly generated ExtracOps will extract.
getDataVector__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion980 Value getDataVector(TransferWriteOp xferOp) const {
981 if (auto extractOp = getExtractOp(xferOp))
982 return extractOp.vector();
983 return xferOp.vector();
984 }
985
986 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
getExtractOp__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion987 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
988 if (auto *op = xferOp.vector().getDefiningOp())
989 return dyn_cast<vector::ExtractOp>(op);
990 return vector::ExtractOp();
991 }
992
993 /// If the input of the given TransferWriteOp is an ExtractOp, return its
994 /// indices.
getExtractionIndices__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion995 void getExtractionIndices(TransferWriteOp xferOp,
996 SmallVector<int64_t, 8> &indices) const {
997 if (auto extractOp = getExtractOp(xferOp)) {
998 llvm::for_each(extractOp.position(), [&](Attribute attr) {
999 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1000 });
1001 }
1002 }
1003
1004 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1005 /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon32638c370111::lowering_n_d_unrolled::UnrollTransferWriteConversion1006 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1007 PatternRewriter &rewriter) const override {
1008 if (xferOp.getVectorType().getRank() <= options.targetRank)
1009 return failure();
1010 if (isTensorOp(xferOp) && !options.lowerTensors)
1011 return failure();
1012 // Transfer ops that modify the element type are not supported atm.
1013 if (xferOp.getVectorType().getElementType() !=
1014 xferOp.getShapedType().getElementType())
1015 return failure();
1016
1017 auto vec = getDataVector(xferOp);
1018 auto xferVecType = xferOp.getVectorType();
1019 int64_t dimSize = xferVecType.getShape()[0];
1020 auto source = xferOp.source(); // memref or tensor to be written to.
1021 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1022
1023 // Generate fully unrolled loop of transfer ops.
1024 Location loc = xferOp.getLoc();
1025 for (int64_t i = 0; i < dimSize; ++i) {
1026 Value iv = rewriter.create<ConstantIndexOp>(loc, i);
1027
1028 auto updatedSource = generateInBoundsCheck(
1029 rewriter, xferOp, iv, unpackedDim(xferOp),
1030 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1031 /*inBoundsCase=*/
1032 [&](OpBuilder &b, Location loc) {
1033 // Indices for the new transfer op.
1034 SmallVector<Value, 8> xferIndices;
1035 getXferIndices(b, xferOp, iv, xferIndices);
1036
1037 // Indices for the new vector.extract op.
1038 SmallVector<int64_t, 8> extractionIndices;
1039 getExtractionIndices(xferOp, extractionIndices);
1040 extractionIndices.push_back(i);
1041
1042 auto extracted =
1043 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1044 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1045 auto newXferOp = b.create<vector::TransferWriteOp>(
1046 loc, sourceType, extracted, source, xferIndices,
1047 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1048 inBoundsAttr);
1049
1050 maybeAssignMask(b, xferOp, newXferOp, i);
1051
1052 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1053 },
1054 /*outOfBoundsCase=*/
1055 [&](OpBuilder &b, Location loc) {
1056 return isTensorOp(xferOp) ? source : Value();
1057 });
1058
1059 if (isTensorOp(xferOp))
1060 source = updatedSource;
1061 }
1062
1063 if (isTensorOp(xferOp))
1064 rewriter.replaceOp(xferOp, source);
1065 else
1066 rewriter.eraseOp(xferOp);
1067
1068 return success();
1069 }
1070 };
1071
1072 } // namespace lowering_n_d_unrolled
1073
1074 namespace lowering_1_d {
1075
1076 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1077 /// part of TransferOp1dConversion. Return the memref dimension on which
1078 /// the transfer is operating. A return value of None indicates a broadcast.
1079 template <typename OpTy>
1080 static Optional<int64_t>
get1dMemrefIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & memrefIndices)1081 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1082 SmallVector<Value, 8> &memrefIndices) {
1083 auto indices = xferOp.indices();
1084 auto map = xferOp.permutation_map();
1085
1086 memrefIndices.append(indices.begin(), indices.end());
1087 assert(map.getNumResults() == 1 &&
1088 "Expected 1 permutation map result for 1D transfer");
1089 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1090 Location loc = xferOp.getLoc();
1091 auto dim = expr.getPosition();
1092 AffineExpr d0, d1;
1093 bindDims(xferOp.getContext(), d0, d1);
1094 Value offset = memrefIndices[dim];
1095 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1096 return dim;
1097 }
1098
1099 assert(xferOp.isBroadcastDim(0) &&
1100 "Expected AffineDimExpr or AffineConstantExpr");
1101 return None;
1102 }
1103
1104 /// Codegen strategy for TransferOp1dConversion, depending on the
1105 /// operation.
1106 template <typename OpTy>
1107 struct Strategy1d;
1108
1109 /// Codegen strategy for TransferReadOp.
1110 template <>
1111 struct Strategy1d<TransferReadOp> {
generateForLoopBody__anon32638c370111::lowering_1_d::Strategy1d1112 static void generateForLoopBody(OpBuilder &b, Location loc,
1113 TransferReadOp xferOp, Value iv,
1114 ValueRange loopState) {
1115 SmallVector<Value, 8> indices;
1116 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1117 Value ivI32 =
1118 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1119 auto vec = loopState[0];
1120
1121 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1122 // padding value).
1123 auto nextVec = generateInBoundsCheck(
1124 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1125 /*inBoundsCase=*/
1126 [&](OpBuilder &b, Location loc) {
1127 Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1128 return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
1129 },
1130 /*outOfBoundsCase=*/
1131 [&](OpBuilder & /*b*/, Location loc) { return vec; });
1132 b.create<scf::YieldOp>(loc, nextVec);
1133 }
1134
initialLoopState__anon32638c370111::lowering_1_d::Strategy1d1135 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1136 // Inititalize vector with padding value.
1137 Location loc = xferOp.getLoc();
1138 return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1139 }
1140 };
1141
1142 /// Codegen strategy for TransferWriteOp.
1143 template <>
1144 struct Strategy1d<TransferWriteOp> {
generateForLoopBody__anon32638c370111::lowering_1_d::Strategy1d1145 static void generateForLoopBody(OpBuilder &b, Location loc,
1146 TransferWriteOp xferOp, Value iv,
1147 ValueRange /*loopState*/) {
1148 SmallVector<Value, 8> indices;
1149 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1150 Value ivI32 =
1151 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1152
1153 // Nothing to do in case of out-of-bounds access.
1154 generateInBoundsCheck(
1155 b, xferOp, iv, dim,
1156 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1157 auto val =
1158 b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1159 b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1160 });
1161 b.create<scf::YieldOp>(loc);
1162 }
1163
initialLoopState__anon32638c370111::lowering_1_d::Strategy1d1164 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1165 return Value();
1166 }
1167 };
1168
1169 /// Return true if the last dimension of the MemRefType has unit stride.
isLastMemrefDimUnitStride(MemRefType type)1170 static bool isLastMemrefDimUnitStride(MemRefType type) {
1171 int64_t offset;
1172 SmallVector<int64_t, 4> strides;
1173 auto successStrides = getStridesAndOffset(type, strides, offset);
1174 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1175 }
1176
1177 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1178 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1179 /// vector load/stores due to non-unit strides or broadcasts:
1180 ///
1181 /// * Transfer dimension is not the last memref dimension
1182 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1183 /// * Memref has a layout map with non-unit stride on the last dimension
1184 ///
1185 /// This pattern generates IR as follows:
1186 ///
1187 /// 1. Generate a for loop iterating over each vector element.
1188 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1189 /// depending on OpTy.
1190 ///
1191 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1192 /// can be generated instead of TransferOp1dConversion. Add such a pattern
1193 /// to ConvertVectorToLLVM.
1194 ///
1195 /// E.g.:
1196 /// ```
1197 /// vector.transfer_write %vec, %A[%a, %b]
1198 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1199 /// : vector<9xf32>, memref<?x?xf32>
1200 /// ```
1201 /// Is rewritten to approximately the following pseudo-IR:
1202 /// ```
1203 /// for i = 0 to 9 {
1204 /// %t = vector.extractelement %vec[i] : vector<9xf32>
1205 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1206 /// }
1207 /// ```
1208 template <typename OpTy>
1209 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1210 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1211
matchAndRewrite__anon32638c370111::lowering_1_d::TransferOp1dConversion1212 LogicalResult matchAndRewrite(OpTy xferOp,
1213 PatternRewriter &rewriter) const override {
1214 auto map = xferOp.permutation_map();
1215 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1216
1217 if (!memRefType)
1218 return failure();
1219 if (xferOp.getVectorType().getRank() != 1)
1220 return failure();
1221 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1222 return failure(); // Handled by ConvertVectorToLLVM
1223
1224 // Loop bounds, step, state...
1225 Location loc = xferOp.getLoc();
1226 auto vecType = xferOp.getVectorType();
1227 auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
1228 auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
1229 auto step = rewriter.create<ConstantIndexOp>(loc, 1);
1230 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1231
1232 // Generate for loop.
1233 rewriter.replaceOpWithNewOp<scf::ForOp>(
1234 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1235 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1236 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1237 });
1238
1239 return success();
1240 }
1241 };
1242
1243 } // namespace lowering_1_d
1244 } // namespace
1245
1246 namespace mlir {
1247
populateVectorToSCFConversionPatterns(RewritePatternSet & patterns,const VectorTransferToSCFOptions & options)1248 void populateVectorToSCFConversionPatterns(
1249 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1250 if (options.unroll) {
1251 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1252 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1253 patterns.getContext(), options);
1254 } else {
1255 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1256 lowering_n_d::PrepareTransferWriteConversion,
1257 lowering_n_d::TransferOpConversion<TransferReadOp>,
1258 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1259 patterns.getContext(), options);
1260 }
1261
1262 if (options.targetRank == 1) {
1263 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1264 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1265 patterns.getContext(), options);
1266 }
1267 }
1268
1269 } // namespace mlir
1270
1271 namespace {
1272
1273 struct ConvertVectorToSCFPass
1274 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1275 ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass__anon32638c371611::ConvertVectorToSCFPass1276 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1277 this->fullUnroll = options.unroll;
1278 this->targetRank = options.targetRank;
1279 this->lowerPermutationMaps = options.lowerPermutationMaps;
1280 this->lowerTensors = options.lowerTensors;
1281 }
1282
runOnFunction__anon32638c371611::ConvertVectorToSCFPass1283 void runOnFunction() override {
1284 VectorTransferToSCFOptions options;
1285 options.unroll = fullUnroll;
1286 options.targetRank = targetRank;
1287 options.lowerPermutationMaps = lowerPermutationMaps;
1288 options.lowerTensors = lowerTensors;
1289
1290 // Lower permutation maps first.
1291 if (lowerPermutationMaps) {
1292 RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1293 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1294 lowerTransferPatterns);
1295 (void)applyPatternsAndFoldGreedily(getFunction(),
1296 std::move(lowerTransferPatterns));
1297 }
1298
1299 RewritePatternSet patterns(getFunction().getContext());
1300 populateVectorToSCFConversionPatterns(patterns, options);
1301 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1302 }
1303 };
1304
1305 } // namespace
1306
1307 std::unique_ptr<Pass>
createConvertVectorToSCFPass(const VectorTransferToSCFOptions & options)1308 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1309 return std::make_unique<ConvertVectorToSCFPass>(options);
1310 }
1311