1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
10 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_
11 
12 #include "mlir/Dialect/Vector/VectorOps.h"
13 #include "mlir/Dialect/Vector/VectorUtils.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 namespace mlir {
18 class MLIRContext;
19 class VectorTransferOpInterface;
20 class RewritePatternSet;
21 using OwningRewritePatternList = RewritePatternSet;
22 
23 namespace scf {
24 class IfOp;
25 } // namespace scf
26 
27 /// Collect a set of patterns to convert from the Vector dialect to itself.
28 /// Should be merged with populateVectorToSCFLoweringPattern.
29 void populateVectorToVectorConversionPatterns(
30     MLIRContext *context, RewritePatternSet &patterns,
31     ArrayRef<int64_t> coarseVectorShape = {},
32     ArrayRef<int64_t> fineVectorShape = {});
33 
34 namespace vector {
35 
36 /// Options that control the vector unrolling.
37 struct UnrollVectorOptions {
38   using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
39   /// Callback function that indicates whether vector unrolling should be
40   /// attempted on the operation.
41   FilterConstraintFnType filterConstraint = nullptr;
setFilterConstraintUnrollVectorOptions42   UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
43     filterConstraint = constraint;
44     return *this;
45   }
46 
47   using NativeShapeFnType =
48       std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
49   /// Function that returns the shape of the vector to unroll to for a given
50   /// operation. The unrolling is aborted if the function returns `llvm::None`.
51   NativeShapeFnType nativeShape = nullptr;
setNativeShapeFnUnrollVectorOptions52   UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
53     nativeShape = fn;
54     return *this;
55   }
56 
57   /// Set the native shape to use for unrolling.
setNativeShapeUnrollVectorOptions58   UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
59     SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
60     nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
61       return tsShape;
62     };
63     return *this;
64   }
65 };
66 
67 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
68 /// `options` structure controls which operations are unrolled and the target
69 /// shape.
70 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
71 ///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
72 ///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
73 ///   assumed the unrolling factors divide the vector sizes.
74 ///   2. ExtractStridedSlice are created to break-up the vector operands.
75 ///   3. the original op is cloned `numUnrolledInstances` times, once for each
76 ///   result.
77 ///   4. InsertStridedSlice are inserted to re-assemble the slices into the
78 ///   original vectore shape.
79 ///
80 /// Example:
81 ///
82 ///    opA(operand0, operand1)  // numUnrolledInstances = 3
83 ///
84 ///            operand0                   operand1
85 ///               |                          |
86 ///             fork                       fork
87 ///        <----------gather all fork ops --------->
88 ///              /|\                        /|\
89 ///          f00 f01 f02                f10 f11 f12
90 ///        <---------- clone op 3 times --------->
91 ///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
92 ///                 \            |            /
93 ///      <-------------------- join ------------------------->
94 ///
95 /// Other local patterns then kick in iteratively (including DCE) and compose
96 /// to combine the ExtractStridedSlice/InsertStridedSlice.
97 void populateVectorUnrollPatterns(RewritePatternSet &patterns,
98                                   const UnrollVectorOptions &options);
99 
100 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
101 /// masking) fastpath and a slowpath.
102 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
103 /// newly created conditional upon function return.
104 /// To accomodate for the fact that the original vector.transfer indexing may be
105 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
106 /// scf.if op returns a view and values of type index.
107 /// At this time, only vector.transfer_read case is implemented.
108 ///
109 /// Example (a 2-D vector.transfer_read):
110 /// ```
111 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
112 /// ```
113 /// is transformed into:
114 /// ```
115 ///    %1:3 = scf.if (%inBounds) {
116 ///      // fastpath, direct cast
117 ///      memref.cast %A: memref<A...> to compatibleMemRefType
118 ///      scf.yield %view : compatibleMemRefType, index, index
119 ///    } else {
120 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
121 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
122 ///      scf.yield %4 : compatibleMemRefType, index, index
123 //     }
124 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
125 /// ```
126 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
127 ///
128 /// Preconditions:
129 ///  1. `xferOp.permutation_map()` must be a minor identity map
130 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
131 ///  must be equal. This will be relaxed in the future but requires
132 ///  rank-reducing subviews.
133 LogicalResult
134 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
135 LogicalResult splitFullAndPartialTransfer(
136     OpBuilder &b, VectorTransferOpInterface xferOp,
137     VectorTransformsOptions options = VectorTransformsOptions(),
138     scf::IfOp *ifOp = nullptr);
139 
140 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
141 /// may take an extra filter to perform selection at a finer granularity.
142 struct VectorTransferFullPartialRewriter : public RewritePattern {
143   using FilterConstraintType =
144       std::function<LogicalResult(VectorTransferOpInterface op)>;
145 
146   explicit VectorTransferFullPartialRewriter(
147       MLIRContext *context,
148       VectorTransformsOptions options = VectorTransformsOptions(),
149       FilterConstraintType filter =
150           [](VectorTransferOpInterface op) { return success(); },
151       PatternBenefit benefit = 1)
RewritePatternVectorTransferFullPartialRewriter152       : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
153         filter(filter) {}
154 
155   /// Performs the rewrite.
156   LogicalResult matchAndRewrite(Operation *op,
157                                 PatternRewriter &rewriter) const override;
158 
159 private:
160   VectorTransformsOptions options;
161   FilterConstraintType filter;
162 };
163 
164 struct DistributeOps {
165   ExtractMapOp extract;
166   InsertMapOp insert;
167 };
168 
169 /// Distribute a N-D vector pointwise operation over a range of given ids taking
170 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
171 /// SPMD id). This transformation only inserts
172 /// vector.extract_map/vector.insert_map. It is meant to be used with
173 /// canonicalizations pattern to propagate and fold the vector
174 /// insert_map/extract_map operations.
175 /// Transforms:
176 //  %v = addf %a, %b : vector<32xf32>
177 /// to:
178 /// %v = addf %a, %b : vector<32xf32>
179 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
180 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
181 Optional<DistributeOps>
182 distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
183                            ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
184                            const AffineMap &map);
185 
186 /// Implements transfer op write to read forwarding and dead transfer write
187 /// optimizations.
188 void transferOpflowOpt(FuncOp func);
189 
190 } // namespace vector
191 
192 //===----------------------------------------------------------------------===//
193 // Finer-grained patterns exposed for more control over individual lowerings.
194 //===----------------------------------------------------------------------===//
195 
196 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
197 /// semantics to:
198 /// ```
199 ///    %flattened_a = vector.shape_cast %a
200 ///    %flattened_b = vector.shape_cast %b
201 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
202 ///    %d = vector.shape_cast %%flattened_d
203 ///    %e = add %c, %d
204 /// ```
205 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
206 //
207 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
208 /// the vector.contract op is a row-major matrix multiply.
209 class ContractionOpToMatmulOpLowering
210     : public OpRewritePattern<vector::ContractionOp> {
211 public:
212   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
213   using FilterConstraintType =
214       std::function<LogicalResult(vector::ContractionOp op)>;
215 
defaultFilter(vector::ContractionOp op)216   static LogicalResult defaultFilter(vector::ContractionOp op) {
217     return success();
218   }
219 
220   ContractionOpToMatmulOpLowering(
221       vector::VectorTransformsOptions vectorTransformsOptions,
222       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
223       : OpRewritePattern<vector::ContractionOp>(context),
224         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
225 
226   LogicalResult matchAndRewrite(vector::ContractionOp op,
227                                 PatternRewriter &rewriter) const override;
228 
229 private:
230   /// Options to control the vector patterns.
231   vector::VectorTransformsOptions vectorTransformsOptions;
232   FilterConstraintType filter;
233 };
234 
235 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
236 /// semantics to a reduction_size-unrolled sequence:
237 /// ```
238 ///    %at = vector.transpose %a, [1, 0]
239 ///    %bRow0 = vector.extract %b[0]
240 ///    %atRow0 = vector.extract %at[0]
241 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
242 ///    ...
243 ///    %bRowK = vector.extract %b[K]
244 ///    %atRowK = vector.extract %at[K]
245 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
246 /// ```
247 ///
248 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
249 /// the vector.contract op is a row-major matrix multiply.
250 class ContractionOpToOuterProductOpLowering
251     : public OpRewritePattern<vector::ContractionOp> {
252 public:
253   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
254   using FilterConstraintType =
255       std::function<LogicalResult(vector::ContractionOp op)>;
256 
defaultFilter(vector::ContractionOp op)257   static LogicalResult defaultFilter(vector::ContractionOp op) {
258     return success();
259   }
260 
261   ContractionOpToOuterProductOpLowering(
262       vector::VectorTransformsOptions vectorTransformsOptions,
263       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
264       : OpRewritePattern<vector::ContractionOp>(context),
265         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
266 
267   LogicalResult matchAndRewrite(vector::ContractionOp op,
268                                 PatternRewriter &rewriter) const override;
269 
270 private:
271   /// Options to control the vector patterns.
272   vector::VectorTransformsOptions vectorTransformsOptions;
273   FilterConstraintType filter;
274 };
275 
276 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
277 /// semantics to an output-size-unrolled sequence:
278 /// ```
279 ///    %out = constant ... : vector<MxNxelt_type>
280 ///    %bt = vector.transpose %b, [1, 0]
281 ///    %aRow0 = vector.extract %a[0]
282 ///    %btRow0 = vector.extract %bt[0]
283 ///    %c00 = vector.reduce %atRow0, %bRow0
284 ///    %out00 = vector.insert %c00, %out[0, 0]
285 ///    ...
286 ///    %aRowLast = vector.extract %at[M-1]
287 ///    %btRowLast = vector.extract %b[N-1]
288 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
289 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
290 /// ```
291 ///
292 /// This only kicks in when VectorTransformsOptions is set to Dot and
293 /// the vector.contract op is a row-major matmul or matvec.
294 class ContractionOpToDotLowering
295     : public OpRewritePattern<vector::ContractionOp> {
296 public:
297   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
298   using FilterConstraintType =
299       std::function<LogicalResult(vector::ContractionOp op)>;
300 
defaultFilter(vector::ContractionOp op)301   static LogicalResult defaultFilter(vector::ContractionOp op) {
302     return success();
303   }
304 
305   ContractionOpToDotLowering(
306       vector::VectorTransformsOptions vectorTransformsOptions,
307       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
308       : OpRewritePattern<vector::ContractionOp>(context),
309         vectorTransformsOptions(vectorTransformsOptions),
310         filter(defaultFilter) {}
311 
312   LogicalResult matchAndRewrite(vector::ContractionOp op,
313                                 PatternRewriter &rewriter) const override;
314 
315 private:
316   /// Options to control the vector patterns.
317   vector::VectorTransformsOptions vectorTransformsOptions;
318   FilterConstraintType filter;
319 };
320 
321 /// Progressive lowering of ContractionOp.
322 ///
323 /// One:
324 ///   %x = vector.contract with at least one free/batch dimension
325 /// is replaced by:
326 ///   %a = vector.contract with one less free/batch dimension
327 ///   %b = vector.contract with one less free/batch dimension
328 ///   ..
329 ///   %x = combine %a %b ..
330 /// until a pure contraction is reached (no free/batch dimensions),
331 /// which is replaced by a dot-product.
332 ///
333 /// This only kicks in when either VectorTransformsOptions is set
334 /// to Dot or when other contraction patterns fail.
335 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
336 public:
337   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
338   using FilterConstraintType =
339       std::function<LogicalResult(vector::ContractionOp op)>;
340 
defaultFilter(vector::ContractionOp op)341   static LogicalResult defaultFilter(vector::ContractionOp op) {
342     return success();
343   }
344 
345   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
346                         MLIRContext *context,
347                         FilterConstraintType constraint = defaultFilter)
348       : OpRewritePattern<vector::ContractionOp>(context),
349         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
350 
351   LogicalResult matchAndRewrite(vector::ContractionOp op,
352                                 PatternRewriter &rewriter) const override;
353 
354 private:
355   /// Options to control the vector patterns.
356   vector::VectorTransformsOptions vectorTransformsOptions;
357   FilterConstraintType filter;
358   // Lower one parallel dimension.
359   Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
360                       int64_t rhsIndex, PatternRewriter &rewriter) const;
361   // Lower one reduction dimension.
362   Value lowerReduction(vector::ContractionOp op,
363                        PatternRewriter &rewriter) const;
364 };
365 
366 } // namespace mlir
367 
368 #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_
369