1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22 using namespace mlir;
23 using namespace mlir::tosa;
24
25 /// There are two potential ways implementing broadcast:
26 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
27 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
28 /// TBD: picking option (a) now.
29
30 /// In this pass, we insert RESHAPE operators to increase the rank of the
31 /// lower rank operand as a first step in the broadcasting process. The TOSA
32 /// operators that support broadcast require that the rank of the operands
33 /// are equal.
34
35 // Examples:
36 // If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
37 // TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
38 // [1, b, 1].
39 // If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
40 // If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
41 // If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
42 // If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
43 // If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
44 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
45 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
46
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)47 static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
48 ArrayRef<int64_t> lowerRankShape,
49 SmallVectorImpl<int64_t> &reshapeOutputShape) {
50 // Initialize new shapes with [1] * higherRank.
51 int64_t higherRank = higherRankShape.size();
52 int64_t lowerRank = lowerRankShape.size();
53
54 reshapeOutputShape.assign(higherRank, 1);
55
56 int64_t higherLeftIndex = 0;
57 int64_t higherRightIndex = higherRank;
58 int64_t lowerLeftIndex = 0;
59 int64_t lowerRightIndex = lowerRank;
60 int64_t higherRankDim, lowerRankDim;
61
62 if (lowerRightIndex != 0 && higherRightIndex != 0) {
63 // Matches lower rank shape from right dimension first, until not
64 // matching high rank shape or reaching dimension 0.
65 while (true) {
66 higherRankDim = higherRankShape[higherRightIndex - 1];
67 lowerRankDim = lowerRankShape[lowerRightIndex - 1];
68 if (higherRankDim != lowerRankDim)
69 break;
70
71 reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
72
73 if (higherRightIndex > 0)
74 higherRightIndex--;
75
76 if (lowerRightIndex > 0)
77 lowerRightIndex--;
78
79 if (higherRightIndex == 0 || lowerRightIndex == 0)
80 break;
81 }
82 if (lowerRightIndex != 0 && higherRightIndex != 0) {
83 // Matches lower rank shape from left dimension, until not matching
84 // high rank shape or reaching right index.
85 while (true) {
86 higherRankDim = higherRankShape[higherLeftIndex];
87 lowerRankDim = lowerRankShape[lowerLeftIndex];
88 if (higherRankDim != lowerRankDim)
89 break;
90
91 reshapeOutputShape[higherLeftIndex] = higherRankDim;
92
93 if (higherLeftIndex < higherRightIndex)
94 higherLeftIndex++;
95
96 if (lowerLeftIndex < lowerRightIndex)
97 lowerLeftIndex++;
98
99 if (higherLeftIndex == higherRightIndex ||
100 lowerLeftIndex == lowerRightIndex)
101 break;
102 }
103 }
104 }
105 }
106
107 /// Common code to create the reshape op where necessary to make the rank of the
108 /// operations equal. Returns the updated input1 and input2 for the original
109 /// input. The caller is expected to use these to rewrite the original operator
110 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)111 static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
112 Location loc,
113 RankedTensorType outputType,
114 Value input1, Value input2,
115 Value &outInput1, Value &outInput2) {
116 auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
117 auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
118
119 if (!input1Ty || !input2Ty)
120 return failure();
121
122 int64_t input1Rank = input1Ty.getRank();
123 int64_t input2Rank = input2Ty.getRank();
124
125 Value higherTensorValue, lowerTensorValue;
126 // Cannot rewrite as its already correct.
127 if (input1Rank == input2Rank)
128 return failure();
129
130 if (input1Rank > input2Rank) {
131 higherTensorValue = input1;
132 lowerTensorValue = input2;
133 } else {
134 higherTensorValue = input2;
135 lowerTensorValue = input1;
136 }
137
138 ArrayRef<int64_t> higherRankShape =
139 higherTensorValue.getType().cast<RankedTensorType>().getShape();
140 (void)higherRankShape;
141 ArrayRef<int64_t> lowerRankShape =
142 lowerTensorValue.getType().cast<RankedTensorType>().getShape();
143
144 SmallVector<int64_t, 4> reshapeOutputShape;
145
146 computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
147
148 auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
149 auto reshapeOutputType = RankedTensorType::get(
150 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
151
152 // Verify the rank agrees with the output type if the output type is ranked.
153 if (outputType) {
154 if (outputType.getShape().size() != reshapeOutputShape.size() ||
155 outputType.getShape().size() != higherRankShape.size())
156 return failure();
157 }
158
159 auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
160 loc, reshapeOutputType, lowerTensorValue,
161 rewriter.getI64ArrayAttr(reshapeOutputShape));
162
163 if (input1Rank > input2Rank) {
164 outInput1 = higherTensorValue;
165 outInput2 = reshapeLower.getResult();
166 } else {
167 outInput1 = reshapeLower.getResult();
168 outInput2 = higherTensorValue;
169 }
170
171 return success();
172 }
173
174 namespace {
175 template <typename OpTy>
176 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
177 using OpRewritePattern<OpTy>::OpRewritePattern;
178
matchAndRewrite__anon85774d1c0111::ConvertTosaOp179 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
180 PatternRewriter &rewriter) const override {
181
182 Value input1 = tosaBinaryOp.input1();
183 Value input2 = tosaBinaryOp.input2();
184 Value output = tosaBinaryOp.getResult();
185
186 auto outputType = output.getType().dyn_cast<RankedTensorType>();
187
188 Value outInput1, outInput2;
189 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
190 input1, input2, outInput1, outInput2)
191 .failed())
192 return failure();
193
194 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
195 outInput2);
196
197 return success();
198 }
199 };
200
201 // The MulOp has an extra parameter 'shift' not present in other elementwise
202 // binary ops, that necessitates special handling of its builder.
203 template <>
204 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
205 using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
206
matchAndRewrite__anon85774d1c0111::ConvertTosaOp207 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
208 PatternRewriter &rewriter) const override {
209
210 Value input1 = tosaBinaryOp.input1();
211 Value input2 = tosaBinaryOp.input2();
212 int32_t shift = tosaBinaryOp.shift();
213 Value output = tosaBinaryOp.getResult();
214 auto outputType = output.getType().dyn_cast<RankedTensorType>();
215
216 Value outInput1, outInput2;
217 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
218 input1, input2, outInput1, outInput2)
219 .failed())
220 return failure();
221
222 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
223 outInput1, outInput2, shift);
224
225 return success();
226 }
227 };
228
229 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
230 // other elementwise binary ops, that necessitates special handling of its
231 // builder.
232 template <>
233 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
234 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
235 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
236
matchAndRewrite__anon85774d1c0111::ConvertTosaOp237 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
238 PatternRewriter &rewriter) const override {
239
240 Value input1 = tosaBinaryOp.input1();
241 Value input2 = tosaBinaryOp.input2();
242 int32_t round = tosaBinaryOp.round();
243 Value output = tosaBinaryOp.getResult();
244 auto outputType = output.getType().dyn_cast<RankedTensorType>();
245
246 Value outInput1, outInput2;
247 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
248 input1, input2, outInput1, outInput2)
249 .failed())
250 return failure();
251
252 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
253 tosaBinaryOp, outputType, outInput1, outInput2, round);
254
255 return success();
256 }
257 };
258 } // end anonymous namespace
259
260 namespace {
261 /// Pass that enables broadcast by making all input arrays have the same
262 /// number of dimensions. Insert RESHAPE operations to lower rank operand
263 struct TosaMakeBroadcastable
264 : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
265 public:
runOnFunction__anon85774d1c0211::TosaMakeBroadcastable266 void runOnFunction() override {
267 auto func = getFunction();
268 RewritePatternSet patterns(func.getContext());
269 MLIRContext *ctx = func.getContext();
270 // Add the generated patterns to the list.
271 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
272 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
273 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
274 patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
275 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
276 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
277 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
278 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
279 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
280 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
281 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
282 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
283 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
284 }
285 };
286 } // end anonymous namespace
287
createTosaMakeBroadcastablePass()288 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
289 return std::make_unique<TosaMakeBroadcastable>();
290 }
291