1 //
2 // Arm82Binary.cpp
3 // MNN
4 //
5 // Created by MNN on 2021/01/05.
6 // Copyright © 2021, Alibaba Group Holding Limited
7 //
8
9 #if defined(__ANDROID__) || defined(__aarch64__)
10 #include <algorithm>
11 #include "backend/arm82/Arm82Binary.hpp"
12 #include "backend/arm82/Arm82Backend.hpp"
13 #include "backend/cpu/BinaryUtils.hpp"
14 #include "core/Macro.h"
15
16 #include <arm_neon.h>
17
18 namespace MNN {
19 template<typename Func>
Arm82BinaryWrap(void * dstRaw,const void * src0Raw,const void * src1Raw,const int elementSize,const int needBroadcastIndex)20 void Arm82BinaryWrap(void *dstRaw, const void *src0Raw, const void *src1Raw, const int elementSize, const int needBroadcastIndex) {
21 auto dst = (FLOAT16*)dstRaw;
22 auto src0 = (const FLOAT16*)src0Raw;
23 auto src1 = (const FLOAT16*)src1Raw;
24 Func compute;
25 const int sizeDivUnit = elementSize / 4;
26 const int remainCount = elementSize - sizeDivUnit * 4;
27
28 float A[4];
29 float B[4];
30 float C[4];
31 if (-1 == needBroadcastIndex) {
32 if (sizeDivUnit > 0) {
33 for (int i = 0; i < sizeDivUnit; ++i) {
34 const auto src0Ptr = src0;
35 const auto src1Ptr = src1;
36 auto dstPtr = dst;
37 vst1q_f32(A, vcvt_f32_f16(vld1_f16(src0Ptr)));
38 vst1q_f32(B, vcvt_f32_f16(vld1_f16(src1Ptr)));
39 for (int v = 0; v < 4; ++ v) {
40 C[v] = compute(A[v], B[v]);
41 }
42 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
43 src0 += 4;
44 src1 += 4;
45 dst += 4;
46 }
47 }
48 if (remainCount > 0) {
49 FLOAT16 tempSrc0[4];
50 FLOAT16 tempSrc1[4];
51 FLOAT16 tempDst[4];
52 ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
53 ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
54 vst1q_f32(A, vcvt_f32_f16(vld1_f16(tempSrc0)));
55 vst1q_f32(B, vcvt_f32_f16(vld1_f16(tempSrc1)));
56 for (int v = 0; v < remainCount; ++ v) {
57 C[v] = compute(A[v], B[v]);
58 }
59 vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
60 ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
61 }
62 } else if (0 == needBroadcastIndex) {
63 const FLOAT16 srcValue0 = src0[0];
64 float16x4_t a = vmov_n_f16(srcValue0);
65 vst1q_f32(A, vcvt_f32_f16(a));
66 if (sizeDivUnit > 0) {
67 for (int i = 0; i < sizeDivUnit; ++i) {
68 const auto src1Ptr = src1;
69 auto dstPtr = dst;
70 vst1q_f32(B, vcvt_f32_f16(vld1_f16(src1Ptr)));
71 for (int v = 0; v < 4; ++ v) {
72 C[v] = compute(A[v], B[v]);
73 }
74 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
75 src1 += 4;
76 dst += 4;
77 }
78 }
79 if (remainCount > 0) {
80 FLOAT16 tempSrc1[4];
81 FLOAT16 tempDst[4];
82 ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
83 vst1q_f32(B, vcvt_f32_f16(vld1_f16(tempSrc1)));
84 for (int v = 0; v < remainCount; ++ v) {
85 C[v] = compute(A[v], B[v]);
86 }
87 vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
88 ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
89 }
90 } else {
91 const FLOAT16 srcValue1 = src1[0];
92 float16x4_t b = vmov_n_f16(srcValue1);
93 vst1q_f32(B, vcvt_f32_f16(b));
94 if (sizeDivUnit > 0) {
95 for (int i = 0; i < sizeDivUnit; ++i) {
96 const auto src0Ptr = src0;
97 auto dstPtr = dst;
98 vst1q_f32(A, vcvt_f32_f16(vld1_f16(src0Ptr)));
99 for (int v = 0; v < 4; ++ v) {
100 C[v] = compute(A[v], B[v]);
101 }
102 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
103 src0 += 4;
104 dst += 4;
105 }
106 }
107 if (remainCount > 0) {
108 FLOAT16 tempSrc0[4];
109 FLOAT16 tempDst[4];
110 ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
111 vst1q_f32(A, vcvt_f32_f16(vld1_f16(tempSrc0)));
112 for (int v = 0; v < remainCount; ++ v) {
113 C[v] = compute(A[v], B[v]);
114 }
115 vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
116 ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
117 }
118 }
119 }
120
121
122 template<typename Func>
Arm82Binary(void * dstRaw,const void * src0Raw,const void * src1Raw,const int elementSize,const int needBroadcastIndex)123 void Arm82Binary(void *dstRaw, const void *src0Raw, const void *src1Raw, const int elementSize, const int needBroadcastIndex) {
124 auto dst = (FLOAT16*)dstRaw;
125 auto src0 = (FLOAT16*)src0Raw;
126 auto src1 = (FLOAT16*)src1Raw;
127 Func compute;
128 const int sizeDivUnit = elementSize / ARMV82_CHANNEL_UNIT;
129 const int remainCount = elementSize - sizeDivUnit * ARMV82_CHANNEL_UNIT;
130
131 if (-1 == needBroadcastIndex) {
132 if (sizeDivUnit > 0) {
133 for (int i = 0; i < sizeDivUnit; ++i) {
134 const auto src0Ptr = src0;
135 const auto src1Ptr = src1;
136 auto dstPtr = dst;
137 float16x8_t a = vld1q_f16(src0Ptr);
138 float16x8_t b = vld1q_f16(src1Ptr);
139 vst1q_f16(dstPtr, compute(a, b));
140 src0 += 8;
141 src1 += 8;
142 dst += 8;
143 }
144 }
145 if (remainCount > 0) {
146 FLOAT16 tempSrc0[8];
147 FLOAT16 tempSrc1[8];
148 FLOAT16 tempDst[8];
149 ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
150 ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
151 float16x8_t a = vld1q_f16(tempSrc0);
152 float16x8_t b = vld1q_f16(tempSrc1);
153 vst1q_f16(tempDst, compute(a, b));
154 ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
155 }
156 } else if (0 == needBroadcastIndex) {
157 const FLOAT16 srcValue0 = src0[0];
158 float16x8_t a = vmovq_n_f16(srcValue0);
159 if (sizeDivUnit > 0) {
160 for (int i = 0; i < sizeDivUnit; ++i) {
161 const auto src1Ptr = src1;
162 auto dstPtr = dst;
163 float16x8_t b = vld1q_f16(src1Ptr);
164 vst1q_f16(dstPtr, compute(a, b));
165 src1 += 8;
166 dst += 8;
167 }
168 }
169 if (remainCount > 0) {
170 FLOAT16 tempSrc1[8];
171 FLOAT16 tempDst[8];
172 ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
173 float16x8_t b = vld1q_f16(tempSrc1);
174 vst1q_f16(tempDst, compute(a, b));
175 ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
176 }
177 } else {
178 const FLOAT16 srcValue1 = src1[0];
179 float16x8_t b = vmovq_n_f16(srcValue1);
180 if (sizeDivUnit > 0) {
181 for (int i = 0; i < sizeDivUnit; ++i) {
182 const auto src0Ptr = src0;
183 auto dstPtr = dst;
184 float16x8_t a = vld1q_f16(src0Ptr);
185 vst1q_f16(dstPtr, compute(a, b));
186 src0 += 8;
187 dst += 8;
188 }
189 }
190 if (remainCount > 0) {
191 FLOAT16 tempSrc0[8];
192 FLOAT16 tempDst[8];
193 ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
194 float16x8_t a = vld1q_f16(tempSrc0);
195 vst1q_f16(tempDst, compute(a, b));
196 ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
197 }
198 }
199 }
200
201
202 struct VecBinaryAdd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryAdd203 float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
204 return vaddq_f16(x, y);
205 }
206 };
207
208 struct VecBinarySub : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinarySub209 float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
210 return vsubq_f16(x, y);
211 }
212 };
213
214 struct VecBinaryMul : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryMul215 float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
216 return vmulq_f16(x, y);
217 }
218 };
219
220 struct VecBinaryMin : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryMin221 float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
222 return vminq_f16(x, y);
223 }
224 };
225
226 struct VecBinaryMax : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryMax227 float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
228 return vmaxq_f16(x, y);
229 }
230 };
231
232 struct VecBinarySqd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinarySqd233 float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
234 return vmulq_f16(vsubq_f16(x, y), vsubq_f16(x, y));
235 }
236 };
237
238
select(int32_t type)239 MNNBinaryExecute Arm82BinaryFloat::select(int32_t type) {
240 switch (type) {
241 case BinaryOpOperation_ADD:
242 return Arm82Binary<VecBinaryAdd>;
243 break;
244 case BinaryOpOperation_SUB:
245 return Arm82Binary<VecBinarySub>;
246 break;
247 case BinaryOpOperation_MUL:
248 return Arm82Binary<VecBinaryMul>;
249 break;
250 case BinaryOpOperation_MINIMUM:
251 return Arm82Binary<VecBinaryMin>;
252 break;
253 case BinaryOpOperation_MAXIMUM:
254 return Arm82Binary<VecBinaryMax>;
255 break;
256 case BinaryOpOperation_SquaredDifference:
257 return Arm82Binary<VecBinarySqd>;
258 break;
259 case BinaryOpOperation_REALDIV:
260 return Arm82BinaryWrap<BinaryRealDiv<float, float, float>>;
261 break;
262 case BinaryOpOperation_FLOORDIV:
263 return Arm82BinaryWrap<BinaryFloorDiv<float, float, float>>;
264 break;
265 case BinaryOpOperation_FLOORMOD:
266 return Arm82BinaryWrap<BinaryFloorMod<float, float, float>>;
267 break;
268 case BinaryOpOperation_POW:
269 return Arm82BinaryWrap<BinaryPow<float, float, float>>;
270 break;
271 case BinaryOpOperation_ATAN2:
272 return Arm82BinaryWrap<BinaryAtan2<float, float, float>>;
273 break;
274 case BinaryOpOperation_MOD:
275 return Arm82BinaryWrap<BinaryMod<float, float, float>>;
276 break;
277 default:
278 return nullptr;
279 break;
280 }
281 return nullptr;
282 }
283
284 } // namespace MNN
285 #endif
286