1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/VectorOps.h"
14 #include "mlir/Dialect/Vector/VectorTransforms.h"
15 #include "mlir/Dialect/Vector/VectorUtils.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 
24 #define DEBUG_TYPE "vector-multi-reduction"
25 
26 using namespace mlir;
27 
28 /// This file implements the following transformations as composable atomic
29 /// patterns.
30 
31 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
32 /// by using vector.transpose
33 class InnerOuterDimReductionConversion
34     : public OpRewritePattern<vector::MultiDimReductionOp> {
35 public:
36   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
37 
InnerOuterDimReductionConversion(MLIRContext * context,bool useInnerDimsForReduction)38   explicit InnerOuterDimReductionConversion(MLIRContext *context,
39                                             bool useInnerDimsForReduction)
40       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
41         useInnerDimsForReduction(useInnerDimsForReduction) {}
42 
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const43   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
44                                 PatternRewriter &rewriter) const override {
45     auto src = multiReductionOp.source();
46     auto loc = multiReductionOp.getLoc();
47     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
48 
49     // Separate reduction and parallel dims
50     auto reductionDimsRange =
51         multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
52     auto reductionDims = llvm::to_vector<4>(llvm::map_range(
53         reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
54     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
55                                                   reductionDims.end());
56     int64_t reductionSize = reductionDims.size();
57     SmallVector<int64_t, 4> parallelDims;
58     for (int64_t i = 0; i < srcRank; ++i)
59       if (!reductionDimsSet.contains(i))
60         parallelDims.push_back(i);
61 
62     // Add transpose only if inner-most/outer-most dimensions are not parallel
63     if (useInnerDimsForReduction &&
64         (parallelDims ==
65          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
66       return failure();
67 
68     if (!useInnerDimsForReduction &&
69         (parallelDims !=
70          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
71       return failure();
72 
73     SmallVector<int64_t, 4> indices;
74     if (useInnerDimsForReduction) {
75       indices.append(parallelDims.begin(), parallelDims.end());
76       indices.append(reductionDims.begin(), reductionDims.end());
77     } else {
78       indices.append(reductionDims.begin(), reductionDims.end());
79       indices.append(parallelDims.begin(), parallelDims.end());
80     }
81     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
82     SmallVector<bool> reductionMask(srcRank, false);
83     for (int i = 0; i < reductionSize; ++i) {
84       if (useInnerDimsForReduction)
85         reductionMask[srcRank - i - 1] = true;
86       else
87         reductionMask[i] = true;
88     }
89     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90         multiReductionOp, transposeOp.result(), reductionMask,
91         multiReductionOp.kind());
92     return success();
93   }
94 
95 private:
96   const bool useInnerDimsForReduction;
97 };
98 
99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
100 /// dimensions are either inner most or outer most.
101 class ReduceMultiDimReductionRank
102     : public OpRewritePattern<vector::MultiDimReductionOp> {
103 public:
104   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
105 
ReduceMultiDimReductionRank(MLIRContext * context,bool useInnerDimsForReduction)106   explicit ReduceMultiDimReductionRank(MLIRContext *context,
107                                        bool useInnerDimsForReduction)
108       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
109         useInnerDimsForReduction(useInnerDimsForReduction) {}
110 
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const111   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
112                                 PatternRewriter &rewriter) const override {
113     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
114     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
115     auto loc = multiReductionOp.getLoc();
116 
117     // If rank less than 2, nothing to do.
118     if (srcRank < 2)
119       return failure();
120 
121     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
122     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
123     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
124       return failure();
125 
126     // 1. Separate reduction and parallel dims.
127     SmallVector<int64_t, 4> parallelDims, parallelShapes;
128     SmallVector<int64_t, 4> reductionDims, reductionShapes;
129     for (auto it : llvm::enumerate(reductionMask)) {
130       int64_t i = it.index();
131       bool isReduction = it.value();
132       if (isReduction) {
133         reductionDims.push_back(i);
134         reductionShapes.push_back(srcShape[i]);
135       } else {
136         parallelDims.push_back(i);
137         parallelShapes.push_back(srcShape[i]);
138       }
139     }
140 
141     // 2. Compute flattened parallel and reduction sizes.
142     int flattenedParallelDim = 0;
143     int flattenedReductionDim = 0;
144     if (parallelShapes.size() > 0) {
145       flattenedParallelDim = 1;
146       for (auto d : parallelShapes)
147         flattenedParallelDim *= d;
148     }
149     if (reductionShapes.size() > 0) {
150       flattenedReductionDim = 1;
151       for (auto d : reductionShapes)
152         flattenedReductionDim *= d;
153     }
154     // We must at least have some parallel or some reduction.
155     assert((flattenedParallelDim || flattenedReductionDim) &&
156            "expected at least one parallel or reduction dim");
157 
158     // 3. Fail if reduction/parallel dims are not contiguous.
159     // Check parallelDims are exactly [0 .. size).
160     int64_t counter = 0;
161     if (useInnerDimsForReduction &&
162         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
163       return failure();
164     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
165     counter = reductionDims.size();
166     if (!useInnerDimsForReduction &&
167         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
168       return failure();
169 
170     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
171     // a single parallel (resp. reduction) dim.
172     SmallVector<bool, 2> mask;
173     SmallVector<int64_t, 2> vectorShape;
174     if (flattenedParallelDim) {
175       mask.push_back(false);
176       vectorShape.push_back(flattenedParallelDim);
177     }
178     if (flattenedReductionDim) {
179       mask.push_back(true);
180       vectorShape.push_back(flattenedReductionDim);
181     }
182     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
183       std::swap(mask.front(), mask.back());
184       std::swap(vectorShape.front(), vectorShape.back());
185     }
186     auto castedType = VectorType::get(
187         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
188     Value cast = rewriter.create<vector::ShapeCastOp>(
189         loc, castedType, multiReductionOp.source());
190 
191     // 5. Creates the flattened form of vector.multi_reduction with inner/outer
192     // most dim as reduction.
193     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
194         loc, cast, mask, multiReductionOp.kind());
195 
196     // 6. If there are no parallel shapes, the result is a scalar.
197     // TODO: support 0-d vectors when available.
198     if (parallelShapes.empty()) {
199       rewriter.replaceOp(multiReductionOp, newOp.dest());
200       return success();
201     }
202 
203     // 7. Creates shape cast for the output n-D -> 2-D
204     VectorType outputCastedType = VectorType::get(
205         parallelShapes,
206         multiReductionOp.getSourceVectorType().getElementType());
207     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
208         multiReductionOp, outputCastedType, newOp.dest());
209     return success();
210   }
211 
212 private:
213   const bool useInnerDimsForReduction;
214 };
215 
216 /// Unrolls vector.multi_reduction with outermost reductions
217 /// and combines results
218 struct TwoDimMultiReductionToElementWise
219     : public OpRewritePattern<vector::MultiDimReductionOp> {
220   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
221 
matchAndRewriteTwoDimMultiReductionToElementWise222   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
223                                 PatternRewriter &rewriter) const override {
224     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
225     // Rank-2 ["parallel", "reduce"] or bail.
226     if (srcRank != 2)
227       return failure();
228 
229     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
230       return failure();
231 
232     auto loc = multiReductionOp.getLoc();
233     ArrayRef<int64_t> srcShape =
234         multiReductionOp.getSourceVectorType().getShape();
235 
236     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
237     if (!elementType.isIntOrIndexOrFloat())
238       return failure();
239 
240     Value condition;
241     Value result =
242         rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
243             .getResult();
244     for (int64_t i = 1; i < srcShape[0]; i++) {
245       auto operand =
246           rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
247       switch (multiReductionOp.kind()) {
248       case vector::CombiningKind::ADD:
249         if (elementType.isIntOrIndex())
250           result = rewriter.create<AddIOp>(loc, operand, result);
251         else
252           result = rewriter.create<AddFOp>(loc, operand, result);
253         break;
254       case vector::CombiningKind::MUL:
255         if (elementType.isIntOrIndex())
256           result = rewriter.create<MulIOp>(loc, operand, result);
257         else
258           result = rewriter.create<MulFOp>(loc, operand, result);
259         break;
260       case vector::CombiningKind::MINUI:
261         result = rewriter.create<MinUIOp>(loc, operand, result);
262         break;
263       case vector::CombiningKind::MINSI:
264         result = rewriter.create<MinSIOp>(loc, operand, result);
265         break;
266       case vector::CombiningKind::MINF:
267         result = rewriter.create<MinFOp>(loc, operand, result);
268         break;
269       case vector::CombiningKind::MAXUI:
270         result = rewriter.create<MaxUIOp>(loc, operand, result);
271         break;
272       case vector::CombiningKind::MAXSI:
273         result = rewriter.create<MaxSIOp>(loc, operand, result);
274         break;
275       case vector::CombiningKind::MAXF:
276         result = rewriter.create<MaxFOp>(loc, operand, result);
277         break;
278       case vector::CombiningKind::AND:
279         result = rewriter.create<AndOp>(loc, operand, result);
280         break;
281       case vector::CombiningKind::OR:
282         result = rewriter.create<OrOp>(loc, operand, result);
283         break;
284       case vector::CombiningKind::XOR:
285         result = rewriter.create<XOrOp>(loc, operand, result);
286         break;
287       }
288     }
289 
290     rewriter.replaceOp(multiReductionOp, result);
291     return success();
292   }
293 };
294 
295 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
296 /// a sequence of vector.reduction ops.
297 struct TwoDimMultiReductionToReduction
298     : public OpRewritePattern<vector::MultiDimReductionOp> {
299   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
300 
matchAndRewriteTwoDimMultiReductionToReduction301   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
302                                 PatternRewriter &rewriter) const override {
303     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
304     if (srcRank != 2)
305       return failure();
306 
307     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
308       return failure();
309 
310     auto loc = multiReductionOp.getLoc();
311     Value result = rewriter.create<ConstantOp>(
312         loc, multiReductionOp.getDestType(),
313         rewriter.getZeroAttr(multiReductionOp.getDestType()));
314     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
315 
316     // TODO: Add vector::CombiningKind attribute instead of string to
317     // vector.reduction.
318     auto getKindStr = [](vector::CombiningKind kind) {
319       switch (kind) {
320       case vector::CombiningKind::ADD:
321         return "add";
322       case vector::CombiningKind::MUL:
323         return "mul";
324       case vector::CombiningKind::MINUI:
325         return "minui";
326       case vector::CombiningKind::MINSI:
327         return "minsi";
328       case vector::CombiningKind::MINF:
329         return "minf";
330       case vector::CombiningKind::MAXUI:
331         return "maxui";
332       case vector::CombiningKind::MAXSI:
333         return "maxsi";
334       case vector::CombiningKind::MAXF:
335         return "maxf";
336       case vector::CombiningKind::AND:
337         return "and";
338       case vector::CombiningKind::OR:
339         return "or";
340       case vector::CombiningKind::XOR:
341         return "xor";
342       }
343       llvm_unreachable("unknown combining kind");
344     };
345 
346     for (int i = 0; i < outerDim; ++i) {
347       auto v = rewriter.create<vector::ExtractOp>(
348           loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
349       auto reducedValue = rewriter.create<vector::ReductionOp>(
350           loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
351           rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
352           ValueRange{});
353       result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
354                                                         result, i);
355     }
356     rewriter.replaceOp(multiReductionOp, result);
357     return success();
358   }
359 };
360 
361 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
362 /// form with both a single parallel and reduction dimension.
363 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
364 /// The case with a single parallel dimension is a noop and folds away
365 /// separately.
366 struct OneDimMultiReductionToTwoDim
367     : public OpRewritePattern<vector::MultiDimReductionOp> {
368   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
369 
matchAndRewriteOneDimMultiReductionToTwoDim370   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
371                                 PatternRewriter &rewriter) const override {
372     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
373     // Rank-1 or bail.
374     if (srcRank != 1)
375       return failure();
376 
377     auto loc = multiReductionOp.getLoc();
378     auto srcVectorType = multiReductionOp.getSourceVectorType();
379     auto srcShape = srcVectorType.getShape();
380     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
381                                       srcVectorType.getElementType());
382     assert(!multiReductionOp.getDestType().isa<VectorType>() &&
383            "multi_reduction with a single dimension expects a scalar result");
384 
385     // If the unique dim is reduced and we insert a parallel in front, we need a
386     // {false, true} mask.
387     SmallVector<bool, 2> mask{false, true};
388 
389     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
390     Value cast = rewriter.create<vector::ShapeCastOp>(
391         loc, castedType, multiReductionOp.source());
392     Value reduced = rewriter.create<vector::MultiDimReductionOp>(
393         loc, cast, mask, multiReductionOp.kind());
394     rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
395                                                    ArrayRef<int64_t>{0});
396     return success();
397   }
398 };
399 
populateVectorMultiReductionLoweringPatterns(RewritePatternSet & patterns,bool useInnerDimsForReduction)400 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
401     RewritePatternSet &patterns, bool useInnerDimsForReduction) {
402   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
403                OneDimMultiReductionToTwoDim>(patterns.getContext(),
404                                              useInnerDimsForReduction);
405   if (useInnerDimsForReduction)
406     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
407   else
408     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
409 }
410