1 /* NEON implementation of sin, cos, exp and log
2  *
3  *   Inspired by Intel Approximate Math library, and based on the
4  *   corresponding algorithms of the cephes math library
5  */
6 
7 /* Copyright (C) 2011  Julien Pommier
8  *
9  *  This software is provided 'as-is', without any express or implied
10  *  warranty.  In no event will the authors be held liable for any damages
11  *  arising from the use of this software.
12  *
13  *  Permission is granted to anyone to use this software for any purpose,
14  *  including commercial applications, and to alter it and redistribute it
15  *  freely, subject to the following restrictions:
16  *
17  *  1. The origin of this software must not be misrepresented; you must not
18  *     claim that you wrote the original software. If you use this software
19  *     in a product, an acknowledgment in the product documentation would be
20  *     appreciated but is not required.
21  *  2. Altered source versions must be plainly marked as such, and must not be
22  *     misrepresented as being the original software.
23  *  3. This notice may not be removed or altered from any source distribution.
24  *
25  *  (this is the zlib license)
26  */
27 
28 #include <arm_neon.h>
29 
30 #if (__ARM_FP & 2)
loadfp16(const void * ptr)31 static inline float32x4_t loadfp16(const void* ptr)
32 {
33 #if __ARM_FP16_FORMAT_IEEE
34     return vcvt_f32_f16(vld1_f16((const __fp16*)ptr));
35 #else // __ARM_FP16_FORMAT_IEEE
36     float32x4_t v;
37 #if __aarch64__
38     asm volatile(
39         "ld1    {v0.4h}, [%2]       \n"
40         "fcvtl  %0.4s, v0.4h        \n"
41         : "=w"(v) // %0
42         : "0"(v),
43         "r"(ptr) // %2
44         : "memory", "v0");
45 #else
46     asm volatile(
47         "vld1.s16       {d0}, [%2]  \n"
48         "vcvt.f32.f16   %q0, d0     \n"
49         : "=w"(v) // %0
50         : "0"(v),
51         "r"(ptr) // %2
52         : "memory", "d0");
53 #endif // __aarch64__
54     return v;
55 #endif // __ARM_FP16_FORMAT_IEEE
56 }
57 #endif
58 
59 #define c_inv_mant_mask ~0x7f800000u
60 #define c_cephes_SQRTHF 0.707106781186547524
61 #define c_cephes_log_p0 7.0376836292E-2
62 #define c_cephes_log_p1 -1.1514610310E-1
63 #define c_cephes_log_p2 1.1676998740E-1
64 #define c_cephes_log_p3 -1.2420140846E-1
65 #define c_cephes_log_p4 +1.4249322787E-1
66 #define c_cephes_log_p5 -1.6668057665E-1
67 #define c_cephes_log_p6 +2.0000714765E-1
68 #define c_cephes_log_p7 -2.4999993993E-1
69 #define c_cephes_log_p8 +3.3333331174E-1
70 #define c_cephes_log_q1 -2.12194440e-4
71 #define c_cephes_log_q2 0.693359375
72 
73 /* natural logarithm computed for 4 simultaneous float
74  *   return NaN for x <= 0
75  */
log_ps(float32x4_t x)76 static inline float32x4_t log_ps(float32x4_t x)
77 {
78     float32x4_t one = vdupq_n_f32(1);
79 
80     x = vmaxq_f32(x, vdupq_n_f32(0)); /* force flush to zero on denormal values */
81     uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0));
82 
83     int32x4_t ux = vreinterpretq_s32_f32(x);
84 
85     int32x4_t emm0 = vshrq_n_s32(ux, 23);
86 
87     /* keep only the fractional part */
88     ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask));
89     ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f)));
90     x = vreinterpretq_f32_s32(ux);
91 
92     emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f));
93     float32x4_t e = vcvtq_f32_s32(emm0);
94 
95     e = vaddq_f32(e, one);
96 
97     /* part2:
98      *     if( x < SQRTHF ) {
99      *       e -= 1;
100      *       x = x + x - 1.0;
101      *     } else { x = x - 1.0; }
102      */
103     uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF));
104     float32x4_t tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
105     x = vsubq_f32(x, one);
106     e = vsubq_f32(e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask)));
107     x = vaddq_f32(x, tmp);
108 
109     float32x4_t z = vmulq_f32(x, x);
110 
111     float32x4_t y = vdupq_n_f32(c_cephes_log_p0);
112     y = vmulq_f32(y, x);
113     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1));
114     y = vmulq_f32(y, x);
115     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2));
116     y = vmulq_f32(y, x);
117     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3));
118     y = vmulq_f32(y, x);
119     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4));
120     y = vmulq_f32(y, x);
121     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
122     y = vmulq_f32(y, x);
123     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
124     y = vmulq_f32(y, x);
125     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
126     y = vmulq_f32(y, x);
127     y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
128     y = vmulq_f32(y, x);
129 
130     y = vmulq_f32(y, z);
131 
132     tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1));
133     y = vaddq_f32(y, tmp);
134 
135     tmp = vmulq_f32(z, vdupq_n_f32(0.5f));
136     y = vsubq_f32(y, tmp);
137 
138     tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
139     x = vaddq_f32(x, y);
140     x = vaddq_f32(x, tmp);
141     x = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
142     return x;
143 }
144 
145 #define c_exp_hi 88.3762626647949f
146 #define c_exp_lo -88.3762626647949f
147 
148 #define c_cephes_LOG2EF 1.44269504088896341
149 #define c_cephes_exp_C1 0.693359375
150 #define c_cephes_exp_C2 -2.12194440e-4
151 
152 #define c_cephes_exp_p0 1.9875691500E-4
153 #define c_cephes_exp_p1 1.3981999507E-3
154 #define c_cephes_exp_p2 8.3334519073E-3
155 #define c_cephes_exp_p3 4.1665795894E-2
156 #define c_cephes_exp_p4 1.6666665459E-1
157 #define c_cephes_exp_p5 5.0000001201E-1
158 
159 /* exp() computed for 4 float at once */
exp_ps(float32x4_t x)160 static inline float32x4_t exp_ps(float32x4_t x)
161 {
162     float32x4_t tmp, fx;
163 
164     float32x4_t one = vdupq_n_f32(1);
165     x = vminq_f32(x, vdupq_n_f32(c_exp_hi));
166     x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo));
167 
168     /* express exp(x) as exp(g + n*log(2)) */
169     fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
170 
171     /* perform a floorf */
172     tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
173 
174     /* if greater, substract 1 */
175     uint32x4_t mask = vcgtq_f32(tmp, fx);
176     mask = vandq_u32(mask, vreinterpretq_u32_f32(one));
177 
178     fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
179 
180     tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1));
181     float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2));
182     x = vsubq_f32(x, tmp);
183     x = vsubq_f32(x, z);
184 
185     static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5};
186     float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0);
187     float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1);
188     float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2);
189     float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3);
190     float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4);
191     float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5);
192 
193     y = vmulq_f32(y, x);
194     z = vmulq_f32(x, x);
195 
196     y = vaddq_f32(y, c1);
197     y = vmulq_f32(y, x);
198     y = vaddq_f32(y, c2);
199     y = vmulq_f32(y, x);
200     y = vaddq_f32(y, c3);
201     y = vmulq_f32(y, x);
202     y = vaddq_f32(y, c4);
203     y = vmulq_f32(y, x);
204     y = vaddq_f32(y, c5);
205 
206     y = vmulq_f32(y, z);
207     y = vaddq_f32(y, x);
208     y = vaddq_f32(y, one);
209 
210     /* build 2^n */
211     int32x4_t mm;
212     mm = vcvtq_s32_f32(fx);
213     mm = vaddq_s32(mm, vdupq_n_s32(0x7f));
214     mm = vshlq_n_s32(mm, 23);
215     float32x4_t pow2n = vreinterpretq_f32_s32(mm);
216 
217     y = vmulq_f32(y, pow2n);
218     return y;
219 }
220 
221 #define c_minus_cephes_DP1 -0.78515625
222 #define c_minus_cephes_DP2 -2.4187564849853515625e-4
223 #define c_minus_cephes_DP3 -3.77489497744594108e-8
224 #define c_sincof_p0        -1.9515295891E-4
225 #define c_sincof_p1        8.3321608736E-3
226 #define c_sincof_p2        -1.6666654611E-1
227 #define c_coscof_p0        2.443315711809948E-005
228 #define c_coscof_p1        -1.388731625493765E-003
229 #define c_coscof_p2        4.166664568298827E-002
230 #define c_cephes_FOPI      1.27323954473516 // 4 / M_PI
231 
232 /* evaluation of 4 sines & cosines at once.
233  *
234  *   The code is the exact rewriting of the cephes sinf function.
235  *   Precision is excellent as long as x < 8192 (I did not bother to
236  *   take into account the special handling they have for greater values
237  *   -- it does not return garbage for arguments over 8192, though, but
238  *   the extra precision is missing).
239  *
240  *   Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
241  *   surprising but correct result.
242  *
243  *   Note also that when you compute sin(x), cos(x) is available at
244  *   almost no extra price so both sin_ps and cos_ps make use of
245  *   sincos_ps..
246  */
sincos_ps(float32x4_t x,float32x4_t * ysin,float32x4_t * ycos)247 static inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos)
248 {
249     // any x
250     float32x4_t xmm1, xmm2, xmm3, y;
251 
252     uint32x4_t emm2;
253 
254     uint32x4_t sign_mask_sin, sign_mask_cos;
255     sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0));
256     x = vabsq_f32(x);
257 
258     /* scale by 4/Pi */
259     y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI));
260 
261     /* store the integer part of y in mm0 */
262     emm2 = vcvtq_u32_f32(y);
263     /* j=(j+1) & (~1) (see the cephes sources) */
264     emm2 = vaddq_u32(emm2, vdupq_n_u32(1));
265     emm2 = vandq_u32(emm2, vdupq_n_u32(~1));
266     y = vcvtq_f32_u32(emm2);
267 
268     /* get the polynom selection mask
269      *     there is one polynom for 0 <= x <= Pi/4
270      *     and another one for Pi/4<x<=Pi/2
271      *
272      *     Both branches will be computed.
273      */
274     uint32x4_t poly_mask = vtstq_u32(emm2, vdupq_n_u32(2));
275 
276     /* The magic pass: "Extended precision modular arithmetic"
277      *     x = ((x - y * DP1) - y * DP2) - y * DP3; */
278     xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1);
279     xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2);
280     xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3);
281     x = vaddq_f32(x, xmm1);
282     x = vaddq_f32(x, xmm2);
283     x = vaddq_f32(x, xmm3);
284 
285     sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
286     sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));
287 
288     /* Evaluate the first polynom  (0 <= x <= Pi/4) in y1,
289      *     and the second polynom      (Pi/4 <= x <= 0) in y2 */
290     float32x4_t z = vmulq_f32(x, x);
291     float32x4_t y1, y2;
292 
293     y1 = vmulq_n_f32(z, c_coscof_p0);
294     y2 = vmulq_n_f32(z, c_sincof_p0);
295     y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1));
296     y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1));
297     y1 = vmulq_f32(y1, z);
298     y2 = vmulq_f32(y2, z);
299     y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
300     y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
301     y1 = vmulq_f32(y1, z);
302     y2 = vmulq_f32(y2, z);
303     y1 = vmulq_f32(y1, z);
304     y2 = vmulq_f32(y2, x);
305     y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f)));
306     y2 = vaddq_f32(y2, x);
307     y1 = vaddq_f32(y1, vdupq_n_f32(1));
308 
309     /* select the correct result from the two polynoms */
310     float32x4_t ys = vbslq_f32(poly_mask, y1, y2);
311     float32x4_t yc = vbslq_f32(poly_mask, y2, y1);
312     *ysin = vbslq_f32(sign_mask_sin, vnegq_f32(ys), ys);
313     *ycos = vbslq_f32(sign_mask_cos, yc, vnegq_f32(yc));
314 }
315 
sin_ps(float32x4_t x)316 static inline float32x4_t sin_ps(float32x4_t x)
317 {
318     float32x4_t ysin, ycos;
319     sincos_ps(x, &ysin, &ycos);
320     return ysin;
321 }
322 
cos_ps(float32x4_t x)323 static inline float32x4_t cos_ps(float32x4_t x)
324 {
325     float32x4_t ysin, ycos;
326     sincos_ps(x, &ysin, &ycos);
327     return ycos;
328 }
329 
div_ps(float32x4_t a,float32x4_t b)330 static inline float32x4_t div_ps(float32x4_t a, float32x4_t b)
331 {
332 #if __aarch64__
333     return vdivq_f32(a, b);
334 #else
335     float32x4_t reciprocal = vrecpeq_f32(b);
336     reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
337     // reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
338     return vmulq_f32(a, reciprocal);
339 #endif
340 }
341 
pow_ps(float32x4_t a,float32x4_t b)342 static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b)
343 {
344     // pow(x, m) = exp(m * log(x))
345     return exp_ps(vmulq_f32(b, log_ps(a)));
346 }
347 
sigmoid_ps(float32x4_t _v)348 static inline float32x4_t sigmoid_ps(float32x4_t _v)
349 {
350     float32x4_t _one = vdupq_n_f32(1.f);
351     _v = vnegq_f32(_v);
352     _v = exp_ps(_v);
353     _v = vaddq_f32(_v, _one);
354     float32x4_t _outp = vrecpeq_f32(_v);
355     // _outp = vmulq_f32(vrecpsq_f32(_v, _outp), _outp);
356     return vmulq_f32(vrecpsq_f32(_v, _outp), _outp);
357 }
358 
359 #include "neon_mathfun_tanh.h"
360