1 /*
2  * Vector math abstractions.
3  *
4  * Copyright (c) 2019-2023, Arm Limited.
5  * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
6  */
7 
8 #ifndef _V_MATH_H
9 #define _V_MATH_H
10 
11 #ifndef WANT_VMATH
12 /* Enable the build of vector math code.  */
13 # define WANT_VMATH 1
14 #endif
15 
16 #if WANT_VMATH
17 
18 # if __aarch64__
19 #  define VPCS_ATTR __attribute__ ((aarch64_vector_pcs))
20 # else
21 #  error "Cannot build without AArch64"
22 # endif
23 
24 # include <stdint.h>
25 # include "math_config.h"
26 # if __aarch64__
27 
28 #  include <arm_neon.h>
29 
30 /* Shorthand helpers for declaring constants.  */
31 #  define V2(X) { X, X }
32 #  define V4(X) { X, X, X, X }
33 #  define V8(X) { X, X, X, X, X, X, X, X }
34 
35 static inline int
v_any_u16h(uint16x4_t x)36 v_any_u16h (uint16x4_t x)
37 {
38   return vget_lane_u64 (vreinterpret_u64_u16 (x), 0) != 0;
39 }
40 
41 static inline float32x4_t
v_f32(float x)42 v_f32 (float x)
43 {
44   return (float32x4_t) V4 (x);
45 }
46 static inline uint32x4_t
v_u32(uint32_t x)47 v_u32 (uint32_t x)
48 {
49   return (uint32x4_t) V4 (x);
50 }
51 static inline int32x4_t
v_s32(int32_t x)52 v_s32 (int32_t x)
53 {
54   return (int32x4_t) V4 (x);
55 }
56 
57 /* true if any elements of a vector compare result is non-zero.  */
58 static inline int
v_any_u32(uint32x4_t x)59 v_any_u32 (uint32x4_t x)
60 {
61   /* assume elements in x are either 0 or -1u.  */
62   return vpaddd_u64 (vreinterpretq_u64_u32 (x)) != 0;
63 }
64 static inline int
v_any_u32h(uint32x2_t x)65 v_any_u32h (uint32x2_t x)
66 {
67   return vget_lane_u64 (vreinterpret_u64_u32 (x), 0) != 0;
68 }
69 static inline float32x4_t
v_lookup_f32(const float * tab,uint32x4_t idx)70 v_lookup_f32 (const float *tab, uint32x4_t idx)
71 {
72   return (float32x4_t){ tab[idx[0]], tab[idx[1]], tab[idx[2]], tab[idx[3]] };
73 }
74 static inline uint32x4_t
v_lookup_u32(const uint32_t * tab,uint32x4_t idx)75 v_lookup_u32 (const uint32_t *tab, uint32x4_t idx)
76 {
77   return (uint32x4_t){ tab[idx[0]], tab[idx[1]], tab[idx[2]], tab[idx[3]] };
78 }
79 static inline float32x4_t
v_call_f32(float (* f)(float),float32x4_t x,float32x4_t y,uint32x4_t p)80 v_call_f32 (float (*f) (float), float32x4_t x, float32x4_t y, uint32x4_t p)
81 {
82   return (float32x4_t){ p[0] ? f (x[0]) : y[0], p[1] ? f (x[1]) : y[1],
83 			p[2] ? f (x[2]) : y[2], p[3] ? f (x[3]) : y[3] };
84 }
85 static inline float32x4_t
v_call2_f32(float (* f)(float,float),float32x4_t x1,float32x4_t x2,float32x4_t y,uint32x4_t p)86 v_call2_f32 (float (*f) (float, float), float32x4_t x1, float32x4_t x2,
87 	     float32x4_t y, uint32x4_t p)
88 {
89   return (float32x4_t){ p[0] ? f (x1[0], x2[0]) : y[0],
90 			p[1] ? f (x1[1], x2[1]) : y[1],
91 			p[2] ? f (x1[2], x2[2]) : y[2],
92 			p[3] ? f (x1[3], x2[3]) : y[3] };
93 }
94 static inline float32x4_t
v_zerofy_f32(float32x4_t x,uint32x4_t mask)95 v_zerofy_f32 (float32x4_t x, uint32x4_t mask)
96 {
97   return vreinterpretq_f32_u32 (vbicq_u32 (vreinterpretq_u32_f32 (x), mask));
98 }
99 
100 static inline float64x2_t
v_f64(double x)101 v_f64 (double x)
102 {
103   return (float64x2_t) V2 (x);
104 }
105 static inline uint64x2_t
v_u64(uint64_t x)106 v_u64 (uint64_t x)
107 {
108   return (uint64x2_t) V2 (x);
109 }
110 static inline int64x2_t
v_s64(int64_t x)111 v_s64 (int64_t x)
112 {
113   return (int64x2_t) V2 (x);
114 }
115 
116 /* true if any elements of a vector compare result is non-zero.  */
117 static inline int
v_any_u64(uint64x2_t x)118 v_any_u64 (uint64x2_t x)
119 {
120   /* assume elements in x are either 0 or -1u.  */
121   return vpaddd_u64 (x) != 0;
122 }
123 /* true if all elements of a vector compare result is 1.  */
124 static inline int
v_all_u64(uint64x2_t x)125 v_all_u64 (uint64x2_t x)
126 {
127   /* assume elements in x are either 0 or -1u.  */
128   return vpaddd_s64 (vreinterpretq_s64_u64 (x)) == -2;
129 }
130 static inline float64x2_t
v_lookup_f64(const double * tab,uint64x2_t idx)131 v_lookup_f64 (const double *tab, uint64x2_t idx)
132 {
133   return (float64x2_t){ tab[idx[0]], tab[idx[1]] };
134 }
135 static inline uint64x2_t
v_lookup_u64(const uint64_t * tab,uint64x2_t idx)136 v_lookup_u64 (const uint64_t *tab, uint64x2_t idx)
137 {
138   return (uint64x2_t){ tab[idx[0]], tab[idx[1]] };
139 }
140 
141 static inline float64x2_t
v_call_f64(double (* f)(double),float64x2_t x,float64x2_t y,uint64x2_t p)142 v_call_f64 (double (*f) (double), float64x2_t x, float64x2_t y, uint64x2_t p)
143 {
144   double p1 = p[1];
145   double x1 = x[1];
146   if (likely (p[0]))
147     y[0] = f (x[0]);
148   if (likely (p1))
149     y[1] = f (x1);
150   return y;
151 }
152 
153 static inline float64x2_t
v_call2_f64(double (* f)(double,double),float64x2_t x1,float64x2_t x2,float64x2_t y,uint64x2_t p)154 v_call2_f64 (double (*f) (double, double), float64x2_t x1, float64x2_t x2,
155 	     float64x2_t y, uint64x2_t p)
156 {
157   double p1 = p[1];
158   double x1h = x1[1];
159   double x2h = x2[1];
160   if (likely (p[0]))
161     y[0] = f (x1[0], x2[0]);
162   if (likely (p1))
163     y[1] = f (x1h, x2h);
164   return y;
165 }
166 static inline float64x2_t
v_zerofy_f64(float64x2_t x,uint64x2_t mask)167 v_zerofy_f64 (float64x2_t x, uint64x2_t mask)
168 {
169   return vreinterpretq_f64_u64 (vbicq_u64 (vreinterpretq_u64_f64 (x), mask));
170 }
171 
172 # endif
173 #endif
174 
175 #endif
176