1 /*
2  * Wrapper functions for SVE ACLE.
3  *
4  * Copyright (c) 2019-2023, Arm Limited.
5  * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
6  */
7 
8 #ifndef SV_MATH_H
9 #define SV_MATH_H
10 
11 #ifndef WANT_VMATH
12 /* Enable the build of vector math code.  */
13 #define WANT_VMATH 1
14 #endif
15 #if WANT_VMATH
16 
17 #if WANT_SVE_MATH
18 #define SV_SUPPORTED 1
19 
20 #include <arm_sve.h>
21 #include <stdbool.h>
22 
23 #include "math_config.h"
24 
25 typedef float f32_t;
26 typedef uint32_t u32_t;
27 typedef int32_t s32_t;
28 typedef double f64_t;
29 typedef uint64_t u64_t;
30 typedef int64_t s64_t;
31 
32 typedef svfloat64_t sv_f64_t;
33 typedef svuint64_t sv_u64_t;
34 typedef svint64_t sv_s64_t;
35 
36 typedef svfloat32_t sv_f32_t;
37 typedef svuint32_t sv_u32_t;
38 typedef svint32_t sv_s32_t;
39 
40 /* Double precision.  */
41 static inline sv_s64_t
42 sv_s64 (s64_t x)
43 {
44   return svdup_n_s64 (x);
45 }
46 
47 static inline sv_u64_t
48 sv_u64 (u64_t x)
49 {
50   return svdup_n_u64 (x);
51 }
52 
53 static inline sv_f64_t
54 sv_f64 (f64_t x)
55 {
56   return svdup_n_f64 (x);
57 }
58 
59 static inline sv_f64_t
60 sv_fma_f64_x (svbool_t pg, sv_f64_t x, sv_f64_t y, sv_f64_t z)
61 {
62   return svmla_f64_x (pg, z, x, y);
63 }
64 
65 /* res = z + x * y with x scalar. */
66 static inline sv_f64_t
67 sv_fma_n_f64_x (svbool_t pg, f64_t x, sv_f64_t y, sv_f64_t z)
68 {
69   return svmla_n_f64_x (pg, z, y, x);
70 }
71 
72 static inline sv_s64_t
73 sv_as_s64_u64 (sv_u64_t x)
74 {
75   return svreinterpret_s64_u64 (x);
76 }
77 
78 static inline sv_u64_t
79 sv_as_u64_f64 (sv_f64_t x)
80 {
81   return svreinterpret_u64_f64 (x);
82 }
83 
84 static inline sv_f64_t
85 sv_as_f64_u64 (sv_u64_t x)
86 {
87   return svreinterpret_f64_u64 (x);
88 }
89 
90 static inline sv_f64_t
91 sv_to_f64_s64_x (svbool_t pg, sv_s64_t s)
92 {
93   return svcvt_f64_x (pg, s);
94 }
95 
96 static inline sv_f64_t
97 sv_call_f64 (f64_t (*f) (f64_t), sv_f64_t x, sv_f64_t y, svbool_t cmp)
98 {
99   svbool_t p = svpfirst (cmp, svpfalse ());
100   while (svptest_any (cmp, p))
101     {
102       f64_t elem = svclastb_n_f64 (p, 0, x);
103       elem = (*f) (elem);
104       sv_f64_t y2 = svdup_n_f64 (elem);
105       y = svsel_f64 (p, y2, y);
106       p = svpnext_b64 (cmp, p);
107     }
108   return y;
109 }
110 
111 static inline sv_f64_t
112 sv_call2_f64 (f64_t (*f) (f64_t, f64_t), sv_f64_t x1, sv_f64_t x2, sv_f64_t y,
113 	      svbool_t cmp)
114 {
115   svbool_t p = svpfirst (cmp, svpfalse ());
116   while (svptest_any (cmp, p))
117     {
118       f64_t elem1 = svclastb_n_f64 (p, 0, x1);
119       f64_t elem2 = svclastb_n_f64 (p, 0, x2);
120       f64_t ret = (*f) (elem1, elem2);
121       sv_f64_t y2 = svdup_n_f64 (ret);
122       y = svsel_f64 (p, y2, y);
123       p = svpnext_b64 (cmp, p);
124     }
125   return y;
126 }
127 
128 /* Load array of uint64_t into svuint64_t.  */
129 static inline sv_u64_t
130 sv_lookup_u64_x (svbool_t pg, const u64_t *tab, sv_u64_t idx)
131 {
132   return svld1_gather_u64index_u64 (pg, tab, idx);
133 }
134 
135 /* Load array of double into svfloat64_t.  */
136 static inline sv_f64_t
137 sv_lookup_f64_x (svbool_t pg, const f64_t *tab, sv_u64_t idx)
138 {
139   return svld1_gather_u64index_f64 (pg, tab, idx);
140 }
141 
142 static inline sv_u64_t
143 sv_mod_n_u64_x (svbool_t pg, sv_u64_t x, u64_t y)
144 {
145   sv_u64_t q = svdiv_n_u64_x (pg, x, y);
146   return svmls_n_u64_x (pg, x, q, y);
147 }
148 
149 /* Single precision.  */
150 static inline sv_s32_t
151 sv_s32 (s32_t x)
152 {
153   return svdup_n_s32 (x);
154 }
155 
156 static inline sv_u32_t
157 sv_u32 (u32_t x)
158 {
159   return svdup_n_u32 (x);
160 }
161 
162 static inline sv_f32_t
163 sv_f32 (f32_t x)
164 {
165   return svdup_n_f32 (x);
166 }
167 
168 static inline sv_f32_t
169 sv_fma_f32_x (svbool_t pg, sv_f32_t x, sv_f32_t y, sv_f32_t z)
170 {
171   return svmla_f32_x (pg, z, x, y);
172 }
173 
174 /* res = z + x * y with x scalar.  */
175 static inline sv_f32_t
176 sv_fma_n_f32_x (svbool_t pg, f32_t x, sv_f32_t y, sv_f32_t z)
177 {
178   return svmla_n_f32_x (pg, z, y, x);
179 }
180 
181 static inline sv_u32_t
182 sv_as_u32_f32 (sv_f32_t x)
183 {
184   return svreinterpret_u32_f32 (x);
185 }
186 
187 static inline sv_f32_t
188 sv_as_f32_u32 (sv_u32_t x)
189 {
190   return svreinterpret_f32_u32 (x);
191 }
192 
193 static inline sv_s32_t
194 sv_as_s32_u32 (sv_u32_t x)
195 {
196   return svreinterpret_s32_u32 (x);
197 }
198 
199 static inline sv_f32_t
200 sv_to_f32_s32_x (svbool_t pg, sv_s32_t s)
201 {
202   return svcvt_f32_x (pg, s);
203 }
204 
205 static inline sv_s32_t
206 sv_to_s32_f32_x (svbool_t pg, sv_f32_t x)
207 {
208   return svcvt_s32_f32_x (pg, x);
209 }
210 
211 static inline sv_f32_t
212 sv_call_f32 (f32_t (*f) (f32_t), sv_f32_t x, sv_f32_t y, svbool_t cmp)
213 {
214   svbool_t p = svpfirst (cmp, svpfalse ());
215   while (svptest_any (cmp, p))
216     {
217       f32_t elem = svclastb_n_f32 (p, 0, x);
218       elem = (*f) (elem);
219       sv_f32_t y2 = svdup_n_f32 (elem);
220       y = svsel_f32 (p, y2, y);
221       p = svpnext_b32 (cmp, p);
222     }
223   return y;
224 }
225 
226 static inline sv_f32_t
227 sv_call2_f32 (f32_t (*f) (f32_t, f32_t), sv_f32_t x1, sv_f32_t x2, sv_f32_t y,
228 	      svbool_t cmp)
229 {
230   svbool_t p = svpfirst (cmp, svpfalse ());
231   while (svptest_any (cmp, p))
232     {
233       f32_t elem1 = svclastb_n_f32 (p, 0, x1);
234       f32_t elem2 = svclastb_n_f32 (p, 0, x2);
235       f32_t ret = (*f) (elem1, elem2);
236       sv_f32_t y2 = svdup_n_f32 (ret);
237       y = svsel_f32 (p, y2, y);
238       p = svpnext_b32 (cmp, p);
239     }
240   return y;
241 }
242 
243 #endif
244 #endif
245 #endif
246