1 //
2 //  Arm82Binary.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2021/01/05.
6 //  Copyright © 2021, Alibaba Group Holding Limited
7 //
8 
9 #if defined(__ANDROID__) || defined(__aarch64__)
10 #include <algorithm>
11 #include "backend/arm82/Arm82Binary.hpp"
12 #include "backend/arm82/Arm82Backend.hpp"
13 #include "backend/cpu/BinaryUtils.hpp"
14 #include "core/Macro.h"
15 
16 #include <arm_neon.h>
17 
18 namespace MNN {
19 template<typename Func>
Arm82BinaryWrap(void * dstRaw,const void * src0Raw,const void * src1Raw,const int elementSize,const int needBroadcastIndex)20 void Arm82BinaryWrap(void *dstRaw, const void *src0Raw, const void *src1Raw, const int elementSize, const int needBroadcastIndex) {
21     auto dst = (FLOAT16*)dstRaw;
22     auto src0 = (const FLOAT16*)src0Raw;
23     auto src1 = (const FLOAT16*)src1Raw;
24     Func compute;
25     const int sizeDivUnit = elementSize / 4;
26     const int remainCount = elementSize - sizeDivUnit * 4;
27 
28     float A[4];
29     float B[4];
30     float C[4];
31     if (-1 == needBroadcastIndex) {
32         if (sizeDivUnit > 0) {
33             for (int i = 0; i < sizeDivUnit; ++i) {
34                 const auto src0Ptr = src0;
35                 const auto src1Ptr = src1;
36                 auto dstPtr = dst;
37                 vst1q_f32(A, vcvt_f32_f16(vld1_f16(src0Ptr)));
38                 vst1q_f32(B, vcvt_f32_f16(vld1_f16(src1Ptr)));
39                 for (int v = 0; v < 4; ++ v) {
40                     C[v] = compute(A[v], B[v]);
41                 }
42                 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
43                 src0 += 4;
44                 src1 += 4;
45                 dst += 4;
46             }
47         }
48         if (remainCount > 0) {
49             FLOAT16 tempSrc0[4];
50             FLOAT16 tempSrc1[4];
51             FLOAT16 tempDst[4];
52             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
53             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
54             vst1q_f32(A, vcvt_f32_f16(vld1_f16(tempSrc0)));
55             vst1q_f32(B, vcvt_f32_f16(vld1_f16(tempSrc1)));
56             for (int v = 0; v < remainCount; ++ v) {
57                 C[v] = compute(A[v], B[v]);
58             }
59             vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
60             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
61         }
62     } else if (0 == needBroadcastIndex) {
63         const FLOAT16 srcValue0 = src0[0];
64         float16x4_t a = vmov_n_f16(srcValue0);
65         vst1q_f32(A, vcvt_f32_f16(a));
66         if (sizeDivUnit > 0) {
67             for (int i = 0; i < sizeDivUnit; ++i) {
68                 const auto src1Ptr = src1;
69                 auto dstPtr = dst;
70                 vst1q_f32(B, vcvt_f32_f16(vld1_f16(src1Ptr)));
71                 for (int v = 0; v < 4; ++ v) {
72                     C[v] = compute(A[v], B[v]);
73                 }
74                 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
75                 src1 += 4;
76                 dst += 4;
77             }
78         }
79         if (remainCount > 0) {
80             FLOAT16 tempSrc1[4];
81             FLOAT16 tempDst[4];
82             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
83             vst1q_f32(B, vcvt_f32_f16(vld1_f16(tempSrc1)));
84             for (int v = 0; v < remainCount; ++ v) {
85                 C[v] = compute(A[v], B[v]);
86             }
87             vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
88             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
89         }
90     } else {
91         const FLOAT16 srcValue1 = src1[0];
92         float16x4_t b = vmov_n_f16(srcValue1);
93         vst1q_f32(B, vcvt_f32_f16(b));
94         if (sizeDivUnit > 0) {
95             for (int i = 0; i < sizeDivUnit; ++i) {
96                 const auto src0Ptr = src0;
97                 auto dstPtr = dst;
98                 vst1q_f32(A, vcvt_f32_f16(vld1_f16(src0Ptr)));
99                 for (int v = 0; v < 4; ++ v) {
100                     C[v] = compute(A[v], B[v]);
101                 }
102                 vst1_f16(dstPtr, vcvt_f16_f32(vld1q_f32(C)));
103                 src0 += 4;
104                 dst += 4;
105             }
106         }
107         if (remainCount > 0) {
108             FLOAT16 tempSrc0[4];
109             FLOAT16 tempDst[4];
110             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
111             vst1q_f32(A, vcvt_f32_f16(vld1_f16(tempSrc0)));
112             for (int v = 0; v < remainCount; ++ v) {
113                 C[v] = compute(A[v], B[v]);
114             }
115             vst1_f16(tempDst, vcvt_f16_f32(vld1q_f32(C)));
116             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
117         }
118     }
119 }
120 
121 
122 template<typename Func>
Arm82Binary(void * dstRaw,const void * src0Raw,const void * src1Raw,const int elementSize,const int needBroadcastIndex)123 void Arm82Binary(void *dstRaw, const void *src0Raw, const void *src1Raw, const int elementSize, const int needBroadcastIndex) {
124     auto dst = (FLOAT16*)dstRaw;
125     auto src0 = (FLOAT16*)src0Raw;
126     auto src1 = (FLOAT16*)src1Raw;
127     Func compute;
128     const int sizeDivUnit = elementSize / ARMV82_CHANNEL_UNIT;
129     const int remainCount = elementSize - sizeDivUnit * ARMV82_CHANNEL_UNIT;
130 
131     if (-1 == needBroadcastIndex) {
132         if (sizeDivUnit > 0) {
133             for (int i = 0; i < sizeDivUnit; ++i) {
134                 const auto src0Ptr = src0;
135                 const auto src1Ptr = src1;
136                 auto dstPtr = dst;
137                 float16x8_t a = vld1q_f16(src0Ptr);
138                 float16x8_t b = vld1q_f16(src1Ptr);
139                 vst1q_f16(dstPtr, compute(a, b));
140                 src0 += 8;
141                 src1 += 8;
142                 dst += 8;
143             }
144         }
145         if (remainCount > 0) {
146             FLOAT16 tempSrc0[8];
147             FLOAT16 tempSrc1[8];
148             FLOAT16 tempDst[8];
149             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
150             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
151             float16x8_t a = vld1q_f16(tempSrc0);
152             float16x8_t b = vld1q_f16(tempSrc1);
153             vst1q_f16(tempDst, compute(a, b));
154             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
155         }
156     } else if (0 == needBroadcastIndex) {
157         const FLOAT16 srcValue0 = src0[0];
158         float16x8_t a = vmovq_n_f16(srcValue0);
159         if (sizeDivUnit > 0) {
160             for (int i = 0; i < sizeDivUnit; ++i) {
161                 const auto src1Ptr = src1;
162                 auto dstPtr = dst;
163                 float16x8_t b = vld1q_f16(src1Ptr);
164                 vst1q_f16(dstPtr, compute(a, b));
165                 src1 += 8;
166                 dst += 8;
167             }
168         }
169         if (remainCount > 0) {
170             FLOAT16 tempSrc1[8];
171             FLOAT16 tempDst[8];
172             ::memcpy(tempSrc1, src1, remainCount * sizeof(FLOAT16));
173             float16x8_t b = vld1q_f16(tempSrc1);
174             vst1q_f16(tempDst, compute(a, b));
175             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
176         }
177     } else {
178         const FLOAT16 srcValue1 = src1[0];
179         float16x8_t b = vmovq_n_f16(srcValue1);
180         if (sizeDivUnit > 0) {
181             for (int i = 0; i < sizeDivUnit; ++i) {
182                 const auto src0Ptr = src0;
183                 auto dstPtr = dst;
184                 float16x8_t a = vld1q_f16(src0Ptr);
185                 vst1q_f16(dstPtr, compute(a, b));
186                 src0 += 8;
187                 dst += 8;
188             }
189         }
190         if (remainCount > 0) {
191             FLOAT16 tempSrc0[8];
192             FLOAT16 tempDst[8];
193             ::memcpy(tempSrc0, src0, remainCount * sizeof(FLOAT16));
194             float16x8_t a = vld1q_f16(tempSrc0);
195             vst1q_f16(tempDst, compute(a, b));
196             ::memcpy(dst, tempDst, remainCount * sizeof(FLOAT16));
197         }
198     }
199 }
200 
201 
202 struct VecBinaryAdd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryAdd203     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
204         return vaddq_f16(x, y);
205     }
206 };
207 
208 struct VecBinarySub : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinarySub209     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
210         return vsubq_f16(x, y);
211     }
212 };
213 
214 struct VecBinaryMul : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryMul215     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
216         return vmulq_f16(x, y);
217     }
218 };
219 
220 struct VecBinaryMin : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryMin221     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
222         return vminq_f16(x, y);
223     }
224 };
225 
226 struct VecBinaryMax : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinaryMax227     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
228         return vmaxq_f16(x, y);
229     }
230 };
231 
232 struct VecBinarySqd : std::binary_function<float16x8_t, float16x8_t, float16x8_t> {
operator ()MNN::VecBinarySqd233     float16x8_t operator()(const float16x8_t& x, const float16x8_t& y) const {
234         return vmulq_f16(vsubq_f16(x, y), vsubq_f16(x, y));
235     }
236 };
237 
238 
select(int32_t type)239 MNNBinaryExecute Arm82BinaryFloat::select(int32_t type) {
240     switch (type) {
241         case BinaryOpOperation_ADD:
242             return Arm82Binary<VecBinaryAdd>;
243             break;
244         case BinaryOpOperation_SUB:
245             return Arm82Binary<VecBinarySub>;
246             break;
247         case BinaryOpOperation_MUL:
248             return Arm82Binary<VecBinaryMul>;
249             break;
250         case BinaryOpOperation_MINIMUM:
251             return Arm82Binary<VecBinaryMin>;
252             break;
253         case BinaryOpOperation_MAXIMUM:
254             return Arm82Binary<VecBinaryMax>;
255             break;
256         case BinaryOpOperation_SquaredDifference:
257             return Arm82Binary<VecBinarySqd>;
258             break;
259         case BinaryOpOperation_REALDIV:
260             return Arm82BinaryWrap<BinaryRealDiv<float, float, float>>;
261             break;
262         case BinaryOpOperation_FLOORDIV:
263             return Arm82BinaryWrap<BinaryFloorDiv<float, float, float>>;
264             break;
265         case BinaryOpOperation_FLOORMOD:
266             return Arm82BinaryWrap<BinaryFloorMod<float, float, float>>;
267             break;
268         case BinaryOpOperation_POW:
269             return Arm82BinaryWrap<BinaryPow<float, float, float>>;
270             break;
271         case BinaryOpOperation_ATAN2:
272             return Arm82BinaryWrap<BinaryAtan2<float, float, float>>;
273             break;
274         case BinaryOpOperation_MOD:
275             return Arm82BinaryWrap<BinaryMod<float, float, float>>;
276             break;
277         default:
278             return nullptr;
279             break;
280     }
281     return nullptr;
282 }
283 
284 } // namespace MNN
285 #endif
286