1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::tosa;
24 
25 /// There are two potential ways implementing broadcast:
26 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
27 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
28 /// TBD: picking option (a) now.
29 
30 /// In this pass, we insert RESHAPE operators to increase the rank of the
31 /// lower rank operand as a first step in the broadcasting process. The TOSA
32 /// operators that support broadcast require that the rank of the operands
33 /// are equal.
34 
35 // Examples:
36 // If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
37 // TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
38 // [1, b, 1].
39 // If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
40 // If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
41 // If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
42 // If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
43 // If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
44 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
45 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
46 
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)47 static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
48                                  ArrayRef<int64_t> lowerRankShape,
49                                  SmallVectorImpl<int64_t> &reshapeOutputShape) {
50   // Initialize new shapes with [1] * higherRank.
51   int64_t higherRank = higherRankShape.size();
52   int64_t lowerRank = lowerRankShape.size();
53 
54   reshapeOutputShape.assign(higherRank, 1);
55 
56   int64_t higherLeftIndex = 0;
57   int64_t higherRightIndex = higherRank;
58   int64_t lowerLeftIndex = 0;
59   int64_t lowerRightIndex = lowerRank;
60   int64_t higherRankDim, lowerRankDim;
61 
62   if (lowerRightIndex != 0 && higherRightIndex != 0) {
63     // Matches lower rank shape from right dimension first, until not
64     // matching high rank shape or reaching dimension 0.
65     while (true) {
66       higherRankDim = higherRankShape[higherRightIndex - 1];
67       lowerRankDim = lowerRankShape[lowerRightIndex - 1];
68       if (higherRankDim != lowerRankDim)
69         break;
70 
71       reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
72 
73       if (higherRightIndex > 0)
74         higherRightIndex--;
75 
76       if (lowerRightIndex > 0)
77         lowerRightIndex--;
78 
79       if (higherRightIndex == 0 || lowerRightIndex == 0)
80         break;
81     }
82     if (lowerRightIndex != 0 && higherRightIndex != 0) {
83       // Matches lower rank shape from left dimension, until not matching
84       // high rank shape or reaching right index.
85       while (true) {
86         higherRankDim = higherRankShape[higherLeftIndex];
87         lowerRankDim = lowerRankShape[lowerLeftIndex];
88         if (higherRankDim != lowerRankDim)
89           break;
90 
91         reshapeOutputShape[higherLeftIndex] = higherRankDim;
92 
93         if (higherLeftIndex < higherRightIndex)
94           higherLeftIndex++;
95 
96         if (lowerLeftIndex < lowerRightIndex)
97           lowerLeftIndex++;
98 
99         if (higherLeftIndex == higherRightIndex ||
100             lowerLeftIndex == lowerRightIndex)
101           break;
102       }
103     }
104   }
105 }
106 
107 /// Common code to create the reshape op where necessary to make the rank of the
108 /// operations equal. Returns the updated input1 and input2 for the original
109 /// input. The caller is expected to use these to rewrite the original operator
110 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)111 static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
112                                           Location loc,
113                                           RankedTensorType outputType,
114                                           Value input1, Value input2,
115                                           Value &outInput1, Value &outInput2) {
116   auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
117   auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
118 
119   if (!input1Ty || !input2Ty)
120     return failure();
121 
122   int64_t input1Rank = input1Ty.getRank();
123   int64_t input2Rank = input2Ty.getRank();
124 
125   Value higherTensorValue, lowerTensorValue;
126   // Cannot rewrite as its already correct.
127   if (input1Rank == input2Rank)
128     return failure();
129 
130   if (input1Rank > input2Rank) {
131     higherTensorValue = input1;
132     lowerTensorValue = input2;
133   } else {
134     higherTensorValue = input2;
135     lowerTensorValue = input1;
136   }
137 
138   ArrayRef<int64_t> higherRankShape =
139       higherTensorValue.getType().cast<RankedTensorType>().getShape();
140   (void)higherRankShape;
141   ArrayRef<int64_t> lowerRankShape =
142       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
143 
144   SmallVector<int64_t, 4> reshapeOutputShape;
145 
146   computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
147 
148   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
149   auto reshapeOutputType = RankedTensorType::get(
150       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
151 
152   // Verify the rank agrees with the output type if the output type is ranked.
153   if (outputType) {
154     if (outputType.getShape().size() != reshapeOutputShape.size() ||
155         outputType.getShape().size() != higherRankShape.size())
156       return failure();
157   }
158 
159   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
160       loc, reshapeOutputType, lowerTensorValue,
161       rewriter.getI64ArrayAttr(reshapeOutputShape));
162 
163   if (input1Rank > input2Rank) {
164     outInput1 = higherTensorValue;
165     outInput2 = reshapeLower.getResult();
166   } else {
167     outInput1 = reshapeLower.getResult();
168     outInput2 = higherTensorValue;
169   }
170 
171   return success();
172 }
173 
174 namespace {
175 template <typename OpTy>
176 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
177   using OpRewritePattern<OpTy>::OpRewritePattern;
178 
matchAndRewrite__anonfab8b9540111::ConvertTosaOp179   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
180                                 PatternRewriter &rewriter) const override {
181 
182     Value input1 = tosaBinaryOp.input1();
183     Value input2 = tosaBinaryOp.input2();
184     Value output = tosaBinaryOp.getResult();
185 
186     auto outputType = output.getType().dyn_cast<RankedTensorType>();
187 
188     Value outInput1, outInput2;
189     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
190                              input1, input2, outInput1, outInput2)
191             .failed())
192       return failure();
193 
194     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
195                                       outInput2);
196 
197     return success();
198   }
199 };
200 
201 // The MulOp has an extra parameter 'shift' not present in other elementwise
202 // binary ops, that necessitates special handling of its builder.
203 template <>
204 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
205   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
206 
matchAndRewrite__anonfab8b9540111::ConvertTosaOp207   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
208                                 PatternRewriter &rewriter) const override {
209 
210     Value input1 = tosaBinaryOp.input1();
211     Value input2 = tosaBinaryOp.input2();
212     int32_t shift = tosaBinaryOp.shift();
213     Value output = tosaBinaryOp.getResult();
214     auto outputType = output.getType().dyn_cast<RankedTensorType>();
215 
216     Value outInput1, outInput2;
217     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
218                              input1, input2, outInput1, outInput2)
219             .failed())
220       return failure();
221 
222     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
223                                              outInput1, outInput2, shift);
224 
225     return success();
226   }
227 };
228 
229 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
230 // other elementwise binary ops, that necessitates special handling of its
231 // builder.
232 template <>
233 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
234     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
235   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
236 
matchAndRewrite__anonfab8b9540111::ConvertTosaOp237   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
238                                 PatternRewriter &rewriter) const override {
239 
240     Value input1 = tosaBinaryOp.input1();
241     Value input2 = tosaBinaryOp.input2();
242     int32_t round = tosaBinaryOp.round();
243     Value output = tosaBinaryOp.getResult();
244     auto outputType = output.getType().dyn_cast<RankedTensorType>();
245 
246     Value outInput1, outInput2;
247     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
248                              input1, input2, outInput1, outInput2)
249             .failed())
250       return failure();
251 
252     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
253         tosaBinaryOp, outputType, outInput1, outInput2, round);
254 
255     return success();
256   }
257 };
258 } // end anonymous namespace
259 
260 namespace {
261 /// Pass that enables broadcast by making all input arrays have the same
262 /// number of dimensions. Insert RESHAPE operations to lower rank operand
263 struct TosaMakeBroadcastable
264     : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
265 public:
runOnFunction__anonfab8b9540211::TosaMakeBroadcastable266   void runOnFunction() override {
267     auto func = getFunction();
268     RewritePatternSet patterns(func.getContext());
269     MLIRContext *ctx = func.getContext();
270     // Add the generated patterns to the list.
271     patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
272     patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
273     patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
274     patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
275     patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
276     patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
277     patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
278     patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
279     patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
280     patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
281     patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
282     patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
283     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
284   }
285 };
286 } // end anonymous namespace
287 
createTosaMakeBroadcastablePass()288 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
289   return std::make_unique<TosaMakeBroadcastable>();
290 }
291