1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::tosa;
23 
24 /// There are two potential ways implementing broadcast:
25 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
26 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
27 /// TBD: picking option (a) now.
28 
29 /// In this pass, we insert RESHAPE operators to increase the rank of the
30 /// lower rank operand as a first step in the broadcasting process. The TOSA
31 /// operators that support broadcast require that the rank of the operands
32 /// are equal.
33 
34 // Examples:
35 // If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
36 // TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
37 // [1, b, 1].
38 // If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
39 // If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
40 // If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
41 // If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
42 // If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
43 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
44 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
45 
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)46 static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
47                                  ArrayRef<int64_t> lowerRankShape,
48                                  SmallVectorImpl<int64_t> &reshapeOutputShape) {
49   // Intialize new shapes with [1] * higherRank.
50   int64_t higherRank = higherRankShape.size();
51   int64_t lowerRank = lowerRankShape.size();
52 
53   reshapeOutputShape.assign(higherRank, 1);
54 
55   int64_t higherLeftIndex = 0;
56   int64_t higherRightIndex = higherRank;
57   int64_t lowerLeftIndex = 0;
58   int64_t lowerRightIndex = lowerRank;
59   int64_t higherRankDim, lowerRankDim;
60 
61   if (lowerRightIndex != 0 && higherRightIndex != 0) {
62     // Matches lower rank shape from right dimension first, until not
63     // matching high rank shape or reaching dimension 0.
64     while (true) {
65       higherRankDim = higherRankShape[higherRightIndex - 1];
66       lowerRankDim = lowerRankShape[lowerRightIndex - 1];
67       if (higherRankDim != lowerRankDim)
68         break;
69 
70       reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
71 
72       if (higherRightIndex > 0)
73         higherRightIndex--;
74 
75       if (lowerRightIndex > 0)
76         lowerRightIndex--;
77 
78       if (higherRightIndex == 0 || lowerRightIndex == 0)
79         break;
80     }
81     if (lowerRightIndex != 0 && higherRightIndex != 0) {
82       // Matches lower rank shape from left dimension, until not matching
83       // high rank shape or reaching right index.
84       while (true) {
85         higherRankDim = higherRankShape[higherLeftIndex];
86         lowerRankDim = lowerRankShape[lowerLeftIndex];
87         if (higherRankDim != lowerRankDim)
88           break;
89 
90         reshapeOutputShape[higherLeftIndex] = higherRankDim;
91 
92         if (higherLeftIndex < higherRightIndex)
93           higherLeftIndex++;
94 
95         if (lowerLeftIndex < lowerRightIndex)
96           lowerLeftIndex++;
97 
98         if (higherLeftIndex == higherRightIndex ||
99             lowerLeftIndex == lowerRightIndex)
100           break;
101       }
102     }
103   }
104 }
105 
106 /// Common code to reate the reshape op where necessary to make the rank of the
107 /// operations equal. Returns the updated input1 and input2 for the original
108 /// input. The caller is expected to use these to rewrite the original operator
109 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)110 static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
111                                 RankedTensorType outputType, Value input1,
112                                 Value input2, Value &outInput1,
113                                 Value &outInput2) {
114 
115   int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
116   int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
117 
118   Value higherTensorValue, lowerTensorValue;
119   // return if rank already match
120   if (input1Rank == input2Rank) {
121     return 1;
122   } else if (input1Rank > input2Rank) {
123     higherTensorValue = input1;
124     lowerTensorValue = input2;
125   } else {
126     higherTensorValue = input2;
127     lowerTensorValue = input1;
128   }
129 
130   ArrayRef<int64_t> outputRankShape = outputType.getShape();
131   ArrayRef<int64_t> higherRankShape =
132       higherTensorValue.getType().cast<RankedTensorType>().getShape();
133   (void)higherRankShape;
134   ArrayRef<int64_t> lowerRankShape =
135       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
136 
137   // outputRank == higherRank == max(input1Rank, input2Rank)
138   assert(higherRankShape.size() == outputRankShape.size());
139 
140   SmallVector<int64_t, 4> reshapeOutputShape;
141 
142   computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
143 
144   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
145   auto reshapeOutputType = RankedTensorType::get(
146       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
147 
148   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
149       loc, reshapeOutputType, lowerTensorValue,
150       rewriter.getI64ArrayAttr(reshapeOutputShape));
151 
152   if (input1Rank > input2Rank) {
153     outInput1 = higherTensorValue;
154     outInput2 = reshapeLower.getResult();
155   } else {
156     outInput1 = reshapeLower.getResult();
157     outInput2 = higherTensorValue;
158   }
159 
160   return 0;
161 }
162 
163 namespace {
164 template <typename OpTy>
165 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
166   using OpRewritePattern<OpTy>::OpRewritePattern;
167 
matchAndRewrite__anon7823b7870111::ConvertTosaOp168   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
169                                 PatternRewriter &rewriter) const {
170 
171     Value input1 = tosaBinaryOp.input1();
172     Value input2 = tosaBinaryOp.input2();
173     Value output = tosaBinaryOp.getResult();
174     auto outputType = output.getType().cast<RankedTensorType>();
175 
176     Value outInput1, outInput2;
177     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
178                              input1, input2, outInput1, outInput2))
179       return failure();
180 
181     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
182                                       outInput2);
183 
184     return success();
185   }
186 };
187 
188 // The MulOp has an extra parameter 'shift' not present in other elementwise
189 // binary ops, that necessitates special handling of its builder.
190 template <>
191 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
192   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
193 
matchAndRewrite__anon7823b7870111::ConvertTosaOp194   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
195                                 PatternRewriter &rewriter) const {
196 
197     Value input1 = tosaBinaryOp.input1();
198     Value input2 = tosaBinaryOp.input2();
199     int32_t shift = tosaBinaryOp.shift();
200     Value output = tosaBinaryOp.getResult();
201     auto outputType = output.getType().cast<RankedTensorType>();
202 
203     Value outInput1, outInput2;
204     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
205                              input1, input2, outInput1, outInput2))
206       return failure();
207 
208     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
209                                              outInput1, outInput2, shift);
210 
211     return success();
212   }
213 };
214 
215 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
216 // other elementwise binary ops, that necessitates special handling of its
217 // builder.
218 template <>
219 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
220     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
221   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
222 
matchAndRewrite__anon7823b7870111::ConvertTosaOp223   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
224                                 PatternRewriter &rewriter) const {
225 
226     Value input1 = tosaBinaryOp.input1();
227     Value input2 = tosaBinaryOp.input2();
228     int32_t round = tosaBinaryOp.round();
229     Value output = tosaBinaryOp.getResult();
230     auto outputType = output.getType().dyn_cast<RankedTensorType>();
231 
232     Value outInput1, outInput2;
233     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
234                              input1, input2, outInput1, outInput2))
235       return failure();
236 
237     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
238         tosaBinaryOp, outputType, outInput1, outInput2, round);
239 
240     return success();
241   }
242 };
243 } // end anonymous namespace
244 
245 namespace {
246 /// Pass that enables broadcast by making all input arrays have the same
247 /// number of dimensions. Insert RESHAPE operations to lower rank operand
248 struct TosaMakeBroadcastable
249     : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
250 public:
runOnFunction__anon7823b7870211::TosaMakeBroadcastable251   void runOnFunction() override {
252     auto func = getFunction();
253     OwningRewritePatternList patterns;
254     MLIRContext *ctx = func.getContext();
255     // Add the generated patterns to the list.
256     patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
257     patterns.insert<ConvertTosaOp<tosa::SubOp>>(ctx);
258     patterns.insert<ConvertTosaOp<tosa::MulOp>>(ctx);
259     patterns.insert<ConvertTosaOp<tosa::MaximumOp>>(ctx);
260     patterns.insert<ConvertTosaOp<tosa::MinimumOp>>(ctx);
261     patterns.insert<ConvertTosaOp<tosa::EqualOp>>(ctx);
262     patterns.insert<ConvertTosaOp<tosa::GreaterOp>>(ctx);
263     patterns.insert<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
264     patterns.insert<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
265     patterns.insert<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
266     patterns.insert<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
267     applyPatternsAndFoldGreedily(func, std::move(patterns));
268   }
269 };
270 } // end anonymous namespace
271 
createTosaMakeBroadcastablePass()272 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
273   return std::make_unique<TosaMakeBroadcastable>();
274 }
275