1 //
2 // Arm82Unary.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/08/02.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8 #if defined(__ANDROID__) || defined(__aarch64__)
9
10 #include <vector>
11 #include <cmath>
12 #include <algorithm>
13 #include "Arm82Unary.hpp"
14 #include "Arm82Backend.hpp"
15 #include "core/Macro.h"
16 #include "core/OpCommonUtils.hpp"
17 #include "core/Concurrency.h"
18 #include "backend/cpu/UnaryUtils.hpp"
19 #include "Arm82OptFunc.hpp"
20 #include "MNN_generated.h"
21 #include <arm_neon.h>
22 namespace MNN {
23
24 struct VecSquare {
operator ()MNN::VecSquare25 float16x8_t operator()(float16x8_t &x) const {
26 return x * x;
27 }
28 };
29 struct VecRsqrt {
operator ()MNN::VecRsqrt30 float16x8_t operator()(float16x8_t &x) const {
31 return vrsqrteq_f16(x);
32 }
33 };
34
35 struct VecNeg {
operator ()MNN::VecNeg36 float16x8_t operator()(float16x8_t &x) const {
37 return vnegq_f16(x);
38 }
39 };
40
41 struct VecAbs {
operator ()MNN::VecAbs42 float16x8_t operator()(float16x8_t &x) const {
43 return vabsq_f16(x);
44 }
45 };
46 struct VecRecipocal {
operator ()MNN::VecRecipocal47 float16x8_t operator()(float16x8_t &x) const {
48 return vrecpeq_f16(x);
49 }
50 };
51
52 #if defined(__aarch64__)
53 struct VecSqrt {
operator ()MNN::VecSqrt54 float16x8_t operator()(float16x8_t &x) const {
55 return vabsq_f16(x);
56 }
57 };
58 #endif
59
60 template<typename Compute>
FP16VecUnary(void * dstRaw,const void * src0Raw,int elementSize)61 void FP16VecUnary(void *dstRaw, const void *src0Raw, int elementSize) {
62 Compute Func;
63 auto dst = (float16_t*)dstRaw;
64 auto src0 = (const float16_t*)src0Raw;
65 const int sizeDivUnit = elementSize / 8;
66 const int remainCount = elementSize - sizeDivUnit * 8;
67
68 if (sizeDivUnit > 0) {
69 for (int i = 0; i < sizeDivUnit; ++i) {
70 float16x8_t a = vld1q_f16(src0);
71 vst1q_f16(dst, Func(a));
72 src0 += 8;
73 dst += 8;
74 }
75 }
76 if (remainCount > 0) {
77 float16_t tempSrc0[8];
78 float16_t tempDst[8];
79 ::memcpy(tempSrc0, src0, remainCount * sizeof(int16_t));
80 float16x8_t a = vld1q_f16(tempSrc0);
81 vst1q_f16(tempDst, Func(a));
82 ::memcpy(dst, tempDst, remainCount * sizeof(int16_t));
83 }
84 }
85 #define BLOCK_SIZE 16
86 template<typename Compute>
_Wrap(void * outRaw,const void * inpRaw,int realSize)87 static void _Wrap(void* outRaw, const void* inpRaw, int realSize) {
88 Compute execute;
89 float out[BLOCK_SIZE];
90 float inp[BLOCK_SIZE];
91 int b = realSize / BLOCK_SIZE;
92 int remain = realSize % BLOCK_SIZE;
93 auto outR = (int16_t*)outRaw;
94 auto inpR = (const int16_t*)inpRaw;
95 for (int i=0; i<b; ++i) {
96 MNNDequantizeFP16(inpR, inp, BLOCK_SIZE);
97 execute(out, inp, BLOCK_SIZE);
98 MNNQuantizeFP16(out, outR, BLOCK_SIZE);
99 outR += BLOCK_SIZE;
100 inpR += BLOCK_SIZE;
101 }
102 if (remain > 0) {
103 MNNDequantizeFP16(inpR, inp, remain);
104 execute(out, inp, remain);
105 MNNQuantizeFP16(out, outR, remain);
106 }
107 }
108
109 struct _Exp {
operator ()MNN::_Exp110 void operator()(void* outRaw, const void* inpRaw, int realSize) const {
111 auto out = (float*)outRaw;
112 auto inp = (const float*)inpRaw;
113 MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
114 MNNExp(out, out, realSize);
115 }
116 };
117 struct _ExpM1 {
operator ()MNN::_ExpM1118 void operator()(void* outRaw, const void* inpRaw, int realSize) const {
119 auto out = (float*)outRaw;
120 auto inp = (const float*)inpRaw;
121 MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
122 MNNExp(out, out, realSize);
123 for (int i=0; i<realSize; ++i) {
124 out[i] = out[i] - 1.0f;
125 }
126 }
127 };
128
129 struct _Tanh {
operator ()MNN::_Tanh130 void operator()(void* outRaw, const void* inpRaw, int realSize) const {
131 auto out = (float*)outRaw;
132 auto inp = (const float*)inpRaw;
133 MNNTanh(out, inp, realSize);
134 }
135 };
136 struct _Sigmoid {
operator ()MNN::_Sigmoid137 void operator()(void* outRaw, const void* inpRaw, int realSize) const {
138 auto out = (float*)outRaw;
139 auto inp = (const float*)inpRaw;
140 MNNSigmoidLowp(out, inp, realSize);
141 }
142 };
143
FP16HardSwish(void * outRaw,const void * inpRaw,int realSize)144 void FP16HardSwish(void* outRaw, const void* inpRaw, int realSize) {
145 auto out = (FLOAT16*)outRaw;
146 auto inp = (const FLOAT16*)inpRaw;
147 int sizeC8 = realSize / 8;
148 int sizeRemain = realSize % 8;
149 if (sizeC8 > 0) {
150 float16x8_t zero = vdupq_n_f16(0.f);
151 float16x8_t three = vdupq_n_f16(3.f);
152 float16x8_t six = vdupq_n_f16(6.f);
153 float16x8_t divsix = vdupq_n_f16(1.0f/6.f);
154 for (int i = 0; i < sizeC8; i++) {
155 auto x = vld1q_f16(inp);
156 auto y = vmulq_f16(vmulq_f16(x, vminq_f16(vmaxq_f16(vaddq_f16(x, three), zero), six)), divsix);
157 vst1q_f16(out, y);
158 out += 8;
159 inp += 8;
160 }
161 }
162 for (int i=0; i<sizeRemain; ++i) {
163 auto x = inp[i];
164 float16_t y;
165 if (x <= -3) {
166 y = 0;
167 } else if (x >= 3) {
168 y = x;
169 } else {
170 y = x * (x + 3) / 6;
171 }
172 out[i] = y;
173 }
174 }
175
176 template <typename Func, typename T>
177 struct _Unary {
operator ()MNN::_Unary178 void operator()(void* outputPtr, const void* inputPtr, int elementSize) const {
179 Func f;
180 const T *inputData = (T*)inputPtr;
181 T *outputData = (T *)outputPtr;
182 for (int i=0; i<elementSize; ++i) {
183 outputData[i] = f(inputData[i]);
184 }
185 }
186 };
187
select(int type,int precision)188 MNNUnaryExecute Arm82Unary::select(int type, int precision) {
189 switch (type) {
190 case UnaryOpOperation_ABS:
191 return FP16VecUnary<VecAbs>;
192 case UnaryOpOperation_SQUARE:
193 return FP16VecUnary<VecSquare>;
194 case UnaryOpOperation_NEG:
195 return FP16VecUnary<VecNeg>;
196 case UnaryOpOperation_RSQRT:
197 return FP16VecUnary<VecRsqrt>;
198 case UnaryOpOperation_EXP:
199 return _Wrap<_Exp>;
200 case UnaryOpOperation_COS:
201 return _Wrap<_Unary<UnaryCos<float>, float>>;
202 case UnaryOpOperation_SIN:
203 return _Wrap<_Unary<UnarySin<float>, float>>;
204 case UnaryOpOperation_SIGMOID:
205 return _Wrap<_Sigmoid>;
206 case UnaryOpOperation_TANH:
207 return _Wrap<_Tanh>;
208 case UnaryOpOperation_TAN:
209 return _Wrap<_Unary<UnaryTan<float>, float>>;
210 case UnaryOpOperation_ATAN:
211 return _Wrap<_Unary<UnaryATan<float>, float>>;
212 #if defined(__aarch64__)
213 case UnaryOpOperation_SQRT:
214 return FP16VecUnary<VecSqrt>;
215 #else
216 case UnaryOpOperation_SQRT:
217 return _Wrap<_Unary<UnarySqrt<float>, float>>;
218 #endif
219 case UnaryOpOperation_CEIL:
220 return _Wrap<_Unary<UnaryCeil<float>, float>>;
221 case UnaryOpOperation_RECIPROCAL:
222 return FP16VecUnary<VecRecipocal>;
223 case UnaryOpOperation_LOG1P:
224 return _Wrap<_Unary<UnaryLog1p<float>, float>>;
225 case UnaryOpOperation_LOG:
226 return _Wrap<_Unary<UnaryLog<float>, float>>;
227 case UnaryOpOperation_FLOOR:
228 return _Wrap<_Unary<UnaryFloor<float>, float>>;
229 case UnaryOpOperation_BNLL:
230 return _Wrap<_Unary<UnaryBNLL<float>, float>>;
231 case UnaryOpOperation_ACOSH:
232 return _Wrap<_Unary<UnaryAcosh<float>, float>>;
233 case UnaryOpOperation_SINH:
234 return _Wrap<_Unary<UnarySinh<float>, float>>;
235 case UnaryOpOperation_ASINH:
236 return _Wrap<_Unary<UnaryAsinh<float>, float>>;
237 case UnaryOpOperation_ATANH:
238 return _Wrap<_Unary<UnaryAtanh<float>, float>>;
239 case UnaryOpOperation_SIGN:
240 return _Wrap<_Unary<UnarySign<float>, float>>;
241 case UnaryOpOperation_ROUND:
242 return _Wrap<_Unary<UnaryRound<float>, float>>;
243 case UnaryOpOperation_COSH:
244 return _Wrap<_Unary<UnaryCosh<float>, float>>;
245 case UnaryOpOperation_ERF:
246 return _Wrap<_Unary<UnaryErf<float>, float>>;
247 case UnaryOpOperation_ERFC:
248 return _Wrap<_Unary<UnaryErfc<float>, float>>;
249 case UnaryOpOperation_ERFINV:
250 return _Wrap<_Unary<UnaryErfinv<float>, float>>;
251 case UnaryOpOperation_EXPM1:
252 return _Wrap<_ExpM1>;
253 case UnaryOpOperation_ASIN:
254 return _Wrap<_Unary<UnaryAsin<float>, float>>;
255 case UnaryOpOperation_ACOS:
256 return _Wrap<_Unary<UnaryAcos<float>, float>>;
257 case UnaryOpOperation_HARDSWISH:
258 return FP16HardSwish;
259 default:
260 MNN_ASSERT(false);
261 break;
262 }
263 return nullptr;
264 }
265 } // namespace MNN
266
267 #endif
268