1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21 using namespace mlir;
22 using namespace mlir::tosa;
23
24 /// There are two potential ways implementing broadcast:
25 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
26 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
27 /// TBD: picking option (a) now.
28
29 /// In this pass, we insert RESHAPE operators to increase the rank of the
30 /// lower rank operand as a first step in the broadcasting process. The TOSA
31 /// operators that support broadcast require that the rank of the operands
32 /// are equal.
33
34 // Examples:
35 // If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
36 // TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
37 // [1, b, 1].
38 // If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
39 // If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
40 // If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
41 // If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
42 // If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
43 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
44 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
45
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)46 static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
47 ArrayRef<int64_t> lowerRankShape,
48 SmallVectorImpl<int64_t> &reshapeOutputShape) {
49 // Intialize new shapes with [1] * higherRank.
50 int64_t higherRank = higherRankShape.size();
51 int64_t lowerRank = lowerRankShape.size();
52
53 reshapeOutputShape.assign(higherRank, 1);
54
55 int64_t higherLeftIndex = 0;
56 int64_t higherRightIndex = higherRank;
57 int64_t lowerLeftIndex = 0;
58 int64_t lowerRightIndex = lowerRank;
59 int64_t higherRankDim, lowerRankDim;
60
61 if (lowerRightIndex != 0 && higherRightIndex != 0) {
62 // Matches lower rank shape from right dimension first, until not
63 // matching high rank shape or reaching dimension 0.
64 while (true) {
65 higherRankDim = higherRankShape[higherRightIndex - 1];
66 lowerRankDim = lowerRankShape[lowerRightIndex - 1];
67 if (higherRankDim != lowerRankDim)
68 break;
69
70 reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
71
72 if (higherRightIndex > 0)
73 higherRightIndex--;
74
75 if (lowerRightIndex > 0)
76 lowerRightIndex--;
77
78 if (higherRightIndex == 0 || lowerRightIndex == 0)
79 break;
80 }
81 if (lowerRightIndex != 0 && higherRightIndex != 0) {
82 // Matches lower rank shape from left dimension, until not matching
83 // high rank shape or reaching right index.
84 while (true) {
85 higherRankDim = higherRankShape[higherLeftIndex];
86 lowerRankDim = lowerRankShape[lowerLeftIndex];
87 if (higherRankDim != lowerRankDim)
88 break;
89
90 reshapeOutputShape[higherLeftIndex] = higherRankDim;
91
92 if (higherLeftIndex < higherRightIndex)
93 higherLeftIndex++;
94
95 if (lowerLeftIndex < lowerRightIndex)
96 lowerLeftIndex++;
97
98 if (higherLeftIndex == higherRightIndex ||
99 lowerLeftIndex == lowerRightIndex)
100 break;
101 }
102 }
103 }
104 }
105
106 /// Common code to reate the reshape op where necessary to make the rank of the
107 /// operations equal. Returns the updated input1 and input2 for the original
108 /// input. The caller is expected to use these to rewrite the original operator
109 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)110 static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
111 RankedTensorType outputType, Value input1,
112 Value input2, Value &outInput1,
113 Value &outInput2) {
114
115 int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
116 int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
117
118 Value higherTensorValue, lowerTensorValue;
119 // return if rank already match
120 if (input1Rank == input2Rank) {
121 return 1;
122 } else if (input1Rank > input2Rank) {
123 higherTensorValue = input1;
124 lowerTensorValue = input2;
125 } else {
126 higherTensorValue = input2;
127 lowerTensorValue = input1;
128 }
129
130 ArrayRef<int64_t> outputRankShape = outputType.getShape();
131 ArrayRef<int64_t> higherRankShape =
132 higherTensorValue.getType().cast<RankedTensorType>().getShape();
133 (void)higherRankShape;
134 ArrayRef<int64_t> lowerRankShape =
135 lowerTensorValue.getType().cast<RankedTensorType>().getShape();
136
137 // outputRank == higherRank == max(input1Rank, input2Rank)
138 assert(higherRankShape.size() == outputRankShape.size());
139
140 SmallVector<int64_t, 4> reshapeOutputShape;
141
142 computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
143
144 auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
145 auto reshapeOutputType = RankedTensorType::get(
146 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
147
148 auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
149 loc, reshapeOutputType, lowerTensorValue,
150 rewriter.getI64ArrayAttr(reshapeOutputShape));
151
152 if (input1Rank > input2Rank) {
153 outInput1 = higherTensorValue;
154 outInput2 = reshapeLower.getResult();
155 } else {
156 outInput1 = reshapeLower.getResult();
157 outInput2 = higherTensorValue;
158 }
159
160 return 0;
161 }
162
163 namespace {
164 template <typename OpTy>
165 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
166 using OpRewritePattern<OpTy>::OpRewritePattern;
167
matchAndRewrite__anon7823b7870111::ConvertTosaOp168 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
169 PatternRewriter &rewriter) const {
170
171 Value input1 = tosaBinaryOp.input1();
172 Value input2 = tosaBinaryOp.input2();
173 Value output = tosaBinaryOp.getResult();
174 auto outputType = output.getType().cast<RankedTensorType>();
175
176 Value outInput1, outInput2;
177 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
178 input1, input2, outInput1, outInput2))
179 return failure();
180
181 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
182 outInput2);
183
184 return success();
185 }
186 };
187
188 // The MulOp has an extra parameter 'shift' not present in other elementwise
189 // binary ops, that necessitates special handling of its builder.
190 template <>
191 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
192 using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
193
matchAndRewrite__anon7823b7870111::ConvertTosaOp194 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
195 PatternRewriter &rewriter) const {
196
197 Value input1 = tosaBinaryOp.input1();
198 Value input2 = tosaBinaryOp.input2();
199 int32_t shift = tosaBinaryOp.shift();
200 Value output = tosaBinaryOp.getResult();
201 auto outputType = output.getType().cast<RankedTensorType>();
202
203 Value outInput1, outInput2;
204 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
205 input1, input2, outInput1, outInput2))
206 return failure();
207
208 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
209 outInput1, outInput2, shift);
210
211 return success();
212 }
213 };
214
215 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
216 // other elementwise binary ops, that necessitates special handling of its
217 // builder.
218 template <>
219 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
220 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
221 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
222
matchAndRewrite__anon7823b7870111::ConvertTosaOp223 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
224 PatternRewriter &rewriter) const {
225
226 Value input1 = tosaBinaryOp.input1();
227 Value input2 = tosaBinaryOp.input2();
228 int32_t round = tosaBinaryOp.round();
229 Value output = tosaBinaryOp.getResult();
230 auto outputType = output.getType().dyn_cast<RankedTensorType>();
231
232 Value outInput1, outInput2;
233 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
234 input1, input2, outInput1, outInput2))
235 return failure();
236
237 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
238 tosaBinaryOp, outputType, outInput1, outInput2, round);
239
240 return success();
241 }
242 };
243 } // end anonymous namespace
244
245 namespace {
246 /// Pass that enables broadcast by making all input arrays have the same
247 /// number of dimensions. Insert RESHAPE operations to lower rank operand
248 struct TosaMakeBroadcastable
249 : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
250 public:
runOnFunction__anon7823b7870211::TosaMakeBroadcastable251 void runOnFunction() override {
252 auto func = getFunction();
253 OwningRewritePatternList patterns;
254 MLIRContext *ctx = func.getContext();
255 // Add the generated patterns to the list.
256 patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
257 patterns.insert<ConvertTosaOp<tosa::SubOp>>(ctx);
258 patterns.insert<ConvertTosaOp<tosa::MulOp>>(ctx);
259 patterns.insert<ConvertTosaOp<tosa::MaximumOp>>(ctx);
260 patterns.insert<ConvertTosaOp<tosa::MinimumOp>>(ctx);
261 patterns.insert<ConvertTosaOp<tosa::EqualOp>>(ctx);
262 patterns.insert<ConvertTosaOp<tosa::GreaterOp>>(ctx);
263 patterns.insert<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
264 patterns.insert<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
265 patterns.insert<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
266 patterns.insert<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
267 applyPatternsAndFoldGreedily(func, std::move(patterns));
268 }
269 };
270 } // end anonymous namespace
271
createTosaMakeBroadcastablePass()272 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
273 return std::make_unique<TosaMakeBroadcastable>();
274 }
275