1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // ARM SVE[2] vectors (length not known at compile time).
16 // External include guard in highway.h - see comment there.
17
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #if defined(HWY_EMULATE_SVE)
22 #include "third_party/farm_sve/farm_sve.h"
23 #else
24 #include <arm_sve.h>
25 #endif
26
27 #include "hwy/base.h"
28 #include "hwy/ops/shared-inl.h"
29
30 HWY_BEFORE_NAMESPACE();
31 namespace hwy {
32 namespace HWY_NAMESPACE {
33
34 // SVE only supports fractions, not LMUL > 1.
35 template <typename T, int kShift = 0>
36 using Full = Simd<T, (kShift <= 0) ? (HWY_LANES(T) >> (-kShift)) : 0>;
37
38 template <class V>
39 struct DFromV_t {}; // specialized in macros
40 template <class V>
41 using DFromV = typename DFromV_t<RemoveConst<V>>::type;
42
43 template <class V>
44 using TFromV = TFromD<DFromV<V>>;
45
46 #define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV<V>)
47 #define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV<V>)
48 #define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV<V>)
49 #define HWY_IF_LANE_SIZE_V(V, bytes) HWY_IF_LANE_SIZE(TFromV<V>, bytes)
50
51 // ================================================== MACROS
52
53 // Generate specializations and function definitions using X macros. Although
54 // harder to read and debug, writing everything manually is too bulky.
55
56 namespace detail { // for code folding
57
58 // Unsigned:
59 #define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, NAME, OP)
60 #define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, NAME, OP)
61 #define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) X_MACRO(uint, u, 32, NAME, OP)
62 #define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) X_MACRO(uint, u, 64, NAME, OP)
63
64 // Signed:
65 #define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, NAME, OP)
66 #define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, NAME, OP)
67 #define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, NAME, OP)
68 #define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, NAME, OP)
69
70 // Float:
71 #define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) X_MACRO(float, f, 16, NAME, OP)
72 #define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) X_MACRO(float, f, 32, NAME, OP)
73 #define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) X_MACRO(float, f, 64, NAME, OP)
74
75 // For all element sizes:
76 #define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
77 HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \
78 HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \
79 HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
80 HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)
81
82 #define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
83 HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \
84 HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \
85 HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \
86 HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
87
88 #define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \
89 HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \
90 HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
91 HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
92
93 // Commonly used type categories for a given element size:
94 #define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \
95 HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \
96 HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)
97
98 #define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \
99 HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \
100 HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)
101
102 #define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
103 HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
104 HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)
105
106 #define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
107 HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \
108 HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
109
110 #define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \
111 HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
112 HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
113 HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
114 HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
115
116 // Commonly used type categories:
117 #define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \
118 HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
119 HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)
120
121 #define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \
122 HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
123 HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
124
125 #define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \
126 HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
127 HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
128 HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
129
130 // Assemble types for use in x-macros
131 #define HWY_SVE_T(BASE, BITS) BASE##BITS##_t
132 #define HWY_SVE_D(BASE, BITS, N) Simd<HWY_SVE_T(BASE, BITS), N>
133 #define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t
134
135 } // namespace detail
136
137 #define HWY_SPECIALIZE(BASE, CHAR, BITS, NAME, OP) \
138 template <> \
139 struct DFromV_t<HWY_SVE_V(BASE, BITS)> { \
140 using type = HWY_SVE_D(BASE, BITS, HWY_LANES(HWY_SVE_T(BASE, BITS))); \
141 };
142
HWY_SVE_FOREACH(HWY_SPECIALIZE,_,_)143 HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _)
144 #undef HWY_SPECIALIZE
145
146 // vector = f(d), e.g. Undefined
147 #define HWY_SVE_RETV_ARGD(BASE, CHAR, BITS, NAME, OP) \
148 template <size_t N> \
149 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N) d) { \
150 return sv##OP##_##CHAR##BITS(); \
151 }
152
153 // Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX
154 // instructions, and we anyway only use it when the predicate is ptrue.
155
156 // vector = f(vector), e.g. Not
157 #define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, NAME, OP) \
158 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
159 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
160 }
161 #define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, NAME, OP) \
162 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
163 return sv##OP##_##CHAR##BITS(v); \
164 }
165
166 // vector = f(vector, scalar), e.g. detail::AddK
167 #define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, NAME, OP) \
168 HWY_API HWY_SVE_V(BASE, BITS) \
169 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
170 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
171 }
172 #define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, NAME, OP) \
173 HWY_API HWY_SVE_V(BASE, BITS) \
174 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
175 return sv##OP##_##CHAR##BITS(a, b); \
176 }
177
178 // vector = f(vector, vector), e.g. Add
179 #define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, NAME, OP) \
180 HWY_API HWY_SVE_V(BASE, BITS) \
181 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
182 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
183 }
184 #define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, NAME, OP) \
185 HWY_API HWY_SVE_V(BASE, BITS) \
186 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
187 return sv##OP##_##CHAR##BITS(a, b); \
188 }
189
190 // ------------------------------ Lanes
191
192 namespace detail {
193
194 // Returns actual lanes of a hardware vector without rounding to a power of two.
195 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<1> /* tag */) {
196 return svcntb_pat(SV_ALL);
197 }
198 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<2> /* tag */) {
199 return svcnth_pat(SV_ALL);
200 }
201 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<4> /* tag */) {
202 return svcntw_pat(SV_ALL);
203 }
204 HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<8> /* tag */) {
205 return svcntd_pat(SV_ALL);
206 }
207
208 // Returns actual lanes of a hardware vector, rounded down to a power of two.
209 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<1> /* tag */) {
210 return svcntb_pat(SV_POW2);
211 }
212 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<2> /* tag */) {
213 return svcnth_pat(SV_POW2);
214 }
215 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<4> /* tag */) {
216 return svcntw_pat(SV_POW2);
217 }
218 HWY_INLINE size_t HardwareLanes(hwy::SizeTag<8> /* tag */) {
219 return svcntd_pat(SV_POW2);
220 }
221
222 } // namespace detail
223
224 // Capped to <= 128-bit: SVE is at least that large, so no need to query actual.
225 template <typename T, size_t N, HWY_IF_LE128(T, N)>
Lanes(Simd<T,N>)226 HWY_API constexpr size_t Lanes(Simd<T, N> /* tag */) {
227 return N;
228 }
229
230 // Returns actual number of lanes after dividing by div={1,2,4,8}.
231 // May return 0 if div > 16/sizeof(T): there is no "1/8th" of a u32x4, but it
232 // would be valid for u32x8 (i.e. hardware vectors >= 256 bits).
233 template <typename T, size_t N, HWY_IF_GT128(T, N)>
Lanes(Simd<T,N>)234 HWY_API size_t Lanes(Simd<T, N> /* tag */) {
235 static_assert(N <= HWY_LANES(T), "N cannot exceed a full vector");
236
237 const size_t actual = detail::HardwareLanes(hwy::SizeTag<sizeof(T)>());
238 const size_t div = HWY_LANES(T) / N;
239 return (div <= 8) ? actual / div : HWY_MIN(actual, N);
240 }
241
242 // ================================================== MASK INIT
243
244 // One mask bit per byte; only the one belonging to the lowest byte is valid.
245
246 // ------------------------------ FirstN
247 #define HWY_SVE_FIRSTN(BASE, CHAR, BITS, NAME, OP) \
248 template <size_t KN> \
249 HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, KN) /* d */, size_t N) { \
250 return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(N)); \
251 }
HWY_SVE_FOREACH(HWY_SVE_FIRSTN,FirstN,whilelt)252 HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt)
253 #undef HWY_SVE_FIRSTN
254
255 namespace detail {
256
257 // All-true mask from a macro
258 #define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2)
259
260 #define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, NAME, OP) \
261 template <size_t N> \
262 HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N) d) { \
263 return HWY_SVE_PTRUE(BITS); \
264 }
265
266 HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true
267 #undef HWY_SVE_WRAP_PTRUE
268
269 HWY_API svbool_t PFalse() { return svpfalse_b(); }
270
271 // Returns all-true if d is HWY_FULL or FirstN(N) after capping N.
272 //
273 // This is used in functions that load/store memory; other functions (e.g.
274 // arithmetic on partial vectors) can ignore d and use PTrue instead.
275 template <typename T, size_t N>
276 svbool_t Mask(Simd<T, N> d) {
277 return N == HWY_LANES(T) ? PTrue(d) : FirstN(d, Lanes(d));
278 }
279
280 } // namespace detail
281
282 // ================================================== INIT
283
284 // ------------------------------ Set
285 // vector = f(d, scalar), e.g. Set
286 #define HWY_SVE_SET(BASE, CHAR, BITS, NAME, OP) \
287 template <size_t N> \
288 HWY_API HWY_SVE_V(BASE, BITS) \
289 NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_T(BASE, BITS) arg) { \
290 return sv##OP##_##CHAR##BITS(arg); \
291 }
292
HWY_SVE_FOREACH(HWY_SVE_SET,Set,dup_n)293 HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n)
294 #undef HWY_SVE_SET
295
296 // Required for Zero and VFromD
297 template <size_t N>
298 svuint16_t Set(Simd<bfloat16_t, N> d, bfloat16_t arg) {
299 return Set(RebindToUnsigned<decltype(d)>(), arg.bits);
300 }
301
302 template <class D>
303 using VFromD = decltype(Set(D(), TFromD<D>()));
304
305 // ------------------------------ Zero
306
307 template <class D>
Zero(D d)308 VFromD<D> Zero(D d) {
309 return Set(d, 0);
310 }
311
312 // ------------------------------ Undefined
313
314 #if defined(HWY_EMULATE_SVE)
315 template <class D>
Undefined(D d)316 VFromD<D> Undefined(D d) {
317 return Zero(d);
318 }
319 #else
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGD,Undefined,undef)320 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGD, Undefined, undef)
321 #endif
322
323 // ------------------------------ BitCast
324
325 namespace detail {
326
327 // u8: no change
328 #define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, NAME, OP) \
329 HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
330 return v; \
331 } \
332 template <size_t N> \
333 HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \
334 HWY_SVE_D(BASE, BITS, N) /* d */, HWY_SVE_V(BASE, BITS) v) { \
335 return v; \
336 }
337
338 // All other types
339 #define HWY_SVE_CAST(BASE, CHAR, BITS, NAME, OP) \
340 HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
341 return sv##OP##_u8_##CHAR##BITS(v); \
342 } \
343 template <size_t N> \
344 HWY_INLINE HWY_SVE_V(BASE, BITS) \
345 BitCastFromByte(HWY_SVE_D(BASE, BITS, N) /* d */, svuint8_t v) { \
346 return sv##OP##_##CHAR##BITS##_u8(v); \
347 }
348
349 HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _)
350 HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret)
351 HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret)
352 HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret)
353 HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret)
354 HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret)
355
356 #undef HWY_SVE_CAST_NOP
357 #undef HWY_SVE_CAST
358
359 template <size_t N>
360 HWY_INLINE svuint16_t BitCastFromByte(Simd<bfloat16_t, N> /* d */,
361 svuint8_t v) {
362 return BitCastFromByte(Simd<uint16_t, N>(), v);
363 }
364
365 } // namespace detail
366
367 template <class D, class FromV>
BitCast(D d,FromV v)368 HWY_API VFromD<D> BitCast(D d, FromV v) {
369 return detail::BitCastFromByte(d, detail::BitCastToByte(v));
370 }
371
372 // ================================================== LOGICAL
373
374 // detail::*N() functions accept a scalar argument to avoid extra Set().
375
376 // ------------------------------ Not
377
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV,Not,not)378 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not )
379
380 // ------------------------------ And
381
382 namespace detail {
383 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n)
384 } // namespace detail
385
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,And,and)386 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and)
387
388 template <class V, HWY_IF_FLOAT_V(V)>
389 HWY_API V And(const V a, const V b) {
390 const DFromV<V> df;
391 const RebindToUnsigned<decltype(df)> du;
392 return BitCast(df, And(BitCast(du, a), BitCast(du, b)));
393 }
394
395 // ------------------------------ Or
396
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,Or,orr)397 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr)
398
399 template <class V, HWY_IF_FLOAT_V(V)>
400 HWY_API V Or(const V a, const V b) {
401 const DFromV<V> df;
402 const RebindToUnsigned<decltype(df)> du;
403 return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
404 }
405
406 // ------------------------------ Xor
407
408 namespace detail {
409 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n)
410 } // namespace detail
411
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,Xor,eor)412 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor)
413
414 template <class V, HWY_IF_FLOAT_V(V)>
415 HWY_API V Xor(const V a, const V b) {
416 const DFromV<V> df;
417 const RebindToUnsigned<decltype(df)> du;
418 return BitCast(df, Xor(BitCast(du, a), BitCast(du, b)));
419 }
420
421 // ------------------------------ AndNot
422
423 namespace detail {
424 #define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, NAME, OP) \
425 HWY_API HWY_SVE_V(BASE, BITS) \
426 NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
427 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \
428 }
429
430 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n)
431 #undef HWY_SVE_RETV_ARGPVN_SWAP
432 } // namespace detail
433
434 #define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, NAME, OP) \
435 HWY_API HWY_SVE_V(BASE, BITS) \
436 NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
437 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \
438 }
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP,AndNot,bic)439 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic)
440 #undef HWY_SVE_RETV_ARGPVV_SWAP
441
442 template <class V, HWY_IF_FLOAT_V(V)>
443 HWY_API V AndNot(const V a, const V b) {
444 const DFromV<V> df;
445 const RebindToUnsigned<decltype(df)> du;
446 return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b)));
447 }
448
449 // ------------------------------ PopulationCount
450
451 #ifdef HWY_NATIVE_POPCNT
452 #undef HWY_NATIVE_POPCNT
453 #else
454 #define HWY_NATIVE_POPCNT
455 #endif
456
457 // Need to return original type instead of unsigned.
458 #define HWY_SVE_POPCNT(BASE, CHAR, BITS, NAME, OP) \
459 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
460 return BitCast(DFromV<decltype(v)>(), \
461 sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \
462 }
HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT,PopulationCount,cnt)463 HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt)
464 #undef HWY_SVE_POPCNT
465
466 // ================================================== SIGN
467
468 // ------------------------------ Neg
469 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg)
470
471 // ------------------------------ Abs
472 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
473
474 // ------------------------------ CopySign[ToAbs]
475
476 template <class V>
477 HWY_API V CopySign(const V magn, const V sign) {
478 const auto msb = SignBit(DFromV<V>());
479 return Or(AndNot(msb, magn), And(msb, sign));
480 }
481
482 template <class V>
CopySignToAbs(const V abs,const V sign)483 HWY_API V CopySignToAbs(const V abs, const V sign) {
484 const auto msb = SignBit(DFromV<V>());
485 return Or(abs, And(msb, sign));
486 }
487
488 // ================================================== ARITHMETIC
489
490 // ------------------------------ Add
491
492 namespace detail {
493 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n)
494 } // namespace detail
495
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV,Add,add)496 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add)
497
498 // ------------------------------ Sub
499
500 namespace detail {
501 // Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg.
502 #define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, NAME, OP) \
503 HWY_API HWY_SVE_V(BASE, BITS) \
504 NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
505 return sv##OP##_##CHAR##BITS##_z(pg, a, b); \
506 }
507
508 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n)
509 #undef HWY_SVE_RETV_ARGPVN_MASK
510 } // namespace detail
511
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV,Sub,sub)512 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub)
513
514 // ------------------------------ SaturatedAdd
515
516 HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
517 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
518
519 // ------------------------------ SaturatedSub
520
521 HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
522 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
523
524 // ------------------------------ AbsDiff
525 HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd)
526
527 // ------------------------------ ShiftLeft[Same]
528
529 #define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, NAME, OP) \
530 template <int kBits> \
531 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
532 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \
533 } \
534 HWY_API HWY_SVE_V(BASE, BITS) \
535 NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \
536 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits); \
537 }
538
539 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n)
540
541 // ------------------------------ ShiftRight[Same]
542
543 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n)
544 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)
545
546 #undef HWY_SVE_SHIFT_N
547
548 // ------------------------------ RotateRight
549
550 // TODO(janwas): svxar on SVE2
551 template <int kBits, class V>
552 HWY_API V RotateRight(const V v) {
553 constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8;
554 static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count");
555 if (kBits == 0) return v;
556 return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v));
557 }
558
559 // ------------------------------ Shl/r
560
561 #define HWY_SVE_SHIFT(BASE, CHAR, BITS, NAME, OP) \
562 HWY_API HWY_SVE_V(BASE, BITS) \
563 NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \
564 using TU = HWY_SVE_T(uint, BITS); \
565 return sv##OP##_##CHAR##BITS##_x( \
566 HWY_SVE_PTRUE(BITS), v, BitCast(Simd<TU, HWY_LANES(TU)>(), bits)); \
567 }
568
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT,Shl,lsl)569 HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl)
570
571 HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr)
572 HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr)
573
574 #undef HWY_SVE_SHIFT
575
576 // ------------------------------ Min/Max
577
578 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min)
579 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max)
580 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm)
581 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm)
582
583 namespace detail {
584 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n)
585 HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n)
586 } // namespace detail
587
588 // ------------------------------ Mul
HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV,Mul,mul)589 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, Mul, mul)
590 HWY_SVE_FOREACH_UIF3264(HWY_SVE_RETV_ARGPVV, Mul, mul)
591
592 // ------------------------------ MulHigh
593 HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
594 namespace detail {
595 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
596 HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
597 } // namespace detail
598
599 // ------------------------------ Div
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV,Div,div)600 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div)
601
602 // ------------------------------ ApproximateReciprocal
603 HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
604
605 // ------------------------------ Sqrt
606 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)
607
608 // ------------------------------ ApproximateReciprocalSqrt
609 HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte)
610
611 // ------------------------------ MulAdd
612 #define HWY_SVE_FMA(BASE, CHAR, BITS, NAME, OP) \
613 HWY_API HWY_SVE_V(BASE, BITS) \
614 NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \
615 HWY_SVE_V(BASE, BITS) add) { \
616 return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \
617 }
618
619 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulAdd, mad)
620
621 // ------------------------------ NegMulAdd
622 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulAdd, msb)
623
624 // ------------------------------ MulSub
625 HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb)
626
627 // ------------------------------ NegMulSub
628 HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)
629
630 #undef HWY_SVE_FMA
631
632 // ------------------------------ Round etc.
633
634 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
635 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm)
636 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp)
637 HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz)
638
639 // ================================================== MASK
640
641 // ------------------------------ RebindMask
642 template <class D, typename MFrom>
643 HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) {
644 return mask;
645 }
646
647 // ------------------------------ Mask logical
648
Not(svbool_t m)649 HWY_API svbool_t Not(svbool_t m) {
650 // We don't know the lane type, so assume 8-bit. For larger types, this will
651 // de-canonicalize the predicate, i.e. set bits to 1 even though they do not
652 // correspond to the lowest byte in the lane. Per ARM, such bits are ignored.
653 return svnot_b_z(HWY_SVE_PTRUE(8), m);
654 }
And(svbool_t a,svbool_t b)655 HWY_API svbool_t And(svbool_t a, svbool_t b) {
656 return svand_b_z(b, b, a); // same order as AndNot for consistency
657 }
AndNot(svbool_t a,svbool_t b)658 HWY_API svbool_t AndNot(svbool_t a, svbool_t b) {
659 return svbic_b_z(b, b, a); // reversed order like NEON
660 }
Or(svbool_t a,svbool_t b)661 HWY_API svbool_t Or(svbool_t a, svbool_t b) {
662 return svsel_b(a, a, b); // a ? true : b
663 }
Xor(svbool_t a,svbool_t b)664 HWY_API svbool_t Xor(svbool_t a, svbool_t b) {
665 return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b.
666 }
667
668 // ------------------------------ CountTrue
669
670 #define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, NAME, OP) \
671 template <size_t N> \
672 HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N) d, svbool_t m) { \
673 return sv##OP##_b##BITS(detail::Mask(d), m); \
674 }
675
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE,CountTrue,cntp)676 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp)
677 #undef HWY_SVE_COUNT_TRUE
678
679 // For 16-bit Compress: full vector, not limited to SV_POW2.
680 namespace detail {
681
682 #define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, NAME, OP) \
683 template <size_t N> \
684 HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N) d, svbool_t m) { \
685 return sv##OP##_b##BITS(svptrue_b##BITS(), m); \
686 }
687
688 HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp)
689 #undef HWY_SVE_COUNT_TRUE_FULL
690
691 } // namespace detail
692
693 // ------------------------------ AllFalse
694 template <typename T, size_t N>
AllFalse(Simd<T,N> d,svbool_t m)695 HWY_API bool AllFalse(Simd<T, N> d, svbool_t m) {
696 return !svptest_any(detail::Mask(d), m);
697 }
698
699 // ------------------------------ AllTrue
700 template <typename T, size_t N>
AllTrue(Simd<T,N> d,svbool_t m)701 HWY_API bool AllTrue(Simd<T, N> d, svbool_t m) {
702 return CountTrue(d, m) == Lanes(d);
703 }
704
705 // ------------------------------ FindFirstTrue
706 template <typename T, size_t N>
FindFirstTrue(Simd<T,N> d,svbool_t m)707 HWY_API intptr_t FindFirstTrue(Simd<T, N> d, svbool_t m) {
708 return AllFalse(d, m) ? -1 : CountTrue(d, svbrkb_b_z(detail::Mask(d), m));
709 }
710
711 // ------------------------------ IfThenElse
712 #define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, NAME, OP) \
713 HWY_API HWY_SVE_V(BASE, BITS) \
714 NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \
715 return sv##OP##_##CHAR##BITS(m, yes, no); \
716 }
717
HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE,IfThenElse,sel)718 HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel)
719 #undef HWY_SVE_IF_THEN_ELSE
720
721 // ------------------------------ IfThenElseZero
722 template <class M, class V>
723 HWY_API V IfThenElseZero(const M mask, const V yes) {
724 return IfThenElse(mask, yes, Zero(DFromV<V>()));
725 }
726
727 // ------------------------------ IfThenZeroElse
728 template <class M, class V>
IfThenZeroElse(const M mask,const V no)729 HWY_API V IfThenZeroElse(const M mask, const V no) {
730 return IfThenElse(mask, Zero(DFromV<V>()), no);
731 }
732
733 // ================================================== COMPARE
734
735 // mask = f(vector, vector)
736 #define HWY_SVE_COMPARE(BASE, CHAR, BITS, NAME, OP) \
737 HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
738 return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \
739 }
740 #define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, NAME, OP) \
741 HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
742 return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \
743 }
744
745 // ------------------------------ Eq
HWY_SVE_FOREACH(HWY_SVE_COMPARE,Eq,cmpeq)746 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq)
747
748 // ------------------------------ Ne
749 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne)
750
751 // ------------------------------ Lt
752 HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt)
753 namespace detail {
754 HWY_SVE_FOREACH_IF(HWY_SVE_COMPARE_N, LtN, cmplt_n)
755 } // namespace detail
756
757 // ------------------------------ Le
HWY_SVE_FOREACH_F(HWY_SVE_COMPARE,Le,cmple)758 HWY_SVE_FOREACH_F(HWY_SVE_COMPARE, Le, cmple)
759
760 #undef HWY_SVE_COMPARE
761 #undef HWY_SVE_COMPARE_N
762
763 // ------------------------------ Gt/Ge (swapped order)
764
765 template <class V>
766 HWY_API svbool_t Gt(const V a, const V b) {
767 return Lt(b, a);
768 }
769 template <class V>
Ge(const V a,const V b)770 HWY_API svbool_t Ge(const V a, const V b) {
771 return Le(b, a);
772 }
773
774 // ------------------------------ TestBit
775 template <class V>
TestBit(const V a,const V bit)776 HWY_API svbool_t TestBit(const V a, const V bit) {
777 return Ne(And(a, bit), Zero(DFromV<V>()));
778 }
779
780 // ------------------------------ MaskFromVec (Ne)
781 template <class V>
MaskFromVec(const V v)782 HWY_API svbool_t MaskFromVec(const V v) {
783 return Ne(v, Zero(DFromV<V>()));
784 }
785
786 // ------------------------------ VecFromMask
787
788 template <class D, HWY_IF_NOT_FLOAT_D(D)>
VecFromMask(const D d,svbool_t mask)789 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
790 const auto v0 = Zero(RebindToSigned<decltype(d)>());
791 return BitCast(d, detail::SubN(mask, v0, 1));
792 }
793
794 template <class D, HWY_IF_FLOAT_D(D)>
VecFromMask(const D d,svbool_t mask)795 HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
796 return BitCast(d, VecFromMask(RebindToUnsigned<D>(), mask));
797 }
798
799 // ================================================== MEMORY
800
801 // ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream
802
803 #define HWY_SVE_LOAD(BASE, CHAR, BITS, NAME, OP) \
804 template <size_t N> \
805 HWY_API HWY_SVE_V(BASE, BITS) \
806 NAME(HWY_SVE_D(BASE, BITS, N) d, \
807 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
808 return sv##OP##_##CHAR##BITS(detail::Mask(d), p); \
809 }
810
811 #define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, NAME, OP) \
812 template <size_t N> \
813 HWY_API HWY_SVE_V(BASE, BITS) \
814 NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N) d, \
815 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
816 return sv##OP##_##CHAR##BITS(m, p); \
817 }
818
819 #define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, NAME, OP) \
820 template <size_t N> \
821 HWY_API HWY_SVE_V(BASE, BITS) \
822 NAME(HWY_SVE_D(BASE, BITS, N) d, \
823 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
824 /* All-true predicate to load all 128 bits. */ \
825 return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \
826 }
827
828 #define HWY_SVE_STORE(BASE, CHAR, BITS, NAME, OP) \
829 template <size_t N> \
830 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \
831 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
832 sv##OP##_##CHAR##BITS(detail::Mask(d), p, v); \
833 }
834
835 #define HWY_SVE_MASKED_STORE(BASE, CHAR, BITS, NAME, OP) \
836 template <size_t N> \
837 HWY_API void NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v, \
838 HWY_SVE_D(BASE, BITS, N) d, \
839 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
840 sv##OP##_##CHAR##BITS(m, p, v); \
841 }
842
HWY_SVE_FOREACH(HWY_SVE_LOAD,Load,ld1)843 HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1)
844 HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1)
845 HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDup128, ld1rq)
846 HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1)
847 HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1)
848 HWY_SVE_FOREACH(HWY_SVE_MASKED_STORE, MaskedStore, st1)
849
850 #undef HWY_SVE_LOAD
851 #undef HWY_SVE_MASKED_LOAD
852 #undef HWY_SVE_LOAD_DUP128
853 #undef HWY_SVE_STORE
854 #undef HWY_SVE_MASKED_STORE
855
856 // BF16 is the same as svuint16_t because BF16 is optional before v8.6.
857 template <size_t N>
858 HWY_API svuint16_t Load(Simd<bfloat16_t, N> d,
859 const bfloat16_t* HWY_RESTRICT p) {
860 return Load(RebindToUnsigned<decltype(d)>(),
861 reinterpret_cast<const uint16_t * HWY_RESTRICT>(p));
862 }
863
864 template <size_t N>
Store(svuint16_t v,Simd<bfloat16_t,N> d,bfloat16_t * HWY_RESTRICT p)865 HWY_API void Store(svuint16_t v, Simd<bfloat16_t, N> d,
866 bfloat16_t* HWY_RESTRICT p) {
867 Store(v, RebindToUnsigned<decltype(d)>(),
868 reinterpret_cast<uint16_t * HWY_RESTRICT>(p));
869 }
870
871 // ------------------------------ Load/StoreU
872
873 // SVE only requires lane alignment, not natural alignment of the entire
874 // vector.
875 template <class D>
LoadU(D d,const TFromD<D> * HWY_RESTRICT p)876 HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) {
877 return Load(d, p);
878 }
879
880 template <class V, class D>
StoreU(const V v,D d,TFromD<D> * HWY_RESTRICT p)881 HWY_API void StoreU(const V v, D d, TFromD<D>* HWY_RESTRICT p) {
882 Store(v, d, p);
883 }
884
885 // ------------------------------ ScatterOffset/Index
886
887 #define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, NAME, OP) \
888 template <size_t N> \
889 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \
890 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
891 HWY_SVE_V(int, BITS) offset) { \
892 sv##OP##_s##BITS##offset_##CHAR##BITS(detail::Mask(d), base, offset, v); \
893 }
894
895 #define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, NAME, OP) \
896 template <size_t N> \
897 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N) d, \
898 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
899 HWY_SVE_V(int, BITS) index) { \
900 sv##OP##_s##BITS##index_##CHAR##BITS(detail::Mask(d), base, index, v); \
901 }
902
HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET,ScatterOffset,st1_scatter)903 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter)
904 HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter)
905 #undef HWY_SVE_SCATTER_OFFSET
906 #undef HWY_SVE_SCATTER_INDEX
907
908 // ------------------------------ GatherOffset/Index
909
910 #define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, NAME, OP) \
911 template <size_t N> \
912 HWY_API HWY_SVE_V(BASE, BITS) \
913 NAME(HWY_SVE_D(BASE, BITS, N) d, \
914 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
915 HWY_SVE_V(int, BITS) offset) { \
916 return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::Mask(d), base, \
917 offset); \
918 }
919 #define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, NAME, OP) \
920 template <size_t N> \
921 HWY_API HWY_SVE_V(BASE, BITS) \
922 NAME(HWY_SVE_D(BASE, BITS, N) d, \
923 const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
924 HWY_SVE_V(int, BITS) index) { \
925 return sv##OP##_s##BITS##index_##CHAR##BITS(detail::Mask(d), base, index); \
926 }
927
928 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather)
929 HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather)
930 #undef HWY_SVE_GATHER_OFFSET
931 #undef HWY_SVE_GATHER_INDEX
932
933 // ------------------------------ StoreInterleaved3
934
935 #define HWY_SVE_STORE3(BASE, CHAR, BITS, NAME, OP) \
936 template <size_t N> \
937 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
938 HWY_SVE_V(BASE, BITS) v2, HWY_SVE_D(BASE, BITS, N) d, \
939 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \
940 const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \
941 sv##OP##_##CHAR##BITS(detail::Mask(d), unaligned, triple); \
942 }
943 HWY_SVE_FOREACH_U08(HWY_SVE_STORE3, StoreInterleaved3, st3)
944
945 #undef HWY_SVE_STORE3
946
947 // ------------------------------ StoreInterleaved4
948
949 #define HWY_SVE_STORE4(BASE, CHAR, BITS, NAME, OP) \
950 template <size_t N> \
951 HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
952 HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \
953 HWY_SVE_D(BASE, BITS, N) d, \
954 HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \
955 const sv##BASE##BITS##x4_t quad = \
956 svcreate4##_##CHAR##BITS(v0, v1, v2, v3); \
957 sv##OP##_##CHAR##BITS(detail::Mask(d), unaligned, quad); \
958 }
959 HWY_SVE_FOREACH_U08(HWY_SVE_STORE4, StoreInterleaved4, st4)
960
961 #undef HWY_SVE_STORE4
962
963 // ================================================== CONVERT
964
965 // ------------------------------ PromoteTo
966
967 // Same sign
968 #define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, NAME, OP) \
969 template <size_t N> \
970 HWY_API HWY_SVE_V(BASE, BITS) \
971 NAME(HWY_SVE_D(BASE, BITS, N) /* tag */, \
972 VFromD<Simd<MakeNarrow<HWY_SVE_T(BASE, BITS)>, \
973 HWY_LANES(HWY_SVE_T(BASE, BITS)) * 2>> \
974 v) { \
975 return sv##OP##_##CHAR##BITS(v); \
976 }
977
978 HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
979 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
980 HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
981
982 // 2x
983 template <size_t N>
984 HWY_API svuint32_t PromoteTo(Simd<uint32_t, N> dto, svuint8_t vfrom) {
985 const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
986 return PromoteTo(dto, PromoteTo(d2, vfrom));
987 }
988 template <size_t N>
PromoteTo(Simd<int32_t,N> dto,svint8_t vfrom)989 HWY_API svint32_t PromoteTo(Simd<int32_t, N> dto, svint8_t vfrom) {
990 const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
991 return PromoteTo(dto, PromoteTo(d2, vfrom));
992 }
993 template <size_t N>
U32FromU8(svuint8_t v)994 HWY_API svuint32_t U32FromU8(svuint8_t v) {
995 return PromoteTo(Simd<uint32_t, N>(), v);
996 }
997
998 // Sign change
999 template <size_t N>
PromoteTo(Simd<int16_t,N> dto,svuint8_t vfrom)1000 HWY_API svint16_t PromoteTo(Simd<int16_t, N> dto, svuint8_t vfrom) {
1001 const RebindToUnsigned<decltype(dto)> du;
1002 return BitCast(dto, PromoteTo(du, vfrom));
1003 }
1004 template <size_t N>
PromoteTo(Simd<int32_t,N> dto,svuint16_t vfrom)1005 HWY_API svint32_t PromoteTo(Simd<int32_t, N> dto, svuint16_t vfrom) {
1006 const RebindToUnsigned<decltype(dto)> du;
1007 return BitCast(dto, PromoteTo(du, vfrom));
1008 }
1009 template <size_t N>
PromoteTo(Simd<int32_t,N> dto,svuint8_t vfrom)1010 HWY_API svint32_t PromoteTo(Simd<int32_t, N> dto, svuint8_t vfrom) {
1011 const Repartition<uint16_t, DFromV<decltype(vfrom)>> du16;
1012 const Repartition<int16_t, decltype(du16)> di16;
1013 return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom)));
1014 }
1015
1016 // ------------------------------ PromoteTo F
1017
1018 template <size_t N>
PromoteTo(Simd<float32_t,N>,const svfloat16_t v)1019 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N> /* d */, const svfloat16_t v) {
1020 return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N>()), v);
1021 }
1022
1023 template <size_t N>
PromoteTo(Simd<float64_t,N>,const svfloat32_t v)1024 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N> /* d */, const svfloat32_t v) {
1025 return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N>()), v);
1026 }
1027
1028 template <size_t N>
PromoteTo(Simd<float64_t,N>,const svint32_t v)1029 HWY_API svfloat64_t PromoteTo(Simd<float64_t, N> /* d */, const svint32_t v) {
1030 return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N>()), v);
1031 }
1032
1033 // For 16-bit Compress
1034 namespace detail {
HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO,PromoteUpperTo,unpkhi)1035 HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi)
1036 #undef HWY_SVE_PROMOTE_TO
1037
1038 template <size_t N>
1039 HWY_API svfloat32_t PromoteUpperTo(Simd<float, N> df, const svfloat16_t v) {
1040 const RebindToUnsigned<decltype(df)> du;
1041 const RepartitionToNarrow<decltype(du)> dn;
1042 return BitCast(df, PromoteUpperTo(du, BitCast(dn, v)));
1043 }
1044
1045 } // namespace detail
1046
1047 // ------------------------------ DemoteTo U
1048
1049 namespace detail {
1050
1051 // Saturates unsigned vectors to half/quarter-width TN.
1052 template <typename TN, class VU>
SaturateU(VU v)1053 VU SaturateU(VU v) {
1054 return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>()));
1055 }
1056
1057 // Saturates unsigned vectors to half/quarter-width TN.
1058 template <typename TN, class VI>
SaturateI(VI v)1059 VI SaturateI(VI v) {
1060 const DFromV<VI> di;
1061 return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>());
1062 }
1063
1064 } // namespace detail
1065
1066 template <size_t N>
DemoteTo(Simd<uint8_t,N> dn,const svint16_t v)1067 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N> dn, const svint16_t v) {
1068 const DFromV<decltype(v)> di;
1069 const RebindToUnsigned<decltype(di)> du;
1070 using TN = TFromD<decltype(dn)>;
1071 // First clamp negative numbers to zero and cast to unsigned.
1072 const svuint16_t clamped = BitCast(du, Max(Zero(di), v));
1073 // Saturate to unsigned-max and halve the width.
1074 const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
1075 return svuzp1_u8(vn, vn);
1076 }
1077
1078 template <size_t N>
DemoteTo(Simd<uint16_t,N> dn,const svint32_t v)1079 HWY_API svuint16_t DemoteTo(Simd<uint16_t, N> dn, const svint32_t v) {
1080 const DFromV<decltype(v)> di;
1081 const RebindToUnsigned<decltype(di)> du;
1082 using TN = TFromD<decltype(dn)>;
1083 // First clamp negative numbers to zero and cast to unsigned.
1084 const svuint32_t clamped = BitCast(du, Max(Zero(di), v));
1085 // Saturate to unsigned-max and halve the width.
1086 const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
1087 return svuzp1_u16(vn, vn);
1088 }
1089
1090 template <size_t N>
DemoteTo(Simd<uint8_t,N> dn,const svint32_t v)1091 HWY_API svuint8_t DemoteTo(Simd<uint8_t, N> dn, const svint32_t v) {
1092 const DFromV<decltype(v)> di;
1093 const RebindToUnsigned<decltype(di)> du;
1094 const RepartitionToNarrow<decltype(du)> d2;
1095 using TN = TFromD<decltype(dn)>;
1096 // First clamp negative numbers to zero and cast to unsigned.
1097 const svuint32_t clamped = BitCast(du, Max(Zero(di), v));
1098 // Saturate to unsigned-max and quarter the width.
1099 const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped));
1100 const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16));
1101 return svuzp1_u8(x2, x2);
1102 }
1103
U8FromU32(const svuint32_t v)1104 HWY_API svuint8_t U8FromU32(const svuint32_t v) {
1105 const DFromV<svuint32_t> du32;
1106 const RepartitionToNarrow<decltype(du32)> du16;
1107 const RepartitionToNarrow<decltype(du16)> du8;
1108
1109 const svuint16_t cast16 = BitCast(du16, v);
1110 const svuint16_t x2 = svuzp1_u16(cast16, cast16);
1111 const svuint8_t cast8 = BitCast(du8, x2);
1112 return svuzp1_u8(cast8, cast8);
1113 }
1114
1115 // ------------------------------ DemoteTo I
1116
1117 template <size_t N>
DemoteTo(Simd<int8_t,N> dn,const svint16_t v)1118 HWY_API svint8_t DemoteTo(Simd<int8_t, N> dn, const svint16_t v) {
1119 const DFromV<decltype(v)> di;
1120 using TN = TFromD<decltype(dn)>;
1121 #if HWY_TARGET == HWY_SVE2
1122 const svint8_t vn = BitCast(dn, svqxtnb_s16(v));
1123 #else
1124 const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v));
1125 #endif
1126 return svuzp1_s8(vn, vn);
1127 }
1128
1129 template <size_t N>
DemoteTo(Simd<int16_t,N> dn,const svint32_t v)1130 HWY_API svint16_t DemoteTo(Simd<int16_t, N> dn, const svint32_t v) {
1131 const DFromV<decltype(v)> di;
1132 using TN = TFromD<decltype(dn)>;
1133 #if HWY_TARGET == HWY_SVE2
1134 const svint16_t vn = BitCast(dn, svqxtnb_s32(v));
1135 #else
1136 const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v));
1137 #endif
1138 return svuzp1_s16(vn, vn);
1139 }
1140
1141 template <size_t N>
DemoteTo(Simd<int8_t,N> dn,const svint32_t v)1142 HWY_API svint8_t DemoteTo(Simd<int8_t, N> dn, const svint32_t v) {
1143 const DFromV<decltype(v)> di;
1144 using TN = TFromD<decltype(dn)>;
1145 const RepartitionToWide<decltype(dn)> d2;
1146 #if HWY_TARGET == HWY_SVE2
1147 const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v)));
1148 #else
1149 const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v));
1150 #endif
1151 const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16));
1152 return BitCast(dn, svuzp1_s8(v2, v2));
1153 }
1154
1155 // ------------------------------ ConcatEven/ConcatOdd
1156
1157 // WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the
1158 // full vector length, not rounded down to a power of two as we require).
1159 namespace detail {
1160
1161 #define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, NAME, OP) \
1162 HWY_INLINE HWY_SVE_V(BASE, BITS) \
1163 NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
1164 return sv##OP##_##CHAR##BITS(lo, hi); \
1165 }
1166 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEven, uzp1)
1167 HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOdd, uzp2)
1168 #undef HWY_SVE_CONCAT_EVERY_SECOND
1169
1170 // Used to slide up / shift whole register left; mask indicates which range
1171 // to take from lo, and the rest is filled from hi starting at its lowest.
1172 #define HWY_SVE_SPLICE(BASE, CHAR, BITS, NAME, OP) \
1173 HWY_API HWY_SVE_V(BASE, BITS) NAME( \
1174 HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \
1175 return sv##OP##_##CHAR##BITS(mask, lo, hi); \
1176 }
1177 HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice)
1178 #undef HWY_SVE_SPLICE
1179
1180 } // namespace detail
1181
1182 template <class D>
ConcatOdd(D d,VFromD<D> hi,VFromD<D> lo)1183 HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
1184 #if 0 // if we could assume VL is a power of two
1185 return detail::ConcatOdd(hi, lo);
1186 #else
1187 const VFromD<D> hi_odd = detail::ConcatOdd(hi, hi);
1188 const VFromD<D> lo_odd = detail::ConcatOdd(lo, lo);
1189 return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
1190 #endif
1191 }
1192
1193 template <class D>
ConcatEven(D d,VFromD<D> hi,VFromD<D> lo)1194 HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
1195 #if 0 // if we could assume VL is a power of two
1196 return detail::ConcatEven(hi, lo);
1197 #else
1198 const VFromD<D> hi_odd = detail::ConcatEven(hi, hi);
1199 const VFromD<D> lo_odd = detail::ConcatEven(lo, lo);
1200 return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
1201 #endif
1202 }
1203
1204 // ------------------------------ DemoteTo F
1205
1206 template <size_t N>
DemoteTo(Simd<float16_t,N> d,const svfloat32_t v)1207 HWY_API svfloat16_t DemoteTo(Simd<float16_t, N> d, const svfloat32_t v) {
1208 return svcvt_f16_f32_x(detail::PTrue(d), v);
1209 }
1210
1211 template <size_t N>
DemoteTo(Simd<bfloat16_t,N> d,const svfloat32_t v)1212 HWY_API svuint16_t DemoteTo(Simd<bfloat16_t, N> d, const svfloat32_t v) {
1213 const svuint16_t halves = BitCast(Full<uint16_t>(), v);
1214 return detail::ConcatOdd(halves, halves); // can ignore upper half of vec
1215 }
1216
1217 template <size_t N>
DemoteTo(Simd<float32_t,N> d,const svfloat64_t v)1218 HWY_API svfloat32_t DemoteTo(Simd<float32_t, N> d, const svfloat64_t v) {
1219 return svcvt_f32_f64_x(detail::PTrue(d), v);
1220 }
1221
1222 template <size_t N>
DemoteTo(Simd<int32_t,N> d,const svfloat64_t v)1223 HWY_API svint32_t DemoteTo(Simd<int32_t, N> d, const svfloat64_t v) {
1224 return svcvt_s32_f64_x(detail::PTrue(d), v);
1225 }
1226
1227 // ------------------------------ ConvertTo F
1228
1229 #define HWY_SVE_CONVERT(BASE, CHAR, BITS, NAME, OP) \
1230 template <size_t N> \
1231 HWY_API HWY_SVE_V(BASE, BITS) \
1232 NAME(HWY_SVE_D(BASE, BITS, N) /* d */, HWY_SVE_V(int, BITS) v) { \
1233 return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
1234 } \
1235 /* Truncates (rounds toward zero). */ \
1236 template <size_t N> \
1237 HWY_API HWY_SVE_V(int, BITS) \
1238 NAME(HWY_SVE_D(int, BITS, N) /* d */, HWY_SVE_V(BASE, BITS) v) { \
1239 return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
1240 }
1241
1242 // API only requires f32 but we provide f64 for use by Iota.
HWY_SVE_FOREACH_F(HWY_SVE_CONVERT,ConvertTo,cvt)1243 HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt)
1244 #undef HWY_SVE_CONVERT
1245
1246 // ------------------------------ NearestInt (Round, ConvertTo)
1247
1248 template <class VF, class DI = RebindToSigned<DFromV<VF>>>
1249 HWY_API VFromD<DI> NearestInt(VF v) {
1250 // No single instruction, round then truncate.
1251 return ConvertTo(DI(), Round(v));
1252 }
1253
1254 // ------------------------------ Iota (Add, ConvertTo)
1255
1256 #define HWY_SVE_IOTA(BASE, CHAR, BITS, NAME, OP) \
1257 template <size_t N> \
1258 HWY_API HWY_SVE_V(BASE, BITS) \
1259 NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_T(BASE, BITS) first) { \
1260 return sv##OP##_##CHAR##BITS(first, 1); \
1261 }
1262
HWY_SVE_FOREACH_UI(HWY_SVE_IOTA,Iota,index)1263 HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index)
1264 #undef HWY_SVE_IOTA
1265
1266 template <class D, HWY_IF_FLOAT_D(D)>
1267 HWY_API VFromD<D> Iota(const D d, TFromD<D> first) {
1268 const RebindToSigned<D> di;
1269 return detail::AddN(ConvertTo(d, Iota(di, 0)), first);
1270 }
1271
1272 // ================================================== COMBINE
1273
1274 namespace detail {
1275
1276 template <typename T, size_t N>
MaskLowerHalf(Simd<T,N> d)1277 svbool_t MaskLowerHalf(Simd<T, N> d) {
1278 return FirstN(d, Lanes(d) / 2);
1279 }
1280 template <typename T, size_t N>
MaskUpperHalf(Simd<T,N> d)1281 svbool_t MaskUpperHalf(Simd<T, N> d) {
1282 // For Splice to work as intended, make sure bits above Lanes(d) are zero.
1283 return AndNot(MaskLowerHalf(d), detail::Mask(d));
1284 }
1285
1286 // Right-shift vector pair by constexpr; can be used to slide down (=N) or up
1287 // (=Lanes()-N).
1288 #define HWY_SVE_EXT(BASE, CHAR, BITS, NAME, OP) \
1289 template <size_t kIndex> \
1290 HWY_API HWY_SVE_V(BASE, BITS) \
1291 NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
1292 return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \
1293 }
1294 HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext)
1295 #undef HWY_SVE_EXT
1296
1297 } // namespace detail
1298
1299 // ------------------------------ ConcatUpperLower
1300 template <class D, class V>
ConcatUpperLower(const D d,const V hi,const V lo)1301 HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) {
1302 return IfThenElse(detail::MaskLowerHalf(d), lo, hi);
1303 }
1304
1305 // ------------------------------ ConcatLowerLower
1306 template <class D, class V>
ConcatLowerLower(const D d,const V hi,const V lo)1307 HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) {
1308 return detail::Splice(hi, lo, detail::MaskLowerHalf(d));
1309 }
1310
1311 // ------------------------------ ConcatLowerUpper
1312 template <class D, class V>
ConcatLowerUpper(const D d,const V hi,const V lo)1313 HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) {
1314 return detail::Splice(hi, lo, detail::MaskUpperHalf(d));
1315 }
1316
1317 // ------------------------------ ConcatUpperUpper
1318 template <class D, class V>
ConcatUpperUpper(const D d,const V hi,const V lo)1319 HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) {
1320 const svbool_t mask_upper = detail::MaskUpperHalf(d);
1321 const V lo_upper = detail::Splice(lo, lo, mask_upper);
1322 return IfThenElse(mask_upper, hi, lo_upper);
1323 }
1324
1325 // ------------------------------ Combine
1326 template <class D, class V2>
Combine(const D d,const V2 hi,const V2 lo)1327 HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) {
1328 return ConcatLowerLower(d, hi, lo);
1329 }
1330
1331 // ------------------------------ ZeroExtendVector
1332
1333 template <class D, class V>
ZeroExtendVector(const D d,const V lo)1334 HWY_API V ZeroExtendVector(const D d, const V lo) {
1335 return Combine(d, Zero(Half<D>()), lo);
1336 }
1337
1338 // ------------------------------ Lower/UpperHalf
1339
1340 template <class D2, class V>
LowerHalf(D2,const V v)1341 HWY_API V LowerHalf(D2 /* tag */, const V v) {
1342 return v;
1343 }
1344
1345 template <class V>
LowerHalf(const V v)1346 HWY_API V LowerHalf(const V v) {
1347 return v;
1348 }
1349
1350 template <class D2, class V>
UpperHalf(const D2 d2,const V v)1351 HWY_API V UpperHalf(const D2 d2, const V v) {
1352 return detail::Splice(v, v, detail::MaskUpperHalf(Twice<D2>()));
1353 }
1354
1355 // ================================================== SWIZZLE
1356
1357 // ------------------------------ GetLane
1358
1359 #define HWY_SVE_GET_LANE(BASE, CHAR, BITS, NAME, OP) \
1360 HWY_API HWY_SVE_T(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
1361 return sv##OP##_##CHAR##BITS(detail::PFalse(), v); \
1362 }
1363
HWY_SVE_FOREACH(HWY_SVE_GET_LANE,GetLane,lasta)1364 HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLane, lasta)
1365 #undef HWY_SVE_GET_LANE
1366
1367 // ------------------------------ OddEven
1368
1369 namespace detail {
1370 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVN, Insert, insr_n)
1371 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1)
1372 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2)
1373 } // namespace detail
1374
1375 template <class V>
OddEven(const V odd,const V even)1376 HWY_API V OddEven(const V odd, const V even) {
1377 const auto even_in_odd = detail::Insert(even, 0);
1378 return detail::InterleaveOdd(even_in_odd, odd);
1379 }
1380
1381 // ------------------------------ OddEvenBlocks
1382 template <class V>
OddEvenBlocks(const V odd,const V even)1383 HWY_API V OddEvenBlocks(const V odd, const V even) {
1384 const RebindToUnsigned<DFromV<V>> du;
1385 constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV<V>));
1386 const auto idx_block = ShiftRight<kShift>(Iota(du, 0));
1387 const svbool_t is_even = Eq(detail::AndN(idx_block, 1), Zero(du));
1388 return IfThenElse(is_even, even, odd);
1389 }
1390
1391 // ------------------------------ SwapAdjacentBlocks
1392
1393 namespace detail {
1394
1395 template <typename T, size_t N>
LanesPerBlock(Simd<T,N>)1396 constexpr size_t LanesPerBlock(Simd<T, N> /* tag */) {
1397 // We might have a capped vector smaller than a block, so honor that.
1398 return HWY_MIN(16 / sizeof(T), N);
1399 }
1400
1401 } // namespace detail
1402
1403 template <class V>
SwapAdjacentBlocks(const V v)1404 HWY_API V SwapAdjacentBlocks(const V v) {
1405 const DFromV<V> d;
1406 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
1407 const V down = detail::Ext<kLanesPerBlock>(v, v);
1408 const V up = detail::Splice(v, v, FirstN(d, kLanesPerBlock));
1409 return OddEvenBlocks(up, down);
1410 }
1411
1412 // ------------------------------ TableLookupLanes
1413
1414 template <class D, class VI>
IndicesFromVec(D d,VI vec)1415 HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) {
1416 static_assert(sizeof(TFromD<D>) == sizeof(TFromV<VI>), "Index != lane");
1417 const RebindToUnsigned<D> du;
1418 const auto indices = BitCast(du, vec);
1419 #if HWY_IS_DEBUG_BUILD
1420 HWY_DASSERT(AllTrue(du, Lt(indices, Set(du, Lanes(d)))));
1421 #endif
1422 return indices;
1423 }
1424
1425 template <class D, typename TI>
SetTableIndices(D d,const TI * idx)1426 HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) {
1427 static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane");
1428 return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx));
1429 }
1430
1431 // <32bit are not part of Highway API, but used in Broadcast.
1432 #define HWY_SVE_TABLE(BASE, CHAR, BITS, NAME, OP) \
1433 HWY_API HWY_SVE_V(BASE, BITS) \
1434 NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \
1435 return sv##OP##_##CHAR##BITS(v, idx); \
1436 }
1437
HWY_SVE_FOREACH(HWY_SVE_TABLE,TableLookupLanes,tbl)1438 HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl)
1439 #undef HWY_SVE_TABLE
1440
1441 // ------------------------------ Reverse
1442
1443 #if 0 // if we could assume VL is a power of two
1444 #error "Update macro"
1445 #endif
1446 #define HWY_SVE_REVERSE(BASE, CHAR, BITS, NAME, OP) \
1447 template <size_t N> \
1448 HWY_API HWY_SVE_V(BASE, BITS) \
1449 NAME(Simd<HWY_SVE_T(BASE, BITS), N> d, HWY_SVE_V(BASE, BITS) v) { \
1450 const auto reversed = sv##OP##_##CHAR##BITS(v); \
1451 /* Shift right to remove extra (non-pow2 and remainder) lanes. */ \
1452 const size_t all_lanes = \
1453 detail::AllHardwareLanes(hwy::SizeTag<BITS / 8>()); \
1454 /* TODO(janwas): on SVE2, use whilege. */ \
1455 const svbool_t mask = Not(FirstN(d, all_lanes - Lanes(d))); \
1456 return detail::Splice(reversed, reversed, mask); \
1457 }
1458
1459 HWY_SVE_FOREACH(HWY_SVE_REVERSE, Reverse, rev)
1460 #undef HWY_SVE_REVERSE
1461
1462 // ------------------------------ Compress (PromoteTo)
1463
1464 #define HWY_SVE_COMPRESS(BASE, CHAR, BITS, NAME, OP) \
1465 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \
1466 return sv##OP##_##CHAR##BITS(mask, v); \
1467 }
1468
1469 HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact)
1470 #undef HWY_SVE_COMPRESS
1471
1472 template <class V, HWY_IF_LANE_SIZE_V(V, 2)>
1473 HWY_API V Compress(V v, svbool_t mask16) {
1474 static_assert(!IsSame<V, svfloat16_t>(), "Must use overload");
1475 const DFromV<V> d16;
1476
1477 // Promote vector and mask to 32-bit
1478 const RepartitionToWide<decltype(d16)> dw;
1479 const auto v32L = PromoteTo(dw, v);
1480 const auto v32H = detail::PromoteUpperTo(dw, v);
1481 const svbool_t mask32L = svunpklo_b(mask16);
1482 const svbool_t mask32H = svunpkhi_b(mask16);
1483
1484 const auto compressedL = Compress(v32L, mask32L);
1485 const auto compressedH = Compress(v32H, mask32H);
1486
1487 // Demote to 16-bit (already in range) - separately so we can splice
1488 const V evenL = BitCast(d16, compressedL);
1489 const V evenH = BitCast(d16, compressedH);
1490 const V v16L = detail::ConcatEven(evenL, evenL); // only lower half needed
1491 const V v16H = detail::ConcatEven(evenH, evenH);
1492
1493 // We need to combine two vectors of non-constexpr length, so the only option
1494 // is Splice, which requires us to synthesize a mask. NOTE: this function uses
1495 // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt.
1496 const size_t countL = detail::CountTrueFull(dw, mask32L);
1497 const auto compressed_maskL = FirstN(d16, countL);
1498 return detail::Splice(v16H, v16L, compressed_maskL);
1499 }
1500
1501 // Must treat float16_t as integers so we can ConcatEven.
Compress(svfloat16_t v,svbool_t mask16)1502 HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) {
1503 const DFromV<decltype(v)> df;
1504 const RebindToSigned<decltype(df)> di;
1505 return BitCast(df, Compress(BitCast(di, v), mask16));
1506 }
1507
1508 // ------------------------------ CompressStore
1509
1510 template <class V, class M, class D>
CompressStore(const V v,const M mask,const D d,TFromD<D> * HWY_RESTRICT unaligned)1511 HWY_API size_t CompressStore(const V v, const M mask, const D d,
1512 TFromD<D>* HWY_RESTRICT unaligned) {
1513 StoreU(Compress(v, mask), d, unaligned);
1514 return CountTrue(d, mask);
1515 }
1516
1517 // ------------------------------ CompressBlendedStore
1518
1519 template <class V, class M, class D>
CompressBlendedStore(const V v,const M mask,const D d,TFromD<D> * HWY_RESTRICT unaligned)1520 HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d,
1521 TFromD<D>* HWY_RESTRICT unaligned) {
1522 const size_t count = CountTrue(d, mask);
1523 const svbool_t store_mask = FirstN(d, count);
1524 MaskedStore(store_mask, Compress(v, mask), d, unaligned);
1525 return count;
1526 }
1527
1528 // ================================================== BLOCKWISE
1529
1530 // ------------------------------ CombineShiftRightBytes
1531
1532 namespace detail {
1533
1534 // For x86-compatible behaviour mandated by Highway API: TableLookupBytes
1535 // offsets are implicitly relative to the start of their 128-bit block.
1536 template <class D, class V>
OffsetsOf128BitBlocks(const D d,const V iota0)1537 HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) {
1538 using T = MakeUnsigned<TFromD<D>>;
1539 return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0);
1540 }
1541
1542 template <size_t kLanes, class D>
FirstNPerBlock(D d)1543 svbool_t FirstNPerBlock(D d) {
1544 const RebindToSigned<D> di;
1545 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di);
1546 const auto idx_mod = detail::AndN(Iota(di, 0), kLanesPerBlock - 1);
1547 return detail::LtN(BitCast(di, idx_mod), kLanes);
1548 }
1549
1550 } // namespace detail
1551
1552 template <size_t kBytes, class D, class V = VFromD<D>>
CombineShiftRightBytes(const D d,const V hi,const V lo)1553 HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) {
1554 const Repartition<uint8_t, decltype(d)> d8;
1555 const auto hi8 = BitCast(d8, hi);
1556 const auto lo8 = BitCast(d8, lo);
1557 const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes));
1558 const auto lo_down = detail::Ext<kBytes>(lo8, lo8);
1559 const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8);
1560 return BitCast(d, IfThenElse(is_lo, lo_down, hi_up));
1561 }
1562
1563 // ------------------------------ Shuffle2301
1564
1565 #define HWY_SVE_SHUFFLE_2301(BASE, CHAR, BITS, NAME, OP) \
1566 HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
1567 const DFromV<decltype(v)> d; \
1568 const svuint64_t vu64 = BitCast(Repartition<uint64_t, decltype(d)>(), v); \
1569 return BitCast(d, sv##OP##_u64_x(HWY_SVE_PTRUE(64), vu64)); \
1570 }
1571
HWY_SVE_FOREACH_UI32(HWY_SVE_SHUFFLE_2301,Shuffle2301,revw)1572 HWY_SVE_FOREACH_UI32(HWY_SVE_SHUFFLE_2301, Shuffle2301, revw)
1573 #undef HWY_SVE_SHUFFLE_2301
1574
1575 template <class V, HWY_IF_FLOAT_V(V)>
1576 HWY_API V Shuffle2301(const V v) {
1577 const DFromV<V> df;
1578 const RebindToUnsigned<decltype(df)> du;
1579 return BitCast(df, Shuffle2301(BitCast(du, v)));
1580 }
1581
1582 // ------------------------------ Shuffle2103
1583 template <class V>
Shuffle2103(const V v)1584 HWY_API V Shuffle2103(const V v) {
1585 const DFromV<V> d;
1586 const Repartition<uint8_t, decltype(d)> d8;
1587 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1588 const svuint8_t v8 = BitCast(d8, v);
1589 return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8));
1590 }
1591
1592 // ------------------------------ Shuffle0321
1593 template <class V>
Shuffle0321(const V v)1594 HWY_API V Shuffle0321(const V v) {
1595 const DFromV<V> d;
1596 const Repartition<uint8_t, decltype(d)> d8;
1597 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1598 const svuint8_t v8 = BitCast(d8, v);
1599 return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8));
1600 }
1601
1602 // ------------------------------ Shuffle1032
1603 template <class V>
Shuffle1032(const V v)1604 HWY_API V Shuffle1032(const V v) {
1605 const DFromV<V> d;
1606 const Repartition<uint8_t, decltype(d)> d8;
1607 static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
1608 const svuint8_t v8 = BitCast(d8, v);
1609 return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
1610 }
1611
1612 // ------------------------------ Shuffle01
1613 template <class V>
Shuffle01(const V v)1614 HWY_API V Shuffle01(const V v) {
1615 const DFromV<V> d;
1616 const Repartition<uint8_t, decltype(d)> d8;
1617 static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types");
1618 const svuint8_t v8 = BitCast(d8, v);
1619 return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
1620 }
1621
1622 // ------------------------------ Shuffle0123
1623 template <class V>
Shuffle0123(const V v)1624 HWY_API V Shuffle0123(const V v) {
1625 return Shuffle2301(Shuffle1032(v));
1626 }
1627
1628 // ------------------------------ TableLookupBytes
1629
1630 template <class V, class VI>
TableLookupBytes(const V v,const VI idx)1631 HWY_API VI TableLookupBytes(const V v, const VI idx) {
1632 const DFromV<VI> d;
1633 const Repartition<uint8_t, decltype(d)> du8;
1634 const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0));
1635 const auto idx8 = Add(BitCast(du8, idx), offsets128);
1636 return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8));
1637 }
1638
1639 template <class V, class VI>
TableLookupBytesOr0(const V v,const VI idx)1640 HWY_API VI TableLookupBytesOr0(const V v, const VI idx) {
1641 const DFromV<VI> d;
1642 // Mask size must match vector type, so cast everything to this type.
1643 const Repartition<int8_t, decltype(d)> di8;
1644
1645 auto idx8 = BitCast(di8, idx);
1646 const auto msb = Lt(idx8, Zero(di8));
1647 // Prevent overflow in table lookups (unnecessary if native)
1648 #if defined(HWY_EMULATE_SVE)
1649 idx8 = IfThenZeroElse(msb, idx8);
1650 #endif
1651
1652 const auto lookup = TableLookupBytes(BitCast(di8, v), idx8);
1653 return BitCast(d, IfThenZeroElse(msb, lookup));
1654 }
1655
1656 // ------------------------------ Broadcast
1657
1658 template <int kLane, class V>
Broadcast(const V v)1659 HWY_API V Broadcast(const V v) {
1660 const DFromV<V> d;
1661 const RebindToUnsigned<decltype(d)> du;
1662 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
1663 static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane");
1664 auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0));
1665 if (kLane != 0) {
1666 idx = detail::AddN(idx, kLane);
1667 }
1668 return TableLookupLanes(v, idx);
1669 }
1670
1671 // ------------------------------ ShiftLeftLanes
1672
1673 template <size_t kLanes, class D, class V = VFromD<D>>
ShiftLeftLanes(D d,const V v)1674 HWY_API V ShiftLeftLanes(D d, const V v) {
1675 const RebindToSigned<decltype(d)> di;
1676 const auto zero = Zero(d);
1677 const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes));
1678 // Match x86 semantics by zeroing lower lanes in 128-bit blocks
1679 return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted);
1680 }
1681
1682 template <size_t kLanes, class V>
ShiftLeftLanes(const V v)1683 HWY_API V ShiftLeftLanes(const V v) {
1684 return ShiftLeftLanes<kLanes>(DFromV<V>(), v);
1685 }
1686
1687 // ------------------------------ ShiftRightLanes
1688 template <size_t kLanes, typename T, size_t N, class V = VFromD<Simd<T, N>>>
ShiftRightLanes(Simd<T,N> d,V v)1689 HWY_API V ShiftRightLanes(Simd<T, N> d, V v) {
1690 const RebindToSigned<decltype(d)> di;
1691 // For partial vectors, clear upper lanes so we shift in zeros.
1692 if (N != HWY_LANES(T)) {
1693 v = IfThenElseZero(detail::Mask(d), v);
1694 }
1695
1696 const auto shifted = detail::Ext<kLanes>(v, v);
1697 // Match x86 semantics by zeroing upper lanes in 128-bit blocks
1698 constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
1699 const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d);
1700 return IfThenElseZero(mask, shifted);
1701 }
1702
1703 // ------------------------------ ShiftLeftBytes
1704
1705 template <int kBytes, class D, class V = VFromD<D>>
ShiftLeftBytes(const D d,const V v)1706 HWY_API V ShiftLeftBytes(const D d, const V v) {
1707 const Repartition<uint8_t, decltype(d)> d8;
1708 return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v)));
1709 }
1710
1711 template <int kBytes, class V>
ShiftLeftBytes(const V v)1712 HWY_API V ShiftLeftBytes(const V v) {
1713 return ShiftLeftBytes<kBytes>(DFromV<V>(), v);
1714 }
1715
1716 // ------------------------------ ShiftRightBytes
1717 template <int kBytes, class D, class V = VFromD<D>>
ShiftRightBytes(const D d,const V v)1718 HWY_API V ShiftRightBytes(const D d, const V v) {
1719 const Repartition<uint8_t, decltype(d)> d8;
1720 return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v)));
1721 }
1722
1723 // ------------------------------ InterleaveLower
1724
1725 namespace detail {
1726 HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLower, zip1)
1727 // Do not use zip2 to implement PromoteUpperTo or similar because vectors may be
1728 // non-powers of two, so getting the actual "upper half" requires MaskUpperHalf.
1729 } // namespace detail
1730
1731 template <class D, class V>
InterleaveLower(D d,const V a,const V b)1732 HWY_API V InterleaveLower(D d, const V a, const V b) {
1733 static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch");
1734 // Move lower halves of blocks to lower half of vector.
1735 const Repartition<uint64_t, decltype(d)> d64;
1736 const auto a64 = BitCast(d64, a);
1737 const auto b64 = BitCast(d64, b);
1738 const auto a_blocks = detail::ConcatEven(a64, a64); // only lower half needed
1739 const auto b_blocks = detail::ConcatEven(b64, b64);
1740
1741 return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks));
1742 }
1743
1744 template <class V>
InterleaveLower(const V a,const V b)1745 HWY_API V InterleaveLower(const V a, const V b) {
1746 return InterleaveLower(DFromV<V>(), a, b);
1747 }
1748
1749 // ------------------------------ InterleaveUpper
1750
1751 // Full vector: guaranteed to have at least one block
1752 template <typename T, class V = VFromD<Full<T>>>
InterleaveUpper(Simd<T,HWY_LANES (T)> d,const V a,const V b)1753 HWY_API V InterleaveUpper(Simd<T, HWY_LANES(T)> d, const V a, const V b) {
1754 // Move upper halves of blocks to lower half of vector.
1755 const Repartition<uint64_t, decltype(d)> d64;
1756 const auto a64 = BitCast(d64, a);
1757 const auto b64 = BitCast(d64, b);
1758 const auto a_blocks = detail::ConcatOdd(a64, a64); // only lower half needed
1759 const auto b_blocks = detail::ConcatOdd(b64, b64);
1760 return detail::ZipLower(BitCast(d, a_blocks), BitCast(d, b_blocks));
1761 }
1762
1763 // Capped: less than one block
1764 template <typename T, size_t N, HWY_IF_LE64(T, N), class V = VFromD<Simd<T, N>>>
InterleaveUpper(Simd<T,N> d,const V a,const V b)1765 HWY_API V InterleaveUpper(Simd<T, N> d, const V a, const V b) {
1766 static_assert(IsSame<T, TFromV<V>>(), "D/V mismatch");
1767 const Half<decltype(d)> d2;
1768 return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b));
1769 }
1770
1771 // Partial: need runtime check
1772 template <typename T, size_t N,
1773 hwy::EnableIf<(N < HWY_LANES(T) && N * sizeof(T) >= 16)>* = nullptr,
1774 class V = VFromD<Simd<T, N>>>
InterleaveUpper(Simd<T,N> d,const V a,const V b)1775 HWY_API V InterleaveUpper(Simd<T, N> d, const V a, const V b) {
1776 static_assert(IsSame<T, TFromV<V>>(), "D/V mismatch");
1777 // Less than one block: treat as capped
1778 if (Lanes(d) * sizeof(T) < 16) {
1779 const Half<decltype(d)> d2;
1780 return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b));
1781 }
1782 return InterleaveUpper(Full<T>(), a, b);
1783 }
1784
1785 // ------------------------------ ZipLower
1786
1787 template <class V, class DW = RepartitionToWide<DFromV<V>>>
ZipLower(DW dw,V a,V b)1788 HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) {
1789 const RepartitionToNarrow<DW> dn;
1790 static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
1791 return BitCast(dw, InterleaveLower(dn, a, b));
1792 }
1793 template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>>
ZipLower(const V a,const V b)1794 HWY_API VFromD<DW> ZipLower(const V a, const V b) {
1795 return BitCast(DW(), InterleaveLower(D(), a, b));
1796 }
1797
1798 // ------------------------------ ZipUpper
1799 template <class V, class DW = RepartitionToWide<DFromV<V>>>
ZipUpper(DW dw,V a,V b)1800 HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) {
1801 const RepartitionToNarrow<DW> dn;
1802 static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
1803 return BitCast(dw, InterleaveUpper(dn, a, b));
1804 }
1805
1806 // ================================================== REDUCE
1807
1808 #define HWY_SVE_REDUCE(BASE, CHAR, BITS, NAME, OP) \
1809 template <size_t N> \
1810 HWY_API HWY_SVE_V(BASE, BITS) \
1811 NAME(HWY_SVE_D(BASE, BITS, N) d, HWY_SVE_V(BASE, BITS) v) { \
1812 return Set(d, sv##OP##_##CHAR##BITS(detail::Mask(d), v)); \
1813 }
1814
HWY_SVE_FOREACH(HWY_SVE_REDUCE,SumOfLanes,addv)1815 HWY_SVE_FOREACH(HWY_SVE_REDUCE, SumOfLanes, addv)
1816 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanes, minv)
1817 HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanes, maxv)
1818 // NaN if all are
1819 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanes, minnmv)
1820 HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanes, maxnmv)
1821
1822 #undef HWY_SVE_REDUCE
1823
1824 // ================================================== Ops with dependencies
1825
1826 // ------------------------------ PromoteTo bfloat16 (ZipLower)
1827
1828 template <size_t N>
1829 HWY_API svfloat32_t PromoteTo(Simd<float32_t, N> df32, const svuint16_t v) {
1830 return BitCast(df32, detail::ZipLower(svdup_n_u16(0), v));
1831 }
1832
1833 // ------------------------------ ReorderDemote2To (OddEven)
1834
1835 template <size_t N>
ReorderDemote2To(Simd<bfloat16_t,N> dbf16,svfloat32_t a,svfloat32_t b)1836 HWY_API svuint16_t ReorderDemote2To(Simd<bfloat16_t, N> dbf16, svfloat32_t a,
1837 svfloat32_t b) {
1838 const RebindToUnsigned<decltype(dbf16)> du16;
1839 const Repartition<uint32_t, decltype(dbf16)> du32;
1840 const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b));
1841 return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even)));
1842 }
1843
1844 // ------------------------------ ZeroIfNegative (Lt, IfThenElse)
1845 template <class V>
ZeroIfNegative(const V v)1846 HWY_API V ZeroIfNegative(const V v) {
1847 const auto v0 = Zero(DFromV<V>());
1848 // We already have a zero constant, so avoid IfThenZeroElse.
1849 return IfThenElse(Lt(v, v0), v0, v);
1850 }
1851
1852 // ------------------------------ BroadcastSignBit (ShiftRight)
1853 template <class V>
BroadcastSignBit(const V v)1854 HWY_API V BroadcastSignBit(const V v) {
1855 return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v);
1856 }
1857
1858 // ------------------------------ AverageRound (ShiftRight)
1859
1860 #if HWY_TARGET == HWY_SVE2
HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV,AverageRound,rhadd)1861 HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
1862 HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
1863 #else
1864 template <class V>
1865 V AverageRound(const V a, const V b) {
1866 return ShiftRight<1>(Add(Add(a, b), Set(DFromV<V>(), 1)));
1867 }
1868 #endif // HWY_TARGET == HWY_SVE2
1869
1870 // ------------------------------ LoadMaskBits (TestBit)
1871
1872 // `p` points to at least 8 readable bytes, not all of which need be valid.
1873 template <class D, HWY_IF_LANE_SIZE_D(D, 1)>
1874 HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
1875 const RebindToUnsigned<D> du;
1876 const svuint8_t iota = Iota(du, 0);
1877
1878 // Load correct number of bytes (bits/8) with 7 zeros after each.
1879 const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits));
1880 // Replicate bytes 8x such that each byte contains the bit that governs it.
1881 const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota));
1882
1883 // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 ..
1884 const svuint8_t bit = Shl(Set(du, 1), detail::AndN(iota, 7));
1885
1886 return TestBit(rep8, bit);
1887 }
1888
1889 template <class D, HWY_IF_LANE_SIZE_D(D, 2)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1890 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1891 const uint8_t* HWY_RESTRICT bits) {
1892 const RebindToUnsigned<D> du;
1893 const Repartition<uint8_t, D> du8;
1894
1895 // There may be up to 128 bits; avoid reading past the end.
1896 const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits);
1897
1898 // Replicate bytes 16x such that each lane contains the bit that governs it.
1899 const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0)));
1900
1901 // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 ..
1902 const svuint16_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
1903
1904 return TestBit(BitCast(du, rep16), bit);
1905 }
1906
1907 template <class D, HWY_IF_LANE_SIZE_D(D, 4)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1908 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1909 const uint8_t* HWY_RESTRICT bits) {
1910 const RebindToUnsigned<D> du;
1911 const Repartition<uint8_t, D> du8;
1912
1913 // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable,
1914 // so we can skip computing the actual length (Lanes(du)+7)/8.
1915 const svuint8_t bytes = svld1(FirstN(du8, 8), bits);
1916
1917 // Replicate bytes 32x such that each lane contains the bit that governs it.
1918 const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0)));
1919
1920 // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 ..
1921 const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
1922
1923 return TestBit(BitCast(du, rep32), bit);
1924 }
1925
1926 template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
LoadMaskBits(D,const uint8_t * HWY_RESTRICT bits)1927 HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
1928 const uint8_t* HWY_RESTRICT bits) {
1929 const RebindToUnsigned<D> du;
1930
1931 // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane.
1932 // The "at least 8 byte" guarantee in quick_reference ensures this is safe.
1933 uint32_t mask_bits;
1934 CopyBytes<4>(bits, &mask_bits);
1935 const auto vbits = Set(du, mask_bits);
1936
1937 // 2 ^ {0,1, .., 31}, will not have more lanes than that.
1938 const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0));
1939
1940 return TestBit(vbits, bit);
1941 }
1942
1943 // ------------------------------ StoreMaskBits
1944
1945 namespace detail {
1946
1947 // Returns mask ? 1 : 0 in BYTE lanes.
1948 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)>
BoolFromMask(Simd<T,N> d,svbool_t m)1949 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1950 return svdup_n_u8_z(m, 1);
1951 }
1952 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)>
BoolFromMask(Simd<T,N> d,svbool_t m)1953 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1954 const Repartition<uint8_t, decltype(d)> d8;
1955 const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1));
1956 return detail::ConcatEven(b16, b16); // only lower half needed
1957 }
1958 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)>
BoolFromMask(Simd<T,N> d,svbool_t m)1959 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1960 return U8FromU32(svdup_n_u32_z(m, 1));
1961 }
1962 template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)>
BoolFromMask(Simd<T,N> d,svbool_t m)1963 HWY_API svuint8_t BoolFromMask(Simd<T, N> d, svbool_t m) {
1964 const Repartition<uint32_t, decltype(d)> d32;
1965 const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1));
1966 return U8FromU32(detail::ConcatEven(b64, b64)); // only lower half needed
1967 }
1968
1969 } // namespace detail
1970
1971 // `p` points to at least 8 writable bytes.
1972 template <typename T, size_t N>
StoreMaskBits(Simd<T,N> d,svbool_t m,uint8_t * bits)1973 HWY_API size_t StoreMaskBits(Simd<T, N> d, svbool_t m, uint8_t* bits) {
1974 const Repartition<uint8_t, decltype(d)> d8;
1975 const Repartition<uint16_t, decltype(d)> d16;
1976 const Repartition<uint32_t, decltype(d)> d32;
1977 const Repartition<uint64_t, decltype(d)> d64;
1978 auto x = detail::BoolFromMask(d, m);
1979 // Compact bytes to bits. Could use SVE2 BDEP, but it's optional.
1980 x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x))));
1981 x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x))));
1982 x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x))));
1983
1984 const size_t num_bits = Lanes(d);
1985 const size_t num_bytes = (num_bits + 8 - 1) / 8; // Round up, see below
1986
1987 // Truncate to 8 bits and store.
1988 svst1b_u64(FirstN(d64, num_bytes), bits, BitCast(d64, x));
1989
1990 // Non-full byte, need to clear the undefined upper bits. Can happen for
1991 // capped/partial vectors or large T and small hardware vectors.
1992 if (num_bits < 8) {
1993 const int mask = (1 << num_bits) - 1;
1994 bits[0] = static_cast<uint8_t>(bits[0] & mask);
1995 }
1996 // Else: we wrote full bytes because num_bits is a power of two >= 8.
1997
1998 return num_bytes;
1999 }
2000
2001 // ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits)
2002
2003 template <class V>
CompressBits(V v,const uint8_t * HWY_RESTRICT bits)2004 HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) {
2005 return Compress(v, LoadMaskBits(DFromV<V>(), bits));
2006 }
2007
2008 template <class D>
CompressBitsStore(VFromD<D> v,const uint8_t * HWY_RESTRICT bits,D d,TFromD<D> * HWY_RESTRICT unaligned)2009 HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits,
2010 D d, TFromD<D>* HWY_RESTRICT unaligned) {
2011 return CompressStore(v, LoadMaskBits(d, bits), d, unaligned);
2012 }
2013
2014 // ------------------------------ MulEven (InterleaveEven)
2015
2016 #if HWY_TARGET == HWY_SVE2
2017 namespace detail {
2018 HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulEven, mullb)
2019 } // namespace detail
2020 #endif
2021
2022 template <class V, class DW = RepartitionToWide<DFromV<V>>>
MulEven(const V a,const V b)2023 HWY_API VFromD<DW> MulEven(const V a, const V b) {
2024 #if HWY_TARGET == HWY_SVE2
2025 return BitCast(DW(), detail::MulEven(a, b));
2026 #else
2027 const auto lo = Mul(a, b);
2028 const auto hi = detail::MulHigh(a, b);
2029 return BitCast(DW(), detail::InterleaveEven(lo, hi));
2030 #endif
2031 }
2032
MulEven(const svuint64_t a,const svuint64_t b)2033 HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) {
2034 const auto lo = Mul(a, b);
2035 const auto hi = detail::MulHigh(a, b);
2036 return detail::InterleaveEven(lo, hi);
2037 }
2038
MulOdd(const svuint64_t a,const svuint64_t b)2039 HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) {
2040 const auto lo = Mul(a, b);
2041 const auto hi = detail::MulHigh(a, b);
2042 return detail::InterleaveOdd(lo, hi);
2043 }
2044
2045 // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)
2046
2047 template <size_t N>
ReorderWidenMulAccumulate(Simd<float,N> df32,svuint16_t a,svuint16_t b,const svfloat32_t sum0,svfloat32_t & sum1)2048 HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd<float, N> df32, svuint16_t a,
2049 svuint16_t b,
2050 const svfloat32_t sum0,
2051 svfloat32_t& sum1) {
2052 // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16.
2053 const Repartition<uint16_t, decltype(df32)> du16;
2054 const RebindToUnsigned<decltype(df32)> du32;
2055 const svuint16_t zero = Zero(du16);
2056 const svuint32_t a0 = ZipLower(du32, zero, BitCast(du16, a));
2057 const svuint32_t a1 = ZipUpper(du32, zero, BitCast(du16, a));
2058 const svuint32_t b0 = ZipLower(du32, zero, BitCast(du16, b));
2059 const svuint32_t b1 = ZipUpper(du32, zero, BitCast(du16, b));
2060 sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1);
2061 return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0);
2062 }
2063
2064 // ------------------------------ AESRound / CLMul
2065
2066 #if defined(__ARM_FEATURE_SVE2_AES)
2067
2068 // Per-target flag to prevent generic_ops-inl.h from defining AESRound.
2069 #ifdef HWY_NATIVE_AES
2070 #undef HWY_NATIVE_AES
2071 #else
2072 #define HWY_NATIVE_AES
2073 #endif
2074
AESRound(svuint8_t state,svuint8_t round_key)2075 HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) {
2076 // NOTE: it is important that AESE and AESMC be consecutive instructions so
2077 // they can be fused. AESE includes AddRoundKey, which is a different ordering
2078 // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual
2079 // round key (the compiler will hopefully optimize this for multiple rounds).
2080 const svuint8_t zero = svdup_n_u8(0);
2081 return Xor(vaesmcq_u8(vaeseq_u8(state, zero), round_key));
2082 }
2083
CLMulLower(const svuint64_t a,const svuint64_t b)2084 HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) {
2085 return svpmullb_pair(a, b);
2086 }
2087
CLMulUpper(const svuint64_t a,const svuint64_t b)2088 HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) {
2089 return svpmullt_pair(a, b);
2090 }
2091
2092 #endif // __ARM_FEATURE_SVE2_AES
2093
2094 // ================================================== END MACROS
2095 namespace detail { // for code folding
2096 #undef HWY_IF_FLOAT_V
2097 #undef HWY_IF_LANE_SIZE_V
2098 #undef HWY_IF_SIGNED_V
2099 #undef HWY_IF_UNSIGNED_V
2100 #undef HWY_SVE_D
2101 #undef HWY_SVE_FOREACH
2102 #undef HWY_SVE_FOREACH_F
2103 #undef HWY_SVE_FOREACH_F16
2104 #undef HWY_SVE_FOREACH_F32
2105 #undef HWY_SVE_FOREACH_F64
2106 #undef HWY_SVE_FOREACH_I
2107 #undef HWY_SVE_FOREACH_I08
2108 #undef HWY_SVE_FOREACH_I16
2109 #undef HWY_SVE_FOREACH_I32
2110 #undef HWY_SVE_FOREACH_I64
2111 #undef HWY_SVE_FOREACH_IF
2112 #undef HWY_SVE_FOREACH_U
2113 #undef HWY_SVE_FOREACH_U08
2114 #undef HWY_SVE_FOREACH_U16
2115 #undef HWY_SVE_FOREACH_U32
2116 #undef HWY_SVE_FOREACH_U64
2117 #undef HWY_SVE_FOREACH_UI
2118 #undef HWY_SVE_FOREACH_UI08
2119 #undef HWY_SVE_FOREACH_UI16
2120 #undef HWY_SVE_FOREACH_UI32
2121 #undef HWY_SVE_FOREACH_UI64
2122 #undef HWY_SVE_FOREACH_UIF3264
2123 #undef HWY_SVE_PTRUE
2124 #undef HWY_SVE_RETV_ARGD
2125 #undef HWY_SVE_RETV_ARGPV
2126 #undef HWY_SVE_RETV_ARGPVN
2127 #undef HWY_SVE_RETV_ARGPVV
2128 #undef HWY_SVE_RETV_ARGV
2129 #undef HWY_SVE_RETV_ARGVN
2130 #undef HWY_SVE_RETV_ARGVV
2131 #undef HWY_SVE_T
2132 #undef HWY_SVE_V
2133
2134 } // namespace detail
2135 // NOLINTNEXTLINE(google-readability-namespace-comments)
2136 } // namespace HWY_NAMESPACE
2137 } // namespace hwy
2138 HWY_AFTER_NAMESPACE();
2139