1 //
2 //  Arm82Unary.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/08/02.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 #if defined(__ANDROID__) || defined(__aarch64__)
9 
10 #include <vector>
11 #include <cmath>
12 #include <algorithm>
13 #include "Arm82Unary.hpp"
14 #include "Arm82Backend.hpp"
15 #include "core/Macro.h"
16 #include "core/OpCommonUtils.hpp"
17 #include "core/Concurrency.h"
18 #include "backend/cpu/UnaryUtils.hpp"
19 #include "Arm82OptFunc.hpp"
20 #include "MNN_generated.h"
21 #include <arm_neon.h>
22 namespace MNN {
23 
24 struct VecSquare {
operator ()MNN::VecSquare25     float16x8_t operator()(float16x8_t &x) const {
26         return x * x;
27     }
28 };
29 struct VecRsqrt {
operator ()MNN::VecRsqrt30     float16x8_t operator()(float16x8_t &x) const {
31         return vrsqrteq_f16(x);
32     }
33 };
34 
35 struct VecNeg {
operator ()MNN::VecNeg36     float16x8_t operator()(float16x8_t &x) const {
37         return vnegq_f16(x);
38     }
39 };
40 
41 struct VecAbs {
operator ()MNN::VecAbs42     float16x8_t operator()(float16x8_t &x) const {
43         return vabsq_f16(x);
44     }
45 };
46 struct VecRecipocal {
operator ()MNN::VecRecipocal47     float16x8_t operator()(float16x8_t &x) const {
48         return vrecpeq_f16(x);
49     }
50 };
51 
52 #if defined(__aarch64__)
53 struct VecSqrt {
operator ()MNN::VecSqrt54     float16x8_t operator()(float16x8_t &x) const {
55         return vabsq_f16(x);
56     }
57 };
58 #endif
59 
60 template<typename Compute>
FP16VecUnary(void * dstRaw,const void * src0Raw,int elementSize)61 void FP16VecUnary(void *dstRaw, const void *src0Raw, int elementSize) {
62     Compute Func;
63     auto dst = (float16_t*)dstRaw;
64     auto src0 = (const float16_t*)src0Raw;
65     const int sizeDivUnit = elementSize / 8;
66     const int remainCount = elementSize - sizeDivUnit * 8;
67 
68     if (sizeDivUnit > 0) {
69         for (int i = 0; i < sizeDivUnit; ++i) {
70             float16x8_t a = vld1q_f16(src0);
71             vst1q_f16(dst, Func(a));
72             src0 += 8;
73             dst += 8;
74         }
75     }
76     if (remainCount > 0) {
77         float16_t tempSrc0[8];
78         float16_t tempDst[8];
79         ::memcpy(tempSrc0, src0, remainCount * sizeof(int16_t));
80         float16x8_t a = vld1q_f16(tempSrc0);
81         vst1q_f16(tempDst, Func(a));
82         ::memcpy(dst, tempDst, remainCount * sizeof(int16_t));
83     }
84 }
85 #define BLOCK_SIZE 16
86 template<typename Compute>
_Wrap(void * outRaw,const void * inpRaw,int realSize)87 static void _Wrap(void* outRaw, const void* inpRaw, int realSize) {
88     Compute execute;
89     float out[BLOCK_SIZE];
90     float inp[BLOCK_SIZE];
91     int b = realSize / BLOCK_SIZE;
92     int remain = realSize % BLOCK_SIZE;
93     auto outR = (int16_t*)outRaw;
94     auto inpR = (const int16_t*)inpRaw;
95     for (int i=0; i<b; ++i) {
96         MNNDequantizeFP16(inpR, inp, BLOCK_SIZE);
97         execute(out, inp, BLOCK_SIZE);
98         MNNQuantizeFP16(out, outR, BLOCK_SIZE);
99         outR += BLOCK_SIZE;
100         inpR += BLOCK_SIZE;
101     }
102     if (remain > 0) {
103         MNNDequantizeFP16(inpR, inp, remain);
104         execute(out, inp, remain);
105         MNNQuantizeFP16(out, outR, remain);
106     }
107 }
108 
109 struct _Exp {
operator ()MNN::_Exp110     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
111         auto out = (float*)outRaw;
112         auto inp = (const float*)inpRaw;
113         MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
114         MNNExp(out, out, realSize);
115     }
116 };
117 struct _ExpM1 {
operator ()MNN::_ExpM1118     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
119         auto out = (float*)outRaw;
120         auto inp = (const float*)inpRaw;
121         MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
122         MNNExp(out, out, realSize);
123         for (int i=0; i<realSize; ++i) {
124             out[i] = out[i] - 1.0f;
125         }
126     }
127 };
128 
129 struct _Tanh {
operator ()MNN::_Tanh130     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
131         auto out = (float*)outRaw;
132         auto inp = (const float*)inpRaw;
133         MNNTanh(out, inp, realSize);
134     }
135 };
136 struct _Sigmoid {
operator ()MNN::_Sigmoid137     void operator()(void* outRaw, const void* inpRaw, int realSize) const {
138         auto out = (float*)outRaw;
139         auto inp = (const float*)inpRaw;
140         MNNSigmoidLowp(out, inp, realSize);
141     }
142 };
143 
FP16HardSwish(void * outRaw,const void * inpRaw,int realSize)144 void FP16HardSwish(void* outRaw, const void* inpRaw, int realSize) {
145     auto out = (FLOAT16*)outRaw;
146     auto inp = (const FLOAT16*)inpRaw;
147     int sizeC8 = realSize / 8;
148     int sizeRemain = realSize % 8;
149     if (sizeC8 > 0) {
150         float16x8_t zero = vdupq_n_f16(0.f);
151         float16x8_t three = vdupq_n_f16(3.f);
152         float16x8_t six = vdupq_n_f16(6.f);
153         float16x8_t divsix = vdupq_n_f16(1.0f/6.f);
154         for (int i = 0; i < sizeC8; i++) {
155             auto x = vld1q_f16(inp);
156             auto y = vmulq_f16(vmulq_f16(x, vminq_f16(vmaxq_f16(vaddq_f16(x, three), zero), six)), divsix);
157             vst1q_f16(out, y);
158             out += 8;
159             inp += 8;
160         }
161     }
162     for (int i=0; i<sizeRemain; ++i) {
163         auto x = inp[i];
164         float16_t y;
165         if (x <= -3) {
166             y = 0;
167         } else if (x >= 3) {
168             y = x;
169         } else {
170             y = x * (x + 3) / 6;
171         }
172         out[i] = y;
173     }
174 }
175 
176 template <typename Func, typename T>
177 struct _Unary {
operator ()MNN::_Unary178     void operator()(void* outputPtr, const void* inputPtr, int elementSize) const {
179         Func f;
180         const T *inputData = (T*)inputPtr;
181         T *outputData      = (T *)outputPtr;
182         for (int i=0; i<elementSize; ++i) {
183             outputData[i] = f(inputData[i]);
184         }
185     }
186 };
187 
select(int type,int precision)188 MNNUnaryExecute Arm82Unary::select(int type, int precision) {
189     switch (type) {
190         case UnaryOpOperation_ABS:
191             return FP16VecUnary<VecAbs>;
192         case UnaryOpOperation_SQUARE:
193             return FP16VecUnary<VecSquare>;
194         case UnaryOpOperation_NEG:
195             return FP16VecUnary<VecNeg>;
196         case UnaryOpOperation_RSQRT:
197             return FP16VecUnary<VecRsqrt>;
198         case UnaryOpOperation_EXP:
199             return _Wrap<_Exp>;
200         case UnaryOpOperation_COS:
201             return _Wrap<_Unary<UnaryCos<float>, float>>;
202         case UnaryOpOperation_SIN:
203             return _Wrap<_Unary<UnarySin<float>, float>>;
204         case UnaryOpOperation_SIGMOID:
205             return _Wrap<_Sigmoid>;
206         case UnaryOpOperation_TANH:
207             return _Wrap<_Tanh>;
208         case UnaryOpOperation_TAN:
209             return _Wrap<_Unary<UnaryTan<float>, float>>;
210         case UnaryOpOperation_ATAN:
211             return _Wrap<_Unary<UnaryATan<float>, float>>;
212 #if defined(__aarch64__)
213         case UnaryOpOperation_SQRT:
214             return FP16VecUnary<VecSqrt>;
215 #else
216         case UnaryOpOperation_SQRT:
217             return _Wrap<_Unary<UnarySqrt<float>, float>>;
218 #endif
219         case UnaryOpOperation_CEIL:
220             return _Wrap<_Unary<UnaryCeil<float>, float>>;
221         case UnaryOpOperation_RECIPROCAL:
222             return FP16VecUnary<VecRecipocal>;
223         case UnaryOpOperation_LOG1P:
224             return _Wrap<_Unary<UnaryLog1p<float>, float>>;
225         case UnaryOpOperation_LOG:
226             return _Wrap<_Unary<UnaryLog<float>, float>>;
227         case UnaryOpOperation_FLOOR:
228             return _Wrap<_Unary<UnaryFloor<float>, float>>;
229         case UnaryOpOperation_BNLL:
230             return _Wrap<_Unary<UnaryBNLL<float>, float>>;
231         case UnaryOpOperation_ACOSH:
232             return _Wrap<_Unary<UnaryAcosh<float>, float>>;
233         case UnaryOpOperation_SINH:
234             return _Wrap<_Unary<UnarySinh<float>, float>>;
235         case UnaryOpOperation_ASINH:
236             return _Wrap<_Unary<UnaryAsinh<float>, float>>;
237         case UnaryOpOperation_ATANH:
238             return _Wrap<_Unary<UnaryAtanh<float>, float>>;
239         case UnaryOpOperation_SIGN:
240             return _Wrap<_Unary<UnarySign<float>, float>>;
241         case UnaryOpOperation_ROUND:
242             return _Wrap<_Unary<UnaryRound<float>, float>>;
243         case UnaryOpOperation_COSH:
244             return _Wrap<_Unary<UnaryCosh<float>, float>>;
245         case UnaryOpOperation_ERF:
246             return _Wrap<_Unary<UnaryErf<float>, float>>;
247         case UnaryOpOperation_ERFC:
248             return _Wrap<_Unary<UnaryErfc<float>, float>>;
249         case UnaryOpOperation_ERFINV:
250             return _Wrap<_Unary<UnaryErfinv<float>, float>>;
251         case UnaryOpOperation_EXPM1:
252             return _Wrap<_ExpM1>;
253         case UnaryOpOperation_ASIN:
254             return _Wrap<_Unary<UnaryAsin<float>, float>>;
255         case UnaryOpOperation_ACOS:
256             return _Wrap<_Unary<UnaryAcos<float>, float>>;
257         case UnaryOpOperation_HARDSWISH:
258             return FP16HardSwish;
259         default:
260             MNN_ASSERT(false);
261             break;
262     }
263     return nullptr;
264 }
265 } // namespace MNN
266 
267 #endif
268