1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // ARM SVE[2] vectors (length not known at compile time).
16 // External include guard in highway.h - see comment there.
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <arm_sve.h>
22 
23 #include "hwy/base.h"
24 #include "hwy/ops/shared-inl.h"
25 
26 HWY_BEFORE_NAMESPACE();
27 namespace hwy {
28 namespace HWY_NAMESPACE {
29 
30 template <class V>
31 struct DFromV_t {};  // specialized in macros
32 template <class V>
33 using DFromV = typename DFromV_t<RemoveConst<V>>::type;
34 
35 template <class V>
36 using TFromV = TFromD<DFromV<V>>;
37 
38 #define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV<V>)
39 #define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV<V>)
40 #define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV<V>)
41 #define HWY_IF_LANE_SIZE_V(V, bytes) HWY_IF_LANE_SIZE(TFromV<V>, bytes)
42 
43 // ================================================== MACROS
44 
45 // Generate specializations and function definitions using X macros. Although
46 // harder to read and debug, writing everything manually is too bulky.
47 
48 namespace detail {  // for code folding
49 
50 // Unsigned:
51 #define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP)
52 #define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP)
53 #define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
54   X_MACRO(uint, u, 32, 16, NAME, OP)
55 #define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \
56   X_MACRO(uint, u, 64, 32, NAME, OP)
57 
58 // Signed:
59 #define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP)
60 #define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP)
61 #define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP)
62 #define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP)
63 
64 // Float:
65 #define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \
66   X_MACRO(float, f, 16, 16, NAME, OP)
67 #define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
68   X_MACRO(float, f, 32, 16, NAME, OP)
69 #define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \
70   X_MACRO(float, f, 64, 32, NAME, OP)
71 
72 // For all element sizes:
73 #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
74   HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP)     \
75   HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP)     \
76   HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP)     \
77   HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)
78 
79 #define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
80   HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)     \
81   HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)     \
82   HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)     \
83   HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
84 
85 #define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \
86   HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP)     \
87   HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP)     \
88   HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
89 
90 // Commonly used type categories for a given element size:
91 #define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \
92   HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP)        \
93   HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)
94 
95 #define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \
96   HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP)        \
97   HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)
98 
99 #define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
100   HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP)        \
101   HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)
102 
103 #define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
104   HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)        \
105   HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
106 
107 #define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \
108   HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP)          \
109   HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP)          \
110   HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP)           \
111   HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
112 
113 // Commonly used type categories:
114 #define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \
115   HWY_SVE_FOREACH_U(X_MACRO, NAME, OP)        \
116   HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)
117 
118 #define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \
119   HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)        \
120   HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
121 
122 #define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \
123   HWY_SVE_FOREACH_U(X_MACRO, NAME, OP)     \
124   HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)     \
125   HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
126 
127 // Assemble types for use in x-macros
128 #define HWY_SVE_T(BASE, BITS) BASE##BITS##_t
129 #define HWY_SVE_D(BASE, BITS, N, POW2) Simd<HWY_SVE_T(BASE, BITS), N, POW2>
130 #define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t
131 
132 }  // namespace detail
133 
134 #define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \
135   template <>                                            \
136   struct DFromV_t<HWY_SVE_V(BASE, BITS)> {               \
137     using type = ScalableTag<HWY_SVE_T(BASE, BITS)>;     \
138   };
139 
HWY_SVE_FOREACH(HWY_SPECIALIZE,_,_)140 HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _)
141 #undef HWY_SPECIALIZE
142 
143 // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX
144 // instructions, and we anyway only use it when the predicate is ptrue.
145 
146 // vector = f(vector), e.g. Not
147 #define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP)    \
148   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
149     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v);   \
150   }
151 #define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP)     \
152   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
153     return sv##OP##_##CHAR##BITS(v);                            \
154   }
155 
156 // vector = f(vector, scalar), e.g. detail::AddN
157 #define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP)    \
158   HWY_API HWY_SVE_V(BASE, BITS)                                  \
159       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) {   \
160     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
161   }
162 #define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP)   \
163   HWY_API HWY_SVE_V(BASE, BITS)                                \
164       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
165     return sv##OP##_##CHAR##BITS(a, b);                        \
166   }
167 
168 // vector = f(vector, vector), e.g. Add
169 #define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP)    \
170   HWY_API HWY_SVE_V(BASE, BITS)                                  \
171       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) {   \
172     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
173   }
174 #define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP)   \
175   HWY_API HWY_SVE_V(BASE, BITS)                                \
176       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
177     return sv##OP##_##CHAR##BITS(a, b);                        \
178   }
179 
180 // ------------------------------ Lanes
181 
182 namespace detail {
183 
184 // Returns actual lanes of a hardware vector without rounding to a power of two.
185 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<1> /* tag */) {
186   return svcntb_pat(SV_ALL);
187 }
188 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<2> /* tag */) {
189   return svcnth_pat(SV_ALL);
190 }
191 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<4> /* tag */) {
192   return svcntw_pat(SV_ALL);
193 }
194 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<8> /* tag */) {
195   return svcntd_pat(SV_ALL);
196 }
197 
198 // Returns actual lanes of a hardware vector, rounded down to a power of two.
199 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<1> /* tag */) {
200   return svcntb_pat(SV_POW2);
201 }
202 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<2> /* tag */) {
203   return svcnth_pat(SV_POW2);
204 }
205 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<4> /* tag */) {
206   return svcntw_pat(SV_POW2);
207 }
208 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<8> /* tag */) {
209   return svcntd_pat(SV_POW2);
210 }
211 
212 }  // namespace detail
213 
214 // Returns actual number of lanes after capping by N and shifting. May return 0
215 // (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8).
216 template <typename T, size_t N, int kPow2>
Lanes(Simd<T,N,kPow2> d)217 HWY_API size_t Lanes(Simd<T, N, kPow2> d) {
218   const size_t actual = detail::HardwareLanes(hwy::SizeTag<sizeof(T)>());
219   // Common case of full vectors: avoid any extra instructions.
220   if (detail::IsFull(d)) return actual;
221   return HWY_MIN(detail::ScaleByPower(actual, kPow2), N);
222 }
223 
224 // ================================================== MASK INIT
225 
226 // One mask bit per byte; only the one belonging to the lowest byte is valid.
227 
228 // ------------------------------ FirstN
229 #define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP)                       \
230   template <size_t N, int kPow2>                                               \
231   HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) {     \
232     const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \
233     return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(limit));  \
234   }
HWY_SVE_FOREACH(HWY_SVE_FIRSTN,FirstN,whilelt)235 HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt)
236 #undef HWY_SVE_FIRSTN
237 
238 namespace detail {
239 
240 // All-true mask from a macro
241 #define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2)
242 
243 #define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP)       \
244   template <size_t N, int kPow2>                                   \
245   HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \
246     return HWY_SVE_PTRUE(BITS);                                    \
247   }
248 
249 HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue)  // return all-true
250 #undef HWY_SVE_WRAP_PTRUE
251 
252 HWY_API svbool_t PFalse() { return svpfalse_b(); }
253 
254 // Returns all-true if d is HWY_FULL or FirstN(N) after capping N.
255 //
256 // This is used in functions that load/store memory; other functions (e.g.
257 // arithmetic) can ignore d and use PTrue instead.
258 template <class D>
259 svbool_t MakeMask(D d) {
260   return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d));
261 }
262 
263 }  // namespace detail
264 
265 // ================================================== INIT
266 
267 // ------------------------------ Set
268 // vector = f(d, scalar), e.g. Set
269 #define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP)                         \
270   template <size_t N, int kPow2>                                              \
271   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
272                                      HWY_SVE_T(BASE, BITS) arg) {             \
273     return sv##OP##_##CHAR##BITS(arg);                                        \
274   }
275 
HWY_SVE_FOREACH(HWY_SVE_SET,Set,dup_n)276 HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n)
277 #undef HWY_SVE_SET
278 
279 // Required for Zero and VFromD
280 template <size_t N, int kPow2>
281 svuint16_t Set(Simd<bfloat16_t, N, kPow2> d, bfloat16_t arg) {
282   return Set(RebindToUnsigned<decltype(d)>(), arg.bits);
283 }
284 
285 template <class D>
286 using VFromD = decltype(Set(D(), TFromD<D>()));
287 
288 // ------------------------------ Zero
289 
290 template <class D>
Zero(D d)291 VFromD<D> Zero(D d) {
292   return Set(d, 0);
293 }
294 
295 // ------------------------------ Undefined
296 
297 #define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \
298   template <size_t N, int kPow2>                            \
299   HWY_API HWY_SVE_V(BASE, BITS)                             \
300       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) {       \
301     return sv##OP##_##CHAR##BITS();                         \
302   }
303 
HWY_SVE_FOREACH(HWY_SVE_UNDEFINED,Undefined,undef)304 HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef)
305 
306 // ------------------------------ BitCast
307 
308 namespace detail {
309 
310 // u8: no change
311 #define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP)                \
312   HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) {  \
313     return v;                                                             \
314   }                                                                       \
315   template <size_t N, int kPow2>                                          \
316   HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte(                          \
317       HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \
318     return v;                                                             \
319   }
320 
321 // All other types
322 #define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP)                        \
323   HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) {               \
324     return sv##OP##_u8_##CHAR##BITS(v);                                       \
325   }                                                                           \
326   template <size_t N, int kPow2>                                              \
327   HWY_INLINE HWY_SVE_V(BASE, BITS)                                            \
328       BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \
329     return sv##OP##_##CHAR##BITS##_u8(v);                                     \
330   }
331 
332 HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _)
333 HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret)
334 HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret)
335 HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret)
336 HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret)
337 HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret)
338 
339 #undef HWY_SVE_CAST_NOP
340 #undef HWY_SVE_CAST
341 
342 template <size_t N, int kPow2>
343 HWY_INLINE svuint16_t BitCastFromByte(Simd<bfloat16_t, N, kPow2> /* d */,
344                                       svuint8_t v) {
345   return BitCastFromByte(Simd<uint16_t, N, kPow2>(), v);
346 }
347 
348 }  // namespace detail
349 
350 template <class D, class FromV>
BitCast(D d,FromV v)351 HWY_API VFromD<D> BitCast(D d, FromV v) {
352   return detail::BitCastFromByte(d, detail::BitCastToByte(v));
353 }
354 
355 // ================================================== LOGICAL
356 
357 // detail::*N() functions accept a scalar argument to avoid extra Set().
358 
359 // ------------------------------ Not
360 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV,Not,not)361 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not )
362 
363 // ------------------------------ And
364 
365 namespace detail {
366 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n)
367 }  // namespace detail
368 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,And,and)369 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and)
370 
371 template <class V, HWY_IF_FLOAT_V(V)>
372 HWY_API V And(const V a, const V b) {
373   const DFromV<V> df;
374   const RebindToUnsigned<decltype(df)> du;
375   return BitCast(df, And(BitCast(du, a), BitCast(du, b)));
376 }
377 
378 // ------------------------------ Or
379 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,Or,orr)380 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr)
381 
382 template <class V, HWY_IF_FLOAT_V(V)>
383 HWY_API V Or(const V a, const V b) {
384   const DFromV<V> df;
385   const RebindToUnsigned<decltype(df)> du;
386   return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
387 }
388 
389 // ------------------------------ Xor
390 
391 namespace detail {
392 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n)
393 }  // namespace detail
394 
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,Xor,eor)395 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor)
396 
397 template <class V, HWY_IF_FLOAT_V(V)>
398 HWY_API V Xor(const V a, const V b) {
399   const DFromV<V> df;
400   const RebindToUnsigned<decltype(df)> du;
401   return BitCast(df, Xor(BitCast(du, a), BitCast(du, b)));
402 }
403 
404 // ------------------------------ AndNot
405 
406 namespace detail {
407 #define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \
408   HWY_API HWY_SVE_V(BASE, BITS)                                    \
409       NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) {     \
410     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a);   \
411   }
412 
413 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n)
414 #undef HWY_SVE_RETV_ARGPVN_SWAP
415 }  // namespace detail
416 
417 #define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \
418   HWY_API HWY_SVE_V(BASE, BITS)                                    \
419       NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) {     \
420     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a);   \
421   }
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP,AndNot,bic)422 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic)
423 #undef HWY_SVE_RETV_ARGPVV_SWAP
424 
425 template <class V, HWY_IF_FLOAT_V(V)>
426 HWY_API V AndNot(const V a, const V b) {
427   const DFromV<V> df;
428   const RebindToUnsigned<decltype(df)> du;
429   return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b)));
430 }
431 
432 // ------------------------------ OrAnd
433 
434 template <class V>
OrAnd(const V o,const V a1,const V a2)435 HWY_API V OrAnd(const V o, const V a1, const V a2) {
436   return Or(o, And(a1, a2));
437 }
438 
439 // ------------------------------ PopulationCount
440 
441 #ifdef HWY_NATIVE_POPCNT
442 #undef HWY_NATIVE_POPCNT
443 #else
444 #define HWY_NATIVE_POPCNT
445 #endif
446 
447 // Need to return original type instead of unsigned.
448 #define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP)               \
449   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) {        \
450     return BitCast(DFromV<decltype(v)>(),                              \
451                    sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \
452   }
HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT,PopulationCount,cnt)453 HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt)
454 #undef HWY_SVE_POPCNT
455 
456 // ================================================== SIGN
457 
458 // ------------------------------ Neg
459 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg)
460 
461 // ------------------------------ Abs
462 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
463 
464 // ------------------------------ CopySign[ToAbs]
465 
466 template <class V>
467 HWY_API V CopySign(const V magn, const V sign) {
468   const auto msb = SignBit(DFromV<V>());
469   return Or(AndNot(msb, magn), And(msb, sign));
470 }
471 
472 template <class V>
CopySignToAbs(const V abs,const V sign)473 HWY_API V CopySignToAbs(const V abs, const V sign) {
474   const auto msb = SignBit(DFromV<V>());
475   return Or(abs, And(msb, sign));
476 }
477 
478 // ================================================== ARITHMETIC
479 
480 // ------------------------------ Add
481 
482 namespace detail {
483 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n)
484 }  // namespace detail
485 
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV,Add,add)486 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add)
487 
488 // ------------------------------ Sub
489 
490 namespace detail {
491 // Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg.
492 #define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP)          \
493   HWY_API HWY_SVE_V(BASE, BITS)                                             \
494       NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
495     return sv##OP##_##CHAR##BITS##_z(pg, a, b);                             \
496   }
497 
498 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n)
499 #undef HWY_SVE_RETV_ARGPVN_MASK
500 }  // namespace detail
501 
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV,Sub,sub)502 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub)
503 
504 // ------------------------------ SumsOf8
505 HWY_API svuint64_t SumsOf8(const svuint8_t v) {
506   const ScalableTag<uint32_t> du32;
507   const ScalableTag<uint64_t> du64;
508   const svbool_t pg = detail::PTrue(du64);
509 
510   const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1);
511   // Compute pairwise sum of u32 and extend to u64.
512   // TODO(janwas): on SVE2, we can instead use svaddp.
513   const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32);
514   // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended)
515   const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4));
516   return Add(hi, lo);
517 }
518 
519 // ------------------------------ SaturatedAdd
520 
HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV,SaturatedAdd,qadd)521 HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
522 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
523 
524 // ------------------------------ SaturatedSub
525 
526 HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
527 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
528 
529 // ------------------------------ AbsDiff
530 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd)
531 
532 // ------------------------------ ShiftLeft[Same]
533 
534 #define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP)               \
535   template <int kBits>                                                  \
536   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) {         \
537     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits);    \
538   }                                                                     \
539   HWY_API HWY_SVE_V(BASE, BITS)                                         \
540       NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \
541     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits);     \
542   }
543 
544 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n)
545 
546 // ------------------------------ ShiftRight[Same]
547 
548 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n)
549 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)
550 
551 #undef HWY_SVE_SHIFT_N
552 
553 // ------------------------------ RotateRight
554 
555 // TODO(janwas): svxar on SVE2
556 template <int kBits, class V>
557 HWY_API V RotateRight(const V v) {
558   constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8;
559   static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count");
560   if (kBits == 0) return v;
561   return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v));
562 }
563 
564 // ------------------------------ Shl/r
565 
566 #define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP)           \
567   HWY_API HWY_SVE_V(BASE, BITS)                                   \
568       NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \
569     const RebindToUnsigned<DFromV<decltype(v)>> du;               \
570     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v,      \
571                                      BitCast(du, bits));          \
572   }
573 
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT,Shl,lsl)574 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl)
575 
576 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr)
577 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr)
578 
579 #undef HWY_SVE_SHIFT
580 
581 // ------------------------------ Min/Max
582 
583 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min)
584 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max)
585 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm)
586 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm)
587 
588 namespace detail {
589 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n)
590 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n)
591 }  // namespace detail
592 
593 // ------------------------------ Mul
HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV,Mul,mul)594 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, Mul, mul)
595 HWY_SVE_FOREACH_UIF3264(HWY_SVE_RETV_ARGPVV, Mul, mul)
596 
597 // ------------------------------ MulHigh
598 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
599 namespace detail {
600 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
601 HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
602 }  // namespace detail
603 
604 // ------------------------------ Div
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV,Div,div)605 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div)
606 
607 // ------------------------------ ApproximateReciprocal
608 HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
609 
610 // ------------------------------ Sqrt
611 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)
612 
613 // ------------------------------ ApproximateReciprocalSqrt
614 HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte)
615 
616 // ------------------------------ MulAdd
617 #define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP)                   \
618   HWY_API HWY_SVE_V(BASE, BITS)                                         \
619       NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x,          \
620            HWY_SVE_V(BASE, BITS) add) {                                 \
621     return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \
622   }
623 
624 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulAdd, mad)
625 
626 // ------------------------------ NegMulAdd
627 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulAdd, msb)
628 
629 // ------------------------------ MulSub
630 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb)
631 
632 // ------------------------------ NegMulSub
633 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)
634 
635 #undef HWY_SVE_FMA
636 
637 // ------------------------------ Round etc.
638 
639 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
640 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm)
641 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp)
642 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz)
643 
644 // ================================================== MASK
645 
646 // ------------------------------ RebindMask
647 template <class D, typename MFrom>
648 HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) {
649   return mask;
650 }
651 
652 // ------------------------------ Mask logical
653 
Not(svbool_t m)654 HWY_API svbool_t Not(svbool_t m) {
655   // We don't know the lane type, so assume 8-bit. For larger types, this will
656   // de-canonicalize the predicate, i.e. set bits to 1 even though they do not
657   // correspond to the lowest byte in the lane. Per ARM, such bits are ignored.
658   return svnot_b_z(HWY_SVE_PTRUE(8), m);
659 }
And(svbool_t a,svbool_t b)660 HWY_API svbool_t And(svbool_t a, svbool_t b) {
661   return svand_b_z(b, b, a);  // same order as AndNot for consistency
662 }
AndNot(svbool_t a,svbool_t b)663 HWY_API svbool_t AndNot(svbool_t a, svbool_t b) {
664   return svbic_b_z(b, b, a);  // reversed order like NEON
665 }
Or(svbool_t a,svbool_t b)666 HWY_API svbool_t Or(svbool_t a, svbool_t b) {
667   return svsel_b(a, a, b);  // a ? true : b
668 }
Xor(svbool_t a,svbool_t b)669 HWY_API svbool_t Xor(svbool_t a, svbool_t b) {
670   return svsel_b(a, svnand_b_z(a, a, b), b);  // a ? !(a & b) : b.
671 }
672 
673 // ------------------------------ CountTrue
674 
675 #define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP)           \
676   template <size_t N, int kPow2>                                       \
677   HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \
678     return sv##OP##_b##BITS(detail::MakeMask(d), m);                   \
679   }
680 
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE,CountTrue,cntp)681 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp)
682 #undef HWY_SVE_COUNT_TRUE
683 
684 // For 16-bit Compress: full vector, not limited to SV_POW2.
685 namespace detail {
686 
687 #define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP)            \
688   template <size_t N, int kPow2>                                             \
689   HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \
690     return sv##OP##_b##BITS(svptrue_b##BITS(), m);                           \
691   }
692 
693 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp)
694 #undef HWY_SVE_COUNT_TRUE_FULL
695 
696 }  // namespace detail
697 
698 // ------------------------------ AllFalse
699 template <class D>
AllFalse(D d,svbool_t m)700 HWY_API bool AllFalse(D d, svbool_t m) {
701   return !svptest_any(detail::MakeMask(d), m);
702 }
703 
704 // ------------------------------ AllTrue
705 template <class D>
AllTrue(D d,svbool_t m)706 HWY_API bool AllTrue(D d, svbool_t m) {
707   return CountTrue(d, m) == Lanes(d);
708 }
709 
710 // ------------------------------ FindFirstTrue
711 template <class D>
FindFirstTrue(D d,svbool_t m)712 HWY_API intptr_t FindFirstTrue(D d, svbool_t m) {
713   return AllFalse(d, m) ? intptr_t{-1}
714                         : static_cast<intptr_t>(
715                               CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)));
716 }
717 
718 // ------------------------------ IfThenElse
719 #define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP)                \
720   HWY_API HWY_SVE_V(BASE, BITS)                                               \
721       NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \
722     return sv##OP##_##CHAR##BITS(m, yes, no);                                 \
723   }
724 
HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE,IfThenElse,sel)725 HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel)
726 #undef HWY_SVE_IF_THEN_ELSE
727 
728 // ------------------------------ IfThenElseZero
729 template <class M, class V>
730 HWY_API V IfThenElseZero(const M mask, const V yes) {
731   return IfThenElse(mask, yes, Zero(DFromV<V>()));
732 }
733 
734 // ------------------------------ IfThenZeroElse
735 template <class M, class V>
IfThenZeroElse(const M mask,const V no)736 HWY_API V IfThenZeroElse(const M mask, const V no) {
737   return IfThenElse(mask, Zero(DFromV<V>()), no);
738 }
739 
740 // ================================================== COMPARE
741 
742 // mask = f(vector, vector)
743 #define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP)                   \
744   HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
745     return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b);                \
746   }
747 #define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP)                 \
748   HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
749     return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b);                \
750   }
751 
752 // ------------------------------ Eq
HWY_SVE_FOREACH(HWY_SVE_COMPARE,Eq,cmpeq)753 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq)
754 namespace detail {
755 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n)
756 }  // namespace detail
757 
758 // ------------------------------ Ne
HWY_SVE_FOREACH(HWY_SVE_COMPARE,Ne,cmpne)759 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne)
760 namespace detail {
761 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n)
762 }  // namespace detail
763 
764 // ------------------------------ Lt
HWY_SVE_FOREACH(HWY_SVE_COMPARE,Lt,cmplt)765 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt)
766 namespace detail {
767 HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n)
768 }  // namespace detail
769 
770 // ------------------------------ Le
HWY_SVE_FOREACH_F(HWY_SVE_COMPARE,Le,cmple)771 HWY_SVE_FOREACH_F(HWY_SVE_COMPARE, Le, cmple)
772 
773 #undef HWY_SVE_COMPARE
774 #undef HWY_SVE_COMPARE_N
775 
776 // ------------------------------ Gt/Ge (swapped order)
777 
778 template <class V>
779 HWY_API svbool_t Gt(const V a, const V b) {
780   return Lt(b, a);
781 }
782 template <class V>
Ge(const V a,const V b)783 HWY_API svbool_t Ge(const V a, const V b) {
784   return Le(b, a);
785 }
786 
787 // ------------------------------ TestBit
788 template <class V>
TestBit(const V a,const V bit)789 HWY_API svbool_t TestBit(const V a, const V bit) {
790   return detail::NeN(And(a, bit), 0);
791 }
792 
793 // ------------------------------ MaskFromVec (Ne)
794 template <class V>
MaskFromVec(const V v)795 HWY_API svbool_t MaskFromVec(const V v) {
796   return detail::NeN(v, static_cast<TFromV<V>>(0));
797 }
798 
799 // ------------------------------ VecFromMask
800 
801 template <class D, HWY_IF_NOT_FLOAT_D(D)>
VecFromMask(const D d,svbool_t mask)802 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
803   const auto v0 = Zero(RebindToSigned<decltype(d)>());
804   return BitCast(d, detail::SubN(mask, v0, 1));
805 }
806 
807 template <class D, HWY_IF_FLOAT_D(D)>
VecFromMask(const D d,svbool_t mask)808 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
809   return BitCast(d, VecFromMask(RebindToUnsigned<D>(), mask));
810 }
811 
812 // ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse)
813 
814 template <class V>
IfVecThenElse(const V mask,const V yes,const V no)815 HWY_API V IfVecThenElse(const V mask, const V yes, const V no) {
816   // TODO(janwas): use svbsl for SVE2
817   return IfThenElse(MaskFromVec(mask), yes, no);
818 }
819 
820 // ================================================== MEMORY
821 
822 // ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream
823 
824 #define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP)     \
825   template <size_t N, int kPow2>                           \
826   HWY_API HWY_SVE_V(BASE, BITS)                            \
827       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d,              \
828            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
829     return sv##OP##_##CHAR##BITS(detail::MakeMask(d), p);  \
830   }
831 
832 #define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP)   \
833   template <size_t N, int kPow2>                                \
834   HWY_API HWY_SVE_V(BASE, BITS)                                 \
835       NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
836            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) {      \
837     return sv##OP##_##CHAR##BITS(m, p);                         \
838   }
839 
840 #define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \
841   template <size_t N, int kPow2>                              \
842   HWY_API HWY_SVE_V(BASE, BITS)                               \
843       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */,           \
844            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) {    \
845     /* All-true predicate to load all 128 bits. */            \
846     return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p);        \
847   }
848 
849 #define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP)       \
850   template <size_t N, int kPow2>                              \
851   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v,                  \
852                     HWY_SVE_D(BASE, BITS, N, kPow2) d,        \
853                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
854     sv##OP##_##CHAR##BITS(detail::MakeMask(d), p, v);         \
855   }
856 
857 #define HWY_SVE_MASKED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \
858   template <size_t N, int kPow2>                               \
859   HWY_API void NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v,       \
860                     HWY_SVE_D(BASE, BITS, N, kPow2) /* d */,   \
861                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) {  \
862     sv##OP##_##CHAR##BITS(m, p, v);                            \
863   }
864 
HWY_SVE_FOREACH(HWY_SVE_LOAD,Load,ld1)865 HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1)
866 HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1)
867 HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDup128, ld1rq)
868 HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1)
869 HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1)
870 HWY_SVE_FOREACH(HWY_SVE_MASKED_STORE, MaskedStore, st1)
871 
872 #undef HWY_SVE_LOAD
873 #undef HWY_SVE_MASKED_LOAD
874 #undef HWY_SVE_LOAD_DUP128
875 #undef HWY_SVE_STORE
876 #undef HWY_SVE_MASKED_STORE
877 
878 // BF16 is the same as svuint16_t because BF16 is optional before v8.6.
879 template <size_t N, int kPow2>
880 HWY_API svuint16_t Load(Simd<bfloat16_t, N, kPow2> d,
881                         const bfloat16_t* HWY_RESTRICT p) {
882   return Load(RebindToUnsigned<decltype(d)>(),
883               reinterpret_cast<const uint16_t * HWY_RESTRICT>(p));
884 }
885 
886 template <size_t N, int kPow2>
Store(svuint16_t v,Simd<bfloat16_t,N,kPow2> d,bfloat16_t * HWY_RESTRICT p)887 HWY_API void Store(svuint16_t v, Simd<bfloat16_t, N, kPow2> d,
888                    bfloat16_t* HWY_RESTRICT p) {
889   Store(v, RebindToUnsigned<decltype(d)>(),
890         reinterpret_cast<uint16_t * HWY_RESTRICT>(p));
891 }
892 
893 // ------------------------------ Load/StoreU
894 
895 // SVE only requires lane alignment, not natural alignment of the entire
896 // vector.
897 template <class D>
LoadU(D d,const TFromD<D> * HWY_RESTRICT p)898 HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) {
899   return Load(d, p);
900 }
901 
902 template <class V, class D>
StoreU(const V v,D d,TFromD<D> * HWY_RESTRICT p)903 HWY_API void StoreU(const V v, D d, TFromD<D>* HWY_RESTRICT p) {
904   Store(v, d, p);
905 }
906 
907 // ------------------------------ ScatterOffset/Index
908 
909 #define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP)             \
910   template <size_t N, int kPow2>                                             \
911   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v,                                 \
912                     HWY_SVE_D(BASE, BITS, N, kPow2) d,                       \
913                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,               \
914                     HWY_SVE_V(int, BITS) offset) {                           \
915     sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \
916                                           v);                                \
917   }
918 
919 #define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP)                \
920   template <size_t N, int kPow2>                                               \
921   HWY_API void NAME(                                                           \
922       HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N, kPow2) d,              \
923       HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, HWY_SVE_V(int, BITS) index) { \
924     sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, index, v); \
925   }
926 
HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET,ScatterOffset,st1_scatter)927 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter)
928 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter)
929 #undef HWY_SVE_SCATTER_OFFSET
930 #undef HWY_SVE_SCATTER_INDEX
931 
932 // ------------------------------ GatherOffset/Index
933 
934 #define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP)             \
935   template <size_t N, int kPow2>                                            \
936   HWY_API HWY_SVE_V(BASE, BITS)                                             \
937       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d,                               \
938            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,                 \
939            HWY_SVE_V(int, BITS) offset) {                                   \
940     return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \
941                                                  offset);                   \
942   }
943 #define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP)             \
944   template <size_t N, int kPow2>                                           \
945   HWY_API HWY_SVE_V(BASE, BITS)                                            \
946       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d,                              \
947            const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base,                \
948            HWY_SVE_V(int, BITS) index) {                                   \
949     return sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, \
950                                                 index);                    \
951   }
952 
953 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather)
954 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather)
955 #undef HWY_SVE_GATHER_OFFSET
956 #undef HWY_SVE_GATHER_INDEX
957 
958 // ------------------------------ StoreInterleaved3
959 
960 #define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP)                      \
961   template <size_t N, int kPow2>                                              \
962   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1,       \
963                     HWY_SVE_V(BASE, BITS) v2,                                 \
964                     HWY_SVE_D(BASE, BITS, N, kPow2) d,                        \
965                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) {         \
966     const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \
967     sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, triple);            \
968   }
969 HWY_SVE_FOREACH_U08(HWY_SVE_STORE3, StoreInterleaved3, st3)
970 
971 #undef HWY_SVE_STORE3
972 
973 // ------------------------------ StoreInterleaved4
974 
975 #define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP)                \
976   template <size_t N, int kPow2>                                        \
977   HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
978                     HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \
979                     HWY_SVE_D(BASE, BITS, N, kPow2) d,                  \
980                     HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) {   \
981     const sv##BASE##BITS##x4_t quad =                                   \
982         svcreate4##_##CHAR##BITS(v0, v1, v2, v3);                       \
983     sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, quad);        \
984   }
985 HWY_SVE_FOREACH_U08(HWY_SVE_STORE4, StoreInterleaved4, st4)
986 
987 #undef HWY_SVE_STORE4
988 
989 // ================================================== CONVERT
990 
991 // ------------------------------ PromoteTo
992 
993 // Same sign
994 #define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP)                \
995   template <size_t N, int kPow2>                                            \
996   HWY_API HWY_SVE_V(BASE, BITS) NAME(                                       \
997       HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \
998     return sv##OP##_##CHAR##BITS(v);                                        \
999   }
1000 
1001 HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
1002 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
1003 HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
1004 
1005 // 2x
1006 template <size_t N, int kPow2>
1007 HWY_API svuint32_t PromoteTo(Simd<uint32_t, N, kPow2> dto, svuint8_t vfrom) {
1008   const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
1009   return PromoteTo(dto, PromoteTo(d2, vfrom));
1010 }
1011 template <size_t N, int kPow2>
PromoteTo(Simd<int32_t,N,kPow2> dto,svint8_t vfrom)1012 HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svint8_t vfrom) {
1013   const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
1014   return PromoteTo(dto, PromoteTo(d2, vfrom));
1015 }
1016 template <size_t N, int kPow2>
U32FromU8(svuint8_t v)1017 HWY_API svuint32_t U32FromU8(svuint8_t v) {
1018   return PromoteTo(Simd<uint32_t, N, kPow2>(), v);
1019 }
1020 
1021 // Sign change
1022 template <size_t N, int kPow2>
PromoteTo(Simd<int16_t,N,kPow2> dto,svuint8_t vfrom)1023 HWY_API svint16_t PromoteTo(Simd<int16_t, N, kPow2> dto, svuint8_t vfrom) {
1024   const RebindToUnsigned<decltype(dto)> du;
1025   return BitCast(dto, PromoteTo(du, vfrom));
1026 }
1027 template <size_t N, int kPow2>
PromoteTo(Simd<int32_t,N,kPow2> dto,svuint16_t vfrom)1028 HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svuint16_t vfrom) {
1029   const RebindToUnsigned<decltype(dto)> du;
1030   return BitCast(dto, PromoteTo(du, vfrom));
1031 }
1032 template <size_t N, int kPow2>
PromoteTo(Simd<int32_t,N,kPow2> dto,svuint8_t vfrom)1033 HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svuint8_t vfrom) {
1034   const Repartition<uint16_t, DFromV<decltype(vfrom)>> du16;
1035   const Repartition<int16_t, decltype(du16)> di16;
1036   return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom)));
1037 }
1038 
1039 // ------------------------------ PromoteTo F
1040 
1041 // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so
1042 // first replicate each lane once.
1043 namespace detail {
1044 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLower, zip1)
1045 // Do not use zip2 to implement PromoteUpperTo or similar because vectors may be
1046 // non-powers of two, so getting the actual "upper half" requires MaskUpperHalf.
1047 }  // namespace detail
1048 
1049 template <size_t N, int kPow2>
PromoteTo(Simd<float32_t,N,kPow2>,const svfloat16_t v)1050 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> /* d */,
1051                               const svfloat16_t v) {
1052   const svfloat16_t vv = detail::ZipLower(v, v);
1053   return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), vv);
1054 }
1055 
1056 template <size_t N, int kPow2>
PromoteTo(Simd<float64_t,N,kPow2>,const svfloat32_t v)1057 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */,
1058                               const svfloat32_t v) {
1059   const svfloat32_t vv = detail::ZipLower(v, v);
1060   return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N, kPow2>()), vv);
1061 }
1062 
1063 template <size_t N, int kPow2>
PromoteTo(Simd<float64_t,N,kPow2>,const svint32_t v)1064 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */,
1065                               const svint32_t v) {
1066   const svint32_t vv = detail::ZipLower(v, v);
1067   return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N, kPow2>()), vv);
1068 }
1069 
1070 // For 16-bit Compress
1071 namespace detail {
HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO,PromoteUpperTo,unpkhi)1072 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi)
1073 #undef HWY_SVE_PROMOTE_TO
1074 
1075 template <size_t N, int kPow2>
1076 HWY_API svfloat32_t PromoteUpperTo(Simd<float, N, kPow2> df, svfloat16_t v) {
1077   const RebindToUnsigned<decltype(df)> du;
1078   const RepartitionToNarrow<decltype(du)> dn;
1079   return BitCast(df, PromoteUpperTo(du, BitCast(dn, v)));
1080 }
1081 
1082 }  // namespace detail
1083 
1084 // ------------------------------ DemoteTo U
1085 
1086 namespace detail {
1087 
1088 // Saturates unsigned vectors to half/quarter-width TN.
1089 template <typename TN, class VU>
SaturateU(VU v)1090 VU SaturateU(VU v) {
1091   return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>()));
1092 }
1093 
1094 // Saturates unsigned vectors to half/quarter-width TN.
1095 template <typename TN, class VI>
SaturateI(VI v)1096 VI SaturateI(VI v) {
1097   return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>());
1098 }
1099 
1100 }  // namespace detail
1101 
1102 template <size_t N, int kPow2>
DemoteTo(Simd<uint8_t,N,kPow2> dn,const svint16_t v)1103 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint16_t v) {
1104   const DFromV<decltype(v)> di;
1105   const RebindToUnsigned<decltype(di)> du;
1106   using TN = TFromD<decltype(dn)>;
1107   // First clamp negative numbers to zero and cast to unsigned.
1108   const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0));
1109   // Saturate to unsigned-max and halve the width.
1110   const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
1111   return svuzp1_u8(vn, vn);
1112 }
1113 
1114 template <size_t N, int kPow2>
DemoteTo(Simd<uint16_t,N,kPow2> dn,const svint32_t v)1115 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint32_t v) {
1116   const DFromV<decltype(v)> di;
1117   const RebindToUnsigned<decltype(di)> du;
1118   using TN = TFromD<decltype(dn)>;
1119   // First clamp negative numbers to zero and cast to unsigned.
1120   const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0));
1121   // Saturate to unsigned-max and halve the width.
1122   const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
1123   return svuzp1_u16(vn, vn);
1124 }
1125 
1126 template <size_t N, int kPow2>
DemoteTo(Simd<uint8_t,N,kPow2> dn,const svint32_t v)1127 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint32_t v) {
1128   const DFromV<decltype(v)> di;
1129   const RebindToUnsigned<decltype(di)> du;
1130   const RepartitionToNarrow<decltype(du)> d2;
1131   using TN = TFromD<decltype(dn)>;
1132   // First clamp negative numbers to zero and cast to unsigned.
1133   const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0));
1134   // Saturate to unsigned-max and quarter the width.
1135   const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped));
1136   const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16));
1137   return svuzp1_u8(x2, x2);
1138 }
1139 
U8FromU32(const svuint32_t v)1140 HWY_API svuint8_t U8FromU32(const svuint32_t v) {
1141   const DFromV<svuint32_t> du32;
1142   const RepartitionToNarrow<decltype(du32)> du16;
1143   const RepartitionToNarrow<decltype(du16)> du8;
1144 
1145   const svuint16_t cast16 = BitCast(du16, v);
1146   const svuint16_t x2 = svuzp1_u16(cast16, cast16);
1147   const svuint8_t cast8 = BitCast(du8, x2);
1148   return svuzp1_u8(cast8, cast8);
1149 }
1150 
1151 // ------------------------------ DemoteTo I
1152 
1153 template <size_t N, int kPow2>
DemoteTo(Simd<int8_t,N,kPow2> dn,const svint16_t v)1154 HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint16_t v) {
1155 #if HWY_TARGET == HWY_SVE2
1156   const svint8_t vn = BitCast(dn, svqxtnb_s16(v));
1157 #else
1158   using TN = TFromD<decltype(dn)>;
1159   const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v));
1160 #endif
1161   return svuzp1_s8(vn, vn);
1162 }
1163 
1164 template <size_t N, int kPow2>
DemoteTo(Simd<int16_t,N,kPow2> dn,const svint32_t v)1165 HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint32_t v) {
1166 #if HWY_TARGET == HWY_SVE2
1167   const svint16_t vn = BitCast(dn, svqxtnb_s32(v));
1168 #else
1169   using TN = TFromD<decltype(dn)>;
1170   const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v));
1171 #endif
1172   return svuzp1_s16(vn, vn);
1173 }
1174 
1175 template <size_t N, int kPow2>
DemoteTo(Simd<int8_t,N,kPow2> dn,const svint32_t v)1176 HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint32_t v) {
1177   const RepartitionToWide<decltype(dn)> d2;
1178 #if HWY_TARGET == HWY_SVE2
1179   const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v)));
1180 #else
1181   using TN = TFromD<decltype(dn)>;
1182   const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v));
1183 #endif
1184   const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16));
1185   return BitCast(dn, svuzp1_s8(v2, v2));
1186 }
1187 
1188 // ------------------------------ ConcatEven/ConcatOdd
1189 
1190 // WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the
1191 // full vector length, not rounded down to a power of two as we require).
1192 namespace detail {
1193 
1194 #define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \
1195   HWY_INLINE HWY_SVE_V(BASE, BITS)                                    \
1196       NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) {      \
1197     return sv##OP##_##CHAR##BITS(lo, hi);                             \
1198   }
1199 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEven, uzp1)
1200 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOdd, uzp2)
1201 #undef HWY_SVE_CONCAT_EVERY_SECOND
1202 
1203 // Used to slide up / shift whole register left; mask indicates which range
1204 // to take from lo, and the rest is filled from hi starting at its lowest.
1205 #define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP)                   \
1206   HWY_API HWY_SVE_V(BASE, BITS) NAME(                                      \
1207       HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \
1208     return sv##OP##_##CHAR##BITS(mask, lo, hi);                            \
1209   }
1210 HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice)
1211 #undef HWY_SVE_SPLICE
1212 
1213 }  // namespace detail
1214 
1215 template <class D>
ConcatOdd(D d,VFromD<D> hi,VFromD<D> lo)1216 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
1217 #if 0  // if we could assume VL is a power of two
1218   return detail::ConcatOdd(hi, lo);
1219 #else
1220   const VFromD<D> hi_odd = detail::ConcatOdd(hi, hi);
1221   const VFromD<D> lo_odd = detail::ConcatOdd(lo, lo);
1222   return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
1223 #endif
1224 }
1225 
1226 template <class D>
ConcatEven(D d,VFromD<D> hi,VFromD<D> lo)1227 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
1228 #if 0  // if we could assume VL is a power of two
1229   return detail::ConcatEven(hi, lo);
1230 #else
1231   const VFromD<D> hi_odd = detail::ConcatEven(hi, hi);
1232   const VFromD<D> lo_odd = detail::ConcatEven(lo, lo);
1233   return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
1234 #endif
1235 }
1236 
1237 // ------------------------------ DemoteTo F
1238 
1239 template <size_t N, int kPow2>
DemoteTo(Simd<float16_t,N,kPow2> d,const svfloat32_t v)1240 HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat32_t v) {
1241   const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v);
1242   return detail::ConcatEven(in_even, in_even);  // only low 1/2 of result valid
1243 }
1244 
1245 template <size_t N, int kPow2>
DemoteTo(Simd<bfloat16_t,N,kPow2>,svfloat32_t v)1246 HWY_API svuint16_t DemoteTo(Simd<bfloat16_t, N, kPow2> /* d */, svfloat32_t v) {
1247   const svuint16_t in_even = BitCast(ScalableTag<uint16_t>(), v);
1248   return detail::ConcatOdd(in_even, in_even);  // can ignore upper half of vec
1249 }
1250 
1251 template <size_t N, int kPow2>
DemoteTo(Simd<float32_t,N,kPow2> d,const svfloat64_t v)1252 HWY_API svfloat32_t DemoteTo(Simd<float32_t, N, kPow2> d, const svfloat64_t v) {
1253   const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v);
1254   return detail::ConcatEven(in_even, in_even);  // only low 1/2 of result valid
1255 }
1256 
1257 template <size_t N, int kPow2>
DemoteTo(Simd<int32_t,N,kPow2> d,const svfloat64_t v)1258 HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> d, const svfloat64_t v) {
1259   const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v);
1260   return detail::ConcatEven(in_even, in_even);  // only low 1/2 of result valid
1261 }
1262 
1263 // ------------------------------ ConvertTo F
1264 
1265 #define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP)                     \
1266   template <size_t N, int kPow2>                                              \
1267   HWY_API HWY_SVE_V(BASE, BITS)                                               \
1268       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \
1269     return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v);       \
1270   }                                                                           \
1271   /* Truncates (rounds toward zero). */                                       \
1272   template <size_t N, int kPow2>                                              \
1273   HWY_API HWY_SVE_V(int, BITS)                                                \
1274       NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \
1275     return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v);       \
1276   }
1277 
1278 // API only requires f32 but we provide f64 for use by Iota.
HWY_SVE_FOREACH_F(HWY_SVE_CONVERT,ConvertTo,cvt)1279 HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt)
1280 #undef HWY_SVE_CONVERT
1281 
1282 // ------------------------------ NearestInt (Round, ConvertTo)
1283 
1284 template <class VF, class DI = RebindToSigned<DFromV<VF>>>
1285 HWY_API VFromD<DI> NearestInt(VF v) {
1286   // No single instruction, round then truncate.
1287   return ConvertTo(DI(), Round(v));
1288 }
1289 
1290 // ------------------------------ Iota (Add, ConvertTo)
1291 
1292 #define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP)                        \
1293   template <size_t N, int kPow2>                                              \
1294   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
1295                                      HWY_SVE_T(BASE, BITS) first) {           \
1296     return sv##OP##_##CHAR##BITS(first, 1);                                   \
1297   }
1298 
HWY_SVE_FOREACH_UI(HWY_SVE_IOTA,Iota,index)1299 HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index)
1300 #undef HWY_SVE_IOTA
1301 
1302 template <class D, HWY_IF_FLOAT_D(D)>
1303 HWY_API VFromD<D> Iota(const D d, TFromD<D> first) {
1304   const RebindToSigned<D> di;
1305   return detail::AddN(ConvertTo(d, Iota(di, 0)), first);
1306 }
1307 
1308 // ================================================== COMBINE
1309 
1310 namespace detail {
1311 
1312 template <class D>
MaskLowerHalf(D d)1313 svbool_t MaskLowerHalf(D d) {
1314   return FirstN(d, Lanes(d) / 2);
1315 }
1316 template <class D>
MaskUpperHalf(D d)1317 svbool_t MaskUpperHalf(D d) {
1318   // For Splice to work as intended, make sure bits above Lanes(d) are zero.
1319   return AndNot(MaskLowerHalf(d), detail::MakeMask(d));
1320 }
1321 
1322 // Right-shift vector pair by constexpr; can be used to slide down (=N) or up
1323 // (=Lanes()-N).
1324 #define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP)            \
1325   template <size_t kIndex>                                       \
1326   HWY_API HWY_SVE_V(BASE, BITS)                                  \
1327       NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
1328     return sv##OP##_##CHAR##BITS(lo, hi, kIndex);                \
1329   }
1330 HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext)
1331 #undef HWY_SVE_EXT
1332 
1333 }  // namespace detail
1334 
1335 // ------------------------------ ConcatUpperLower
1336 template <class D, class V>
ConcatUpperLower(const D d,const V hi,const V lo)1337 HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) {
1338   return IfThenElse(detail::MaskLowerHalf(d), lo, hi);
1339 }
1340 
1341 // ------------------------------ ConcatLowerLower
1342 template <class D, class V>
ConcatLowerLower(const D d,const V hi,const V lo)1343 HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) {
1344   return detail::Splice(hi, lo, detail::MaskLowerHalf(d));
1345 }
1346 
1347 // ------------------------------ ConcatLowerUpper
1348 template <class D, class V>
ConcatLowerUpper(const D d,const V hi,const V lo)1349 HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) {
1350   return detail::Splice(hi, lo, detail::MaskUpperHalf(d));
1351 }
1352 
1353 // ------------------------------ ConcatUpperUpper
1354 template <class D, class V>
ConcatUpperUpper(const D d,const V hi,const V lo)1355 HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) {
1356   const svbool_t mask_upper = detail::MaskUpperHalf(d);
1357   const V lo_upper = detail::Splice(lo, lo, mask_upper);
1358   return IfThenElse(mask_upper, hi, lo_upper);
1359 }
1360 
1361 // ------------------------------ Combine
1362 template <class D, class V2>
Combine(const D d,const V2 hi,const V2 lo)1363 HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) {
1364   return ConcatLowerLower(d, hi, lo);
1365 }
1366 
1367 // ------------------------------ ZeroExtendVector
1368 
1369 template <class D, class V>
ZeroExtendVector(const D d,const V lo)1370 HWY_API V ZeroExtendVector(const D d, const V lo) {
1371   return Combine(d, Zero(Half<D>()), lo);
1372 }
1373 
1374 // ------------------------------ Lower/UpperHalf
1375 
1376 template <class D2, class V>
LowerHalf(D2,const V v)1377 HWY_API V LowerHalf(D2 /* tag */, const V v) {
1378   return v;
1379 }
1380 
1381 template <class V>
LowerHalf(const V v)1382 HWY_API V LowerHalf(const V v) {
1383   return v;
1384 }
1385 
1386 template <class D2, class V>
UpperHalf(const D2,const V v)1387 HWY_API V UpperHalf(const D2 /* d2 */, const V v) {
1388   return detail::Splice(v, v, detail::MaskUpperHalf(Twice<D2>()));
1389 }
1390 
1391 // ================================================== SWIZZLE
1392 
1393 // ------------------------------ GetLane
1394 
1395 #define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP)      \
1396   HWY_API HWY_SVE_T(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
1397     return sv##OP##_##CHAR##BITS(detail::PFalse(), v);          \
1398   }
1399 
HWY_SVE_FOREACH(HWY_SVE_GET_LANE,GetLane,lasta)1400 HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLane, lasta)
1401 #undef HWY_SVE_GET_LANE
1402 
1403 // ------------------------------ DupEven
1404 
1405 namespace detail {
1406 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1)
1407 }  // namespace detail
1408 
1409 template <class V>
DupEven(const V v)1410 HWY_API V DupEven(const V v) {
1411   return detail::InterleaveEven(v, v);
1412 }
1413 
1414 // ------------------------------ DupOdd
1415 
1416 namespace detail {
1417 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2)
1418 }  // namespace detail
1419 
1420 template <class V>
DupOdd(const V v)1421 HWY_API V DupOdd(const V v) {
1422   return detail::InterleaveOdd(v, v);
1423 }
1424 
1425 // ------------------------------ OddEven
1426 
1427 namespace detail {
1428 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVN, Insert, insr_n)
1429 }  // namespace detail
1430 
1431 template <class V>
OddEven(const V odd,const V even)1432 HWY_API V OddEven(const V odd, const V even) {
1433   const auto even_in_odd = detail::Insert(even, 0);
1434   return detail::InterleaveOdd(even_in_odd, odd);
1435 }
1436 
1437 // ------------------------------ OddEvenBlocks
1438 template <class V>
OddEvenBlocks(const V odd,const V even)1439 HWY_API V OddEvenBlocks(const V odd, const V even) {
1440   const RebindToUnsigned<DFromV<V>> du;
1441   using TU = TFromD<decltype(du)>;
1442   constexpr size_t kShift = CeilLog2(16 / sizeof(TU));
1443   const auto idx_block = ShiftRight<kShift>(Iota(du, 0));
1444   const auto lsb = detail::AndN(idx_block, static_cast<TU>(1));
1445   const svbool_t is_even = detail::EqN(lsb, static_cast<TU>(0));
1446   return IfThenElse(is_even, even, odd);
1447 }
1448 
1449 // ------------------------------ TableLookupLanes
1450 
1451 template <class D, class VI>
IndicesFromVec(D d,VI vec)1452 HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) {
1453   using TI = TFromV<VI>;
1454   static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index/lane size mismatch");
1455   const RebindToUnsigned<D> du;
1456   const auto indices = BitCast(du, vec);
1457 #if HWY_IS_DEBUG_BUILD
1458   HWY_DASSERT(AllTrue(du, detail::LtN(indices, static_cast<TI>(Lanes(d)))));
1459 #else
1460   (void)d;
1461 #endif
1462   return indices;
1463 }
1464 
1465 template <class D, typename TI>
SetTableIndices(D d,const TI * idx)1466 HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) {
1467   static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane");
1468   return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx));
1469 }
1470 
1471 // <32bit are not part of Highway API, but used in Broadcast.
1472 #define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP)          \
1473   HWY_API HWY_SVE_V(BASE, BITS)                                  \
1474       NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \
1475     return sv##OP##_##CHAR##BITS(v, idx);                        \
1476   }
1477 
HWY_SVE_FOREACH(HWY_SVE_TABLE,TableLookupLanes,tbl)1478 HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl)
1479 #undef HWY_SVE_TABLE
1480 
1481 // ------------------------------ SwapAdjacentBlocks (TableLookupLanes)
1482 
1483 namespace detail {
1484 
1485 template <typename T, size_t N, int kPow2>
1486 constexpr size_t LanesPerBlock(Simd<T, N, kPow2> /* tag */) {
1487   // We might have a capped vector smaller than a block, so honor that.
1488   return HWY_MIN(16 / sizeof(T), detail::ScaleByPower(N, kPow2));
1489 }
1490 
1491 }  // namespace detail
1492 
1493 template <class V>
SwapAdjacentBlocks(const V v)1494 HWY_API V SwapAdjacentBlocks(const V v) {
1495   const DFromV<V> d;
1496   const RebindToUnsigned<decltype(d)> du;
1497   constexpr auto kLanesPerBlock =
1498       static_cast<TFromV<V>>(detail::LanesPerBlock(d));
1499   const VFromD<decltype(du)> idx = detail::XorN(Iota(du, 0), kLanesPerBlock);
1500   return TableLookupLanes(v, idx);
1501 }
1502 
1503 // ------------------------------ Reverse
1504 
1505 #if 0  // if we could assume VL is a power of two
1506 #error "Update macro"
1507 #endif
1508 #define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP)                \
1509   template <size_t N, int kPow2>                                         \
1510   HWY_API HWY_SVE_V(BASE, BITS)                                          \
1511       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, HWY_SVE_V(BASE, BITS) v) { \
1512     const auto reversed = sv##OP##_##CHAR##BITS(v);                      \
1513     /* Shift right to remove extra (non-pow2 and remainder) lanes. */    \
1514     const size_t all_lanes =                                             \
1515         detail::AllHardwareLanes(hwy::SizeTag<BITS / 8>());              \
1516     /* TODO(janwas): on SVE2, use whilege. */                            \
1517     /* Avoids FirstN truncating to the return vector size. */            \
1518     const ScalableTag<HWY_SVE_T(BASE, BITS)> dfull;                      \
1519     const svbool_t mask = Not(FirstN(dfull, all_lanes - Lanes(d)));      \
1520     return detail::Splice(reversed, reversed, mask);                     \
1521   }
1522 
HWY_SVE_FOREACH(HWY_SVE_REVERSE,Reverse,rev)1523 HWY_SVE_FOREACH(HWY_SVE_REVERSE, Reverse, rev)
1524 #undef HWY_SVE_REVERSE
1525 
1526 // ------------------------------ Reverse2
1527 
1528 template <class D, HWY_IF_LANE_SIZE_D(D, 2)>
1529 HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) {
1530   const RebindToUnsigned<decltype(d)> du;
1531   const RepartitionToWide<decltype(du)> dw;
1532   return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v)));
1533 }
1534 
1535 template <class D, HWY_IF_LANE_SIZE_D(D, 4)>
Reverse2(D d,const VFromD<D> v)1536 HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) {
1537   const RebindToUnsigned<decltype(d)> du;
1538   const RepartitionToWide<decltype(du)> dw;
1539   return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v)));
1540 }
1541 
1542 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
Reverse2(D,const VFromD<D> v)1543 HWY_API VFromD<D> Reverse2(D /* tag */, const VFromD<D> v) {  // 3210
1544   const auto even_in_odd = detail::Insert(v, 0);              // 210z
1545   return detail::InterleaveOdd(v, even_in_odd);               // 2301
1546 }
1547 
1548 // ------------------------------ Reverse4 (TableLookupLanes)
1549 
1550 // TODO(janwas): is this approach faster than Shuffle0123?
1551 template <class D>
Reverse4(D d,const VFromD<D> v)1552 HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) {
1553   const RebindToUnsigned<decltype(d)> du;
1554   const auto idx = detail::XorN(Iota(du, 0), 3);
1555   return TableLookupLanes(v, idx);
1556 }
1557 
1558 // ------------------------------ Reverse8 (TableLookupLanes)
1559 
1560 template <class D>
Reverse8(D d,const VFromD<D> v)1561 HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) {
1562   const RebindToUnsigned<decltype(d)> du;
1563   const auto idx = detail::XorN(Iota(du, 0), 7);
1564   return TableLookupLanes(v, idx);
1565 }
1566 
1567 // ------------------------------ Compress (PromoteTo)
1568 
1569 #define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP)                     \
1570   HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \
1571     return sv##OP##_##CHAR##BITS(mask, v);                                     \
1572   }
1573 
HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS,Compress,compact)1574 HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact)
1575 #undef HWY_SVE_COMPRESS
1576 
1577 template <class V, HWY_IF_LANE_SIZE_V(V, 2)>
1578 HWY_API V Compress(V v, svbool_t mask16) {
1579   static_assert(!IsSame<V, svfloat16_t>(), "Must use overload");
1580   const DFromV<V> d16;
1581 
1582   // Promote vector and mask to 32-bit
1583   const RepartitionToWide<decltype(d16)> dw;
1584   const auto v32L = PromoteTo(dw, v);
1585   const auto v32H = detail::PromoteUpperTo(dw, v);
1586   const svbool_t mask32L = svunpklo_b(mask16);
1587   const svbool_t mask32H = svunpkhi_b(mask16);
1588 
1589   const auto compressedL = Compress(v32L, mask32L);
1590   const auto compressedH = Compress(v32H, mask32H);
1591 
1592   // Demote to 16-bit (already in range) - separately so we can splice
1593   const V evenL = BitCast(d16, compressedL);
1594   const V evenH = BitCast(d16, compressedH);
1595   const V v16L = detail::ConcatEven(evenL, evenL);  // only lower half needed
1596   const V v16H = detail::ConcatEven(evenH, evenH);
1597 
1598   // We need to combine two vectors of non-constexpr length, so the only option
1599   // is Splice, which requires us to synthesize a mask. NOTE: this function uses
1600   // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt.
1601   const size_t countL = detail::CountTrueFull(dw, mask32L);
1602   const auto compressed_maskL = FirstN(d16, countL);
1603   return detail::Splice(v16H, v16L, compressed_maskL);
1604 }
1605 
1606 // Must treat float16_t as integers so we can ConcatEven.
Compress(svfloat16_t v,svbool_t mask16)1607 HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) {
1608   const DFromV<decltype(v)> df;
1609   const RebindToSigned<decltype(df)> di;
1610   return BitCast(df, Compress(BitCast(di, v), mask16));
1611 }
1612 
1613 // ------------------------------ CompressStore
1614 
1615 template <class V, class M, class D>
CompressStore(const V v,const M mask,const D d,TFromD<D> * HWY_RESTRICT unaligned)1616 HWY_API size_t CompressStore(const V v, const M mask, const D d,
1617                              TFromD<D>* HWY_RESTRICT unaligned) {
1618   StoreU(Compress(v, mask), d, unaligned);
1619   return CountTrue(d, mask);
1620 }
1621 
1622 // ------------------------------ CompressBlendedStore
1623 
1624 template <class V, class M, class D>
CompressBlendedStore(const V v,const M mask,const D d,TFromD<D> * HWY_RESTRICT unaligned)1625 HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d,
1626                                     TFromD<D>* HWY_RESTRICT unaligned) {
1627   const size_t count = CountTrue(d, mask);
1628   const svbool_t store_mask = FirstN(d, count);
1629   MaskedStore(store_mask, Compress(v, mask), d, unaligned);
1630   return count;
1631 }
1632 
1633 // ================================================== BLOCKWISE
1634 
1635 // ------------------------------ CombineShiftRightBytes
1636 
1637 namespace detail {
1638 
1639 // For x86-compatible behaviour mandated by Highway API: TableLookupBytes
1640 // offsets are implicitly relative to the start of their 128-bit block.
1641 template <class D, class V>
OffsetsOf128BitBlocks(const D d,const V iota0)1642 HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) {
1643   using T = MakeUnsigned<TFromD<D>>;
1644   return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0);
1645 }
1646 
1647 template <size_t kLanes, class D>
FirstNPerBlock(D d)1648 svbool_t FirstNPerBlock(D d) {
1649   const RebindToSigned<decltype(d)> di;
1650   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di);
1651   const auto idx_mod = detail::AndN(Iota(di, 0), kLanesPerBlock - 1);
1652   return detail::LtN(BitCast(di, idx_mod), kLanes);
1653 }
1654 
1655 }  // namespace detail
1656 
1657 template <size_t kBytes, class D, class V = VFromD<D>>
CombineShiftRightBytes(const D d,const V hi,const V lo)1658 HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) {
1659   const Repartition<uint8_t, decltype(d)> d8;
1660   const auto hi8 = BitCast(d8, hi);
1661   const auto lo8 = BitCast(d8, lo);
1662   const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes));
1663   const auto lo_down = detail::Ext<kBytes>(lo8, lo8);
1664   const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8);
1665   return BitCast(d, IfThenElse(is_lo, lo_down, hi_up));
1666 }
1667 
1668 // ------------------------------ Shuffle2301
1669 
1670 template <class V>
Shuffle2301(const V v)1671 HWY_API V Shuffle2301(const V v) {
1672   const DFromV<V> d;
1673   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1674   return Reverse2(d, v);
1675 }
1676 
1677 // ------------------------------ Shuffle2103
1678 template <class V>
Shuffle2103(const V v)1679 HWY_API V Shuffle2103(const V v) {
1680   const DFromV<V> d;
1681   const Repartition<uint8_t, decltype(d)> d8;
1682   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1683   const svuint8_t v8 = BitCast(d8, v);
1684   return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8));
1685 }
1686 
1687 // ------------------------------ Shuffle0321
1688 template <class V>
Shuffle0321(const V v)1689 HWY_API V Shuffle0321(const V v) {
1690   const DFromV<V> d;
1691   const Repartition<uint8_t, decltype(d)> d8;
1692   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1693   const svuint8_t v8 = BitCast(d8, v);
1694   return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8));
1695 }
1696 
1697 // ------------------------------ Shuffle1032
1698 template <class V>
Shuffle1032(const V v)1699 HWY_API V Shuffle1032(const V v) {
1700   const DFromV<V> d;
1701   const Repartition<uint8_t, decltype(d)> d8;
1702   static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1703   const svuint8_t v8 = BitCast(d8, v);
1704   return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
1705 }
1706 
1707 // ------------------------------ Shuffle01
1708 template <class V>
Shuffle01(const V v)1709 HWY_API V Shuffle01(const V v) {
1710   const DFromV<V> d;
1711   const Repartition<uint8_t, decltype(d)> d8;
1712   static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types");
1713   const svuint8_t v8 = BitCast(d8, v);
1714   return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
1715 }
1716 
1717 // ------------------------------ Shuffle0123
1718 template <class V>
Shuffle0123(const V v)1719 HWY_API V Shuffle0123(const V v) {
1720   return Shuffle2301(Shuffle1032(v));
1721 }
1722 
1723 // ------------------------------ ReverseBlocks (Reverse, Shuffle01)
1724 template <class D, class V = VFromD<D>>
ReverseBlocks(D d,V v)1725 HWY_API V ReverseBlocks(D d, V v) {
1726   const Repartition<uint64_t, D> du64;
1727   return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v))));
1728 }
1729 
1730 // ------------------------------ TableLookupBytes
1731 
1732 template <class V, class VI>
TableLookupBytes(const V v,const VI idx)1733 HWY_API VI TableLookupBytes(const V v, const VI idx) {
1734   const DFromV<VI> d;
1735   const Repartition<uint8_t, decltype(d)> du8;
1736   const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0));
1737   const auto idx8 = Add(BitCast(du8, idx), offsets128);
1738   return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8));
1739 }
1740 
1741 template <class V, class VI>
TableLookupBytesOr0(const V v,const VI idx)1742 HWY_API VI TableLookupBytesOr0(const V v, const VI idx) {
1743   const DFromV<VI> d;
1744   // Mask size must match vector type, so cast everything to this type.
1745   const Repartition<int8_t, decltype(d)> di8;
1746 
1747   auto idx8 = BitCast(di8, idx);
1748   const auto msb = detail::LtN(idx8, 0);
1749 
1750   const auto lookup = TableLookupBytes(BitCast(di8, v), idx8);
1751   return BitCast(d, IfThenZeroElse(msb, lookup));
1752 }
1753 
1754 // ------------------------------ Broadcast
1755 
1756 template <int kLane, class V>
Broadcast(const V v)1757 HWY_API V Broadcast(const V v) {
1758   const DFromV<V> d;
1759   const RebindToUnsigned<decltype(d)> du;
1760   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
1761   static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane");
1762   auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0));
1763   if (kLane != 0) {
1764     idx = detail::AddN(idx, kLane);
1765   }
1766   return TableLookupLanes(v, idx);
1767 }
1768 
1769 // ------------------------------ ShiftLeftLanes
1770 
1771 template <size_t kLanes, class D, class V = VFromD<D>>
ShiftLeftLanes(D d,const V v)1772 HWY_API V ShiftLeftLanes(D d, const V v) {
1773   const auto zero = Zero(d);
1774   const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes));
1775   // Match x86 semantics by zeroing lower lanes in 128-bit blocks
1776   return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted);
1777 }
1778 
1779 template <size_t kLanes, class V>
ShiftLeftLanes(const V v)1780 HWY_API V ShiftLeftLanes(const V v) {
1781   return ShiftLeftLanes<kLanes>(DFromV<V>(), v);
1782 }
1783 
1784 // ------------------------------ ShiftRightLanes
1785 template <size_t kLanes, class D, class V = VFromD<D>>
ShiftRightLanes(D d,V v)1786 HWY_API V ShiftRightLanes(D d, V v) {
1787   // For capped/fractional vectors, clear upper lanes so we shift in zeros.
1788   if (!detail::IsFull(d)) {
1789     v = IfThenElseZero(detail::MakeMask(d), v);
1790   }
1791 
1792   const auto shifted = detail::Ext<kLanes>(v, v);
1793   // Match x86 semantics by zeroing upper lanes in 128-bit blocks
1794   constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
1795   const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d);
1796   return IfThenElseZero(mask, shifted);
1797 }
1798 
1799 // ------------------------------ ShiftLeftBytes
1800 
1801 template <int kBytes, class D, class V = VFromD<D>>
ShiftLeftBytes(const D d,const V v)1802 HWY_API V ShiftLeftBytes(const D d, const V v) {
1803   const Repartition<uint8_t, decltype(d)> d8;
1804   return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v)));
1805 }
1806 
1807 template <int kBytes, class V>
ShiftLeftBytes(const V v)1808 HWY_API V ShiftLeftBytes(const V v) {
1809   return ShiftLeftBytes<kBytes>(DFromV<V>(), v);
1810 }
1811 
1812 // ------------------------------ ShiftRightBytes
1813 template <int kBytes, class D, class V = VFromD<D>>
ShiftRightBytes(const D d,const V v)1814 HWY_API V ShiftRightBytes(const D d, const V v) {
1815   const Repartition<uint8_t, decltype(d)> d8;
1816   return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v)));
1817 }
1818 
1819 // ------------------------------ InterleaveLower
1820 
1821 template <class D, class V>
InterleaveLower(D d,const V a,const V b)1822 HWY_API V InterleaveLower(D d, const V a, const V b) {
1823   static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch");
1824   // Move lower halves of blocks to lower half of vector.
1825   const Repartition<uint64_t, decltype(d)> d64;
1826   const auto a64 = BitCast(d64, a);
1827   const auto b64 = BitCast(d64, b);
1828   const auto a_blocks = detail::ConcatEven(a64, a64);  // only lower half needed
1829   const auto b_blocks = detail::ConcatEven(b64, b64);
1830 
1831   return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks));
1832 }
1833 
1834 template <class V>
InterleaveLower(const V a,const V b)1835 HWY_API V InterleaveLower(const V a, const V b) {
1836   return InterleaveLower(DFromV<V>(), a, b);
1837 }
1838 
1839 // ------------------------------ InterleaveUpper
1840 
1841 // Full vector: guaranteed to have at least one block
1842 template <class D, class V = VFromD<D>,
1843           hwy::EnableIf<detail::IsFull(D())>* = nullptr>
InterleaveUpper(D d,const V a,const V b)1844 HWY_API V InterleaveUpper(D d, const V a, const V b) {
1845   // Move upper halves of blocks to lower half of vector.
1846   const Repartition<uint64_t, decltype(d)> d64;
1847   const auto a64 = BitCast(d64, a);
1848   const auto b64 = BitCast(d64, b);
1849   const auto a_blocks = detail::ConcatOdd(a64, a64);  // only lower half needed
1850   const auto b_blocks = detail::ConcatOdd(b64, b64);
1851   return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks));
1852 }
1853 
1854 // Capped/fraction: need runtime check
1855 template <class D, class V = VFromD<D>,
1856           hwy::EnableIf<!detail::IsFull(D())>* = nullptr>
InterleaveUpper(D d,const V a,const V b)1857 HWY_API V InterleaveUpper(D d, const V a, const V b) {
1858   // Less than one block: treat as capped
1859   if (Lanes(d) * sizeof(TFromD<D>) < 16) {
1860     const Half<decltype(d)> d2;
1861     return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b));
1862   }
1863   return InterleaveUpper(DFromV<V>(), a, b);
1864 }
1865 
1866 // ------------------------------ ZipLower
1867 
1868 template <class V, class DW = RepartitionToWide<DFromV<V>>>
ZipLower(DW dw,V a,V b)1869 HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) {
1870   const RepartitionToNarrow<DW> dn;
1871   static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
1872   return BitCast(dw, InterleaveLower(dn, a, b));
1873 }
1874 template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>>
ZipLower(const V a,const V b)1875 HWY_API VFromD<DW> ZipLower(const V a, const V b) {
1876   return BitCast(DW(), InterleaveLower(D(), a, b));
1877 }
1878 
1879 // ------------------------------ ZipUpper
1880 template <class V, class DW = RepartitionToWide<DFromV<V>>>
ZipUpper(DW dw,V a,V b)1881 HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) {
1882   const RepartitionToNarrow<DW> dn;
1883   static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
1884   return BitCast(dw, InterleaveUpper(dn, a, b));
1885 }
1886 
1887 // ================================================== REDUCE
1888 
1889 #define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP)                 \
1890   template <size_t N, int kPow2>                                         \
1891   HWY_API HWY_SVE_V(BASE, BITS)                                          \
1892       NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, HWY_SVE_V(BASE, BITS) v) { \
1893     return Set(d, static_cast<HWY_SVE_T(BASE, BITS)>(                    \
1894                       sv##OP##_##CHAR##BITS(detail::MakeMask(d), v)));   \
1895   }
1896 
HWY_SVE_FOREACH(HWY_SVE_REDUCE,SumOfLanes,addv)1897 HWY_SVE_FOREACH(HWY_SVE_REDUCE, SumOfLanes, addv)
1898 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanes, minv)
1899 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanes, maxv)
1900 // NaN if all are
1901 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanes, minnmv)
1902 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanes, maxnmv)
1903 
1904 #undef HWY_SVE_REDUCE
1905 
1906 // ================================================== Ops with dependencies
1907 
1908 // ------------------------------ PromoteTo bfloat16 (ZipLower)
1909 
1910 template <size_t N, int kPow2>
1911 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> df32,
1912                               const svuint16_t v) {
1913   return BitCast(df32, detail::ZipLower(svdup_n_u16(0), v));
1914 }
1915 
1916 // ------------------------------ ReorderDemote2To (OddEven)
1917 
1918 template <size_t N, int kPow2>
ReorderDemote2To(Simd<bfloat16_t,N,kPow2> dbf16,svfloat32_t a,svfloat32_t b)1919 HWY_API svuint16_t ReorderDemote2To(Simd<bfloat16_t, N, kPow2> dbf16,
1920                                     svfloat32_t a, svfloat32_t b) {
1921   const RebindToUnsigned<decltype(dbf16)> du16;
1922   const Repartition<uint32_t, decltype(dbf16)> du32;
1923   const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b));
1924   return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even)));
1925 }
1926 
1927 // ------------------------------ ZeroIfNegative (Lt, IfThenElse)
1928 template <class V>
ZeroIfNegative(const V v)1929 HWY_API V ZeroIfNegative(const V v) {
1930   return IfThenZeroElse(detail::LtN(v, 0), v);
1931 }
1932 
1933 // ------------------------------ BroadcastSignBit (ShiftRight)
1934 template <class V>
BroadcastSignBit(const V v)1935 HWY_API V BroadcastSignBit(const V v) {
1936   return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v);
1937 }
1938 
1939 // ------------------------------ IfNegativeThenElse (BroadcastSignBit)
1940 template <class V>
IfNegativeThenElse(V v,V yes,V no)1941 HWY_API V IfNegativeThenElse(V v, V yes, V no) {
1942   static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
1943   const DFromV<V> d;
1944   const RebindToSigned<decltype(d)> di;
1945 
1946   const svbool_t m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v))));
1947   return IfThenElse(m, yes, no);
1948 }
1949 
1950 // ------------------------------ AverageRound (ShiftRight)
1951 
1952 #if HWY_TARGET == HWY_SVE2
HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV,AverageRound,rhadd)1953 HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
1954 HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
1955 #else
1956 template <class V>
1957 V AverageRound(const V a, const V b) {
1958   return ShiftRight<1>(detail::AddN(Add(a, b), 1));
1959 }
1960 #endif  // HWY_TARGET == HWY_SVE2
1961 
1962 // ------------------------------ LoadMaskBits (TestBit)
1963 
1964 // `p` points to at least 8 readable bytes, not all of which need be valid.
1965 template <class D, HWY_IF_LANE_SIZE_D(D, 1)>
1966 HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
1967   const RebindToUnsigned<D> du;
1968   const svuint8_t iota = Iota(du, 0);
1969 
1970   // Load correct number of bytes (bits/8) with 7 zeros after each.
1971   const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits));
1972   // Replicate bytes 8x such that each byte contains the bit that governs it.
1973   const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota));
1974 
1975   // 1, 2, 4, 8, 16, 32, 64, 128,  1, 2 ..
1976   const svuint8_t bit = Shl(Set(du, 1), detail::AndN(iota, 7));
1977 
1978   return TestBit(rep8, bit);
1979 }
1980 
1981 template <class D, HWY_IF_LANE_SIZE_D(D, 2)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1982 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1983                                  const uint8_t* HWY_RESTRICT bits) {
1984   const RebindToUnsigned<D> du;
1985   const Repartition<uint8_t, D> du8;
1986 
1987   // There may be up to 128 bits; avoid reading past the end.
1988   const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits);
1989 
1990   // Replicate bytes 16x such that each lane contains the bit that governs it.
1991   const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0)));
1992 
1993   // 1, 2, 4, 8, 16, 32, 64, 128,  1, 2 ..
1994   const svuint16_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
1995 
1996   return TestBit(BitCast(du, rep16), bit);
1997 }
1998 
1999 template <class D, HWY_IF_LANE_SIZE_D(D, 4)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)2000 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
2001                                  const uint8_t* HWY_RESTRICT bits) {
2002   const RebindToUnsigned<D> du;
2003   const Repartition<uint8_t, D> du8;
2004 
2005   // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable,
2006   // so we can skip computing the actual length (Lanes(du)+7)/8.
2007   const svuint8_t bytes = svld1(FirstN(du8, 8), bits);
2008 
2009   // Replicate bytes 32x such that each lane contains the bit that governs it.
2010   const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0)));
2011 
2012   // 1, 2, 4, 8, 16, 32, 64, 128,  1, 2 ..
2013   const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
2014 
2015   return TestBit(BitCast(du, rep32), bit);
2016 }
2017 
2018 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)2019 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
2020                                  const uint8_t* HWY_RESTRICT bits) {
2021   const RebindToUnsigned<D> du;
2022 
2023   // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane.
2024   // The "at least 8 byte" guarantee in quick_reference ensures this is safe.
2025   uint32_t mask_bits;
2026   CopyBytes<4>(bits, &mask_bits);
2027   const auto vbits = Set(du, mask_bits);
2028 
2029   // 2 ^ {0,1, .., 31}, will not have more lanes than that.
2030   const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0));
2031 
2032   return TestBit(vbits, bit);
2033 }
2034 
2035 // ------------------------------ StoreMaskBits
2036 
2037 namespace detail {
2038 
2039 // For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes.
2040 template <class T, HWY_IF_LANE_SIZE(T, 1)>
BoolFromMask(svbool_t m)2041 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
2042   return svdup_n_u8_z(m, 1);
2043 }
2044 template <class T, HWY_IF_LANE_SIZE(T, 2)>
BoolFromMask(svbool_t m)2045 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
2046   const ScalableTag<uint8_t> d8;
2047   const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1));
2048   return detail::ConcatEven(b16, b16);  // only lower half needed
2049 }
2050 template <class T, HWY_IF_LANE_SIZE(T, 4)>
BoolFromMask(svbool_t m)2051 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
2052   return U8FromU32(svdup_n_u32_z(m, 1));
2053 }
2054 template <class T, HWY_IF_LANE_SIZE(T, 8)>
BoolFromMask(svbool_t m)2055 HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
2056   const ScalableTag<uint32_t> d32;
2057   const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1));
2058   return U8FromU32(detail::ConcatEven(b64, b64));  // only lower half needed
2059 }
2060 
2061 // Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane.
BitsFromBool(svuint8_t x)2062 HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) {
2063   const ScalableTag<uint8_t> d8;
2064   const ScalableTag<uint16_t> d16;
2065   const ScalableTag<uint32_t> d32;
2066   const ScalableTag<uint64_t> d64;
2067   // TODO(janwas): could use SVE2 BDEP, but it's optional.
2068   x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x))));
2069   x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x))));
2070   x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x))));
2071   return BitCast(d64, x);
2072 }
2073 
2074 }  // namespace detail
2075 
2076 // `p` points to at least 8 writable bytes.
2077 template <class D>
StoreMaskBits(D d,svbool_t m,uint8_t * bits)2078 HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) {
2079   svuint64_t bits_in_u64 =
2080       detail::BitsFromBool(detail::BoolFromMask<TFromD<D>>(m));
2081 
2082   const size_t num_bits = Lanes(d);
2083   const size_t num_bytes = (num_bits + 8 - 1) / 8;  // Round up, see below
2084 
2085   // Truncate each u64 to 8 bits and store to u8.
2086   svst1b_u64(FirstN(ScalableTag<uint64_t>(), num_bytes), bits, bits_in_u64);
2087 
2088   // Non-full byte, need to clear the undefined upper bits. Can happen for
2089   // capped/fractional vectors or large T and small hardware vectors.
2090   if (num_bits < 8) {
2091     const int mask = (1 << num_bits) - 1;
2092     bits[0] = static_cast<uint8_t>(bits[0] & mask);
2093   }
2094   // Else: we wrote full bytes because num_bits is a power of two >= 8.
2095 
2096   return num_bytes;
2097 }
2098 
2099 // ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits)
2100 
2101 template <class V>
CompressBits(V v,const uint8_t * HWY_RESTRICT bits)2102 HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) {
2103   return Compress(v, LoadMaskBits(DFromV<V>(), bits));
2104 }
2105 
2106 template <class D>
CompressBitsStore(VFromD<D> v,const uint8_t * HWY_RESTRICT bits,D d,TFromD<D> * HWY_RESTRICT unaligned)2107 HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits,
2108                                  D d, TFromD<D>* HWY_RESTRICT unaligned) {
2109   return CompressStore(v, LoadMaskBits(d, bits), d, unaligned);
2110 }
2111 
2112 // ------------------------------ MulEven (InterleaveEven)
2113 
2114 #if HWY_TARGET == HWY_SVE2
2115 namespace detail {
2116 #define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP)     \
2117   HWY_API HWY_SVE_V(BASE, BITS)                                \
2118       NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \
2119     return sv##OP##_##CHAR##BITS(a, b);                        \
2120   }
2121 
2122 HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEven, mullb)
2123 #undef HWY_SVE_MUL_EVEN
2124 }  // namespace detail
2125 #endif
2126 
2127 template <class V, class DW = RepartitionToWide<DFromV<V>>>
MulEven(const V a,const V b)2128 HWY_API VFromD<DW> MulEven(const V a, const V b) {
2129 #if HWY_TARGET == HWY_SVE2
2130   return BitCast(DW(), detail::MulEven(a, b));
2131 #else
2132   const auto lo = Mul(a, b);
2133   const auto hi = detail::MulHigh(a, b);
2134   return BitCast(DW(), detail::InterleaveEven(lo, hi));
2135 #endif
2136 }
2137 
MulEven(const svuint64_t a,const svuint64_t b)2138 HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) {
2139   const auto lo = Mul(a, b);
2140   const auto hi = detail::MulHigh(a, b);
2141   return detail::InterleaveEven(lo, hi);
2142 }
2143 
MulOdd(const svuint64_t a,const svuint64_t b)2144 HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) {
2145   const auto lo = Mul(a, b);
2146   const auto hi = detail::MulHigh(a, b);
2147   return detail::InterleaveOdd(lo, hi);
2148 }
2149 
2150 // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)
2151 
2152 template <size_t N, int kPow2>
ReorderWidenMulAccumulate(Simd<float,N,kPow2> df32,svuint16_t a,svuint16_t b,const svfloat32_t sum0,svfloat32_t & sum1)2153 HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd<float, N, kPow2> df32,
2154                                               svuint16_t a, svuint16_t b,
2155                                               const svfloat32_t sum0,
2156                                               svfloat32_t& sum1) {
2157   // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16.
2158   const Repartition<uint16_t, decltype(df32)> du16;
2159   const RebindToUnsigned<decltype(df32)> du32;
2160   const svuint16_t zero = Zero(du16);
2161   const svuint32_t a0 = ZipLower(du32, zero, BitCast(du16, a));
2162   const svuint32_t a1 = ZipUpper(du32, zero, BitCast(du16, a));
2163   const svuint32_t b0 = ZipLower(du32, zero, BitCast(du16, b));
2164   const svuint32_t b1 = ZipUpper(du32, zero, BitCast(du16, b));
2165   sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1);
2166   return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0);
2167 }
2168 
2169 // ------------------------------ AESRound / CLMul
2170 
2171 #if defined(__ARM_FEATURE_SVE2_AES)
2172 
2173 // Per-target flag to prevent generic_ops-inl.h from defining AESRound.
2174 #ifdef HWY_NATIVE_AES
2175 #undef HWY_NATIVE_AES
2176 #else
2177 #define HWY_NATIVE_AES
2178 #endif
2179 
AESRound(svuint8_t state,svuint8_t round_key)2180 HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) {
2181   // It is not clear whether E and MC fuse like they did on NEON.
2182   const svuint8_t zero = svdup_n_u8(0);
2183   return Xor(svaesmc_u8(svaese_u8(state, zero)), round_key);
2184 }
2185 
AESLastRound(svuint8_t state,svuint8_t round_key)2186 HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) {
2187   return Xor(svaese_u8(state, svdup_n_u8(0)), round_key);
2188 }
2189 
CLMulLower(const svuint64_t a,const svuint64_t b)2190 HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) {
2191   return svpmullb_pair(a, b);
2192 }
2193 
CLMulUpper(const svuint64_t a,const svuint64_t b)2194 HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) {
2195   return svpmullt_pair(a, b);
2196 }
2197 
2198 #endif  // __ARM_FEATURE_SVE2_AES
2199 
2200 // ------------------------------ Lt128
2201 
2202 template <class D>
Lt128(D,const svuint64_t a,const svuint64_t b)2203 HWY_INLINE svbool_t Lt128(D /* d */, const svuint64_t a, const svuint64_t b) {
2204   static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, "Use u64");
2205   // Truth table of Eq and Compare for Hi and Lo u64.
2206   // (removed lines with (=H && cH) or (=L && cL) - cannot both be true)
2207   // =H =L cH cL  | out = cH | (=H & cL) = IfThenElse(=H, cL, cH)
2208   //  0  0  0  0  |  0
2209   //  0  0  0  1  |  0
2210   //  0  0  1  0  |  1
2211   //  0  0  1  1  |  1
2212   //  0  1  0  0  |  0
2213   //  0  1  0  1  |  0
2214   //  0  1  1  0  |  1
2215   //  1  0  0  0  |  0
2216   //  1  0  0  1  |  1
2217   //  1  1  0  0  |  0
2218   const svbool_t eqHL = Eq(a, b);
2219   const svbool_t ltHL = Lt(a, b);
2220   // trn (interleave even/odd) allow us to move and copy masks across lanes.
2221   const svbool_t cmpLL = svtrn1_b64(ltHL, ltHL);
2222   const svbool_t outHx = svsel_b(eqHL, cmpLL, ltHL);   // See truth table above.
2223   return svtrn2_b64(outHx, outHx);                     // replicate to HH
2224 }
2225 
2226 // ------------------------------ Min128, Max128 (Lt128)
2227 
2228 template <class D>
Min128(D d,const svuint64_t a,const svuint64_t b)2229 HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) {
2230   return IfThenElse(Lt128(d, a, b), a, b);
2231 }
2232 
2233 template <class D>
Max128(D d,const svuint64_t a,const svuint64_t b)2234 HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) {
2235   return IfThenElse(Lt128(d, a, b), b, a);
2236 }
2237 
2238 // ================================================== END MACROS
2239 namespace detail {  // for code folding
2240 #undef HWY_IF_FLOAT_V
2241 #undef HWY_IF_LANE_SIZE_V
2242 #undef HWY_IF_SIGNED_V
2243 #undef HWY_IF_UNSIGNED_V
2244 #undef HWY_SVE_D
2245 #undef HWY_SVE_FOREACH
2246 #undef HWY_SVE_FOREACH_F
2247 #undef HWY_SVE_FOREACH_F16
2248 #undef HWY_SVE_FOREACH_F32
2249 #undef HWY_SVE_FOREACH_F64
2250 #undef HWY_SVE_FOREACH_I
2251 #undef HWY_SVE_FOREACH_I08
2252 #undef HWY_SVE_FOREACH_I16
2253 #undef HWY_SVE_FOREACH_I32
2254 #undef HWY_SVE_FOREACH_I64
2255 #undef HWY_SVE_FOREACH_IF
2256 #undef HWY_SVE_FOREACH_U
2257 #undef HWY_SVE_FOREACH_U08
2258 #undef HWY_SVE_FOREACH_U16
2259 #undef HWY_SVE_FOREACH_U32
2260 #undef HWY_SVE_FOREACH_U64
2261 #undef HWY_SVE_FOREACH_UI
2262 #undef HWY_SVE_FOREACH_UI08
2263 #undef HWY_SVE_FOREACH_UI16
2264 #undef HWY_SVE_FOREACH_UI32
2265 #undef HWY_SVE_FOREACH_UI64
2266 #undef HWY_SVE_FOREACH_UIF3264
2267 #undef HWY_SVE_PTRUE
2268 #undef HWY_SVE_RETV_ARGPV
2269 #undef HWY_SVE_RETV_ARGPVN
2270 #undef HWY_SVE_RETV_ARGPVV
2271 #undef HWY_SVE_RETV_ARGV
2272 #undef HWY_SVE_RETV_ARGVN
2273 #undef HWY_SVE_RETV_ARGVV
2274 #undef HWY_SVE_T
2275 #undef HWY_SVE_UNDEFINED
2276 #undef HWY_SVE_V
2277 
2278 }  // namespace detail
2279 // NOLINTNEXTLINE(google-readability-namespace-comments)
2280 }  // namespace HWY_NAMESPACE
2281 }  // namespace hwy
2282 HWY_AFTER_NAMESPACE();
2283