1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // ARM SVE[2] vectors (length not known at compile time).
16 // External include guard in highway.h - see comment there.
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #if defined(HWY_EMULATE_SVE)
22 #include "third_party/farm_sve/farm_sve.h"
23 #else
24 #include <arm_sve.h>
25 #endif
26 
27 #include "hwy/base.h"
28 #include "hwy/ops/shared-inl.h"
29 
30 HWY_BEFORE_NAMESPACE();
31 namespace hwy {
32 namespace HWY_NAMESPACE {
33 
34 // SVE only supports fractions, not LMUL > 1.
35 template <typename T, int kShift = 0>
36 using Full = Simd<T, (kShift <= 0) ? (HWY_LANES(T) >> (-kShift)) : 0>;
37 
38 template <class V>
39 struct DFromV_t {};  // specialized in macros
40 template <class V>
41 using DFromV = typename DFromV_t<RemoveConst<V>>::type;
42 
43 template <class V>
44 using TFromV = TFromD<DFromV<V>>;
45 
46 #define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV<V>)
47 #define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV<V>)
48 #define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV<V>)
49 #define HWY_IF_LANE_SIZE_V(V, bytes) HWY_IF_LANE_SIZE(TFromV<V>, bytes)
50 
51 // ================================================== MACROS
52 
53 // Generate specializations and function definitions using X macros. Although
54 // harder to read and debug, writing everything manually is too bulky.
55 
56 namespace detail {  // for code folding
57 
58 // Unsigned:
59 #define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, NAME, OP)
60 #define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, NAME, OP)
61 #define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) X_MACRO(uint, u, 32, NAME, OP)
62 #define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) X_MACRO(uint, u, 64, NAME, OP)
63 
64 // Signed:
65 #define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, NAME, OP)
66 #define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, NAME, OP)
67 #define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, NAME, OP)
68 #define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, NAME, OP)
69 
70 // Float:
71 #define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) X_MACRO(float, f, 16, NAME, OP)
72 #define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) X_MACRO(float, f, 32, NAME, OP)
73 #define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) X_MACRO(float, f, 64, NAME, OP)
74 
75 // For all element sizes:
76 #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
77   HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP)     \
78   HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP)     \
79   HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP)     \
80   HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)
81 
82 #define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
83   HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)     \
84   HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)     \
85   HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)     \
86   HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
87 
88 #define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \
89   HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP)     \
90   HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP)     \
91   HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
92 
93 // Commonly used type categories for a given element size:
94 #define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \
95   HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP)        \
96   HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)
97 
98 #define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \
99   HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP)        \
100   HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)
101 
102 #define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
103   HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP)        \
104   HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)
105 
106 #define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
107   HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)        \
108   HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
109 
110 #define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \
111   HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP)          \
112   HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP)          \
113   HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP)           \
114   HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
115 
116 // Commonly used type categories:
117 #define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \
118   HWY_SVE_FOREACH_U(X_MACRO, NAME, OP)        \
119   HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)
120 
121 #define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \
122   HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)        \
123   HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
124 
125 #define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \
126   HWY_SVE_FOREACH_U(X_MACRO, NAME, OP)     \
127   HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)     \
128   HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
129 
130 // Assemble types for use in x-macros
131 #define HWY_SVE_T(BASE, BITS) BASE##BITS##_t
132 #define HWY_SVE_D(BASE, BITS, N) Simd<HWY_SVE_T(BASE, BITS), N>
133 #define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t
134 
135 }  // namespace detail
136 
137 #define HWY_SPECIALIZE(BASE, CHAR, BITS, NAME, OP)                        \
138   template <>                                                             \
139   struct DFromV_t<HWY_SVE_V(BASE, BITS)> {                                \
140     using type = HWY_SVE_D(BASE, BITS, HWY_LANES(HWY_SVE_T(BASE, BITS))); \
141   };
142 
HWY_SVE_FOREACH(HWY_SPECIALIZE,_,_)143 HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _)
144 #undef HWY_SPECIALIZE
145 
146 // vector = f(d), e.g. Undefined
147 #define HWY_SVE_RETV_ARGD(BASE, CHAR, BITS, NAME, OP)              \
148   template <size_t N>                                              \
149   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N) d) { \
150     return sv##OP##_##CHAR##BITS();                                \
151   }
152 
153 // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX
154 // instructions, and we anyway only use it when the predicate is ptrue.
155 
156 // vector = f(vector), e.g. Not
157 #define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, NAME, OP)          \
158   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
159     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v);   \
160   }
161 #define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, NAME, OP)           \
162   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
163     return sv##OP##_##CHAR##BITS(v);                            \
164   }
165 
166 // vector = f(vector, scalar), e.g. detail::AddK
167 #define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, NAME, OP)          \
168   HWY_API HWY_SVE_V(BASE, BITS)                                  \
169       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) {   \
170     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
171   }
172 #define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, NAME, OP)         \
173   HWY_API HWY_SVE_V(BASE, BITS)                                \
174       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
175     return sv##OP##_##CHAR##BITS(a, b);                        \
176   }
177 
178 // vector = f(vector, vector), e.g. Add
179 #define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, NAME, OP)          \
180   HWY_API HWY_SVE_V(BASE, BITS)                                  \
181       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) {   \
182     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
183   }
184 #define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, NAME, OP)         \
185   HWY_API HWY_SVE_V(BASE, BITS)                                \
186       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
187     return sv##OP##_##CHAR##BITS(a, b);                        \
188   }
189 
190 // ------------------------------ Lanes
191 
192 namespace detail {
193 
194 // Returns actual lanes of a hardware vector without rounding to a power of two.
195 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<1> /* tag */) {
196   return svcntb_pat(SV_ALL);
197 }
198 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<2> /* tag */) {
199   return svcnth_pat(SV_ALL);
200 }
201 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<4> /* tag */) {
202   return svcntw_pat(SV_ALL);
203 }
204 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<8> /* tag */) {
205   return svcntd_pat(SV_ALL);
206 }
207 
208 // Returns actual lanes of a hardware vector, rounded down to a power of two.
209 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<1> /* tag */) {
210   return svcntb_pat(SV_POW2);
211 }
212 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<2> /* tag */) {
213   return svcnth_pat(SV_POW2);
214 }
215 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<4> /* tag */) {
216   return svcntw_pat(SV_POW2);
217 }
218 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<8> /* tag */) {
219   return svcntd_pat(SV_POW2);
220 }
221 
222 }  // namespace detail
223 
224 // Capped to <= 128-bit: SVE is at least that large, so no need to query actual.
225 template <typename T, size_t N, HWY_IF_LE128(T, N)>
Lanes(Simd<T,N>)226 HWY_API constexpr size_t Lanes(Simd<T, N> /* tag */) {
227   return N;
228 }
229 
230 // Returns actual number of lanes after dividing by div={1,2,4,8}.
231 // May return 0 if div > 16/sizeof(T): there is no "1/8th" of a u32x4, but it
232 // would be valid for u32x8 (i.e. hardware vectors >= 256 bits).
233 template <typename T, size_t N, HWY_IF_GT128(T, N)>
Lanes(Simd<T,N>)234 HWY_API size_t Lanes(Simd<T, N> /* tag */) {
235   static_assert(N <= HWY_LANES(T), "N cannot exceed a full vector");
236 
237   const size_t actual = detail::HardwareLanes(hwy::SizeTag<sizeof(T)>());
238   const size_t div = HWY_LANES(T) / N;
239   return (div <= 8) ? actual / div : HWY_MIN(actual, N);
240 }
241 
242 // ================================================== MASK INIT
243 
244 // One mask bit per byte; only the one belonging to the lowest byte is valid.
245 
246 // ------------------------------ FirstN
247 #define HWY_SVE_FIRSTN(BASE, CHAR, BITS, NAME, OP)                        \
248   template <size_t KN>                                                    \
249   HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, KN) /* d */, size_t N) {    \
250     return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(N)); \
251   }
HWY_SVE_FOREACH(HWY_SVE_FIRSTN,FirstN,whilelt)252 HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt)
253 #undef HWY_SVE_FIRSTN
254 
255 namespace detail {
256 
257 // All-true mask from a macro
258 #define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2)
259 
260 #define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, NAME, OP) \
261   template <size_t N>                                  \
262   HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N) d) {  \
263     return HWY_SVE_PTRUE(BITS);                        \
264   }
265 
266 HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue)  // return all-true
267 #undef HWY_SVE_WRAP_PTRUE
268 
269 HWY_API svbool_t PFalse() { return svpfalse_b(); }
270 
271 // Returns all-true if d is HWY_FULL or FirstN(N) after capping N.
272 //
273 // This is used in functions that load/store memory; other functions (e.g.
274 // arithmetic on partial vectors) can ignore d and use PTrue instead.
275 template <typename T, size_t N>
276 svbool_t Mask(Simd<T, N> d) {
277   return N == HWY_LANES(T) ? PTrue(d) : FirstN(d, Lanes(d));
278 }
279 
280 }  // namespace detail
281 
282 // ================================================== INIT
283 
284 // ------------------------------ Set
285 // vector = f(d, scalar), e.g. Set
286 #define HWY_SVE_SET(BASE, CHAR, BITS, NAME, OP)                     \
287   template <size_t N>                                               \
288   HWY_API HWY_SVE_V(BASE, BITS)                                     \
289       NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_T(BASE, BITS) arg) { \
290     return sv##OP##_##CHAR##BITS(arg);                              \
291   }
292 
HWY_SVE_FOREACH(HWY_SVE_SET,Set,dup_n)293 HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n)
294 #undef HWY_SVE_SET
295 
296 // Required for Zero and VFromD
297 template <size_t N>
298 svuint16_t Set(Simd<bfloat16_t, N> d, bfloat16_t arg) {
299   return Set(RebindToUnsigned<decltype(d)>(), arg.bits);
300 }
301 
302 template <class D>
303 using VFromD = decltype(Set(D(), TFromD<D>()));
304 
305 // ------------------------------ Zero
306 
307 template <class D>
Zero(D d)308 VFromD<D> Zero(D d) {
309   return Set(d, 0);
310 }
311 
312 // ------------------------------ Undefined
313 
314 #if defined(HWY_EMULATE_SVE)
315 template <class D>
Undefined(D d)316 VFromD<D> Undefined(D d) {
317   return Zero(d);
318 }
319 #else
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGD,Undefined,undef)320 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGD, Undefined, undef)
321 #endif
322 
323 // ------------------------------ BitCast
324 
325 namespace detail {
326 
327 // u8: no change
328 #define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, NAME, OP)                     \
329   HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
330     return v;                                                            \
331   }                                                                      \
332   template <size_t N>                                                    \
333   HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte(                         \
334       HWY_SVE_D(BASE, BITS, N) /* d */, HWY_SVE_V(BASE, BITS) v) {       \
335     return v;                                                            \
336   }
337 
338 // All other types
339 #define HWY_SVE_CAST(BASE, CHAR, BITS, NAME, OP)                       \
340   HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) {        \
341     return sv##OP##_u8_##CHAR##BITS(v);                                \
342   }                                                                    \
343   template <size_t N>                                                  \
344   HWY_INLINE HWY_SVE_V(BASE, BITS)                                     \
345       BitCastFromByte(HWY_SVE_D(BASE, BITS, N) /* d */, svuint8_t v) { \
346     return sv##OP##_##CHAR##BITS##_u8(v);                              \
347   }
348 
349 HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _)
350 HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret)
351 HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret)
352 HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret)
353 HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret)
354 HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret)
355 
356 #undef HWY_SVE_CAST_NOP
357 #undef HWY_SVE_CAST
358 
359 template <size_t N>
360 HWY_INLINE svuint16_t BitCastFromByte(Simd<bfloat16_t, N> /* d */,
361                                       svuint8_t v) {
362   return BitCastFromByte(Simd<uint16_t, N>(), v);
363 }
364 
365 }  // namespace detail
366 
367 template <class D, class FromV>
BitCast(D d,FromV v)368 HWY_API VFromD<D> BitCast(D d, FromV v) {
369   return detail::BitCastFromByte(d, detail::BitCastToByte(v));
370 }
371 
372 // ================================================== LOGICAL
373 
374 // detail::*N() functions accept a scalar argument to avoid extra Set().
375 
376 // ------------------------------ Not
377 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV,Not,not)378 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not )
379 
380 // ------------------------------ And
381 
382 namespace detail {
383 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n)
384 }  // namespace detail
385 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,And,and)386 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and)
387 
388 template <class V, HWY_IF_FLOAT_V(V)>
389 HWY_API V And(const V a, const V b) {
390   const DFromV<V> df;
391   const RebindToUnsigned<decltype(df)> du;
392   return BitCast(df, And(BitCast(du, a), BitCast(du, b)));
393 }
394 
395 // ------------------------------ Or
396 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,Or,orr)397 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr)
398 
399 template <class V, HWY_IF_FLOAT_V(V)>
400 HWY_API V Or(const V a, const V b) {
401   const DFromV<V> df;
402   const RebindToUnsigned<decltype(df)> du;
403   return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
404 }
405 
406 // ------------------------------ Xor
407 
408 namespace detail {
409 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n)
410 }  // namespace detail
411 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,Xor,eor)412 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor)
413 
414 template <class V, HWY_IF_FLOAT_V(V)>
415 HWY_API V Xor(const V a, const V b) {
416   const DFromV<V> df;
417   const RebindToUnsigned<decltype(df)> du;
418   return BitCast(df, Xor(BitCast(du, a), BitCast(du, b)));
419 }
420 
421 // ------------------------------ AndNot
422 
423 namespace detail {
424 #define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, NAME, OP)     \
425   HWY_API HWY_SVE_V(BASE, BITS)                                  \
426       NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) {   \
427     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \
428   }
429 
430 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n)
431 #undef HWY_SVE_RETV_ARGPVN_SWAP
432 }  // namespace detail
433 
434 #define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, NAME, OP)     \
435   HWY_API HWY_SVE_V(BASE, BITS)                                  \
436       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) {   \
437     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \
438   }
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP,AndNot,bic)439 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic)
440 #undef HWY_SVE_RETV_ARGPVV_SWAP
441 
442 template <class V, HWY_IF_FLOAT_V(V)>
443 HWY_API V AndNot(const V a, const V b) {
444   const DFromV<V> df;
445   const RebindToUnsigned<decltype(df)> du;
446   return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b)));
447 }
448 
449 // ------------------------------ PopulationCount
450 
451 #ifdef HWY_NATIVE_POPCNT
452 #undef HWY_NATIVE_POPCNT
453 #else
454 #define HWY_NATIVE_POPCNT
455 #endif
456 
457 // Need to return original type instead of unsigned.
458 #define HWY_SVE_POPCNT(BASE, CHAR, BITS, NAME, OP)                     \
459   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) {        \
460     return BitCast(DFromV<decltype(v)>(),                              \
461                    sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \
462   }
HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT,PopulationCount,cnt)463 HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt)
464 #undef HWY_SVE_POPCNT
465 
466 // ================================================== SIGN
467 
468 // ------------------------------ Neg
469 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg)
470 
471 // ------------------------------ Abs
472 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
473 
474 // ------------------------------ CopySign[ToAbs]
475 
476 template <class V>
477 HWY_API V CopySign(const V magn, const V sign) {
478   const auto msb = SignBit(DFromV<V>());
479   return Or(AndNot(msb, magn), And(msb, sign));
480 }
481 
482 template <class V>
CopySignToAbs(const V abs,const V sign)483 HWY_API V CopySignToAbs(const V abs, const V sign) {
484   const auto msb = SignBit(DFromV<V>());
485   return Or(abs, And(msb, sign));
486 }
487 
488 // ================================================== ARITHMETIC
489 
490 // ------------------------------ Add
491 
492 namespace detail {
493 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n)
494 }  // namespace detail
495 
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV,Add,add)496 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add)
497 
498 // ------------------------------ Sub
499 
500 namespace detail {
501 // Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg.
502 #define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, NAME, OP)                \
503   HWY_API HWY_SVE_V(BASE, BITS)                                             \
504       NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
505     return sv##OP##_##CHAR##BITS##_z(pg, a, b);                             \
506   }
507 
508 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n)
509 #undef HWY_SVE_RETV_ARGPVN_MASK
510 }  // namespace detail
511 
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV,Sub,sub)512 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub)
513 
514 // ------------------------------ SaturatedAdd
515 
516 HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
517 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
518 
519 // ------------------------------ SaturatedSub
520 
521 HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
522 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
523 
524 // ------------------------------ AbsDiff
525 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd)
526 
527 // ------------------------------ ShiftLeft[Same]
528 
529 #define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, NAME, OP)                     \
530   template <int kBits>                                                  \
531   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) {         \
532     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits);    \
533   }                                                                     \
534   HWY_API HWY_SVE_V(BASE, BITS)                                         \
535       NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \
536     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits);     \
537   }
538 
539 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n)
540 
541 // ------------------------------ ShiftRight[Same]
542 
543 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n)
544 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)
545 
546 #undef HWY_SVE_SHIFT_N
547 
548 // ------------------------------ RotateRight
549 
550 // TODO(janwas): svxar on SVE2
551 template <int kBits, class V>
552 HWY_API V RotateRight(const V v) {
553   constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8;
554   static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count");
555   if (kBits == 0) return v;
556   return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v));
557 }
558 
559 // ------------------------------ Shl/r
560 
561 #define HWY_SVE_SHIFT(BASE, CHAR, BITS, NAME, OP)                          \
562   HWY_API HWY_SVE_V(BASE, BITS)                                            \
563       NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) {          \
564     using TU = HWY_SVE_T(uint, BITS);                                      \
565     return sv##OP##_##CHAR##BITS##_x(                                      \
566         HWY_SVE_PTRUE(BITS), v, BitCast(Simd<TU, HWY_LANES(TU)>(), bits)); \
567   }
568 
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT,Shl,lsl)569 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl)
570 
571 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr)
572 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr)
573 
574 #undef HWY_SVE_SHIFT
575 
576 // ------------------------------ Min/Max
577 
578 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min)
579 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max)
580 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm)
581 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm)
582 
583 namespace detail {
584 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n)
585 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n)
586 }  // namespace detail
587 
588 // ------------------------------ Mul
HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV,Mul,mul)589 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, Mul, mul)
590 HWY_SVE_FOREACH_UIF3264(HWY_SVE_RETV_ARGPVV, Mul, mul)
591 
592 // ------------------------------ MulHigh
593 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
594 namespace detail {
595 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
596 HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
597 }  // namespace detail
598 
599 // ------------------------------ Div
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV,Div,div)600 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div)
601 
602 // ------------------------------ ApproximateReciprocal
603 HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
604 
605 // ------------------------------ Sqrt
606 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)
607 
608 // ------------------------------ ApproximateReciprocalSqrt
609 HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte)
610 
611 // ------------------------------ MulAdd
612 #define HWY_SVE_FMA(BASE, CHAR, BITS, NAME, OP)                         \
613   HWY_API HWY_SVE_V(BASE, BITS)                                         \
614       NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x,          \
615            HWY_SVE_V(BASE, BITS) add) {                                 \
616     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \
617   }
618 
619 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulAdd, mad)
620 
621 // ------------------------------ NegMulAdd
622 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulAdd, msb)
623 
624 // ------------------------------ MulSub
625 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb)
626 
627 // ------------------------------ NegMulSub
628 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)
629 
630 #undef HWY_SVE_FMA
631 
632 // ------------------------------ Round etc.
633 
634 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
635 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm)
636 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp)
637 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz)
638 
639 // ================================================== MASK
640 
641 // ------------------------------ RebindMask
642 template <class D, typename MFrom>
643 HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) {
644   return mask;
645 }
646 
647 // ------------------------------ Mask logical
648 
Not(svbool_t m)649 HWY_API svbool_t Not(svbool_t m) {
650   // We don't know the lane type, so assume 8-bit. For larger types, this will
651   // de-canonicalize the predicate, i.e. set bits to 1 even though they do not
652   // correspond to the lowest byte in the lane. Per ARM, such bits are ignored.
653   return svnot_b_z(HWY_SVE_PTRUE(8), m);
654 }
And(svbool_t a,svbool_t b)655 HWY_API svbool_t And(svbool_t a, svbool_t b) {
656   return svand_b_z(b, b, a);  // same order as AndNot for consistency
657 }
AndNot(svbool_t a,svbool_t b)658 HWY_API svbool_t AndNot(svbool_t a, svbool_t b) {
659   return svbic_b_z(b, b, a);  // reversed order like NEON
660 }
Or(svbool_t a,svbool_t b)661 HWY_API svbool_t Or(svbool_t a, svbool_t b) {
662   return svsel_b(a, a, b);  // a ? true : b
663 }
Xor(svbool_t a,svbool_t b)664 HWY_API svbool_t Xor(svbool_t a, svbool_t b) {
665   return svsel_b(a, svnand_b_z(a, a, b), b);  // a ? !(a & b) : b.
666 }
667 
668 // ------------------------------ CountTrue
669 
670 #define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, NAME, OP)          \
671   template <size_t N>                                           \
672   HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N) d, svbool_t m) { \
673     return sv##OP##_b##BITS(detail::Mask(d), m);                \
674   }
675 
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE,CountTrue,cntp)676 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp)
677 #undef HWY_SVE_COUNT_TRUE
678 
679 // For 16-bit Compress: full vector, not limited to SV_POW2.
680 namespace detail {
681 
682 #define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, NAME, OP)     \
683   template <size_t N>                                           \
684   HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N) d, svbool_t m) { \
685     return sv##OP##_b##BITS(svptrue_b##BITS(), m);              \
686   }
687 
688 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp)
689 #undef HWY_SVE_COUNT_TRUE_FULL
690 
691 }  // namespace detail
692 
693 // ------------------------------ AllFalse
694 template <typename T, size_t N>
AllFalse(Simd<T,N> d,svbool_t m)695 HWY_API bool AllFalse(Simd<T, N> d, svbool_t m) {
696   return !svptest_any(detail::Mask(d), m);
697 }
698 
699 // ------------------------------ AllTrue
700 template <typename T, size_t N>
AllTrue(Simd<T,N> d,svbool_t m)701 HWY_API bool AllTrue(Simd<T, N> d, svbool_t m) {
702   return CountTrue(d, m) == Lanes(d);
703 }
704 
705 // ------------------------------ FindFirstTrue
706 template <typename T, size_t N>
FindFirstTrue(Simd<T,N> d,svbool_t m)707 HWY_API intptr_t FindFirstTrue(Simd<T, N> d, svbool_t m) {
708   return AllFalse(d, m) ? -1 : CountTrue(d, svbrkb_b_z(detail::Mask(d), m));
709 }
710 
711 // ------------------------------ IfThenElse
712 #define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, NAME, OP)                      \
713   HWY_API HWY_SVE_V(BASE, BITS)                                               \
714       NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \
715     return sv##OP##_##CHAR##BITS(m, yes, no);                                 \
716   }
717 
HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE,IfThenElse,sel)718 HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel)
719 #undef HWY_SVE_IF_THEN_ELSE
720 
721 // ------------------------------ IfThenElseZero
722 template <class M, class V>
723 HWY_API V IfThenElseZero(const M mask, const V yes) {
724   return IfThenElse(mask, yes, Zero(DFromV<V>()));
725 }
726 
727 // ------------------------------ IfThenZeroElse
728 template <class M, class V>
IfThenZeroElse(const M mask,const V no)729 HWY_API V IfThenZeroElse(const M mask, const V no) {
730   return IfThenElse(mask, Zero(DFromV<V>()), no);
731 }
732 
733 // ================================================== COMPARE
734 
735 // mask = f(vector, vector)
736 #define HWY_SVE_COMPARE(BASE, CHAR, BITS, NAME, OP)                         \
737   HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
738     return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b);                \
739   }
740 #define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, NAME, OP)                       \
741   HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
742     return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b);                \
743   }
744 
745 // ------------------------------ Eq
HWY_SVE_FOREACH(HWY_SVE_COMPARE,Eq,cmpeq)746 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq)
747 
748 // ------------------------------ Ne
749 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne)
750 
751 // ------------------------------ Lt
752 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt)
753 namespace detail {
754 HWY_SVE_FOREACH_IF(HWY_SVE_COMPARE_N, LtN, cmplt_n)
755 }  // namespace detail
756 
757 // ------------------------------ Le
HWY_SVE_FOREACH_F(HWY_SVE_COMPARE,Le,cmple)758 HWY_SVE_FOREACH_F(HWY_SVE_COMPARE, Le, cmple)
759 
760 #undef HWY_SVE_COMPARE
761 #undef HWY_SVE_COMPARE_N
762 
763 // ------------------------------ Gt/Ge (swapped order)
764 
765 template <class V>
766 HWY_API svbool_t Gt(const V a, const V b) {
767   return Lt(b, a);
768 }
769 template <class V>
Ge(const V a,const V b)770 HWY_API svbool_t Ge(const V a, const V b) {
771   return Le(b, a);
772 }
773 
774 // ------------------------------ TestBit
775 template <class V>
TestBit(const V a,const V bit)776 HWY_API svbool_t TestBit(const V a, const V bit) {
777   return Ne(And(a, bit), Zero(DFromV<V>()));
778 }
779 
780 // ------------------------------ MaskFromVec (Ne)
781 template <class V>
MaskFromVec(const V v)782 HWY_API svbool_t MaskFromVec(const V v) {
783   return Ne(v, Zero(DFromV<V>()));
784 }
785 
786 // ------------------------------ VecFromMask
787 
788 template <class D, HWY_IF_NOT_FLOAT_D(D)>
VecFromMask(const D d,svbool_t mask)789 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
790   const auto v0 = Zero(RebindToSigned<decltype(d)>());
791   return BitCast(d, detail::SubN(mask, v0, 1));
792 }
793 
794 template <class D, HWY_IF_FLOAT_D(D)>
VecFromMask(const D d,svbool_t mask)795 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
796   return BitCast(d, VecFromMask(RebindToUnsigned<D>(), mask));
797 }
798 
799 // ================================================== MEMORY
800 
801 // ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream
802 
803 #define HWY_SVE_LOAD(BASE, CHAR, BITS, NAME, OP)           \
804   template <size_t N>                                      \
805   HWY_API HWY_SVE_V(BASE, BITS)                            \
806       NAME(HWY_SVE_D(BASE, BITS, N) d,                     \
807            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
808     return sv##OP##_##CHAR##BITS(detail::Mask(d), p);      \
809   }
810 
811 #define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, NAME, OP)    \
812   template <size_t N>                                      \
813   HWY_API HWY_SVE_V(BASE, BITS)                            \
814       NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N) d,         \
815            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
816     return sv##OP##_##CHAR##BITS(m, p);                    \
817   }
818 
819 #define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, NAME, OP)    \
820   template <size_t N>                                      \
821   HWY_API HWY_SVE_V(BASE, BITS)                            \
822       NAME(HWY_SVE_D(BASE, BITS, N) d,                     \
823            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
824     /* All-true predicate to load all 128 bits. */         \
825     return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p);     \
826   }
827 
828 #define HWY_SVE_STORE(BASE, CHAR, BITS, NAME, OP)                        \
829   template <size_t N>                                                    \
830   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \
831                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) {            \
832     sv##OP##_##CHAR##BITS(detail::Mask(d), p, v);                        \
833   }
834 
835 #define HWY_SVE_MASKED_STORE(BASE, CHAR, BITS, NAME, OP)      \
836   template <size_t N>                                         \
837   HWY_API void NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v,      \
838                     HWY_SVE_D(BASE, BITS, N) d,               \
839                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
840     sv##OP##_##CHAR##BITS(m, p, v);                           \
841   }
842 
HWY_SVE_FOREACH(HWY_SVE_LOAD,Load,ld1)843 HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1)
844 HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1)
845 HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDup128, ld1rq)
846 HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1)
847 HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1)
848 HWY_SVE_FOREACH(HWY_SVE_MASKED_STORE, MaskedStore, st1)
849 
850 #undef HWY_SVE_LOAD
851 #undef HWY_SVE_MASKED_LOAD
852 #undef HWY_SVE_LOAD_DUP128
853 #undef HWY_SVE_STORE
854 #undef HWY_SVE_MASKED_STORE
855 
856 // BF16 is the same as svuint16_t because BF16 is optional before v8.6.
857 template <size_t N>
858 HWY_API svuint16_t Load(Simd<bfloat16_t, N> d,
859                         const bfloat16_t* HWY_RESTRICT p) {
860   return Load(RebindToUnsigned<decltype(d)>(),
861               reinterpret_cast<const uint16_t * HWY_RESTRICT>(p));
862 }
863 
864 template <size_t N>
Store(svuint16_t v,Simd<bfloat16_t,N> d,bfloat16_t * HWY_RESTRICT p)865 HWY_API void Store(svuint16_t v, Simd<bfloat16_t, N> d,
866                    bfloat16_t* HWY_RESTRICT p) {
867   Store(v, RebindToUnsigned<decltype(d)>(),
868         reinterpret_cast<uint16_t * HWY_RESTRICT>(p));
869 }
870 
871 // ------------------------------ Load/StoreU
872 
873 // SVE only requires lane alignment, not natural alignment of the entire
874 // vector.
875 template <class D>
LoadU(D d,const TFromD<D> * HWY_RESTRICT p)876 HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) {
877   return Load(d, p);
878 }
879 
880 template <class V, class D>
StoreU(const V v,D d,TFromD<D> * HWY_RESTRICT p)881 HWY_API void StoreU(const V v, D d, TFromD<D>* HWY_RESTRICT p) {
882   Store(v, d, p);
883 }
884 
885 // ------------------------------ ScatterOffset/Index
886 
887 #define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, NAME, OP)                   \
888   template <size_t N>                                                        \
889   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d,     \
890                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,               \
891                     HWY_SVE_V(int, BITS) offset) {                           \
892     sv##OP##_s##BITS##offset_##CHAR##BITS(detail::Mask(d), base, offset, v); \
893   }
894 
895 #define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, NAME, OP)                  \
896   template <size_t N>                                                      \
897   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d,   \
898                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,             \
899                     HWY_SVE_V(int, BITS) index) {                          \
900     sv##OP##_s##BITS##index_##CHAR##BITS(detail::Mask(d), base, index, v); \
901   }
902 
HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET,ScatterOffset,st1_scatter)903 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter)
904 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter)
905 #undef HWY_SVE_SCATTER_OFFSET
906 #undef HWY_SVE_SCATTER_INDEX
907 
908 // ------------------------------ GatherOffset/Index
909 
910 #define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, NAME, OP)               \
911   template <size_t N>                                                   \
912   HWY_API HWY_SVE_V(BASE, BITS)                                         \
913       NAME(HWY_SVE_D(BASE, BITS, N) d,                                  \
914            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,             \
915            HWY_SVE_V(int, BITS) offset) {                               \
916     return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::Mask(d), base, \
917                                                  offset);               \
918   }
919 #define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, NAME, OP)                       \
920   template <size_t N>                                                          \
921   HWY_API HWY_SVE_V(BASE, BITS)                                                \
922       NAME(HWY_SVE_D(BASE, BITS, N) d,                                         \
923            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,                    \
924            HWY_SVE_V(int, BITS) index) {                                       \
925     return sv##OP##_s##BITS##index_##CHAR##BITS(detail::Mask(d), base, index); \
926   }
927 
928 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather)
929 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather)
930 #undef HWY_SVE_GATHER_OFFSET
931 #undef HWY_SVE_GATHER_INDEX
932 
933 // ------------------------------ StoreInterleaved3
934 
935 #define HWY_SVE_STORE3(BASE, CHAR, BITS, NAME, OP)                            \
936   template <size_t N>                                                         \
937   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1,       \
938                     HWY_SVE_V(BASE, BITS) v2, HWY_SVE_D(BASE, BITS, N) d,     \
939                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) {         \
940     const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \
941     sv##OP##_##CHAR##BITS(detail::Mask(d), unaligned, triple);                \
942   }
943 HWY_SVE_FOREACH_U08(HWY_SVE_STORE3, StoreInterleaved3, st3)
944 
945 #undef HWY_SVE_STORE3
946 
947 // ------------------------------ StoreInterleaved4
948 
949 #define HWY_SVE_STORE4(BASE, CHAR, BITS, NAME, OP)                      \
950   template <size_t N>                                                   \
951   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
952                     HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \
953                     HWY_SVE_D(BASE, BITS, N) d,                         \
954                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) {   \
955     const sv##BASE##BITS##x4_t quad =                                   \
956         svcreate4##_##CHAR##BITS(v0, v1, v2, v3);                       \
957     sv##OP##_##CHAR##BITS(detail::Mask(d), unaligned, quad);            \
958   }
959 HWY_SVE_FOREACH_U08(HWY_SVE_STORE4, StoreInterleaved4, st4)
960 
961 #undef HWY_SVE_STORE4
962 
963 // ================================================== CONVERT
964 
965 // ------------------------------ PromoteTo
966 
967 // Same sign
968 #define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, NAME, OP)        \
969   template <size_t N>                                         \
970   HWY_API HWY_SVE_V(BASE, BITS)                               \
971       NAME(HWY_SVE_D(BASE, BITS, N) /* tag */,                \
972            VFromD<Simd<MakeNarrow<HWY_SVE_T(BASE, BITS)>,     \
973                        HWY_LANES(HWY_SVE_T(BASE, BITS)) * 2>> \
974                v) {                                           \
975     return sv##OP##_##CHAR##BITS(v);                          \
976   }
977 
978 HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
979 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
980 HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
981 
982 // 2x
983 template <size_t N>
984 HWY_API svuint32_t PromoteTo(Simd<uint32_t, N> dto, svuint8_t vfrom) {
985   const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
986   return PromoteTo(dto, PromoteTo(d2, vfrom));
987 }
988 template <size_t N>
PromoteTo(Simd<int32_t,N> dto,svint8_t vfrom)989 HWY_API svint32_t PromoteTo(Simd<int32_t, N> dto, svint8_t vfrom) {
990   const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
991   return PromoteTo(dto, PromoteTo(d2, vfrom));
992 }
993 template <size_t N>
U32FromU8(svuint8_t v)994 HWY_API svuint32_t U32FromU8(svuint8_t v) {
995   return PromoteTo(Simd<uint32_t, N>(), v);
996 }
997 
998 // Sign change
999 template <size_t N>
PromoteTo(Simd<int16_t,N> dto,svuint8_t vfrom)1000 HWY_API svint16_t PromoteTo(Simd<int16_t, N> dto, svuint8_t vfrom) {
1001   const RebindToUnsigned<decltype(dto)> du;
1002   return BitCast(dto, PromoteTo(du, vfrom));
1003 }
1004 template <size_t N>
PromoteTo(Simd<int32_t,N> dto,svuint16_t vfrom)1005 HWY_API svint32_t PromoteTo(Simd<int32_t, N> dto, svuint16_t vfrom) {
1006   const RebindToUnsigned<decltype(dto)> du;
1007   return BitCast(dto, PromoteTo(du, vfrom));
1008 }
1009 template <size_t N>
PromoteTo(Simd<int32_t,N> dto,svuint8_t vfrom)1010 HWY_API svint32_t PromoteTo(Simd<int32_t, N> dto, svuint8_t vfrom) {
1011   const Repartition<uint16_t, DFromV<decltype(vfrom)>> du16;
1012   const Repartition<int16_t, decltype(du16)> di16;
1013   return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom)));
1014 }
1015 
1016 // ------------------------------ PromoteTo F
1017 
1018 template <size_t N>
PromoteTo(Simd<float32_t,N>,const svfloat16_t v)1019 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N> /* d */, const svfloat16_t v) {
1020   return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N>()), v);
1021 }
1022 
1023 template <size_t N>
PromoteTo(Simd<float64_t,N>,const svfloat32_t v)1024 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N> /* d */, const svfloat32_t v) {
1025   return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N>()), v);
1026 }
1027 
1028 template <size_t N>
PromoteTo(Simd<float64_t,N>,const svint32_t v)1029 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N> /* d */, const svint32_t v) {
1030   return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N>()), v);
1031 }
1032 
1033 // For 16-bit Compress
1034 namespace detail {
HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO,PromoteUpperTo,unpkhi)1035 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi)
1036 #undef HWY_SVE_PROMOTE_TO
1037 
1038 template <size_t N>
1039 HWY_API svfloat32_t PromoteUpperTo(Simd<float, N> df, const svfloat16_t v) {
1040   const RebindToUnsigned<decltype(df)> du;
1041   const RepartitionToNarrow<decltype(du)> dn;
1042   return BitCast(df, PromoteUpperTo(du, BitCast(dn, v)));
1043 }
1044 
1045 }  // namespace detail
1046 
1047 // ------------------------------ DemoteTo U
1048 
1049 namespace detail {
1050 
1051 // Saturates unsigned vectors to half/quarter-width TN.
1052 template <typename TN, class VU>
SaturateU(VU v)1053 VU SaturateU(VU v) {
1054   return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>()));
1055 }
1056 
1057 // Saturates unsigned vectors to half/quarter-width TN.
1058 template <typename TN, class VI>
SaturateI(VI v)1059 VI SaturateI(VI v) {
1060   const DFromV<VI> di;
1061   return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>());
1062 }
1063 
1064 }  // namespace detail
1065 
1066 template <size_t N>
DemoteTo(Simd<uint8_t,N> dn,const svint16_t v)1067 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N> dn, const svint16_t v) {
1068   const DFromV<decltype(v)> di;
1069   const RebindToUnsigned<decltype(di)> du;
1070   using TN = TFromD<decltype(dn)>;
1071   // First clamp negative numbers to zero and cast to unsigned.
1072   const svuint16_t clamped = BitCast(du, Max(Zero(di), v));
1073   // Saturate to unsigned-max and halve the width.
1074   const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
1075   return svuzp1_u8(vn, vn);
1076 }
1077 
1078 template <size_t N>
DemoteTo(Simd<uint16_t,N> dn,const svint32_t v)1079 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N> dn, const svint32_t v) {
1080   const DFromV<decltype(v)> di;
1081   const RebindToUnsigned<decltype(di)> du;
1082   using TN = TFromD<decltype(dn)>;
1083   // First clamp negative numbers to zero and cast to unsigned.
1084   const svuint32_t clamped = BitCast(du, Max(Zero(di), v));
1085   // Saturate to unsigned-max and halve the width.
1086   const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
1087   return svuzp1_u16(vn, vn);
1088 }
1089 
1090 template <size_t N>
DemoteTo(Simd<uint8_t,N> dn,const svint32_t v)1091 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N> dn, const svint32_t v) {
1092   const DFromV<decltype(v)> di;
1093   const RebindToUnsigned<decltype(di)> du;
1094   const RepartitionToNarrow<decltype(du)> d2;
1095   using TN = TFromD<decltype(dn)>;
1096   // First clamp negative numbers to zero and cast to unsigned.
1097   const svuint32_t clamped = BitCast(du, Max(Zero(di), v));
1098   // Saturate to unsigned-max and quarter the width.
1099   const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped));
1100   const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16));
1101   return svuzp1_u8(x2, x2);
1102 }
1103 
U8FromU32(const svuint32_t v)1104 HWY_API svuint8_t U8FromU32(const svuint32_t v) {
1105   const DFromV<svuint32_t> du32;
1106   const RepartitionToNarrow<decltype(du32)> du16;
1107   const RepartitionToNarrow<decltype(du16)> du8;
1108 
1109   const svuint16_t cast16 = BitCast(du16, v);
1110   const svuint16_t x2 = svuzp1_u16(cast16, cast16);
1111   const svuint8_t cast8 = BitCast(du8, x2);
1112   return svuzp1_u8(cast8, cast8);
1113 }
1114 
1115 // ------------------------------ DemoteTo I
1116 
1117 template <size_t N>
DemoteTo(Simd<int8_t,N> dn,const svint16_t v)1118 HWY_API svint8_t DemoteTo(Simd<int8_t, N> dn, const svint16_t v) {
1119   const DFromV<decltype(v)> di;
1120   using TN = TFromD<decltype(dn)>;
1121 #if HWY_TARGET == HWY_SVE2
1122   const svint8_t vn = BitCast(dn, svqxtnb_s16(v));
1123 #else
1124   const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v));
1125 #endif
1126   return svuzp1_s8(vn, vn);
1127 }
1128 
1129 template <size_t N>
DemoteTo(Simd<int16_t,N> dn,const svint32_t v)1130 HWY_API svint16_t DemoteTo(Simd<int16_t, N> dn, const svint32_t v) {
1131   const DFromV<decltype(v)> di;
1132   using TN = TFromD<decltype(dn)>;
1133 #if HWY_TARGET == HWY_SVE2
1134   const svint16_t vn = BitCast(dn, svqxtnb_s32(v));
1135 #else
1136   const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v));
1137 #endif
1138   return svuzp1_s16(vn, vn);
1139 }
1140 
1141 template <size_t N>
DemoteTo(Simd<int8_t,N> dn,const svint32_t v)1142 HWY_API svint8_t DemoteTo(Simd<int8_t, N> dn, const svint32_t v) {
1143   const DFromV<decltype(v)> di;
1144   using TN = TFromD<decltype(dn)>;
1145   const RepartitionToWide<decltype(dn)> d2;
1146 #if HWY_TARGET == HWY_SVE2
1147   const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v)));
1148 #else
1149   const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v));
1150 #endif
1151   const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16));
1152   return BitCast(dn, svuzp1_s8(v2, v2));
1153 }
1154 
1155 // ------------------------------ ConcatEven/ConcatOdd
1156 
1157 // WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the
1158 // full vector length, not rounded down to a power of two as we require).
1159 namespace detail {
1160 
1161 #define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, NAME, OP)  \
1162   HWY_INLINE HWY_SVE_V(BASE, BITS)                               \
1163       NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
1164     return sv##OP##_##CHAR##BITS(lo, hi);                        \
1165   }
1166 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEven, uzp1)
1167 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOdd, uzp2)
1168 #undef HWY_SVE_CONCAT_EVERY_SECOND
1169 
1170 // Used to slide up / shift whole register left; mask indicates which range
1171 // to take from lo, and the rest is filled from hi starting at its lowest.
1172 #define HWY_SVE_SPLICE(BASE, CHAR, BITS, NAME, OP)                         \
1173   HWY_API HWY_SVE_V(BASE, BITS) NAME(                                      \
1174       HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \
1175     return sv##OP##_##CHAR##BITS(mask, lo, hi);                            \
1176   }
1177 HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice)
1178 #undef HWY_SVE_SPLICE
1179 
1180 }  // namespace detail
1181 
1182 template <class D>
ConcatOdd(D d,VFromD<D> hi,VFromD<D> lo)1183 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
1184 #if 0  // if we could assume VL is a power of two
1185   return detail::ConcatOdd(hi, lo);
1186 #else
1187   const VFromD<D> hi_odd = detail::ConcatOdd(hi, hi);
1188   const VFromD<D> lo_odd = detail::ConcatOdd(lo, lo);
1189   return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
1190 #endif
1191 }
1192 
1193 template <class D>
ConcatEven(D d,VFromD<D> hi,VFromD<D> lo)1194 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
1195 #if 0  // if we could assume VL is a power of two
1196   return detail::ConcatEven(hi, lo);
1197 #else
1198   const VFromD<D> hi_odd = detail::ConcatEven(hi, hi);
1199   const VFromD<D> lo_odd = detail::ConcatEven(lo, lo);
1200   return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
1201 #endif
1202 }
1203 
1204 // ------------------------------ DemoteTo F
1205 
1206 template <size_t N>
DemoteTo(Simd<float16_t,N> d,const svfloat32_t v)1207 HWY_API svfloat16_t DemoteTo(Simd<float16_t, N> d, const svfloat32_t v) {
1208   return svcvt_f16_f32_x(detail::PTrue(d), v);
1209 }
1210 
1211 template <size_t N>
DemoteTo(Simd<bfloat16_t,N> d,const svfloat32_t v)1212 HWY_API svuint16_t DemoteTo(Simd<bfloat16_t, N> d, const svfloat32_t v) {
1213   const svuint16_t halves = BitCast(Full<uint16_t>(), v);
1214   return detail::ConcatOdd(halves, halves);  // can ignore upper half of vec
1215 }
1216 
1217 template <size_t N>
DemoteTo(Simd<float32_t,N> d,const svfloat64_t v)1218 HWY_API svfloat32_t DemoteTo(Simd<float32_t, N> d, const svfloat64_t v) {
1219   return svcvt_f32_f64_x(detail::PTrue(d), v);
1220 }
1221 
1222 template <size_t N>
DemoteTo(Simd<int32_t,N> d,const svfloat64_t v)1223 HWY_API svint32_t DemoteTo(Simd<int32_t, N> d, const svfloat64_t v) {
1224   return svcvt_s32_f64_x(detail::PTrue(d), v);
1225 }
1226 
1227 // ------------------------------ ConvertTo F
1228 
1229 #define HWY_SVE_CONVERT(BASE, CHAR, BITS, NAME, OP)                     \
1230   template <size_t N>                                                   \
1231   HWY_API HWY_SVE_V(BASE, BITS)                                         \
1232       NAME(HWY_SVE_D(BASE, BITS, N) /* d */, HWY_SVE_V(int, BITS) v) {  \
1233     return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
1234   }                                                                     \
1235   /* Truncates (rounds toward zero). */                                 \
1236   template <size_t N>                                                   \
1237   HWY_API HWY_SVE_V(int, BITS)                                          \
1238       NAME(HWY_SVE_D(int, BITS, N) /* d */, HWY_SVE_V(BASE, BITS) v) {  \
1239     return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
1240   }
1241 
1242 // API only requires f32 but we provide f64 for use by Iota.
HWY_SVE_FOREACH_F(HWY_SVE_CONVERT,ConvertTo,cvt)1243 HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt)
1244 #undef HWY_SVE_CONVERT
1245 
1246 // ------------------------------ NearestInt (Round, ConvertTo)
1247 
1248 template <class VF, class DI = RebindToSigned<DFromV<VF>>>
1249 HWY_API VFromD<DI> NearestInt(VF v) {
1250   // No single instruction, round then truncate.
1251   return ConvertTo(DI(), Round(v));
1252 }
1253 
1254 // ------------------------------ Iota (Add, ConvertTo)
1255 
1256 #define HWY_SVE_IOTA(BASE, CHAR, BITS, NAME, OP)                      \
1257   template <size_t N>                                                 \
1258   HWY_API HWY_SVE_V(BASE, BITS)                                       \
1259       NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_T(BASE, BITS) first) { \
1260     return sv##OP##_##CHAR##BITS(first, 1);                           \
1261   }
1262 
HWY_SVE_FOREACH_UI(HWY_SVE_IOTA,Iota,index)1263 HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index)
1264 #undef HWY_SVE_IOTA
1265 
1266 template <class D, HWY_IF_FLOAT_D(D)>
1267 HWY_API VFromD<D> Iota(const D d, TFromD<D> first) {
1268   const RebindToSigned<D> di;
1269   return detail::AddN(ConvertTo(d, Iota(di, 0)), first);
1270 }
1271 
1272 // ================================================== COMBINE
1273 
1274 namespace detail {
1275 
1276 template <typename T, size_t N>
MaskLowerHalf(Simd<T,N> d)1277 svbool_t MaskLowerHalf(Simd<T, N> d) {
1278   return FirstN(d, Lanes(d) / 2);
1279 }
1280 template <typename T, size_t N>
MaskUpperHalf(Simd<T,N> d)1281 svbool_t MaskUpperHalf(Simd<T, N> d) {
1282   // For Splice to work as intended, make sure bits above Lanes(d) are zero.
1283   return AndNot(MaskLowerHalf(d), detail::Mask(d));
1284 }
1285 
1286 // Right-shift vector pair by constexpr; can be used to slide down (=N) or up
1287 // (=Lanes()-N).
1288 #define HWY_SVE_EXT(BASE, CHAR, BITS, NAME, OP)                  \
1289   template <size_t kIndex>                                       \
1290   HWY_API HWY_SVE_V(BASE, BITS)                                  \
1291       NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
1292     return sv##OP##_##CHAR##BITS(lo, hi, kIndex);                \
1293   }
1294 HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext)
1295 #undef HWY_SVE_EXT
1296 
1297 }  // namespace detail
1298 
1299 // ------------------------------ ConcatUpperLower
1300 template <class D, class V>
ConcatUpperLower(const D d,const V hi,const V lo)1301 HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) {
1302   return IfThenElse(detail::MaskLowerHalf(d), lo, hi);
1303 }
1304 
1305 // ------------------------------ ConcatLowerLower
1306 template <class D, class V>
ConcatLowerLower(const D d,const V hi,const V lo)1307 HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) {
1308   return detail::Splice(hi, lo, detail::MaskLowerHalf(d));
1309 }
1310 
1311 // ------------------------------ ConcatLowerUpper
1312 template <class D, class V>
ConcatLowerUpper(const D d,const V hi,const V lo)1313 HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) {
1314   return detail::Splice(hi, lo, detail::MaskUpperHalf(d));
1315 }
1316 
1317 // ------------------------------ ConcatUpperUpper
1318 template <class D, class V>
ConcatUpperUpper(const D d,const V hi,const V lo)1319 HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) {
1320   const svbool_t mask_upper = detail::MaskUpperHalf(d);
1321   const V lo_upper = detail::Splice(lo, lo, mask_upper);
1322   return IfThenElse(mask_upper, hi, lo_upper);
1323 }
1324 
1325 // ------------------------------ Combine
1326 template <class D, class V2>
Combine(const D d,const V2 hi,const V2 lo)1327 HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) {
1328   return ConcatLowerLower(d, hi, lo);
1329 }
1330 
1331 // ------------------------------ ZeroExtendVector
1332 
1333 template <class D, class V>
ZeroExtendVector(const D d,const V lo)1334 HWY_API V ZeroExtendVector(const D d, const V lo) {
1335   return Combine(d, Zero(Half<D>()), lo);
1336 }
1337 
1338 // ------------------------------ Lower/UpperHalf
1339 
1340 template <class D2, class V>
LowerHalf(D2,const V v)1341 HWY_API V LowerHalf(D2 /* tag */, const V v) {
1342   return v;
1343 }
1344 
1345 template <class V>
LowerHalf(const V v)1346 HWY_API V LowerHalf(const V v) {
1347   return v;
1348 }
1349 
1350 template <class D2, class V>
UpperHalf(const D2 d2,const V v)1351 HWY_API V UpperHalf(const D2 d2, const V v) {
1352   return detail::Splice(v, v, detail::MaskUpperHalf(Twice<D2>()));
1353 }
1354 
1355 // ================================================== SWIZZLE
1356 
1357 // ------------------------------ GetLane
1358 
1359 #define HWY_SVE_GET_LANE(BASE, CHAR, BITS, NAME, OP)            \
1360   HWY_API HWY_SVE_T(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
1361     return sv##OP##_##CHAR##BITS(detail::PFalse(), v);          \
1362   }
1363 
HWY_SVE_FOREACH(HWY_SVE_GET_LANE,GetLane,lasta)1364 HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLane, lasta)
1365 #undef HWY_SVE_GET_LANE
1366 
1367 // ------------------------------ OddEven
1368 
1369 namespace detail {
1370 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVN, Insert, insr_n)
1371 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1)
1372 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2)
1373 }  // namespace detail
1374 
1375 template <class V>
OddEven(const V odd,const V even)1376 HWY_API V OddEven(const V odd, const V even) {
1377   const auto even_in_odd = detail::Insert(even, 0);
1378   return detail::InterleaveOdd(even_in_odd, odd);
1379 }
1380 
1381 // ------------------------------ OddEvenBlocks
1382 template <class V>
OddEvenBlocks(const V odd,const V even)1383 HWY_API V OddEvenBlocks(const V odd, const V even) {
1384   const RebindToUnsigned<DFromV<V>> du;
1385   constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV<V>));
1386   const auto idx_block = ShiftRight<kShift>(Iota(du, 0));
1387   const svbool_t is_even = Eq(detail::AndN(idx_block, 1), Zero(du));
1388   return IfThenElse(is_even, even, odd);
1389 }
1390 
1391 // ------------------------------ SwapAdjacentBlocks
1392 
1393 namespace detail {
1394 
1395 template <typename T, size_t N>
LanesPerBlock(Simd<T,N>)1396 constexpr size_t LanesPerBlock(Simd<T, N> /* tag */) {
1397   // We might have a capped vector smaller than a block, so honor that.
1398   return HWY_MIN(16 / sizeof(T), N);
1399 }
1400 
1401 }  // namespace detail
1402 
1403 template <class V>
SwapAdjacentBlocks(const V v)1404 HWY_API V SwapAdjacentBlocks(const V v) {
1405   const DFromV<V> d;
1406   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
1407   const V down = detail::Ext<kLanesPerBlock>(v, v);
1408   const V up = detail::Splice(v, v, FirstN(d, kLanesPerBlock));
1409   return OddEvenBlocks(up, down);
1410 }
1411 
1412 // ------------------------------ TableLookupLanes
1413 
1414 template <class D, class VI>
IndicesFromVec(D d,VI vec)1415 HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) {
1416   static_assert(sizeof(TFromD<D>) == sizeof(TFromV<VI>), "Index != lane");
1417   const RebindToUnsigned<D> du;
1418   const auto indices = BitCast(du, vec);
1419 #if HWY_IS_DEBUG_BUILD
1420   HWY_DASSERT(AllTrue(du, Lt(indices, Set(du, Lanes(d)))));
1421 #endif
1422   return indices;
1423 }
1424 
1425 template <class D, typename TI>
SetTableIndices(D d,const TI * idx)1426 HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) {
1427   static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane");
1428   return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx));
1429 }
1430 
1431 // <32bit are not part of Highway API, but used in Broadcast.
1432 #define HWY_SVE_TABLE(BASE, CHAR, BITS, NAME, OP)                \
1433   HWY_API HWY_SVE_V(BASE, BITS)                                  \
1434       NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \
1435     return sv##OP##_##CHAR##BITS(v, idx);                        \
1436   }
1437 
HWY_SVE_FOREACH(HWY_SVE_TABLE,TableLookupLanes,tbl)1438 HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl)
1439 #undef HWY_SVE_TABLE
1440 
1441 // ------------------------------ Reverse
1442 
1443 #if 0  // if we could assume VL is a power of two
1444 #error "Update macro"
1445 #endif
1446 #define HWY_SVE_REVERSE(BASE, CHAR, BITS, NAME, OP)                     \
1447   template <size_t N>                                                   \
1448   HWY_API HWY_SVE_V(BASE, BITS)                                         \
1449       NAME(Simd<HWY_SVE_T(BASE, BITS), N> d, HWY_SVE_V(BASE, BITS) v) { \
1450     const auto reversed = sv##OP##_##CHAR##BITS(v);                     \
1451     /* Shift right to remove extra (non-pow2 and remainder) lanes. */   \
1452     const size_t all_lanes =                                            \
1453         detail::AllHardwareLanes(hwy::SizeTag<BITS / 8>());             \
1454     /* TODO(janwas): on SVE2, use whilege. */                           \
1455     const svbool_t mask = Not(FirstN(d, all_lanes - Lanes(d)));         \
1456     return detail::Splice(reversed, reversed, mask);                    \
1457   }
1458 
1459 HWY_SVE_FOREACH(HWY_SVE_REVERSE, Reverse, rev)
1460 #undef HWY_SVE_REVERSE
1461 
1462 // ------------------------------ Compress (PromoteTo)
1463 
1464 #define HWY_SVE_COMPRESS(BASE, CHAR, BITS, NAME, OP)                           \
1465   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \
1466     return sv##OP##_##CHAR##BITS(mask, v);                                     \
1467   }
1468 
1469 HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact)
1470 #undef HWY_SVE_COMPRESS
1471 
1472 template <class V, HWY_IF_LANE_SIZE_V(V, 2)>
1473 HWY_API V Compress(V v, svbool_t mask16) {
1474   static_assert(!IsSame<V, svfloat16_t>(), "Must use overload");
1475   const DFromV<V> d16;
1476 
1477   // Promote vector and mask to 32-bit
1478   const RepartitionToWide<decltype(d16)> dw;
1479   const auto v32L = PromoteTo(dw, v);
1480   const auto v32H = detail::PromoteUpperTo(dw, v);
1481   const svbool_t mask32L = svunpklo_b(mask16);
1482   const svbool_t mask32H = svunpkhi_b(mask16);
1483 
1484   const auto compressedL = Compress(v32L, mask32L);
1485   const auto compressedH = Compress(v32H, mask32H);
1486 
1487   // Demote to 16-bit (already in range) - separately so we can splice
1488   const V evenL = BitCast(d16, compressedL);
1489   const V evenH = BitCast(d16, compressedH);
1490   const V v16L = detail::ConcatEven(evenL, evenL);  // only lower half needed
1491   const V v16H = detail::ConcatEven(evenH, evenH);
1492 
1493   // We need to combine two vectors of non-constexpr length, so the only option
1494   // is Splice, which requires us to synthesize a mask. NOTE: this function uses
1495   // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt.
1496   const size_t countL = detail::CountTrueFull(dw, mask32L);
1497   const auto compressed_maskL = FirstN(d16, countL);
1498   return detail::Splice(v16H, v16L, compressed_maskL);
1499 }
1500 
1501 // Must treat float16_t as integers so we can ConcatEven.
Compress(svfloat16_t v,svbool_t mask16)1502 HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) {
1503   const DFromV<decltype(v)> df;
1504   const RebindToSigned<decltype(df)> di;
1505   return BitCast(df, Compress(BitCast(di, v), mask16));
1506 }
1507 
1508 // ------------------------------ CompressStore
1509 
1510 template <class V, class M, class D>
CompressStore(const V v,const M mask,const D d,TFromD<D> * HWY_RESTRICT unaligned)1511 HWY_API size_t CompressStore(const V v, const M mask, const D d,
1512                              TFromD<D>* HWY_RESTRICT unaligned) {
1513   StoreU(Compress(v, mask), d, unaligned);
1514   return CountTrue(d, mask);
1515 }
1516 
1517 // ------------------------------ CompressBlendedStore
1518 
1519 template <class V, class M, class D>
CompressBlendedStore(const V v,const M mask,const D d,TFromD<D> * HWY_RESTRICT unaligned)1520 HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d,
1521                                     TFromD<D>* HWY_RESTRICT unaligned) {
1522   const size_t count = CountTrue(d, mask);
1523   const svbool_t store_mask = FirstN(d, count);
1524   MaskedStore(store_mask, Compress(v, mask), d, unaligned);
1525   return count;
1526 }
1527 
1528 // ================================================== BLOCKWISE
1529 
1530 // ------------------------------ CombineShiftRightBytes
1531 
1532 namespace detail {
1533 
1534 // For x86-compatible behaviour mandated by Highway API: TableLookupBytes
1535 // offsets are implicitly relative to the start of their 128-bit block.
1536 template <class D, class V>
OffsetsOf128BitBlocks(const D d,const V iota0)1537 HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) {
1538   using T = MakeUnsigned<TFromD<D>>;
1539   return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0);
1540 }
1541 
1542 template <size_t kLanes, class D>
FirstNPerBlock(D d)1543 svbool_t FirstNPerBlock(D d) {
1544   const RebindToSigned<D> di;
1545   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di);
1546   const auto idx_mod = detail::AndN(Iota(di, 0), kLanesPerBlock - 1);
1547   return detail::LtN(BitCast(di, idx_mod), kLanes);
1548 }
1549 
1550 }  // namespace detail
1551 
1552 template <size_t kBytes, class D, class V = VFromD<D>>
CombineShiftRightBytes(const D d,const V hi,const V lo)1553 HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) {
1554   const Repartition<uint8_t, decltype(d)> d8;
1555   const auto hi8 = BitCast(d8, hi);
1556   const auto lo8 = BitCast(d8, lo);
1557   const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes));
1558   const auto lo_down = detail::Ext<kBytes>(lo8, lo8);
1559   const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8);
1560   return BitCast(d, IfThenElse(is_lo, lo_down, hi_up));
1561 }
1562 
1563 // ------------------------------ Shuffle2301
1564 
1565 #define HWY_SVE_SHUFFLE_2301(BASE, CHAR, BITS, NAME, OP)                      \
1566   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) {               \
1567     const DFromV<decltype(v)> d;                                              \
1568     const svuint64_t vu64 = BitCast(Repartition<uint64_t, decltype(d)>(), v); \
1569     return BitCast(d, sv##OP##_u64_x(HWY_SVE_PTRUE(64), vu64));               \
1570   }
1571 
HWY_SVE_FOREACH_UI32(HWY_SVE_SHUFFLE_2301,Shuffle2301,revw)1572 HWY_SVE_FOREACH_UI32(HWY_SVE_SHUFFLE_2301, Shuffle2301, revw)
1573 #undef HWY_SVE_SHUFFLE_2301
1574 
1575 template <class V, HWY_IF_FLOAT_V(V)>
1576 HWY_API V Shuffle2301(const V v) {
1577   const DFromV<V> df;
1578   const RebindToUnsigned<decltype(df)> du;
1579   return BitCast(df, Shuffle2301(BitCast(du, v)));
1580 }
1581 
1582 // ------------------------------ Shuffle2103
1583 template <class V>
Shuffle2103(const V v)1584 HWY_API V Shuffle2103(const V v) {
1585   const DFromV<V> d;
1586   const Repartition<uint8_t, decltype(d)> d8;
1587   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1588   const svuint8_t v8 = BitCast(d8, v);
1589   return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8));
1590 }
1591 
1592 // ------------------------------ Shuffle0321
1593 template <class V>
Shuffle0321(const V v)1594 HWY_API V Shuffle0321(const V v) {
1595   const DFromV<V> d;
1596   const Repartition<uint8_t, decltype(d)> d8;
1597   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1598   const svuint8_t v8 = BitCast(d8, v);
1599   return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8));
1600 }
1601 
1602 // ------------------------------ Shuffle1032
1603 template <class V>
Shuffle1032(const V v)1604 HWY_API V Shuffle1032(const V v) {
1605   const DFromV<V> d;
1606   const Repartition<uint8_t, decltype(d)> d8;
1607   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1608   const svuint8_t v8 = BitCast(d8, v);
1609   return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
1610 }
1611 
1612 // ------------------------------ Shuffle01
1613 template <class V>
Shuffle01(const V v)1614 HWY_API V Shuffle01(const V v) {
1615   const DFromV<V> d;
1616   const Repartition<uint8_t, decltype(d)> d8;
1617   static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types");
1618   const svuint8_t v8 = BitCast(d8, v);
1619   return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
1620 }
1621 
1622 // ------------------------------ Shuffle0123
1623 template <class V>
Shuffle0123(const V v)1624 HWY_API V Shuffle0123(const V v) {
1625   return Shuffle2301(Shuffle1032(v));
1626 }
1627 
1628 // ------------------------------ TableLookupBytes
1629 
1630 template <class V, class VI>
TableLookupBytes(const V v,const VI idx)1631 HWY_API VI TableLookupBytes(const V v, const VI idx) {
1632   const DFromV<VI> d;
1633   const Repartition<uint8_t, decltype(d)> du8;
1634   const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0));
1635   const auto idx8 = Add(BitCast(du8, idx), offsets128);
1636   return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8));
1637 }
1638 
1639 template <class V, class VI>
TableLookupBytesOr0(const V v,const VI idx)1640 HWY_API VI TableLookupBytesOr0(const V v, const VI idx) {
1641   const DFromV<VI> d;
1642   // Mask size must match vector type, so cast everything to this type.
1643   const Repartition<int8_t, decltype(d)> di8;
1644 
1645   auto idx8 = BitCast(di8, idx);
1646   const auto msb = Lt(idx8, Zero(di8));
1647 // Prevent overflow in table lookups (unnecessary if native)
1648 #if defined(HWY_EMULATE_SVE)
1649   idx8 = IfThenZeroElse(msb, idx8);
1650 #endif
1651 
1652   const auto lookup = TableLookupBytes(BitCast(di8, v), idx8);
1653   return BitCast(d, IfThenZeroElse(msb, lookup));
1654 }
1655 
1656 // ------------------------------ Broadcast
1657 
1658 template <int kLane, class V>
Broadcast(const V v)1659 HWY_API V Broadcast(const V v) {
1660   const DFromV<V> d;
1661   const RebindToUnsigned<decltype(d)> du;
1662   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
1663   static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane");
1664   auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0));
1665   if (kLane != 0) {
1666     idx = detail::AddN(idx, kLane);
1667   }
1668   return TableLookupLanes(v, idx);
1669 }
1670 
1671 // ------------------------------ ShiftLeftLanes
1672 
1673 template <size_t kLanes, class D, class V = VFromD<D>>
ShiftLeftLanes(D d,const V v)1674 HWY_API V ShiftLeftLanes(D d, const V v) {
1675   const RebindToSigned<decltype(d)> di;
1676   const auto zero = Zero(d);
1677   const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes));
1678   // Match x86 semantics by zeroing lower lanes in 128-bit blocks
1679   return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted);
1680 }
1681 
1682 template <size_t kLanes, class V>
ShiftLeftLanes(const V v)1683 HWY_API V ShiftLeftLanes(const V v) {
1684   return ShiftLeftLanes<kLanes>(DFromV<V>(), v);
1685 }
1686 
1687 // ------------------------------ ShiftRightLanes
1688 template <size_t kLanes, typename T, size_t N, class V = VFromD<Simd<T, N>>>
ShiftRightLanes(Simd<T,N> d,V v)1689 HWY_API V ShiftRightLanes(Simd<T, N> d, V v) {
1690   const RebindToSigned<decltype(d)> di;
1691   // For partial vectors, clear upper lanes so we shift in zeros.
1692   if (N != HWY_LANES(T)) {
1693     v = IfThenElseZero(detail::Mask(d), v);
1694   }
1695 
1696   const auto shifted = detail::Ext<kLanes>(v, v);
1697   // Match x86 semantics by zeroing upper lanes in 128-bit blocks
1698   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
1699   const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d);
1700   return IfThenElseZero(mask, shifted);
1701 }
1702 
1703 // ------------------------------ ShiftLeftBytes
1704 
1705 template <int kBytes, class D, class V = VFromD<D>>
ShiftLeftBytes(const D d,const V v)1706 HWY_API V ShiftLeftBytes(const D d, const V v) {
1707   const Repartition<uint8_t, decltype(d)> d8;
1708   return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v)));
1709 }
1710 
1711 template <int kBytes, class V>
ShiftLeftBytes(const V v)1712 HWY_API V ShiftLeftBytes(const V v) {
1713   return ShiftLeftBytes<kBytes>(DFromV<V>(), v);
1714 }
1715 
1716 // ------------------------------ ShiftRightBytes
1717 template <int kBytes, class D, class V = VFromD<D>>
ShiftRightBytes(const D d,const V v)1718 HWY_API V ShiftRightBytes(const D d, const V v) {
1719   const Repartition<uint8_t, decltype(d)> d8;
1720   return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v)));
1721 }
1722 
1723 // ------------------------------ InterleaveLower
1724 
1725 namespace detail {
1726 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLower, zip1)
1727 // Do not use zip2 to implement PromoteUpperTo or similar because vectors may be
1728 // non-powers of two, so getting the actual "upper half" requires MaskUpperHalf.
1729 }  // namespace detail
1730 
1731 template <class D, class V>
InterleaveLower(D d,const V a,const V b)1732 HWY_API V InterleaveLower(D d, const V a, const V b) {
1733   static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch");
1734   // Move lower halves of blocks to lower half of vector.
1735   const Repartition<uint64_t, decltype(d)> d64;
1736   const auto a64 = BitCast(d64, a);
1737   const auto b64 = BitCast(d64, b);
1738   const auto a_blocks = detail::ConcatEven(a64, a64);  // only lower half needed
1739   const auto b_blocks = detail::ConcatEven(b64, b64);
1740 
1741   return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks));
1742 }
1743 
1744 template <class V>
InterleaveLower(const V a,const V b)1745 HWY_API V InterleaveLower(const V a, const V b) {
1746   return InterleaveLower(DFromV<V>(), a, b);
1747 }
1748 
1749 // ------------------------------ InterleaveUpper
1750 
1751 // Full vector: guaranteed to have at least one block
1752 template <typename T, class V = VFromD<Full<T>>>
InterleaveUpper(Simd<T,HWY_LANES (T)> d,const V a,const V b)1753 HWY_API V InterleaveUpper(Simd<T, HWY_LANES(T)> d, const V a, const V b) {
1754   // Move upper halves of blocks to lower half of vector.
1755   const Repartition<uint64_t, decltype(d)> d64;
1756   const auto a64 = BitCast(d64, a);
1757   const auto b64 = BitCast(d64, b);
1758   const auto a_blocks = detail::ConcatOdd(a64, a64);  // only lower half needed
1759   const auto b_blocks = detail::ConcatOdd(b64, b64);
1760   return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks));
1761 }
1762 
1763 // Capped: less than one block
1764 template <typename T, size_t N, HWY_IF_LE64(T, N), class V = VFromD<Simd<T, N>>>
InterleaveUpper(Simd<T,N> d,const V a,const V b)1765 HWY_API V InterleaveUpper(Simd<T, N> d, const V a, const V b) {
1766   static_assert(IsSame<T, TFromV<V>>(), "D/V mismatch");
1767   const Half<decltype(d)> d2;
1768   return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b));
1769 }
1770 
1771 // Partial: need runtime check
1772 template <typename T, size_t N,
1773           hwy::EnableIf<(N < HWY_LANES(T) && N * sizeof(T) >= 16)>* = nullptr,
1774           class V = VFromD<Simd<T, N>>>
InterleaveUpper(Simd<T,N> d,const V a,const V b)1775 HWY_API V InterleaveUpper(Simd<T, N> d, const V a, const V b) {
1776   static_assert(IsSame<T, TFromV<V>>(), "D/V mismatch");
1777   // Less than one block: treat as capped
1778   if (Lanes(d) * sizeof(T) < 16) {
1779     const Half<decltype(d)> d2;
1780     return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b));
1781   }
1782   return InterleaveUpper(Full<T>(), a, b);
1783 }
1784 
1785 // ------------------------------ ZipLower
1786 
1787 template <class V, class DW = RepartitionToWide<DFromV<V>>>
ZipLower(DW dw,V a,V b)1788 HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) {
1789   const RepartitionToNarrow<DW> dn;
1790   static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
1791   return BitCast(dw, InterleaveLower(dn, a, b));
1792 }
1793 template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>>
ZipLower(const V a,const V b)1794 HWY_API VFromD<DW> ZipLower(const V a, const V b) {
1795   return BitCast(DW(), InterleaveLower(D(), a, b));
1796 }
1797 
1798 // ------------------------------ ZipUpper
1799 template <class V, class DW = RepartitionToWide<DFromV<V>>>
ZipUpper(DW dw,V a,V b)1800 HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) {
1801   const RepartitionToNarrow<DW> dn;
1802   static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
1803   return BitCast(dw, InterleaveUpper(dn, a, b));
1804 }
1805 
1806 // ================================================== REDUCE
1807 
1808 #define HWY_SVE_REDUCE(BASE, CHAR, BITS, NAME, OP)                \
1809   template <size_t N>                                             \
1810   HWY_API HWY_SVE_V(BASE, BITS)                                   \
1811       NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_V(BASE, BITS) v) { \
1812     return Set(d, sv##OP##_##CHAR##BITS(detail::Mask(d), v));     \
1813   }
1814 
HWY_SVE_FOREACH(HWY_SVE_REDUCE,SumOfLanes,addv)1815 HWY_SVE_FOREACH(HWY_SVE_REDUCE, SumOfLanes, addv)
1816 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanes, minv)
1817 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanes, maxv)
1818 // NaN if all are
1819 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanes, minnmv)
1820 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanes, maxnmv)
1821 
1822 #undef HWY_SVE_REDUCE
1823 
1824 // ================================================== Ops with dependencies
1825 
1826 // ------------------------------ PromoteTo bfloat16 (ZipLower)
1827 
1828 template <size_t N>
1829 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N> df32, const svuint16_t v) {
1830   return BitCast(df32, detail::ZipLower(svdup_n_u16(0), v));
1831 }
1832 
1833 // ------------------------------ ReorderDemote2To (OddEven)
1834 
1835 template <size_t N>
ReorderDemote2To(Simd<bfloat16_t,N> dbf16,svfloat32_t a,svfloat32_t b)1836 HWY_API svuint16_t ReorderDemote2To(Simd<bfloat16_t, N> dbf16, svfloat32_t a,
1837                                     svfloat32_t b) {
1838   const RebindToUnsigned<decltype(dbf16)> du16;
1839   const Repartition<uint32_t, decltype(dbf16)> du32;
1840   const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b));
1841   return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even)));
1842 }
1843 
1844 // ------------------------------ ZeroIfNegative (Lt, IfThenElse)
1845 template <class V>
ZeroIfNegative(const V v)1846 HWY_API V ZeroIfNegative(const V v) {
1847   const auto v0 = Zero(DFromV<V>());
1848   // We already have a zero constant, so avoid IfThenZeroElse.
1849   return IfThenElse(Lt(v, v0), v0, v);
1850 }
1851 
1852 // ------------------------------ BroadcastSignBit (ShiftRight)
1853 template <class V>
BroadcastSignBit(const V v)1854 HWY_API V BroadcastSignBit(const V v) {
1855   return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v);
1856 }
1857 
1858 // ------------------------------ AverageRound (ShiftRight)
1859 
1860 #if HWY_TARGET == HWY_SVE2
HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV,AverageRound,rhadd)1861 HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
1862 HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
1863 #else
1864 template <class V>
1865 V AverageRound(const V a, const V b) {
1866   return ShiftRight<1>(Add(Add(a, b), Set(DFromV<V>(), 1)));
1867 }
1868 #endif  // HWY_TARGET == HWY_SVE2
1869 
1870 // ------------------------------ LoadMaskBits (TestBit)
1871 
1872 // `p` points to at least 8 readable bytes, not all of which need be valid.
1873 template <class D, HWY_IF_LANE_SIZE_D(D, 1)>
1874 HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
1875   const RebindToUnsigned<D> du;
1876   const svuint8_t iota = Iota(du, 0);
1877 
1878   // Load correct number of bytes (bits/8) with 7 zeros after each.
1879   const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits));
1880   // Replicate bytes 8x such that each byte contains the bit that governs it.
1881   const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota));
1882 
1883   // 1, 2, 4, 8, 16, 32, 64, 128,  1, 2 ..
1884   const svuint8_t bit = Shl(Set(du, 1), detail::AndN(iota, 7));
1885 
1886   return TestBit(rep8, bit);
1887 }
1888 
1889 template <class D, HWY_IF_LANE_SIZE_D(D, 2)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1890 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1891                                  const uint8_t* HWY_RESTRICT bits) {
1892   const RebindToUnsigned<D> du;
1893   const Repartition<uint8_t, D> du8;
1894 
1895   // There may be up to 128 bits; avoid reading past the end.
1896   const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits);
1897 
1898   // Replicate bytes 16x such that each lane contains the bit that governs it.
1899   const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0)));
1900 
1901   // 1, 2, 4, 8, 16, 32, 64, 128,  1, 2 ..
1902   const svuint16_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
1903 
1904   return TestBit(BitCast(du, rep16), bit);
1905 }
1906 
1907 template <class D, HWY_IF_LANE_SIZE_D(D, 4)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1908 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1909                                  const uint8_t* HWY_RESTRICT bits) {
1910   const RebindToUnsigned<D> du;
1911   const Repartition<uint8_t, D> du8;
1912 
1913   // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable,
1914   // so we can skip computing the actual length (Lanes(du)+7)/8.
1915   const svuint8_t bytes = svld1(FirstN(du8, 8), bits);
1916 
1917   // Replicate bytes 32x such that each lane contains the bit that governs it.
1918   const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0)));
1919 
1920   // 1, 2, 4, 8, 16, 32, 64, 128,  1, 2 ..
1921   const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
1922 
1923   return TestBit(BitCast(du, rep32), bit);
1924 }
1925 
1926 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1927 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1928                                  const uint8_t* HWY_RESTRICT bits) {
1929   const RebindToUnsigned<D> du;
1930 
1931   // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane.
1932   // The "at least 8 byte" guarantee in quick_reference ensures this is safe.
1933   uint32_t mask_bits;
1934   CopyBytes<4>(bits, &mask_bits);
1935   const auto vbits = Set(du, mask_bits);
1936 
1937   // 2 ^ {0,1, .., 31}, will not have more lanes than that.
1938   const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0));
1939 
1940   return TestBit(vbits, bit);
1941 }
1942 
1943 // ------------------------------ StoreMaskBits
1944 
1945 namespace detail {
1946 
1947 // Returns mask ? 1 : 0 in BYTE lanes.
1948 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)>
BoolFromMask(Simd<T,N> d,svbool_t m)1949 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1950   return svdup_n_u8_z(m, 1);
1951 }
1952 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)>
BoolFromMask(Simd<T,N> d,svbool_t m)1953 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1954   const Repartition<uint8_t, decltype(d)> d8;
1955   const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1));
1956   return detail::ConcatEven(b16, b16);  // only lower half needed
1957 }
1958 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)>
BoolFromMask(Simd<T,N> d,svbool_t m)1959 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1960   return U8FromU32(svdup_n_u32_z(m, 1));
1961 }
1962 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)>
BoolFromMask(Simd<T,N> d,svbool_t m)1963 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1964   const Repartition<uint32_t, decltype(d)> d32;
1965   const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1));
1966   return U8FromU32(detail::ConcatEven(b64, b64));  // only lower half needed
1967 }
1968 
1969 }  // namespace detail
1970 
1971 // `p` points to at least 8 writable bytes.
1972 template <typename T, size_t N>
StoreMaskBits(Simd<T,N> d,svbool_t m,uint8_t * bits)1973 HWY_API size_t StoreMaskBits(Simd<T, N> d, svbool_t m, uint8_t* bits) {
1974   const Repartition<uint8_t, decltype(d)> d8;
1975   const Repartition<uint16_t, decltype(d)> d16;
1976   const Repartition<uint32_t, decltype(d)> d32;
1977   const Repartition<uint64_t, decltype(d)> d64;
1978   auto x = detail::BoolFromMask(d, m);
1979   // Compact bytes to bits. Could use SVE2 BDEP, but it's optional.
1980   x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x))));
1981   x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x))));
1982   x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x))));
1983 
1984   const size_t num_bits = Lanes(d);
1985   const size_t num_bytes = (num_bits + 8 - 1) / 8;  // Round up, see below
1986 
1987   // Truncate to 8 bits and store.
1988   svst1b_u64(FirstN(d64, num_bytes), bits, BitCast(d64, x));
1989 
1990   // Non-full byte, need to clear the undefined upper bits. Can happen for
1991   // capped/partial vectors or large T and small hardware vectors.
1992   if (num_bits < 8) {
1993     const int mask = (1 << num_bits) - 1;
1994     bits[0] = static_cast<uint8_t>(bits[0] & mask);
1995   }
1996   // Else: we wrote full bytes because num_bits is a power of two >= 8.
1997 
1998   return num_bytes;
1999 }
2000 
2001 // ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits)
2002 
2003 template <class V>
CompressBits(V v,const uint8_t * HWY_RESTRICT bits)2004 HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) {
2005   return Compress(v, LoadMaskBits(DFromV<V>(), bits));
2006 }
2007 
2008 template <class D>
CompressBitsStore(VFromD<D> v,const uint8_t * HWY_RESTRICT bits,D d,TFromD<D> * HWY_RESTRICT unaligned)2009 HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits,
2010                                  D d, TFromD<D>* HWY_RESTRICT unaligned) {
2011   return CompressStore(v, LoadMaskBits(d, bits), d, unaligned);
2012 }
2013 
2014 // ------------------------------ MulEven (InterleaveEven)
2015 
2016 #if HWY_TARGET == HWY_SVE2
2017 namespace detail {
2018 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulEven, mullb)
2019 }  // namespace detail
2020 #endif
2021 
2022 template <class V, class DW = RepartitionToWide<DFromV<V>>>
MulEven(const V a,const V b)2023 HWY_API VFromD<DW> MulEven(const V a, const V b) {
2024 #if HWY_TARGET == HWY_SVE2
2025   return BitCast(DW(), detail::MulEven(a, b));
2026 #else
2027   const auto lo = Mul(a, b);
2028   const auto hi = detail::MulHigh(a, b);
2029   return BitCast(DW(), detail::InterleaveEven(lo, hi));
2030 #endif
2031 }
2032 
MulEven(const svuint64_t a,const svuint64_t b)2033 HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) {
2034   const auto lo = Mul(a, b);
2035   const auto hi = detail::MulHigh(a, b);
2036   return detail::InterleaveEven(lo, hi);
2037 }
2038 
MulOdd(const svuint64_t a,const svuint64_t b)2039 HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) {
2040   const auto lo = Mul(a, b);
2041   const auto hi = detail::MulHigh(a, b);
2042   return detail::InterleaveOdd(lo, hi);
2043 }
2044 
2045 // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)
2046 
2047 template <size_t N>
ReorderWidenMulAccumulate(Simd<float,N> df32,svuint16_t a,svuint16_t b,const svfloat32_t sum0,svfloat32_t & sum1)2048 HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd<float, N> df32, svuint16_t a,
2049                                               svuint16_t b,
2050                                               const svfloat32_t sum0,
2051                                               svfloat32_t& sum1) {
2052   // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16.
2053   const Repartition<uint16_t, decltype(df32)> du16;
2054   const RebindToUnsigned<decltype(df32)> du32;
2055   const svuint16_t zero = Zero(du16);
2056   const svuint32_t a0 = ZipLower(du32, zero, BitCast(du16, a));
2057   const svuint32_t a1 = ZipUpper(du32, zero, BitCast(du16, a));
2058   const svuint32_t b0 = ZipLower(du32, zero, BitCast(du16, b));
2059   const svuint32_t b1 = ZipUpper(du32, zero, BitCast(du16, b));
2060   sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1);
2061   return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0);
2062 }
2063 
2064 // ------------------------------ AESRound / CLMul
2065 
2066 #if defined(__ARM_FEATURE_SVE2_AES)
2067 
2068 // Per-target flag to prevent generic_ops-inl.h from defining AESRound.
2069 #ifdef HWY_NATIVE_AES
2070 #undef HWY_NATIVE_AES
2071 #else
2072 #define HWY_NATIVE_AES
2073 #endif
2074 
AESRound(svuint8_t state,svuint8_t round_key)2075 HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) {
2076   // NOTE: it is important that AESE and AESMC be consecutive instructions so
2077   // they can be fused. AESE includes AddRoundKey, which is a different ordering
2078   // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual
2079   // round key (the compiler will hopefully optimize this for multiple rounds).
2080   const svuint8_t zero = svdup_n_u8(0);
2081   return Xor(vaesmcq_u8(vaeseq_u8(state, zero), round_key));
2082 }
2083 
CLMulLower(const svuint64_t a,const svuint64_t b)2084 HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) {
2085   return svpmullb_pair(a, b);
2086 }
2087 
CLMulUpper(const svuint64_t a,const svuint64_t b)2088 HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) {
2089   return svpmullt_pair(a, b);
2090 }
2091 
2092 #endif  // __ARM_FEATURE_SVE2_AES
2093 
2094 // ================================================== END MACROS
2095 namespace detail {  // for code folding
2096 #undef HWY_IF_FLOAT_V
2097 #undef HWY_IF_LANE_SIZE_V
2098 #undef HWY_IF_SIGNED_V
2099 #undef HWY_IF_UNSIGNED_V
2100 #undef HWY_SVE_D
2101 #undef HWY_SVE_FOREACH
2102 #undef HWY_SVE_FOREACH_F
2103 #undef HWY_SVE_FOREACH_F16
2104 #undef HWY_SVE_FOREACH_F32
2105 #undef HWY_SVE_FOREACH_F64
2106 #undef HWY_SVE_FOREACH_I
2107 #undef HWY_SVE_FOREACH_I08
2108 #undef HWY_SVE_FOREACH_I16
2109 #undef HWY_SVE_FOREACH_I32
2110 #undef HWY_SVE_FOREACH_I64
2111 #undef HWY_SVE_FOREACH_IF
2112 #undef HWY_SVE_FOREACH_U
2113 #undef HWY_SVE_FOREACH_U08
2114 #undef HWY_SVE_FOREACH_U16
2115 #undef HWY_SVE_FOREACH_U32
2116 #undef HWY_SVE_FOREACH_U64
2117 #undef HWY_SVE_FOREACH_UI
2118 #undef HWY_SVE_FOREACH_UI08
2119 #undef HWY_SVE_FOREACH_UI16
2120 #undef HWY_SVE_FOREACH_UI32
2121 #undef HWY_SVE_FOREACH_UI64
2122 #undef HWY_SVE_FOREACH_UIF3264
2123 #undef HWY_SVE_PTRUE
2124 #undef HWY_SVE_RETV_ARGD
2125 #undef HWY_SVE_RETV_ARGPV
2126 #undef HWY_SVE_RETV_ARGPVN
2127 #undef HWY_SVE_RETV_ARGPVV
2128 #undef HWY_SVE_RETV_ARGV
2129 #undef HWY_SVE_RETV_ARGVN
2130 #undef HWY_SVE_RETV_ARGVV
2131 #undef HWY_SVE_T
2132 #undef HWY_SVE_V
2133 
2134 }  // namespace detail
2135 // NOLINTNEXTLINE(google-readability-namespace-comments)
2136 }  // namespace HWY_NAMESPACE
2137 }  // namespace hwy
2138 HWY_AFTER_NAMESPACE();
2139