1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_ 10 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ 11 12 #include "mlir/Dialect/Vector/VectorOps.h" 13 #include "mlir/Dialect/Vector/VectorUtils.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 namespace mlir { 18 class MLIRContext; 19 class VectorTransferOpInterface; 20 class RewritePatternSet; 21 using OwningRewritePatternList = RewritePatternSet; 22 23 namespace scf { 24 class IfOp; 25 } // namespace scf 26 27 /// Collect a set of patterns to convert from the Vector dialect to itself. 28 /// Should be merged with populateVectorToSCFLoweringPattern. 29 void populateVectorToVectorConversionPatterns( 30 MLIRContext *context, RewritePatternSet &patterns, 31 ArrayRef<int64_t> coarseVectorShape = {}, 32 ArrayRef<int64_t> fineVectorShape = {}); 33 34 namespace vector { 35 36 /// Options that control the vector unrolling. 37 struct UnrollVectorOptions { 38 using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>; 39 /// Callback function that indicates whether vector unrolling should be 40 /// attempted on the operation. 41 FilterConstraintFnType filterConstraint = nullptr; setFilterConstraintUnrollVectorOptions42 UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { 43 filterConstraint = constraint; 44 return *this; 45 } 46 47 using NativeShapeFnType = 48 std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>; 49 /// Function that returns the shape of the vector to unroll to for a given 50 /// operation. The unrolling is aborted if the function returns `llvm::None`. 51 NativeShapeFnType nativeShape = nullptr; setNativeShapeFnUnrollVectorOptions52 UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { 53 nativeShape = fn; 54 return *this; 55 } 56 57 /// Set the native shape to use for unrolling. setNativeShapeUnrollVectorOptions58 UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) { 59 SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end()); 60 nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> { 61 return tsShape; 62 }; 63 return *this; 64 } 65 }; 66 67 /// Collect a set of pattern to unroll vector operations to a smaller shapes. 68 /// `options` structure controls which operations are unrolled and the target 69 /// shape. 70 /// `op` is unrolled to the `targetShape` as follows, for each of its operands: 71 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances 72 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is 73 /// assumed the unrolling factors divide the vector sizes. 74 /// 2. ExtractStridedSlice are created to break-up the vector operands. 75 /// 3. the original op is cloned `numUnrolledInstances` times, once for each 76 /// result. 77 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the 78 /// original vectore shape. 79 /// 80 /// Example: 81 /// 82 /// opA(operand0, operand1) // numUnrolledInstances = 3 83 /// 84 /// operand0 operand1 85 /// | | 86 /// fork fork 87 /// <----------gather all fork ops ---------> 88 /// /|\ /|\ 89 /// f00 f01 f02 f10 f11 f12 90 /// <---------- clone op 3 times ---------> 91 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) 92 /// \ | / 93 /// <-------------------- join -------------------------> 94 /// 95 /// Other local patterns then kick in iteratively (including DCE) and compose 96 /// to combine the ExtractStridedSlice/InsertStridedSlice. 97 void populateVectorUnrollPatterns(RewritePatternSet &patterns, 98 const UnrollVectorOptions &options); 99 100 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 101 /// masking) fastpath and a slowpath. 102 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 103 /// newly created conditional upon function return. 104 /// To accomodate for the fact that the original vector.transfer indexing may be 105 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 106 /// scf.if op returns a view and values of type index. 107 /// At this time, only vector.transfer_read case is implemented. 108 /// 109 /// Example (a 2-D vector.transfer_read): 110 /// ``` 111 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 112 /// ``` 113 /// is transformed into: 114 /// ``` 115 /// %1:3 = scf.if (%inBounds) { 116 /// // fastpath, direct cast 117 /// memref.cast %A: memref<A...> to compatibleMemRefType 118 /// scf.yield %view : compatibleMemRefType, index, index 119 /// } else { 120 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 121 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 122 /// scf.yield %4 : compatibleMemRefType, index, index 123 // } 124 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 125 /// ``` 126 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 127 /// 128 /// Preconditions: 129 /// 1. `xferOp.permutation_map()` must be a minor identity map 130 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 131 /// must be equal. This will be relaxed in the future but requires 132 /// rank-reducing subviews. 133 LogicalResult 134 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp); 135 LogicalResult splitFullAndPartialTransfer( 136 OpBuilder &b, VectorTransferOpInterface xferOp, 137 VectorTransformsOptions options = VectorTransformsOptions(), 138 scf::IfOp *ifOp = nullptr); 139 140 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern 141 /// may take an extra filter to perform selection at a finer granularity. 142 struct VectorTransferFullPartialRewriter : public RewritePattern { 143 using FilterConstraintType = 144 std::function<LogicalResult(VectorTransferOpInterface op)>; 145 146 explicit VectorTransferFullPartialRewriter( 147 MLIRContext *context, 148 VectorTransformsOptions options = VectorTransformsOptions(), 149 FilterConstraintType filter = 150 [](VectorTransferOpInterface op) { return success(); }, 151 PatternBenefit benefit = 1) RewritePatternVectorTransferFullPartialRewriter152 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), 153 filter(filter) {} 154 155 /// Performs the rewrite. 156 LogicalResult matchAndRewrite(Operation *op, 157 PatternRewriter &rewriter) const override; 158 159 private: 160 VectorTransformsOptions options; 161 FilterConstraintType filter; 162 }; 163 164 struct DistributeOps { 165 ExtractMapOp extract; 166 InsertMapOp insert; 167 }; 168 169 /// Distribute a N-D vector pointwise operation over a range of given ids taking 170 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or 171 /// SPMD id). This transformation only inserts 172 /// vector.extract_map/vector.insert_map. It is meant to be used with 173 /// canonicalizations pattern to propagate and fold the vector 174 /// insert_map/extract_map operations. 175 /// Transforms: 176 // %v = addf %a, %b : vector<32xf32> 177 /// to: 178 /// %v = addf %a, %b : vector<32xf32> 179 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> 180 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> 181 Optional<DistributeOps> 182 distributPointwiseVectorOp(OpBuilder &builder, Operation *op, 183 ArrayRef<Value> id, ArrayRef<int64_t> multiplicity, 184 const AffineMap &map); 185 186 /// Implements transfer op write to read forwarding and dead transfer write 187 /// optimizations. 188 void transferOpflowOpt(FuncOp func); 189 190 } // namespace vector 191 192 //===----------------------------------------------------------------------===// 193 // Finer-grained patterns exposed for more control over individual lowerings. 194 //===----------------------------------------------------------------------===// 195 196 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 197 /// semantics to: 198 /// ``` 199 /// %flattened_a = vector.shape_cast %a 200 /// %flattened_b = vector.shape_cast %b 201 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 202 /// %d = vector.shape_cast %%flattened_d 203 /// %e = add %c, %d 204 /// ``` 205 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 206 // 207 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 208 /// the vector.contract op is a row-major matrix multiply. 209 class ContractionOpToMatmulOpLowering 210 : public OpRewritePattern<vector::ContractionOp> { 211 public: 212 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 213 using FilterConstraintType = 214 std::function<LogicalResult(vector::ContractionOp op)>; 215 defaultFilter(vector::ContractionOp op)216 static LogicalResult defaultFilter(vector::ContractionOp op) { 217 return success(); 218 } 219 220 ContractionOpToMatmulOpLowering( 221 vector::VectorTransformsOptions vectorTransformsOptions, 222 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 223 : OpRewritePattern<vector::ContractionOp>(context), 224 vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} 225 226 LogicalResult matchAndRewrite(vector::ContractionOp op, 227 PatternRewriter &rewriter) const override; 228 229 private: 230 /// Options to control the vector patterns. 231 vector::VectorTransformsOptions vectorTransformsOptions; 232 FilterConstraintType filter; 233 }; 234 235 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 236 /// semantics to a reduction_size-unrolled sequence: 237 /// ``` 238 /// %at = vector.transpose %a, [1, 0] 239 /// %bRow0 = vector.extract %b[0] 240 /// %atRow0 = vector.extract %at[0] 241 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 242 /// ... 243 /// %bRowK = vector.extract %b[K] 244 /// %atRowK = vector.extract %at[K] 245 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 246 /// ``` 247 /// 248 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 249 /// the vector.contract op is a row-major matrix multiply. 250 class ContractionOpToOuterProductOpLowering 251 : public OpRewritePattern<vector::ContractionOp> { 252 public: 253 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 254 using FilterConstraintType = 255 std::function<LogicalResult(vector::ContractionOp op)>; 256 defaultFilter(vector::ContractionOp op)257 static LogicalResult defaultFilter(vector::ContractionOp op) { 258 return success(); 259 } 260 261 ContractionOpToOuterProductOpLowering( 262 vector::VectorTransformsOptions vectorTransformsOptions, 263 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 264 : OpRewritePattern<vector::ContractionOp>(context), 265 vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} 266 267 LogicalResult matchAndRewrite(vector::ContractionOp op, 268 PatternRewriter &rewriter) const override; 269 270 private: 271 /// Options to control the vector patterns. 272 vector::VectorTransformsOptions vectorTransformsOptions; 273 FilterConstraintType filter; 274 }; 275 276 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 277 /// semantics to an output-size-unrolled sequence: 278 /// ``` 279 /// %out = constant ... : vector<MxNxelt_type> 280 /// %bt = vector.transpose %b, [1, 0] 281 /// %aRow0 = vector.extract %a[0] 282 /// %btRow0 = vector.extract %bt[0] 283 /// %c00 = vector.reduce %atRow0, %bRow0 284 /// %out00 = vector.insert %c00, %out[0, 0] 285 /// ... 286 /// %aRowLast = vector.extract %at[M-1] 287 /// %btRowLast = vector.extract %b[N-1] 288 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 289 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 290 /// ``` 291 /// 292 /// This only kicks in when VectorTransformsOptions is set to Dot and 293 /// the vector.contract op is a row-major matmul or matvec. 294 class ContractionOpToDotLowering 295 : public OpRewritePattern<vector::ContractionOp> { 296 public: 297 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 298 using FilterConstraintType = 299 std::function<LogicalResult(vector::ContractionOp op)>; 300 defaultFilter(vector::ContractionOp op)301 static LogicalResult defaultFilter(vector::ContractionOp op) { 302 return success(); 303 } 304 305 ContractionOpToDotLowering( 306 vector::VectorTransformsOptions vectorTransformsOptions, 307 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 308 : OpRewritePattern<vector::ContractionOp>(context), 309 vectorTransformsOptions(vectorTransformsOptions), 310 filter(defaultFilter) {} 311 312 LogicalResult matchAndRewrite(vector::ContractionOp op, 313 PatternRewriter &rewriter) const override; 314 315 private: 316 /// Options to control the vector patterns. 317 vector::VectorTransformsOptions vectorTransformsOptions; 318 FilterConstraintType filter; 319 }; 320 321 /// Progressive lowering of ContractionOp. 322 /// 323 /// One: 324 /// %x = vector.contract with at least one free/batch dimension 325 /// is replaced by: 326 /// %a = vector.contract with one less free/batch dimension 327 /// %b = vector.contract with one less free/batch dimension 328 /// .. 329 /// %x = combine %a %b .. 330 /// until a pure contraction is reached (no free/batch dimensions), 331 /// which is replaced by a dot-product. 332 /// 333 /// This only kicks in when either VectorTransformsOptions is set 334 /// to Dot or when other contraction patterns fail. 335 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { 336 public: 337 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 338 using FilterConstraintType = 339 std::function<LogicalResult(vector::ContractionOp op)>; 340 defaultFilter(vector::ContractionOp op)341 static LogicalResult defaultFilter(vector::ContractionOp op) { 342 return success(); 343 } 344 345 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, 346 MLIRContext *context, 347 FilterConstraintType constraint = defaultFilter) 348 : OpRewritePattern<vector::ContractionOp>(context), 349 vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} 350 351 LogicalResult matchAndRewrite(vector::ContractionOp op, 352 PatternRewriter &rewriter) const override; 353 354 private: 355 /// Options to control the vector patterns. 356 vector::VectorTransformsOptions vectorTransformsOptions; 357 FilterConstraintType filter; 358 // Lower one parallel dimension. 359 Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, 360 int64_t rhsIndex, PatternRewriter &rewriter) const; 361 // Lower one reduction dimension. 362 Value lowerReduction(vector::ContractionOp op, 363 PatternRewriter &rewriter) const; 364 }; 365 366 } // namespace mlir 367 368 #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_ 369