1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Vector/VectorOps.h"
14 #include "mlir/Dialect/Vector/VectorTransforms.h"
15 #include "mlir/Dialect/Vector/VectorUtils.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23
24 #define DEBUG_TYPE "vector-multi-reduction"
25
26 using namespace mlir;
27
28 /// This file implements the following transformations as composable atomic
29 /// patterns.
30
31 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
32 /// by using vector.transpose
33 class InnerOuterDimReductionConversion
34 : public OpRewritePattern<vector::MultiDimReductionOp> {
35 public:
36 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
37
InnerOuterDimReductionConversion(MLIRContext * context,bool useInnerDimsForReduction)38 explicit InnerOuterDimReductionConversion(MLIRContext *context,
39 bool useInnerDimsForReduction)
40 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
41 useInnerDimsForReduction(useInnerDimsForReduction) {}
42
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const43 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
44 PatternRewriter &rewriter) const override {
45 auto src = multiReductionOp.source();
46 auto loc = multiReductionOp.getLoc();
47 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
48
49 // Separate reduction and parallel dims
50 auto reductionDimsRange =
51 multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
52 auto reductionDims = llvm::to_vector<4>(llvm::map_range(
53 reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
54 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
55 reductionDims.end());
56 int64_t reductionSize = reductionDims.size();
57 SmallVector<int64_t, 4> parallelDims;
58 for (int64_t i = 0; i < srcRank; ++i)
59 if (!reductionDimsSet.contains(i))
60 parallelDims.push_back(i);
61
62 // Add transpose only if inner-most/outer-most dimensions are not parallel
63 if (useInnerDimsForReduction &&
64 (parallelDims ==
65 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
66 return failure();
67
68 if (!useInnerDimsForReduction &&
69 (parallelDims !=
70 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
71 return failure();
72
73 SmallVector<int64_t, 4> indices;
74 if (useInnerDimsForReduction) {
75 indices.append(parallelDims.begin(), parallelDims.end());
76 indices.append(reductionDims.begin(), reductionDims.end());
77 } else {
78 indices.append(reductionDims.begin(), reductionDims.end());
79 indices.append(parallelDims.begin(), parallelDims.end());
80 }
81 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
82 SmallVector<bool> reductionMask(srcRank, false);
83 for (int i = 0; i < reductionSize; ++i) {
84 if (useInnerDimsForReduction)
85 reductionMask[srcRank - i - 1] = true;
86 else
87 reductionMask[i] = true;
88 }
89 rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90 multiReductionOp, transposeOp.result(), reductionMask,
91 multiReductionOp.kind());
92 return success();
93 }
94
95 private:
96 const bool useInnerDimsForReduction;
97 };
98
99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
100 /// dimensions are either inner most or outer most.
101 class ReduceMultiDimReductionRank
102 : public OpRewritePattern<vector::MultiDimReductionOp> {
103 public:
104 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
105
ReduceMultiDimReductionRank(MLIRContext * context,bool useInnerDimsForReduction)106 explicit ReduceMultiDimReductionRank(MLIRContext *context,
107 bool useInnerDimsForReduction)
108 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
109 useInnerDimsForReduction(useInnerDimsForReduction) {}
110
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const111 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
112 PatternRewriter &rewriter) const override {
113 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
114 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
115 auto loc = multiReductionOp.getLoc();
116
117 // If rank less than 2, nothing to do.
118 if (srcRank < 2)
119 return failure();
120
121 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
122 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
123 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
124 return failure();
125
126 // 1. Separate reduction and parallel dims.
127 SmallVector<int64_t, 4> parallelDims, parallelShapes;
128 SmallVector<int64_t, 4> reductionDims, reductionShapes;
129 for (auto it : llvm::enumerate(reductionMask)) {
130 int64_t i = it.index();
131 bool isReduction = it.value();
132 if (isReduction) {
133 reductionDims.push_back(i);
134 reductionShapes.push_back(srcShape[i]);
135 } else {
136 parallelDims.push_back(i);
137 parallelShapes.push_back(srcShape[i]);
138 }
139 }
140
141 // 2. Compute flattened parallel and reduction sizes.
142 int flattenedParallelDim = 0;
143 int flattenedReductionDim = 0;
144 if (parallelShapes.size() > 0) {
145 flattenedParallelDim = 1;
146 for (auto d : parallelShapes)
147 flattenedParallelDim *= d;
148 }
149 if (reductionShapes.size() > 0) {
150 flattenedReductionDim = 1;
151 for (auto d : reductionShapes)
152 flattenedReductionDim *= d;
153 }
154 // We must at least have some parallel or some reduction.
155 assert((flattenedParallelDim || flattenedReductionDim) &&
156 "expected at least one parallel or reduction dim");
157
158 // 3. Fail if reduction/parallel dims are not contiguous.
159 // Check parallelDims are exactly [0 .. size).
160 int64_t counter = 0;
161 if (useInnerDimsForReduction &&
162 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
163 return failure();
164 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
165 counter = reductionDims.size();
166 if (!useInnerDimsForReduction &&
167 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
168 return failure();
169
170 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
171 // a single parallel (resp. reduction) dim.
172 SmallVector<bool, 2> mask;
173 SmallVector<int64_t, 2> vectorShape;
174 if (flattenedParallelDim) {
175 mask.push_back(false);
176 vectorShape.push_back(flattenedParallelDim);
177 }
178 if (flattenedReductionDim) {
179 mask.push_back(true);
180 vectorShape.push_back(flattenedReductionDim);
181 }
182 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
183 std::swap(mask.front(), mask.back());
184 std::swap(vectorShape.front(), vectorShape.back());
185 }
186 auto castedType = VectorType::get(
187 vectorShape, multiReductionOp.getSourceVectorType().getElementType());
188 Value cast = rewriter.create<vector::ShapeCastOp>(
189 loc, castedType, multiReductionOp.source());
190
191 // 5. Creates the flattened form of vector.multi_reduction with inner/outer
192 // most dim as reduction.
193 auto newOp = rewriter.create<vector::MultiDimReductionOp>(
194 loc, cast, mask, multiReductionOp.kind());
195
196 // 6. If there are no parallel shapes, the result is a scalar.
197 // TODO: support 0-d vectors when available.
198 if (parallelShapes.empty()) {
199 rewriter.replaceOp(multiReductionOp, newOp.dest());
200 return success();
201 }
202
203 // 7. Creates shape cast for the output n-D -> 2-D
204 VectorType outputCastedType = VectorType::get(
205 parallelShapes,
206 multiReductionOp.getSourceVectorType().getElementType());
207 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
208 multiReductionOp, outputCastedType, newOp.dest());
209 return success();
210 }
211
212 private:
213 const bool useInnerDimsForReduction;
214 };
215
216 /// Unrolls vector.multi_reduction with outermost reductions
217 /// and combines results
218 struct TwoDimMultiReductionToElementWise
219 : public OpRewritePattern<vector::MultiDimReductionOp> {
220 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
221
matchAndRewriteTwoDimMultiReductionToElementWise222 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
223 PatternRewriter &rewriter) const override {
224 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
225 // Rank-2 ["parallel", "reduce"] or bail.
226 if (srcRank != 2)
227 return failure();
228
229 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
230 return failure();
231
232 auto loc = multiReductionOp.getLoc();
233 ArrayRef<int64_t> srcShape =
234 multiReductionOp.getSourceVectorType().getShape();
235
236 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
237 if (!elementType.isIntOrIndexOrFloat())
238 return failure();
239
240 Value condition;
241 Value result =
242 rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
243 .getResult();
244 for (int64_t i = 1; i < srcShape[0]; i++) {
245 auto operand =
246 rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
247 switch (multiReductionOp.kind()) {
248 case vector::CombiningKind::ADD:
249 if (elementType.isIntOrIndex())
250 result = rewriter.create<AddIOp>(loc, operand, result);
251 else
252 result = rewriter.create<AddFOp>(loc, operand, result);
253 break;
254 case vector::CombiningKind::MUL:
255 if (elementType.isIntOrIndex())
256 result = rewriter.create<MulIOp>(loc, operand, result);
257 else
258 result = rewriter.create<MulFOp>(loc, operand, result);
259 break;
260 case vector::CombiningKind::MINUI:
261 result = rewriter.create<MinUIOp>(loc, operand, result);
262 break;
263 case vector::CombiningKind::MINSI:
264 result = rewriter.create<MinSIOp>(loc, operand, result);
265 break;
266 case vector::CombiningKind::MINF:
267 result = rewriter.create<MinFOp>(loc, operand, result);
268 break;
269 case vector::CombiningKind::MAXUI:
270 result = rewriter.create<MaxUIOp>(loc, operand, result);
271 break;
272 case vector::CombiningKind::MAXSI:
273 result = rewriter.create<MaxSIOp>(loc, operand, result);
274 break;
275 case vector::CombiningKind::MAXF:
276 result = rewriter.create<MaxFOp>(loc, operand, result);
277 break;
278 case vector::CombiningKind::AND:
279 result = rewriter.create<AndOp>(loc, operand, result);
280 break;
281 case vector::CombiningKind::OR:
282 result = rewriter.create<OrOp>(loc, operand, result);
283 break;
284 case vector::CombiningKind::XOR:
285 result = rewriter.create<XOrOp>(loc, operand, result);
286 break;
287 }
288 }
289
290 rewriter.replaceOp(multiReductionOp, result);
291 return success();
292 }
293 };
294
295 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
296 /// a sequence of vector.reduction ops.
297 struct TwoDimMultiReductionToReduction
298 : public OpRewritePattern<vector::MultiDimReductionOp> {
299 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
300
matchAndRewriteTwoDimMultiReductionToReduction301 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
302 PatternRewriter &rewriter) const override {
303 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
304 if (srcRank != 2)
305 return failure();
306
307 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
308 return failure();
309
310 auto loc = multiReductionOp.getLoc();
311 Value result = rewriter.create<ConstantOp>(
312 loc, multiReductionOp.getDestType(),
313 rewriter.getZeroAttr(multiReductionOp.getDestType()));
314 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
315
316 // TODO: Add vector::CombiningKind attribute instead of string to
317 // vector.reduction.
318 auto getKindStr = [](vector::CombiningKind kind) {
319 switch (kind) {
320 case vector::CombiningKind::ADD:
321 return "add";
322 case vector::CombiningKind::MUL:
323 return "mul";
324 case vector::CombiningKind::MINUI:
325 return "minui";
326 case vector::CombiningKind::MINSI:
327 return "minsi";
328 case vector::CombiningKind::MINF:
329 return "minf";
330 case vector::CombiningKind::MAXUI:
331 return "maxui";
332 case vector::CombiningKind::MAXSI:
333 return "maxsi";
334 case vector::CombiningKind::MAXF:
335 return "maxf";
336 case vector::CombiningKind::AND:
337 return "and";
338 case vector::CombiningKind::OR:
339 return "or";
340 case vector::CombiningKind::XOR:
341 return "xor";
342 }
343 llvm_unreachable("unknown combining kind");
344 };
345
346 for (int i = 0; i < outerDim; ++i) {
347 auto v = rewriter.create<vector::ExtractOp>(
348 loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
349 auto reducedValue = rewriter.create<vector::ReductionOp>(
350 loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
351 rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
352 ValueRange{});
353 result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
354 result, i);
355 }
356 rewriter.replaceOp(multiReductionOp, result);
357 return success();
358 }
359 };
360
361 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
362 /// form with both a single parallel and reduction dimension.
363 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
364 /// The case with a single parallel dimension is a noop and folds away
365 /// separately.
366 struct OneDimMultiReductionToTwoDim
367 : public OpRewritePattern<vector::MultiDimReductionOp> {
368 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
369
matchAndRewriteOneDimMultiReductionToTwoDim370 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
371 PatternRewriter &rewriter) const override {
372 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
373 // Rank-1 or bail.
374 if (srcRank != 1)
375 return failure();
376
377 auto loc = multiReductionOp.getLoc();
378 auto srcVectorType = multiReductionOp.getSourceVectorType();
379 auto srcShape = srcVectorType.getShape();
380 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
381 srcVectorType.getElementType());
382 assert(!multiReductionOp.getDestType().isa<VectorType>() &&
383 "multi_reduction with a single dimension expects a scalar result");
384
385 // If the unique dim is reduced and we insert a parallel in front, we need a
386 // {false, true} mask.
387 SmallVector<bool, 2> mask{false, true};
388
389 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
390 Value cast = rewriter.create<vector::ShapeCastOp>(
391 loc, castedType, multiReductionOp.source());
392 Value reduced = rewriter.create<vector::MultiDimReductionOp>(
393 loc, cast, mask, multiReductionOp.kind());
394 rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
395 ArrayRef<int64_t>{0});
396 return success();
397 }
398 };
399
populateVectorMultiReductionLoweringPatterns(RewritePatternSet & patterns,bool useInnerDimsForReduction)400 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
401 RewritePatternSet &patterns, bool useInnerDimsForReduction) {
402 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
403 OneDimMultiReductionToTwoDim>(patterns.getContext(),
404 useInnerDimsForReduction);
405 if (useInnerDimsForReduction)
406 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
407 else
408 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
409 }
410