1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::tosa;
23 
24 /// There are two potential ways implementing broadcast:
25 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
26 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
27 /// TBD: picking option (a) now.
28 
29 /// In this pass, we insert RESHAPE operators to increase the rank of the
30 /// lower rank operand as a first step in the broadcasting process. The TOSA
31 /// operators that support broadcast require that the rank of the operands
32 /// are equal.
33 
34 // Examples:
35 // If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
36 // TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
37 // [1, b, 1].
38 // If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
39 // If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
40 // If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
41 // If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
42 // If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
43 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
44 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
45 
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)46 static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
47                                  ArrayRef<int64_t> lowerRankShape,
48                                  SmallVectorImpl<int64_t> &reshapeOutputShape) {
49   // Initialize new shapes with [1] * higherRank.
50   int64_t higherRank = higherRankShape.size();
51   int64_t lowerRank = lowerRankShape.size();
52 
53   reshapeOutputShape.assign(higherRank, 1);
54 
55   int64_t higherLeftIndex = 0;
56   int64_t higherRightIndex = higherRank;
57   int64_t lowerLeftIndex = 0;
58   int64_t lowerRightIndex = lowerRank;
59   int64_t higherRankDim, lowerRankDim;
60 
61   if (lowerRightIndex != 0 && higherRightIndex != 0) {
62     // Matches lower rank shape from right dimension first, until not
63     // matching high rank shape or reaching dimension 0.
64     while (true) {
65       higherRankDim = higherRankShape[higherRightIndex - 1];
66       lowerRankDim = lowerRankShape[lowerRightIndex - 1];
67       if (higherRankDim != lowerRankDim)
68         break;
69 
70       reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
71 
72       if (higherRightIndex > 0)
73         higherRightIndex--;
74 
75       if (lowerRightIndex > 0)
76         lowerRightIndex--;
77 
78       if (higherRightIndex == 0 || lowerRightIndex == 0)
79         break;
80     }
81     if (lowerRightIndex != 0 && higherRightIndex != 0) {
82       // Matches lower rank shape from left dimension, until not matching
83       // high rank shape or reaching right index.
84       while (true) {
85         higherRankDim = higherRankShape[higherLeftIndex];
86         lowerRankDim = lowerRankShape[lowerLeftIndex];
87         if (higherRankDim != lowerRankDim)
88           break;
89 
90         reshapeOutputShape[higherLeftIndex] = higherRankDim;
91 
92         if (higherLeftIndex < higherRightIndex)
93           higherLeftIndex++;
94 
95         if (lowerLeftIndex < lowerRightIndex)
96           lowerLeftIndex++;
97 
98         if (higherLeftIndex == higherRightIndex ||
99             lowerLeftIndex == lowerRightIndex)
100           break;
101       }
102     }
103   }
104 }
105 
106 /// Common code to reate the reshape op where necessary to make the rank of the
107 /// operations equal. Returns the updated input1 and input2 for the original
108 /// input. The caller is expected to use these to rewrite the original operator
109 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)110 static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
111                                 RankedTensorType outputType, Value input1,
112                                 Value input2, Value &outInput1,
113                                 Value &outInput2) {
114 
115   int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
116   int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
117 
118   Value higherTensorValue, lowerTensorValue;
119   // return if rank already match
120   if (input1Rank == input2Rank)
121     return 1;
122 
123   if (input1Rank > input2Rank) {
124     higherTensorValue = input1;
125     lowerTensorValue = input2;
126   } else {
127     higherTensorValue = input2;
128     lowerTensorValue = input1;
129   }
130 
131   ArrayRef<int64_t> outputRankShape = outputType.getShape();
132   ArrayRef<int64_t> higherRankShape =
133       higherTensorValue.getType().cast<RankedTensorType>().getShape();
134   (void)higherRankShape;
135   ArrayRef<int64_t> lowerRankShape =
136       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
137 
138   // outputRank == higherRank == max(input1Rank, input2Rank)
139   assert(higherRankShape.size() == outputRankShape.size());
140 
141   SmallVector<int64_t, 4> reshapeOutputShape;
142 
143   computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
144 
145   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
146   auto reshapeOutputType = RankedTensorType::get(
147       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
148 
149   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
150       loc, reshapeOutputType, lowerTensorValue,
151       rewriter.getI64ArrayAttr(reshapeOutputShape));
152 
153   if (input1Rank > input2Rank) {
154     outInput1 = higherTensorValue;
155     outInput2 = reshapeLower.getResult();
156   } else {
157     outInput1 = reshapeLower.getResult();
158     outInput2 = higherTensorValue;
159   }
160 
161   return 0;
162 }
163 
164 namespace {
165 template <typename OpTy>
166 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
167   using OpRewritePattern<OpTy>::OpRewritePattern;
168 
matchAndRewrite__anon7b5be6460111::ConvertTosaOp169   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
170                                 PatternRewriter &rewriter) const override {
171 
172     Value input1 = tosaBinaryOp.input1();
173     Value input2 = tosaBinaryOp.input2();
174     Value output = tosaBinaryOp.getResult();
175     auto outputType = output.getType().cast<RankedTensorType>();
176 
177     Value outInput1, outInput2;
178     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
179                              input1, input2, outInput1, outInput2))
180       return failure();
181 
182     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
183                                       outInput2);
184 
185     return success();
186   }
187 };
188 
189 // The MulOp has an extra parameter 'shift' not present in other elementwise
190 // binary ops, that necessitates special handling of its builder.
191 template <>
192 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
193   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
194 
matchAndRewrite__anon7b5be6460111::ConvertTosaOp195   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
196                                 PatternRewriter &rewriter) const override {
197 
198     Value input1 = tosaBinaryOp.input1();
199     Value input2 = tosaBinaryOp.input2();
200     int32_t shift = tosaBinaryOp.shift();
201     Value output = tosaBinaryOp.getResult();
202     auto outputType = output.getType().cast<RankedTensorType>();
203 
204     Value outInput1, outInput2;
205     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
206                              input1, input2, outInput1, outInput2))
207       return failure();
208 
209     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
210                                              outInput1, outInput2, shift);
211 
212     return success();
213   }
214 };
215 
216 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
217 // other elementwise binary ops, that necessitates special handling of its
218 // builder.
219 template <>
220 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
221     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
222   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
223 
matchAndRewrite__anon7b5be6460111::ConvertTosaOp224   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
225                                 PatternRewriter &rewriter) const override {
226 
227     Value input1 = tosaBinaryOp.input1();
228     Value input2 = tosaBinaryOp.input2();
229     int32_t round = tosaBinaryOp.round();
230     Value output = tosaBinaryOp.getResult();
231     auto outputType = output.getType().dyn_cast<RankedTensorType>();
232 
233     Value outInput1, outInput2;
234     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
235                              input1, input2, outInput1, outInput2))
236       return failure();
237 
238     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
239         tosaBinaryOp, outputType, outInput1, outInput2, round);
240 
241     return success();
242   }
243 };
244 } // end anonymous namespace
245 
246 namespace {
247 /// Pass that enables broadcast by making all input arrays have the same
248 /// number of dimensions. Insert RESHAPE operations to lower rank operand
249 struct TosaMakeBroadcastable
250     : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
251 public:
runOnFunction__anon7b5be6460211::TosaMakeBroadcastable252   void runOnFunction() override {
253     auto func = getFunction();
254     OwningRewritePatternList patterns;
255     MLIRContext *ctx = func.getContext();
256     // Add the generated patterns to the list.
257     patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
258     patterns.insert<ConvertTosaOp<tosa::SubOp>>(ctx);
259     patterns.insert<ConvertTosaOp<tosa::MulOp>>(ctx);
260     patterns.insert<ConvertTosaOp<tosa::MaximumOp>>(ctx);
261     patterns.insert<ConvertTosaOp<tosa::MinimumOp>>(ctx);
262     patterns.insert<ConvertTosaOp<tosa::EqualOp>>(ctx);
263     patterns.insert<ConvertTosaOp<tosa::GreaterOp>>(ctx);
264     patterns.insert<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
265     patterns.insert<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
266     patterns.insert<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
267     patterns.insert<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
268     applyPatternsAndFoldGreedily(func, std::move(patterns));
269   }
270 };
271 } // end anonymous namespace
272 
createTosaMakeBroadcastablePass()273 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
274   return std::make_unique<TosaMakeBroadcastable>();
275 }
276