1 // Copyright 2019 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // 512-bit AVX512 vectors and operations. 16 // External include guard in highway.h - see comment there. 17 18 // WARNING: most operations do not cross 128-bit block boundaries. In 19 // particular, "Broadcast", pack and zip behavior may be surprising. 20 21 #include <immintrin.h> // AVX2+ 22 #if defined(_MSC_VER) && defined(__clang__) 23 // Including <immintrin.h> should be enough, but Clang's headers helpfully skip 24 // including these headers when _MSC_VER is defined, like when using clang-cl. 25 // Include these directly here. 26 #include <smmintrin.h> 27 #include <avxintrin.h> 28 #include <avx2intrin.h> 29 #include <f16cintrin.h> 30 #include <fmaintrin.h> 31 #include <avx512fintrin.h> 32 #include <avx512vlintrin.h> 33 #include <avx512bwintrin.h> 34 #include <avx512dqintrin.h> 35 #include <avx512vlbwintrin.h> 36 #include <avx512vldqintrin.h> 37 #endif 38 39 #include <stddef.h> 40 #include <stdint.h> 41 42 // For half-width vectors. Already includes base.h and shared-inl.h. 43 #include "hwy/ops/x86_256-inl.h" 44 45 HWY_BEFORE_NAMESPACE(); 46 namespace hwy { 47 namespace HWY_NAMESPACE { 48 49 template <typename T> 50 struct Raw512 { 51 using type = __m512i; 52 }; 53 template <> 54 struct Raw512<float> { 55 using type = __m512; 56 }; 57 template <> 58 struct Raw512<double> { 59 using type = __m512d; 60 }; 61 62 template <typename T> 63 using Full512 = Simd<T, 64 / sizeof(T)>; 64 65 template <typename T> 66 class Vec512 { 67 using Raw = typename Raw512<T>::type; 68 69 public: 70 // Compound assignment. Only usable if there is a corresponding non-member 71 // binary operator overload. For example, only f32 and f64 support division. 72 HWY_INLINE Vec512& operator*=(const Vec512 other) { 73 return *this = (*this * other); 74 } 75 HWY_INLINE Vec512& operator/=(const Vec512 other) { 76 return *this = (*this / other); 77 } 78 HWY_INLINE Vec512& operator+=(const Vec512 other) { 79 return *this = (*this + other); 80 } 81 HWY_INLINE Vec512& operator-=(const Vec512 other) { 82 return *this = (*this - other); 83 } 84 HWY_INLINE Vec512& operator&=(const Vec512 other) { 85 return *this = (*this & other); 86 } 87 HWY_INLINE Vec512& operator|=(const Vec512 other) { 88 return *this = (*this | other); 89 } 90 HWY_INLINE Vec512& operator^=(const Vec512 other) { 91 return *this = (*this ^ other); 92 } 93 94 Raw raw; 95 }; 96 97 // Template arg: sizeof(lane type) 98 template <size_t size> 99 struct RawMask512 {}; 100 template <> 101 struct RawMask512<1> { 102 using type = __mmask64; 103 }; 104 template <> 105 struct RawMask512<2> { 106 using type = __mmask32; 107 }; 108 template <> 109 struct RawMask512<4> { 110 using type = __mmask16; 111 }; 112 template <> 113 struct RawMask512<8> { 114 using type = __mmask8; 115 }; 116 117 // Mask register: one bit per lane. 118 template <typename T> 119 class Mask512 { 120 public: 121 using Raw = typename RawMask512<sizeof(T)>::type; 122 Raw raw; 123 }; 124 125 // ------------------------------ BitCast 126 127 namespace detail { 128 129 HWY_API __m512i BitCastToInteger(__m512i v) { return v; } 130 HWY_API __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } 131 HWY_API __m512i BitCastToInteger(__m512d v) { return _mm512_castpd_si512(v); } 132 133 template <typename T> 134 HWY_API Vec512<uint8_t> BitCastToByte(Vec512<T> v) { 135 return Vec512<uint8_t>{BitCastToInteger(v.raw)}; 136 } 137 138 // Cannot rely on function overloading because return types differ. 139 template <typename T> 140 struct BitCastFromInteger512 { 141 HWY_INLINE __m512i operator()(__m512i v) { return v; } 142 }; 143 template <> 144 struct BitCastFromInteger512<float> { 145 HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } 146 }; 147 template <> 148 struct BitCastFromInteger512<double> { 149 HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } 150 }; 151 152 template <typename T> 153 HWY_API Vec512<T> BitCastFromByte(Full512<T> /* tag */, Vec512<uint8_t> v) { 154 return Vec512<T>{BitCastFromInteger512<T>()(v.raw)}; 155 } 156 157 } // namespace detail 158 159 template <typename T, typename FromT> 160 HWY_API Vec512<T> BitCast(Full512<T> d, Vec512<FromT> v) { 161 return detail::BitCastFromByte(d, detail::BitCastToByte(v)); 162 } 163 164 // ------------------------------ Set 165 166 // Returns an all-zero vector. 167 template <typename T> 168 HWY_API Vec512<T> Zero(Full512<T> /* tag */) { 169 return Vec512<T>{_mm512_setzero_si512()}; 170 } 171 HWY_API Vec512<float> Zero(Full512<float> /* tag */) { 172 return Vec512<float>{_mm512_setzero_ps()}; 173 } 174 HWY_API Vec512<double> Zero(Full512<double> /* tag */) { 175 return Vec512<double>{_mm512_setzero_pd()}; 176 } 177 178 // Returns a vector with all lanes set to "t". 179 HWY_API Vec512<uint8_t> Set(Full512<uint8_t> /* tag */, const uint8_t t) { 180 return Vec512<uint8_t>{_mm512_set1_epi8(static_cast<char>(t))}; // NOLINT 181 } 182 HWY_API Vec512<uint16_t> Set(Full512<uint16_t> /* tag */, const uint16_t t) { 183 return Vec512<uint16_t>{_mm512_set1_epi16(static_cast<short>(t))}; // NOLINT 184 } 185 HWY_API Vec512<uint32_t> Set(Full512<uint32_t> /* tag */, const uint32_t t) { 186 return Vec512<uint32_t>{_mm512_set1_epi32(static_cast<int>(t))}; 187 } 188 HWY_API Vec512<uint64_t> Set(Full512<uint64_t> /* tag */, const uint64_t t) { 189 return Vec512<uint64_t>{ 190 _mm512_set1_epi64(static_cast<long long>(t))}; // NOLINT 191 } 192 HWY_API Vec512<int8_t> Set(Full512<int8_t> /* tag */, const int8_t t) { 193 return Vec512<int8_t>{_mm512_set1_epi8(static_cast<char>(t))}; // NOLINT 194 } 195 HWY_API Vec512<int16_t> Set(Full512<int16_t> /* tag */, const int16_t t) { 196 return Vec512<int16_t>{_mm512_set1_epi16(static_cast<short>(t))}; // NOLINT 197 } 198 HWY_API Vec512<int32_t> Set(Full512<int32_t> /* tag */, const int32_t t) { 199 return Vec512<int32_t>{_mm512_set1_epi32(t)}; 200 } 201 HWY_API Vec512<int64_t> Set(Full512<int64_t> /* tag */, const int64_t t) { 202 return Vec512<int64_t>{ 203 _mm512_set1_epi64(static_cast<long long>(t))}; // NOLINT 204 } 205 HWY_API Vec512<float> Set(Full512<float> /* tag */, const float t) { 206 return Vec512<float>{_mm512_set1_ps(t)}; 207 } 208 HWY_API Vec512<double> Set(Full512<double> /* tag */, const double t) { 209 return Vec512<double>{_mm512_set1_pd(t)}; 210 } 211 212 HWY_DIAGNOSTICS(push) 213 HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") 214 215 // Returns a vector with uninitialized elements. 216 template <typename T> 217 HWY_API Vec512<T> Undefined(Full512<T> /* tag */) { 218 // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC 219 // generate an XOR instruction. 220 return Vec512<T>{_mm512_undefined_epi32()}; 221 } 222 HWY_API Vec512<float> Undefined(Full512<float> /* tag */) { 223 return Vec512<float>{_mm512_undefined_ps()}; 224 } 225 HWY_API Vec512<double> Undefined(Full512<double> /* tag */) { 226 return Vec512<double>{_mm512_undefined_pd()}; 227 } 228 229 HWY_DIAGNOSTICS(pop) 230 231 // ================================================== LOGICAL 232 233 // ------------------------------ Not 234 235 template <typename T> 236 HWY_API Vec512<T> Not(const Vec512<T> v) { 237 using TU = MakeUnsigned<T>; 238 const __m512i vu = BitCast(Full512<TU>(), v).raw; 239 return BitCast(Full512<T>(), 240 Vec512<TU>{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); 241 } 242 243 // ------------------------------ And 244 245 template <typename T> 246 HWY_API Vec512<T> And(const Vec512<T> a, const Vec512<T> b) { 247 return Vec512<T>{_mm512_and_si512(a.raw, b.raw)}; 248 } 249 250 HWY_API Vec512<float> And(const Vec512<float> a, const Vec512<float> b) { 251 return Vec512<float>{_mm512_and_ps(a.raw, b.raw)}; 252 } 253 HWY_API Vec512<double> And(const Vec512<double> a, const Vec512<double> b) { 254 return Vec512<double>{_mm512_and_pd(a.raw, b.raw)}; 255 } 256 257 // ------------------------------ AndNot 258 259 // Returns ~not_mask & mask. 260 template <typename T> 261 HWY_API Vec512<T> AndNot(const Vec512<T> not_mask, const Vec512<T> mask) { 262 return Vec512<T>{_mm512_andnot_si512(not_mask.raw, mask.raw)}; 263 } 264 HWY_API Vec512<float> AndNot(const Vec512<float> not_mask, 265 const Vec512<float> mask) { 266 return Vec512<float>{_mm512_andnot_ps(not_mask.raw, mask.raw)}; 267 } 268 HWY_API Vec512<double> AndNot(const Vec512<double> not_mask, 269 const Vec512<double> mask) { 270 return Vec512<double>{_mm512_andnot_pd(not_mask.raw, mask.raw)}; 271 } 272 273 // ------------------------------ Or 274 275 template <typename T> 276 HWY_API Vec512<T> Or(const Vec512<T> a, const Vec512<T> b) { 277 return Vec512<T>{_mm512_or_si512(a.raw, b.raw)}; 278 } 279 280 HWY_API Vec512<float> Or(const Vec512<float> a, const Vec512<float> b) { 281 return Vec512<float>{_mm512_or_ps(a.raw, b.raw)}; 282 } 283 HWY_API Vec512<double> Or(const Vec512<double> a, const Vec512<double> b) { 284 return Vec512<double>{_mm512_or_pd(a.raw, b.raw)}; 285 } 286 287 // ------------------------------ Xor 288 289 template <typename T> 290 HWY_API Vec512<T> Xor(const Vec512<T> a, const Vec512<T> b) { 291 return Vec512<T>{_mm512_xor_si512(a.raw, b.raw)}; 292 } 293 294 HWY_API Vec512<float> Xor(const Vec512<float> a, const Vec512<float> b) { 295 return Vec512<float>{_mm512_xor_ps(a.raw, b.raw)}; 296 } 297 HWY_API Vec512<double> Xor(const Vec512<double> a, const Vec512<double> b) { 298 return Vec512<double>{_mm512_xor_pd(a.raw, b.raw)}; 299 } 300 301 // ------------------------------ Operator overloads (internal-only if float) 302 303 template <typename T> 304 HWY_API Vec512<T> operator&(const Vec512<T> a, const Vec512<T> b) { 305 return And(a, b); 306 } 307 308 template <typename T> 309 HWY_API Vec512<T> operator|(const Vec512<T> a, const Vec512<T> b) { 310 return Or(a, b); 311 } 312 313 template <typename T> 314 HWY_API Vec512<T> operator^(const Vec512<T> a, const Vec512<T> b) { 315 return Xor(a, b); 316 } 317 318 // ------------------------------ CopySign 319 320 template <typename T> 321 HWY_API Vec512<T> CopySign(const Vec512<T> magn, const Vec512<T> sign) { 322 static_assert(IsFloat<T>(), "Only makes sense for floating-point"); 323 324 const Full512<T> d; 325 const auto msb = SignBit(d); 326 327 const Rebind<MakeUnsigned<T>, decltype(d)> du; 328 // Truth table for msb, magn, sign | bitwise msb ? sign : mag 329 // 0 0 0 | 0 330 // 0 0 1 | 0 331 // 0 1 0 | 1 332 // 0 1 1 | 1 333 // 1 0 0 | 0 334 // 1 0 1 | 1 335 // 1 1 0 | 0 336 // 1 1 1 | 1 337 // The lane size does not matter because we are not using predication. 338 const __m512i out = _mm512_ternarylogic_epi32( 339 BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); 340 return BitCast(d, decltype(Zero(du)){out}); 341 } 342 343 template <typename T> 344 HWY_API Vec512<T> CopySignToAbs(const Vec512<T> abs, const Vec512<T> sign) { 345 // AVX3 can also handle abs < 0, so no extra action needed. 346 return CopySign(abs, sign); 347 } 348 349 // ------------------------------ FirstN 350 351 // Possibilities for constructing a bitmask of N ones: 352 // - kshift* only consider the lowest byte of the shift count, so they would 353 // not correctly handle large n. 354 // - Scalar shifts >= 64 are UB. 355 // - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, 356 // we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. 357 358 #if HWY_ARCH_X86_32 359 namespace detail { 360 361 // 32 bit mask is sufficient for lane size >= 2. 362 template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> 363 HWY_API Mask512<T> FirstN(size_t n) { 364 using Bits = typename Mask512<T>::Raw; 365 return Mask512<T>{static_cast<Bits>(_bzhi_u32(~uint32_t(0), n))}; 366 } 367 368 template <typename T, HWY_IF_LANE_SIZE(T, 1)> 369 HWY_API Mask512<T> FirstN(size_t n) { 370 const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t(0); 371 return Mask512<T>{static_cast<__mmask64>(bits)}; 372 } 373 374 } // namespace detail 375 #endif // HWY_ARCH_X86_32 376 377 template <typename T> 378 HWY_API Mask512<T> FirstN(const Full512<T> /*tag*/, size_t n) { 379 #if HWY_ARCH_X86_64 380 using Bits = typename Mask512<T>::Raw; 381 return Mask512<T>{static_cast<Bits>(_bzhi_u64(~uint64_t(0), n))}; 382 #else 383 return detail::FirstN<T>(n); 384 #endif // HWY_ARCH_X86_64 385 } 386 387 // ------------------------------ IfThenElse 388 389 // Returns mask ? b : a. 390 391 namespace detail { 392 393 // Templates for signed/unsigned integer of a particular size. 394 template <typename T> 395 HWY_API Vec512<T> IfThenElse(hwy::SizeTag<1> /* tag */, const Mask512<T> mask, 396 const Vec512<T> yes, const Vec512<T> no) { 397 return Vec512<T>{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; 398 } 399 template <typename T> 400 HWY_API Vec512<T> IfThenElse(hwy::SizeTag<2> /* tag */, const Mask512<T> mask, 401 const Vec512<T> yes, const Vec512<T> no) { 402 return Vec512<T>{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; 403 } 404 template <typename T> 405 HWY_API Vec512<T> IfThenElse(hwy::SizeTag<4> /* tag */, const Mask512<T> mask, 406 const Vec512<T> yes, const Vec512<T> no) { 407 return Vec512<T>{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; 408 } 409 template <typename T> 410 HWY_API Vec512<T> IfThenElse(hwy::SizeTag<8> /* tag */, const Mask512<T> mask, 411 const Vec512<T> yes, const Vec512<T> no) { 412 return Vec512<T>{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; 413 } 414 415 } // namespace detail 416 417 template <typename T> 418 HWY_API Vec512<T> IfThenElse(const Mask512<T> mask, const Vec512<T> yes, 419 const Vec512<T> no) { 420 return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); 421 } 422 template <> 423 HWY_INLINE Vec512<float> IfThenElse(const Mask512<float> mask, 424 const Vec512<float> yes, 425 const Vec512<float> no) { 426 return Vec512<float>{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; 427 } 428 template <> 429 HWY_INLINE Vec512<double> IfThenElse(const Mask512<double> mask, 430 const Vec512<double> yes, 431 const Vec512<double> no) { 432 return Vec512<double>{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; 433 } 434 435 namespace detail { 436 437 template <typename T> 438 HWY_API Vec512<T> IfThenElseZero(hwy::SizeTag<1> /* tag */, 439 const Mask512<T> mask, const Vec512<T> yes) { 440 return Vec512<T>{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; 441 } 442 template <typename T> 443 HWY_API Vec512<T> IfThenElseZero(hwy::SizeTag<2> /* tag */, 444 const Mask512<T> mask, const Vec512<T> yes) { 445 return Vec512<T>{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; 446 } 447 template <typename T> 448 HWY_API Vec512<T> IfThenElseZero(hwy::SizeTag<4> /* tag */, 449 const Mask512<T> mask, const Vec512<T> yes) { 450 return Vec512<T>{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; 451 } 452 template <typename T> 453 HWY_API Vec512<T> IfThenElseZero(hwy::SizeTag<8> /* tag */, 454 const Mask512<T> mask, const Vec512<T> yes) { 455 return Vec512<T>{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; 456 } 457 458 } // namespace detail 459 460 template <typename T> 461 HWY_API Vec512<T> IfThenElseZero(const Mask512<T> mask, const Vec512<T> yes) { 462 return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); 463 } 464 template <> 465 HWY_INLINE Vec512<float> IfThenElseZero(const Mask512<float> mask, 466 const Vec512<float> yes) { 467 return Vec512<float>{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; 468 } 469 template <> 470 HWY_INLINE Vec512<double> IfThenElseZero(const Mask512<double> mask, 471 const Vec512<double> yes) { 472 return Vec512<double>{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; 473 } 474 475 namespace detail { 476 477 template <typename T> 478 HWY_API Vec512<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */, 479 const Mask512<T> mask, const Vec512<T> no) { 480 // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. 481 return Vec512<T>{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; 482 } 483 template <typename T> 484 HWY_API Vec512<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */, 485 const Mask512<T> mask, const Vec512<T> no) { 486 return Vec512<T>{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; 487 } 488 template <typename T> 489 HWY_API Vec512<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */, 490 const Mask512<T> mask, const Vec512<T> no) { 491 return Vec512<T>{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; 492 } 493 template <typename T> 494 HWY_API Vec512<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */, 495 const Mask512<T> mask, const Vec512<T> no) { 496 return Vec512<T>{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; 497 } 498 499 } // namespace detail 500 501 template <typename T> 502 HWY_API Vec512<T> IfThenZeroElse(const Mask512<T> mask, const Vec512<T> no) { 503 return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); 504 } 505 template <> 506 HWY_INLINE Vec512<float> IfThenZeroElse(const Mask512<float> mask, 507 const Vec512<float> no) { 508 return Vec512<float>{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; 509 } 510 template <> 511 HWY_INLINE Vec512<double> IfThenZeroElse(const Mask512<double> mask, 512 const Vec512<double> no) { 513 return Vec512<double>{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; 514 } 515 516 template <typename T, HWY_IF_FLOAT(T)> 517 HWY_API Vec512<T> ZeroIfNegative(const Vec512<T> v) { 518 // AVX3 MaskFromVec only looks at the MSB 519 return IfThenZeroElse(MaskFromVec(v), v); 520 } 521 522 // ================================================== ARITHMETIC 523 524 // ------------------------------ Addition 525 526 // Unsigned 527 HWY_API Vec512<uint8_t> operator+(const Vec512<uint8_t> a, 528 const Vec512<uint8_t> b) { 529 return Vec512<uint8_t>{_mm512_add_epi8(a.raw, b.raw)}; 530 } 531 HWY_API Vec512<uint16_t> operator+(const Vec512<uint16_t> a, 532 const Vec512<uint16_t> b) { 533 return Vec512<uint16_t>{_mm512_add_epi16(a.raw, b.raw)}; 534 } 535 HWY_API Vec512<uint32_t> operator+(const Vec512<uint32_t> a, 536 const Vec512<uint32_t> b) { 537 return Vec512<uint32_t>{_mm512_add_epi32(a.raw, b.raw)}; 538 } 539 HWY_API Vec512<uint64_t> operator+(const Vec512<uint64_t> a, 540 const Vec512<uint64_t> b) { 541 return Vec512<uint64_t>{_mm512_add_epi64(a.raw, b.raw)}; 542 } 543 544 // Signed 545 HWY_API Vec512<int8_t> operator+(const Vec512<int8_t> a, 546 const Vec512<int8_t> b) { 547 return Vec512<int8_t>{_mm512_add_epi8(a.raw, b.raw)}; 548 } 549 HWY_API Vec512<int16_t> operator+(const Vec512<int16_t> a, 550 const Vec512<int16_t> b) { 551 return Vec512<int16_t>{_mm512_add_epi16(a.raw, b.raw)}; 552 } 553 HWY_API Vec512<int32_t> operator+(const Vec512<int32_t> a, 554 const Vec512<int32_t> b) { 555 return Vec512<int32_t>{_mm512_add_epi32(a.raw, b.raw)}; 556 } 557 HWY_API Vec512<int64_t> operator+(const Vec512<int64_t> a, 558 const Vec512<int64_t> b) { 559 return Vec512<int64_t>{_mm512_add_epi64(a.raw, b.raw)}; 560 } 561 562 // Float 563 HWY_API Vec512<float> operator+(const Vec512<float> a, const Vec512<float> b) { 564 return Vec512<float>{_mm512_add_ps(a.raw, b.raw)}; 565 } 566 HWY_API Vec512<double> operator+(const Vec512<double> a, 567 const Vec512<double> b) { 568 return Vec512<double>{_mm512_add_pd(a.raw, b.raw)}; 569 } 570 571 // ------------------------------ Subtraction 572 573 // Unsigned 574 HWY_API Vec512<uint8_t> operator-(const Vec512<uint8_t> a, 575 const Vec512<uint8_t> b) { 576 return Vec512<uint8_t>{_mm512_sub_epi8(a.raw, b.raw)}; 577 } 578 HWY_API Vec512<uint16_t> operator-(const Vec512<uint16_t> a, 579 const Vec512<uint16_t> b) { 580 return Vec512<uint16_t>{_mm512_sub_epi16(a.raw, b.raw)}; 581 } 582 HWY_API Vec512<uint32_t> operator-(const Vec512<uint32_t> a, 583 const Vec512<uint32_t> b) { 584 return Vec512<uint32_t>{_mm512_sub_epi32(a.raw, b.raw)}; 585 } 586 HWY_API Vec512<uint64_t> operator-(const Vec512<uint64_t> a, 587 const Vec512<uint64_t> b) { 588 return Vec512<uint64_t>{_mm512_sub_epi64(a.raw, b.raw)}; 589 } 590 591 // Signed 592 HWY_API Vec512<int8_t> operator-(const Vec512<int8_t> a, 593 const Vec512<int8_t> b) { 594 return Vec512<int8_t>{_mm512_sub_epi8(a.raw, b.raw)}; 595 } 596 HWY_API Vec512<int16_t> operator-(const Vec512<int16_t> a, 597 const Vec512<int16_t> b) { 598 return Vec512<int16_t>{_mm512_sub_epi16(a.raw, b.raw)}; 599 } 600 HWY_API Vec512<int32_t> operator-(const Vec512<int32_t> a, 601 const Vec512<int32_t> b) { 602 return Vec512<int32_t>{_mm512_sub_epi32(a.raw, b.raw)}; 603 } 604 HWY_API Vec512<int64_t> operator-(const Vec512<int64_t> a, 605 const Vec512<int64_t> b) { 606 return Vec512<int64_t>{_mm512_sub_epi64(a.raw, b.raw)}; 607 } 608 609 // Float 610 HWY_API Vec512<float> operator-(const Vec512<float> a, const Vec512<float> b) { 611 return Vec512<float>{_mm512_sub_ps(a.raw, b.raw)}; 612 } 613 HWY_API Vec512<double> operator-(const Vec512<double> a, 614 const Vec512<double> b) { 615 return Vec512<double>{_mm512_sub_pd(a.raw, b.raw)}; 616 } 617 618 // ------------------------------ Saturating addition 619 620 // Returns a + b clamped to the destination range. 621 622 // Unsigned 623 HWY_API Vec512<uint8_t> SaturatedAdd(const Vec512<uint8_t> a, 624 const Vec512<uint8_t> b) { 625 return Vec512<uint8_t>{_mm512_adds_epu8(a.raw, b.raw)}; 626 } 627 HWY_API Vec512<uint16_t> SaturatedAdd(const Vec512<uint16_t> a, 628 const Vec512<uint16_t> b) { 629 return Vec512<uint16_t>{_mm512_adds_epu16(a.raw, b.raw)}; 630 } 631 632 // Signed 633 HWY_API Vec512<int8_t> SaturatedAdd(const Vec512<int8_t> a, 634 const Vec512<int8_t> b) { 635 return Vec512<int8_t>{_mm512_adds_epi8(a.raw, b.raw)}; 636 } 637 HWY_API Vec512<int16_t> SaturatedAdd(const Vec512<int16_t> a, 638 const Vec512<int16_t> b) { 639 return Vec512<int16_t>{_mm512_adds_epi16(a.raw, b.raw)}; 640 } 641 642 // ------------------------------ Saturating subtraction 643 644 // Returns a - b clamped to the destination range. 645 646 // Unsigned 647 HWY_API Vec512<uint8_t> SaturatedSub(const Vec512<uint8_t> a, 648 const Vec512<uint8_t> b) { 649 return Vec512<uint8_t>{_mm512_subs_epu8(a.raw, b.raw)}; 650 } 651 HWY_API Vec512<uint16_t> SaturatedSub(const Vec512<uint16_t> a, 652 const Vec512<uint16_t> b) { 653 return Vec512<uint16_t>{_mm512_subs_epu16(a.raw, b.raw)}; 654 } 655 656 // Signed 657 HWY_API Vec512<int8_t> SaturatedSub(const Vec512<int8_t> a, 658 const Vec512<int8_t> b) { 659 return Vec512<int8_t>{_mm512_subs_epi8(a.raw, b.raw)}; 660 } 661 HWY_API Vec512<int16_t> SaturatedSub(const Vec512<int16_t> a, 662 const Vec512<int16_t> b) { 663 return Vec512<int16_t>{_mm512_subs_epi16(a.raw, b.raw)}; 664 } 665 666 // ------------------------------ Average 667 668 // Returns (a + b + 1) / 2 669 670 // Unsigned 671 HWY_API Vec512<uint8_t> AverageRound(const Vec512<uint8_t> a, 672 const Vec512<uint8_t> b) { 673 return Vec512<uint8_t>{_mm512_avg_epu8(a.raw, b.raw)}; 674 } 675 HWY_API Vec512<uint16_t> AverageRound(const Vec512<uint16_t> a, 676 const Vec512<uint16_t> b) { 677 return Vec512<uint16_t>{_mm512_avg_epu16(a.raw, b.raw)}; 678 } 679 680 // ------------------------------ Absolute value 681 682 // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. 683 HWY_API Vec512<int8_t> Abs(const Vec512<int8_t> v) { 684 #if HWY_COMPILER_MSVC 685 // Workaround for incorrect codegen? (untested due to internal compiler error) 686 const auto zero = Zero(Full512<int8_t>()); 687 return Vec512<int8_t>{_mm512_max_epi8(v.raw, (zero - v).raw)}; 688 #else 689 return Vec512<int8_t>{_mm512_abs_epi8(v.raw)}; 690 #endif 691 } 692 HWY_API Vec512<int16_t> Abs(const Vec512<int16_t> v) { 693 return Vec512<int16_t>{_mm512_abs_epi16(v.raw)}; 694 } 695 HWY_API Vec512<int32_t> Abs(const Vec512<int32_t> v) { 696 return Vec512<int32_t>{_mm512_abs_epi32(v.raw)}; 697 } 698 HWY_API Vec512<int64_t> Abs(const Vec512<int64_t> v) { 699 return Vec512<int64_t>{_mm512_abs_epi64(v.raw)}; 700 } 701 702 // These aren't native instructions, they also involve AND with constant. 703 HWY_API Vec512<float> Abs(const Vec512<float> v) { 704 return Vec512<float>{_mm512_abs_ps(v.raw)}; 705 } 706 HWY_API Vec512<double> Abs(const Vec512<double> v) { 707 return Vec512<double>{_mm512_abs_pd(v.raw)}; 708 } 709 710 // ------------------------------ ShiftLeft 711 712 template <int kBits> 713 HWY_API Vec512<uint16_t> ShiftLeft(const Vec512<uint16_t> v) { 714 return Vec512<uint16_t>{_mm512_slli_epi16(v.raw, kBits)}; 715 } 716 717 template <int kBits> 718 HWY_API Vec512<uint32_t> ShiftLeft(const Vec512<uint32_t> v) { 719 return Vec512<uint32_t>{_mm512_slli_epi32(v.raw, kBits)}; 720 } 721 722 template <int kBits> 723 HWY_API Vec512<uint64_t> ShiftLeft(const Vec512<uint64_t> v) { 724 return Vec512<uint64_t>{_mm512_slli_epi64(v.raw, kBits)}; 725 } 726 727 template <int kBits> 728 HWY_API Vec512<int16_t> ShiftLeft(const Vec512<int16_t> v) { 729 return Vec512<int16_t>{_mm512_slli_epi16(v.raw, kBits)}; 730 } 731 732 template <int kBits> 733 HWY_API Vec512<int32_t> ShiftLeft(const Vec512<int32_t> v) { 734 return Vec512<int32_t>{_mm512_slli_epi32(v.raw, kBits)}; 735 } 736 737 template <int kBits> 738 HWY_API Vec512<int64_t> ShiftLeft(const Vec512<int64_t> v) { 739 return Vec512<int64_t>{_mm512_slli_epi64(v.raw, kBits)}; 740 } 741 742 template <int kBits, typename T, HWY_IF_LANE_SIZE(T, 1)> 743 HWY_API Vec512<T> ShiftLeft(const Vec512<T> v) { 744 const Full512<T> d8; 745 const RepartitionToWide<decltype(d8)> d16; 746 const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v))); 747 return kBits == 1 748 ? (v + v) 749 : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); 750 } 751 752 // ------------------------------ ShiftRight 753 754 template <int kBits> 755 HWY_API Vec512<uint16_t> ShiftRight(const Vec512<uint16_t> v) { 756 return Vec512<uint16_t>{_mm512_srli_epi16(v.raw, kBits)}; 757 } 758 759 template <int kBits> 760 HWY_API Vec512<uint32_t> ShiftRight(const Vec512<uint32_t> v) { 761 return Vec512<uint32_t>{_mm512_srli_epi32(v.raw, kBits)}; 762 } 763 764 template <int kBits> 765 HWY_API Vec512<uint64_t> ShiftRight(const Vec512<uint64_t> v) { 766 return Vec512<uint64_t>{_mm512_srli_epi64(v.raw, kBits)}; 767 } 768 769 template <int kBits> 770 HWY_API Vec512<uint8_t> ShiftRight(const Vec512<uint8_t> v) { 771 const Full512<uint8_t> d8; 772 // Use raw instead of BitCast to support N=1. 773 const Vec512<uint8_t> shifted{ShiftRight<kBits>(Vec512<uint16_t>{v.raw}).raw}; 774 return shifted & Set(d8, 0xFF >> kBits); 775 } 776 777 template <int kBits> 778 HWY_API Vec512<int16_t> ShiftRight(const Vec512<int16_t> v) { 779 return Vec512<int16_t>{_mm512_srai_epi16(v.raw, kBits)}; 780 } 781 782 template <int kBits> 783 HWY_API Vec512<int32_t> ShiftRight(const Vec512<int32_t> v) { 784 return Vec512<int32_t>{_mm512_srai_epi32(v.raw, kBits)}; 785 } 786 787 template <int kBits> 788 HWY_API Vec512<int64_t> ShiftRight(const Vec512<int64_t> v) { 789 return Vec512<int64_t>{_mm512_srai_epi64(v.raw, kBits)}; 790 } 791 792 template <int kBits> 793 HWY_API Vec512<int8_t> ShiftRight(const Vec512<int8_t> v) { 794 const Full512<int8_t> di; 795 const Full512<uint8_t> du; 796 const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); 797 const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); 798 return (shifted ^ shifted_sign) - shifted_sign; 799 } 800 801 // ------------------------------ ShiftLeftSame 802 803 HWY_API Vec512<uint16_t> ShiftLeftSame(const Vec512<uint16_t> v, 804 const int bits) { 805 return Vec512<uint16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 806 } 807 HWY_API Vec512<uint32_t> ShiftLeftSame(const Vec512<uint32_t> v, 808 const int bits) { 809 return Vec512<uint32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 810 } 811 HWY_API Vec512<uint64_t> ShiftLeftSame(const Vec512<uint64_t> v, 812 const int bits) { 813 return Vec512<uint64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 814 } 815 816 HWY_API Vec512<int16_t> ShiftLeftSame(const Vec512<int16_t> v, const int bits) { 817 return Vec512<int16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 818 } 819 820 HWY_API Vec512<int32_t> ShiftLeftSame(const Vec512<int32_t> v, const int bits) { 821 return Vec512<int32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 822 } 823 824 HWY_API Vec512<int64_t> ShiftLeftSame(const Vec512<int64_t> v, const int bits) { 825 return Vec512<int64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 826 } 827 828 template <typename T, HWY_IF_LANE_SIZE(T, 1)> 829 HWY_API Vec512<T> ShiftLeftSame(const Vec512<T> v, const int bits) { 830 const Full512<T> d8; 831 const RepartitionToWide<decltype(d8)> d16; 832 const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); 833 return shifted & Set(d8, (0xFF << bits) & 0xFF); 834 } 835 836 // ------------------------------ ShiftRightSame 837 838 HWY_API Vec512<uint16_t> ShiftRightSame(const Vec512<uint16_t> v, 839 const int bits) { 840 return Vec512<uint16_t>{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 841 } 842 HWY_API Vec512<uint32_t> ShiftRightSame(const Vec512<uint32_t> v, 843 const int bits) { 844 return Vec512<uint32_t>{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 845 } 846 HWY_API Vec512<uint64_t> ShiftRightSame(const Vec512<uint64_t> v, 847 const int bits) { 848 return Vec512<uint64_t>{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 849 } 850 851 HWY_API Vec512<uint8_t> ShiftRightSame(Vec512<uint8_t> v, const int bits) { 852 const Full512<uint8_t> d8; 853 const RepartitionToWide<decltype(d8)> d16; 854 const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); 855 return shifted & Set(d8, 0xFF >> bits); 856 } 857 858 HWY_API Vec512<int16_t> ShiftRightSame(const Vec512<int16_t> v, 859 const int bits) { 860 return Vec512<int16_t>{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; 861 } 862 863 HWY_API Vec512<int32_t> ShiftRightSame(const Vec512<int32_t> v, 864 const int bits) { 865 return Vec512<int32_t>{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; 866 } 867 HWY_API Vec512<int64_t> ShiftRightSame(const Vec512<int64_t> v, 868 const int bits) { 869 return Vec512<int64_t>{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; 870 } 871 872 HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) { 873 const Full512<int8_t> di; 874 const Full512<uint8_t> du; 875 const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); 876 const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); 877 return (shifted ^ shifted_sign) - shifted_sign; 878 } 879 880 // ------------------------------ Shl 881 882 HWY_API Vec512<uint16_t> operator<<(const Vec512<uint16_t> v, 883 const Vec512<uint16_t> bits) { 884 return Vec512<uint16_t>{_mm512_sllv_epi16(v.raw, bits.raw)}; 885 } 886 887 HWY_API Vec512<uint32_t> operator<<(const Vec512<uint32_t> v, 888 const Vec512<uint32_t> bits) { 889 return Vec512<uint32_t>{_mm512_sllv_epi32(v.raw, bits.raw)}; 890 } 891 892 HWY_API Vec512<uint64_t> operator<<(const Vec512<uint64_t> v, 893 const Vec512<uint64_t> bits) { 894 return Vec512<uint64_t>{_mm512_sllv_epi64(v.raw, bits.raw)}; 895 } 896 897 // Signed left shift is the same as unsigned. 898 template <typename T, HWY_IF_SIGNED(T)> 899 HWY_API Vec512<T> operator<<(const Vec512<T> v, const Vec512<T> bits) { 900 const Full512<T> di; 901 const Full512<MakeUnsigned<T>> du; 902 return BitCast(di, BitCast(du, v) << BitCast(du, bits)); 903 } 904 905 // ------------------------------ Shr 906 907 HWY_API Vec512<uint16_t> operator>>(const Vec512<uint16_t> v, 908 const Vec512<uint16_t> bits) { 909 return Vec512<uint16_t>{_mm512_srlv_epi16(v.raw, bits.raw)}; 910 } 911 912 HWY_API Vec512<uint32_t> operator>>(const Vec512<uint32_t> v, 913 const Vec512<uint32_t> bits) { 914 return Vec512<uint32_t>{_mm512_srlv_epi32(v.raw, bits.raw)}; 915 } 916 917 HWY_API Vec512<uint64_t> operator>>(const Vec512<uint64_t> v, 918 const Vec512<uint64_t> bits) { 919 return Vec512<uint64_t>{_mm512_srlv_epi64(v.raw, bits.raw)}; 920 } 921 922 HWY_API Vec512<int16_t> operator>>(const Vec512<int16_t> v, 923 const Vec512<int16_t> bits) { 924 return Vec512<int16_t>{_mm512_srav_epi16(v.raw, bits.raw)}; 925 } 926 927 HWY_API Vec512<int32_t> operator>>(const Vec512<int32_t> v, 928 const Vec512<int32_t> bits) { 929 return Vec512<int32_t>{_mm512_srav_epi32(v.raw, bits.raw)}; 930 } 931 932 HWY_API Vec512<int64_t> operator>>(const Vec512<int64_t> v, 933 const Vec512<int64_t> bits) { 934 return Vec512<int64_t>{_mm512_srav_epi64(v.raw, bits.raw)}; 935 } 936 937 // ------------------------------ Minimum 938 939 // Unsigned 940 HWY_API Vec512<uint8_t> Min(const Vec512<uint8_t> a, const Vec512<uint8_t> b) { 941 return Vec512<uint8_t>{_mm512_min_epu8(a.raw, b.raw)}; 942 } 943 HWY_API Vec512<uint16_t> Min(const Vec512<uint16_t> a, 944 const Vec512<uint16_t> b) { 945 return Vec512<uint16_t>{_mm512_min_epu16(a.raw, b.raw)}; 946 } 947 HWY_API Vec512<uint32_t> Min(const Vec512<uint32_t> a, 948 const Vec512<uint32_t> b) { 949 return Vec512<uint32_t>{_mm512_min_epu32(a.raw, b.raw)}; 950 } 951 HWY_API Vec512<uint64_t> Min(const Vec512<uint64_t> a, 952 const Vec512<uint64_t> b) { 953 return Vec512<uint64_t>{_mm512_min_epu64(a.raw, b.raw)}; 954 } 955 956 // Signed 957 HWY_API Vec512<int8_t> Min(const Vec512<int8_t> a, const Vec512<int8_t> b) { 958 return Vec512<int8_t>{_mm512_min_epi8(a.raw, b.raw)}; 959 } 960 HWY_API Vec512<int16_t> Min(const Vec512<int16_t> a, const Vec512<int16_t> b) { 961 return Vec512<int16_t>{_mm512_min_epi16(a.raw, b.raw)}; 962 } 963 HWY_API Vec512<int32_t> Min(const Vec512<int32_t> a, const Vec512<int32_t> b) { 964 return Vec512<int32_t>{_mm512_min_epi32(a.raw, b.raw)}; 965 } 966 HWY_API Vec512<int64_t> Min(const Vec512<int64_t> a, const Vec512<int64_t> b) { 967 return Vec512<int64_t>{_mm512_min_epi64(a.raw, b.raw)}; 968 } 969 970 // Float 971 HWY_API Vec512<float> Min(const Vec512<float> a, const Vec512<float> b) { 972 return Vec512<float>{_mm512_min_ps(a.raw, b.raw)}; 973 } 974 HWY_API Vec512<double> Min(const Vec512<double> a, const Vec512<double> b) { 975 return Vec512<double>{_mm512_min_pd(a.raw, b.raw)}; 976 } 977 978 // ------------------------------ Maximum 979 980 // Unsigned 981 HWY_API Vec512<uint8_t> Max(const Vec512<uint8_t> a, const Vec512<uint8_t> b) { 982 return Vec512<uint8_t>{_mm512_max_epu8(a.raw, b.raw)}; 983 } 984 HWY_API Vec512<uint16_t> Max(const Vec512<uint16_t> a, 985 const Vec512<uint16_t> b) { 986 return Vec512<uint16_t>{_mm512_max_epu16(a.raw, b.raw)}; 987 } 988 HWY_API Vec512<uint32_t> Max(const Vec512<uint32_t> a, 989 const Vec512<uint32_t> b) { 990 return Vec512<uint32_t>{_mm512_max_epu32(a.raw, b.raw)}; 991 } 992 HWY_API Vec512<uint64_t> Max(const Vec512<uint64_t> a, 993 const Vec512<uint64_t> b) { 994 return Vec512<uint64_t>{_mm512_max_epu64(a.raw, b.raw)}; 995 } 996 997 // Signed 998 HWY_API Vec512<int8_t> Max(const Vec512<int8_t> a, const Vec512<int8_t> b) { 999 return Vec512<int8_t>{_mm512_max_epi8(a.raw, b.raw)}; 1000 } 1001 HWY_API Vec512<int16_t> Max(const Vec512<int16_t> a, const Vec512<int16_t> b) { 1002 return Vec512<int16_t>{_mm512_max_epi16(a.raw, b.raw)}; 1003 } 1004 HWY_API Vec512<int32_t> Max(const Vec512<int32_t> a, const Vec512<int32_t> b) { 1005 return Vec512<int32_t>{_mm512_max_epi32(a.raw, b.raw)}; 1006 } 1007 HWY_API Vec512<int64_t> Max(const Vec512<int64_t> a, const Vec512<int64_t> b) { 1008 return Vec512<int64_t>{_mm512_max_epi64(a.raw, b.raw)}; 1009 } 1010 1011 // Float 1012 HWY_API Vec512<float> Max(const Vec512<float> a, const Vec512<float> b) { 1013 return Vec512<float>{_mm512_max_ps(a.raw, b.raw)}; 1014 } 1015 HWY_API Vec512<double> Max(const Vec512<double> a, const Vec512<double> b) { 1016 return Vec512<double>{_mm512_max_pd(a.raw, b.raw)}; 1017 } 1018 1019 // ------------------------------ Integer multiplication 1020 1021 // Unsigned 1022 HWY_API Vec512<uint16_t> operator*(const Vec512<uint16_t> a, 1023 const Vec512<uint16_t> b) { 1024 return Vec512<uint16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; 1025 } 1026 HWY_API Vec512<uint32_t> operator*(const Vec512<uint32_t> a, 1027 const Vec512<uint32_t> b) { 1028 return Vec512<uint32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; 1029 } 1030 1031 // Signed 1032 HWY_API Vec512<int16_t> operator*(const Vec512<int16_t> a, 1033 const Vec512<int16_t> b) { 1034 return Vec512<int16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; 1035 } 1036 HWY_API Vec512<int32_t> operator*(const Vec512<int32_t> a, 1037 const Vec512<int32_t> b) { 1038 return Vec512<int32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; 1039 } 1040 1041 // Returns the upper 16 bits of a * b in each lane. 1042 HWY_API Vec512<uint16_t> MulHigh(const Vec512<uint16_t> a, 1043 const Vec512<uint16_t> b) { 1044 return Vec512<uint16_t>{_mm512_mulhi_epu16(a.raw, b.raw)}; 1045 } 1046 HWY_API Vec512<int16_t> MulHigh(const Vec512<int16_t> a, 1047 const Vec512<int16_t> b) { 1048 return Vec512<int16_t>{_mm512_mulhi_epi16(a.raw, b.raw)}; 1049 } 1050 1051 // Multiplies even lanes (0, 2 ..) and places the double-wide result into 1052 // even and the upper half into its odd neighbor lane. 1053 HWY_API Vec512<int64_t> MulEven(const Vec512<int32_t> a, 1054 const Vec512<int32_t> b) { 1055 return Vec512<int64_t>{_mm512_mul_epi32(a.raw, b.raw)}; 1056 } 1057 HWY_API Vec512<uint64_t> MulEven(const Vec512<uint32_t> a, 1058 const Vec512<uint32_t> b) { 1059 return Vec512<uint64_t>{_mm512_mul_epu32(a.raw, b.raw)}; 1060 } 1061 1062 // ------------------------------ Negate 1063 1064 template <typename T, HWY_IF_FLOAT(T)> 1065 HWY_API Vec512<T> Neg(const Vec512<T> v) { 1066 return Xor(v, SignBit(Full512<T>())); 1067 } 1068 1069 template <typename T, HWY_IF_NOT_FLOAT(T)> 1070 HWY_API Vec512<T> Neg(const Vec512<T> v) { 1071 return Zero(Full512<T>()) - v; 1072 } 1073 1074 // ------------------------------ Floating-point mul / div 1075 1076 HWY_API Vec512<float> operator*(const Vec512<float> a, const Vec512<float> b) { 1077 return Vec512<float>{_mm512_mul_ps(a.raw, b.raw)}; 1078 } 1079 HWY_API Vec512<double> operator*(const Vec512<double> a, 1080 const Vec512<double> b) { 1081 return Vec512<double>{_mm512_mul_pd(a.raw, b.raw)}; 1082 } 1083 1084 HWY_API Vec512<float> operator/(const Vec512<float> a, const Vec512<float> b) { 1085 return Vec512<float>{_mm512_div_ps(a.raw, b.raw)}; 1086 } 1087 HWY_API Vec512<double> operator/(const Vec512<double> a, 1088 const Vec512<double> b) { 1089 return Vec512<double>{_mm512_div_pd(a.raw, b.raw)}; 1090 } 1091 1092 // Approximate reciprocal 1093 HWY_API Vec512<float> ApproximateReciprocal(const Vec512<float> v) { 1094 return Vec512<float>{_mm512_rcp14_ps(v.raw)}; 1095 } 1096 1097 // Absolute value of difference. 1098 HWY_API Vec512<float> AbsDiff(const Vec512<float> a, const Vec512<float> b) { 1099 return Abs(a - b); 1100 } 1101 1102 // ------------------------------ Floating-point multiply-add variants 1103 1104 // Returns mul * x + add 1105 HWY_API Vec512<float> MulAdd(const Vec512<float> mul, const Vec512<float> x, 1106 const Vec512<float> add) { 1107 return Vec512<float>{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; 1108 } 1109 HWY_API Vec512<double> MulAdd(const Vec512<double> mul, const Vec512<double> x, 1110 const Vec512<double> add) { 1111 return Vec512<double>{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; 1112 } 1113 1114 // Returns add - mul * x 1115 HWY_API Vec512<float> NegMulAdd(const Vec512<float> mul, const Vec512<float> x, 1116 const Vec512<float> add) { 1117 return Vec512<float>{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; 1118 } 1119 HWY_API Vec512<double> NegMulAdd(const Vec512<double> mul, 1120 const Vec512<double> x, 1121 const Vec512<double> add) { 1122 return Vec512<double>{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; 1123 } 1124 1125 // Returns mul * x - sub 1126 HWY_API Vec512<float> MulSub(const Vec512<float> mul, const Vec512<float> x, 1127 const Vec512<float> sub) { 1128 return Vec512<float>{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; 1129 } 1130 HWY_API Vec512<double> MulSub(const Vec512<double> mul, const Vec512<double> x, 1131 const Vec512<double> sub) { 1132 return Vec512<double>{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; 1133 } 1134 1135 // Returns -mul * x - sub 1136 HWY_API Vec512<float> NegMulSub(const Vec512<float> mul, const Vec512<float> x, 1137 const Vec512<float> sub) { 1138 return Vec512<float>{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; 1139 } 1140 HWY_API Vec512<double> NegMulSub(const Vec512<double> mul, 1141 const Vec512<double> x, 1142 const Vec512<double> sub) { 1143 return Vec512<double>{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; 1144 } 1145 1146 // ------------------------------ Floating-point square root 1147 1148 // Full precision square root 1149 HWY_API Vec512<float> Sqrt(const Vec512<float> v) { 1150 return Vec512<float>{_mm512_sqrt_ps(v.raw)}; 1151 } 1152 HWY_API Vec512<double> Sqrt(const Vec512<double> v) { 1153 return Vec512<double>{_mm512_sqrt_pd(v.raw)}; 1154 } 1155 1156 // Approximate reciprocal square root 1157 HWY_API Vec512<float> ApproximateReciprocalSqrt(const Vec512<float> v) { 1158 return Vec512<float>{_mm512_rsqrt14_ps(v.raw)}; 1159 } 1160 1161 // ------------------------------ Floating-point rounding 1162 1163 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1164 HWY_DIAGNOSTICS(push) 1165 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1166 1167 // Toward nearest integer, tie to even 1168 HWY_API Vec512<float> Round(const Vec512<float> v) { 1169 return Vec512<float>{_mm512_roundscale_ps( 1170 v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 1171 } 1172 HWY_API Vec512<double> Round(const Vec512<double> v) { 1173 return Vec512<double>{_mm512_roundscale_pd( 1174 v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; 1175 } 1176 1177 // Toward zero, aka truncate 1178 HWY_API Vec512<float> Trunc(const Vec512<float> v) { 1179 return Vec512<float>{ 1180 _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 1181 } 1182 HWY_API Vec512<double> Trunc(const Vec512<double> v) { 1183 return Vec512<double>{ 1184 _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; 1185 } 1186 1187 // Toward +infinity, aka ceiling 1188 HWY_API Vec512<float> Ceil(const Vec512<float> v) { 1189 return Vec512<float>{ 1190 _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 1191 } 1192 HWY_API Vec512<double> Ceil(const Vec512<double> v) { 1193 return Vec512<double>{ 1194 _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; 1195 } 1196 1197 // Toward -infinity, aka floor 1198 HWY_API Vec512<float> Floor(const Vec512<float> v) { 1199 return Vec512<float>{ 1200 _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 1201 } 1202 HWY_API Vec512<double> Floor(const Vec512<double> v) { 1203 return Vec512<double>{ 1204 _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; 1205 } 1206 1207 HWY_DIAGNOSTICS(pop) 1208 1209 // ================================================== COMPARE 1210 1211 // Comparisons set a mask bit to 1 if the condition is true, else 0. 1212 1213 template <typename TFrom, typename TTo> 1214 HWY_API Mask512<TTo> RebindMask(Full512<TTo> /*tag*/, Mask512<TFrom> m) { 1215 static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); 1216 return Mask512<TTo>{m.raw}; 1217 } 1218 1219 namespace detail { 1220 1221 template <typename T> 1222 HWY_API Mask512<T> TestBit(hwy::SizeTag<1> /*tag*/, const Vec512<T> v, 1223 const Vec512<T> bit) { 1224 return Mask512<T>{_mm512_test_epi8_mask(v.raw, bit.raw)}; 1225 } 1226 template <typename T> 1227 HWY_API Mask512<T> TestBit(hwy::SizeTag<2> /*tag*/, const Vec512<T> v, 1228 const Vec512<T> bit) { 1229 return Mask512<T>{_mm512_test_epi16_mask(v.raw, bit.raw)}; 1230 } 1231 template <typename T> 1232 HWY_API Mask512<T> TestBit(hwy::SizeTag<4> /*tag*/, const Vec512<T> v, 1233 const Vec512<T> bit) { 1234 return Mask512<T>{_mm512_test_epi32_mask(v.raw, bit.raw)}; 1235 } 1236 template <typename T> 1237 HWY_API Mask512<T> TestBit(hwy::SizeTag<8> /*tag*/, const Vec512<T> v, 1238 const Vec512<T> bit) { 1239 return Mask512<T>{_mm512_test_epi64_mask(v.raw, bit.raw)}; 1240 } 1241 1242 } // namespace detail 1243 1244 template <typename T> 1245 HWY_API Mask512<T> TestBit(const Vec512<T> v, const Vec512<T> bit) { 1246 static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); 1247 return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); 1248 } 1249 1250 // ------------------------------ Equality 1251 1252 // Unsigned 1253 HWY_API Mask512<uint8_t> operator==(const Vec512<uint8_t> a, 1254 const Vec512<uint8_t> b) { 1255 return Mask512<uint8_t>{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; 1256 } 1257 HWY_API Mask512<uint16_t> operator==(const Vec512<uint16_t> a, 1258 const Vec512<uint16_t> b) { 1259 return Mask512<uint16_t>{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; 1260 } 1261 HWY_API Mask512<uint32_t> operator==(const Vec512<uint32_t> a, 1262 const Vec512<uint32_t> b) { 1263 return Mask512<uint32_t>{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; 1264 } 1265 HWY_API Mask512<uint64_t> operator==(const Vec512<uint64_t> a, 1266 const Vec512<uint64_t> b) { 1267 return Mask512<uint64_t>{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; 1268 } 1269 1270 // Signed 1271 HWY_API Mask512<int8_t> operator==(const Vec512<int8_t> a, 1272 const Vec512<int8_t> b) { 1273 return Mask512<int8_t>{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; 1274 } 1275 HWY_API Mask512<int16_t> operator==(const Vec512<int16_t> a, 1276 const Vec512<int16_t> b) { 1277 return Mask512<int16_t>{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; 1278 } 1279 HWY_API Mask512<int32_t> operator==(const Vec512<int32_t> a, 1280 const Vec512<int32_t> b) { 1281 return Mask512<int32_t>{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; 1282 } 1283 HWY_API Mask512<int64_t> operator==(const Vec512<int64_t> a, 1284 const Vec512<int64_t> b) { 1285 return Mask512<int64_t>{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; 1286 } 1287 1288 // Float 1289 HWY_API Mask512<float> operator==(const Vec512<float> a, 1290 const Vec512<float> b) { 1291 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 1292 } 1293 HWY_API Mask512<double> operator==(const Vec512<double> a, 1294 const Vec512<double> b) { 1295 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; 1296 } 1297 1298 // ------------------------------ Strict inequality 1299 1300 // Signed/float < 1301 HWY_API Mask512<int8_t> operator<(const Vec512<int8_t> a, 1302 const Vec512<int8_t> b) { 1303 return Mask512<int8_t>{_mm512_cmpgt_epi8_mask(b.raw, a.raw)}; 1304 } 1305 HWY_API Mask512<int16_t> operator<(const Vec512<int16_t> a, 1306 const Vec512<int16_t> b) { 1307 return Mask512<int16_t>{_mm512_cmpgt_epi16_mask(b.raw, a.raw)}; 1308 } 1309 HWY_API Mask512<int32_t> operator<(const Vec512<int32_t> a, 1310 const Vec512<int32_t> b) { 1311 return Mask512<int32_t>{_mm512_cmpgt_epi32_mask(b.raw, a.raw)}; 1312 } 1313 HWY_API Mask512<int64_t> operator<(const Vec512<int64_t> a, 1314 const Vec512<int64_t> b) { 1315 return Mask512<int64_t>{_mm512_cmpgt_epi64_mask(b.raw, a.raw)}; 1316 } 1317 HWY_API Mask512<float> operator<(const Vec512<float> a, const Vec512<float> b) { 1318 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_LT_OQ)}; 1319 } 1320 HWY_API Mask512<double> operator<(const Vec512<double> a, 1321 const Vec512<double> b) { 1322 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_LT_OQ)}; 1323 } 1324 1325 // Signed/float > 1326 HWY_API Mask512<int8_t> operator>(const Vec512<int8_t> a, 1327 const Vec512<int8_t> b) { 1328 return Mask512<int8_t>{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; 1329 } 1330 HWY_API Mask512<int16_t> operator>(const Vec512<int16_t> a, 1331 const Vec512<int16_t> b) { 1332 return Mask512<int16_t>{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; 1333 } 1334 HWY_API Mask512<int32_t> operator>(const Vec512<int32_t> a, 1335 const Vec512<int32_t> b) { 1336 return Mask512<int32_t>{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; 1337 } 1338 HWY_API Mask512<int64_t> operator>(const Vec512<int64_t> a, 1339 const Vec512<int64_t> b) { 1340 return Mask512<int64_t>{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; 1341 } 1342 HWY_API Mask512<float> operator>(const Vec512<float> a, const Vec512<float> b) { 1343 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; 1344 } 1345 HWY_API Mask512<double> operator>(const Vec512<double> a, 1346 const Vec512<double> b) { 1347 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; 1348 } 1349 1350 // ------------------------------ Weak inequality 1351 1352 // Float <= >= 1353 HWY_API Mask512<float> operator<=(const Vec512<float> a, 1354 const Vec512<float> b) { 1355 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_LE_OQ)}; 1356 } 1357 HWY_API Mask512<double> operator<=(const Vec512<double> a, 1358 const Vec512<double> b) { 1359 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_LE_OQ)}; 1360 } 1361 HWY_API Mask512<float> operator>=(const Vec512<float> a, 1362 const Vec512<float> b) { 1363 return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; 1364 } 1365 HWY_API Mask512<double> operator>=(const Vec512<double> a, 1366 const Vec512<double> b) { 1367 return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; 1368 } 1369 1370 // ------------------------------ Mask 1371 1372 namespace detail { 1373 1374 template <typename T> 1375 HWY_API Mask512<T> MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512<T> v) { 1376 return Mask512<T>{_mm512_movepi8_mask(v.raw)}; 1377 } 1378 template <typename T> 1379 HWY_API Mask512<T> MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512<T> v) { 1380 return Mask512<T>{_mm512_movepi16_mask(v.raw)}; 1381 } 1382 template <typename T> 1383 HWY_API Mask512<T> MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512<T> v) { 1384 return Mask512<T>{_mm512_movepi32_mask(v.raw)}; 1385 } 1386 template <typename T> 1387 HWY_API Mask512<T> MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512<T> v) { 1388 return Mask512<T>{_mm512_movepi64_mask(v.raw)}; 1389 } 1390 1391 } // namespace detail 1392 1393 template <typename T> 1394 HWY_API Mask512<T> MaskFromVec(const Vec512<T> v) { 1395 return detail::MaskFromVec(hwy::SizeTag<sizeof(T)>(), v); 1396 } 1397 // There do not seem to be native floating-point versions of these instructions. 1398 HWY_API Mask512<float> MaskFromVec(const Vec512<float> v) { 1399 return Mask512<float>{MaskFromVec(BitCast(Full512<int32_t>(), v)).raw}; 1400 } 1401 HWY_API Mask512<double> MaskFromVec(const Vec512<double> v) { 1402 return Mask512<double>{MaskFromVec(BitCast(Full512<int64_t>(), v)).raw}; 1403 } 1404 1405 HWY_API Vec512<uint8_t> VecFromMask(const Mask512<uint8_t> v) { 1406 return Vec512<uint8_t>{_mm512_movm_epi8(v.raw)}; 1407 } 1408 HWY_API Vec512<int8_t> VecFromMask(const Mask512<int8_t> v) { 1409 return Vec512<int8_t>{_mm512_movm_epi8(v.raw)}; 1410 } 1411 1412 HWY_API Vec512<uint16_t> VecFromMask(const Mask512<uint16_t> v) { 1413 return Vec512<uint16_t>{_mm512_movm_epi16(v.raw)}; 1414 } 1415 HWY_API Vec512<int16_t> VecFromMask(const Mask512<int16_t> v) { 1416 return Vec512<int16_t>{_mm512_movm_epi16(v.raw)}; 1417 } 1418 1419 HWY_API Vec512<uint32_t> VecFromMask(const Mask512<uint32_t> v) { 1420 return Vec512<uint32_t>{_mm512_movm_epi32(v.raw)}; 1421 } 1422 HWY_API Vec512<int32_t> VecFromMask(const Mask512<int32_t> v) { 1423 return Vec512<int32_t>{_mm512_movm_epi32(v.raw)}; 1424 } 1425 HWY_API Vec512<float> VecFromMask(const Mask512<float> v) { 1426 return Vec512<float>{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; 1427 } 1428 1429 HWY_API Vec512<uint64_t> VecFromMask(const Mask512<uint64_t> v) { 1430 return Vec512<uint64_t>{_mm512_movm_epi64(v.raw)}; 1431 } 1432 HWY_API Vec512<int64_t> VecFromMask(const Mask512<int64_t> v) { 1433 return Vec512<int64_t>{_mm512_movm_epi64(v.raw)}; 1434 } 1435 HWY_API Vec512<double> VecFromMask(const Mask512<double> v) { 1436 return Vec512<double>{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; 1437 } 1438 1439 template <typename T> 1440 HWY_API Vec512<T> VecFromMask(Full512<T> /* tag */, const Mask512<T> v) { 1441 return VecFromMask(v); 1442 } 1443 1444 // ------------------------------ Mask logical 1445 1446 // For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. 1447 #if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) && \ 1448 (HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC >= 700 || \ 1449 HWY_COMPILER_CLANG >= 800) 1450 #define HWY_COMPILER_HAS_MASK_INTRINSICS 1 1451 #else 1452 #define HWY_COMPILER_HAS_MASK_INTRINSICS 0 1453 #endif 1454 1455 namespace detail { 1456 1457 template <typename T> 1458 HWY_API Mask512<T> Not(hwy::SizeTag<1> /*tag*/, const Mask512<T> m) { 1459 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1460 return Mask512<T>{_knot_mask64(m.raw)}; 1461 #else 1462 return Mask512<T>{~m.raw}; 1463 #endif 1464 } 1465 template <typename T> 1466 HWY_API Mask512<T> Not(hwy::SizeTag<2> /*tag*/, const Mask512<T> m) { 1467 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1468 return Mask512<T>{_knot_mask32(m.raw)}; 1469 #else 1470 return Mask512<T>{~m.raw}; 1471 #endif 1472 } 1473 template <typename T> 1474 HWY_API Mask512<T> Not(hwy::SizeTag<4> /*tag*/, const Mask512<T> m) { 1475 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1476 return Mask512<T>{_knot_mask16(m.raw)}; 1477 #else 1478 return Mask512<T>{static_cast<uint16_t>(~m.raw & 0xFFFF)}; 1479 #endif 1480 } 1481 template <typename T> 1482 HWY_API Mask512<T> Not(hwy::SizeTag<8> /*tag*/, const Mask512<T> m) { 1483 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1484 return Mask512<T>{_knot_mask8(m.raw)}; 1485 #else 1486 return Mask512<T>{static_cast<uint8_t>(~m.raw & 0xFF)}; 1487 #endif 1488 } 1489 1490 template <typename T> 1491 HWY_API Mask512<T> And(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, 1492 const Mask512<T> b) { 1493 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1494 return Mask512<T>{_kand_mask64(a.raw, b.raw)}; 1495 #else 1496 return Mask512<T>{a.raw & b.raw}; 1497 #endif 1498 } 1499 template <typename T> 1500 HWY_API Mask512<T> And(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, 1501 const Mask512<T> b) { 1502 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1503 return Mask512<T>{_kand_mask32(a.raw, b.raw)}; 1504 #else 1505 return Mask512<T>{a.raw & b.raw}; 1506 #endif 1507 } 1508 template <typename T> 1509 HWY_API Mask512<T> And(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, 1510 const Mask512<T> b) { 1511 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1512 return Mask512<T>{_kand_mask16(a.raw, b.raw)}; 1513 #else 1514 return Mask512<T>{static_cast<uint16_t>(a.raw & b.raw)}; 1515 #endif 1516 } 1517 template <typename T> 1518 HWY_API Mask512<T> And(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, 1519 const Mask512<T> b) { 1520 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1521 return Mask512<T>{_kand_mask8(a.raw, b.raw)}; 1522 #else 1523 return Mask512<T>{static_cast<uint8_t>(a.raw & b.raw)}; 1524 #endif 1525 } 1526 1527 template <typename T> 1528 HWY_API Mask512<T> AndNot(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, 1529 const Mask512<T> b) { 1530 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1531 return Mask512<T>{_kandn_mask64(a.raw, b.raw)}; 1532 #else 1533 return Mask512<T>{~a.raw & b.raw}; 1534 #endif 1535 } 1536 template <typename T> 1537 HWY_API Mask512<T> AndNot(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, 1538 const Mask512<T> b) { 1539 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1540 return Mask512<T>{_kandn_mask32(a.raw, b.raw)}; 1541 #else 1542 return Mask512<T>{~a.raw & b.raw}; 1543 #endif 1544 } 1545 template <typename T> 1546 HWY_API Mask512<T> AndNot(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, 1547 const Mask512<T> b) { 1548 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1549 return Mask512<T>{_kandn_mask16(a.raw, b.raw)}; 1550 #else 1551 return Mask512<T>{static_cast<uint16_t>(~a.raw & b.raw)}; 1552 #endif 1553 } 1554 template <typename T> 1555 HWY_API Mask512<T> AndNot(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, 1556 const Mask512<T> b) { 1557 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1558 return Mask512<T>{_kandn_mask8(a.raw, b.raw)}; 1559 #else 1560 return Mask512<T>{static_cast<uint8_t>(~a.raw & b.raw)}; 1561 #endif 1562 } 1563 1564 template <typename T> 1565 HWY_API Mask512<T> Or(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, 1566 const Mask512<T> b) { 1567 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1568 return Mask512<T>{_kor_mask64(a.raw, b.raw)}; 1569 #else 1570 return Mask512<T>{a.raw | b.raw}; 1571 #endif 1572 } 1573 template <typename T> 1574 HWY_API Mask512<T> Or(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, 1575 const Mask512<T> b) { 1576 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1577 return Mask512<T>{_kor_mask32(a.raw, b.raw)}; 1578 #else 1579 return Mask512<T>{a.raw | b.raw}; 1580 #endif 1581 } 1582 template <typename T> 1583 HWY_API Mask512<T> Or(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, 1584 const Mask512<T> b) { 1585 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1586 return Mask512<T>{_kor_mask16(a.raw, b.raw)}; 1587 #else 1588 return Mask512<T>{static_cast<uint16_t>(a.raw | b.raw)}; 1589 #endif 1590 } 1591 template <typename T> 1592 HWY_API Mask512<T> Or(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, 1593 const Mask512<T> b) { 1594 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1595 return Mask512<T>{_kor_mask8(a.raw, b.raw)}; 1596 #else 1597 return Mask512<T>{static_cast<uint8_t>(a.raw | b.raw)}; 1598 #endif 1599 } 1600 1601 template <typename T> 1602 HWY_API Mask512<T> Xor(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, 1603 const Mask512<T> b) { 1604 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1605 return Mask512<T>{_kxor_mask64(a.raw, b.raw)}; 1606 #else 1607 return Mask512<T>{a.raw ^ b.raw}; 1608 #endif 1609 } 1610 template <typename T> 1611 HWY_API Mask512<T> Xor(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, 1612 const Mask512<T> b) { 1613 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1614 return Mask512<T>{_kxor_mask32(a.raw, b.raw)}; 1615 #else 1616 return Mask512<T>{a.raw ^ b.raw}; 1617 #endif 1618 } 1619 template <typename T> 1620 HWY_API Mask512<T> Xor(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, 1621 const Mask512<T> b) { 1622 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1623 return Mask512<T>{_kxor_mask16(a.raw, b.raw)}; 1624 #else 1625 return Mask512<T>{static_cast<uint16_t>(a.raw ^ b.raw)}; 1626 #endif 1627 } 1628 template <typename T> 1629 HWY_API Mask512<T> Xor(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, 1630 const Mask512<T> b) { 1631 #if HWY_COMPILER_HAS_MASK_INTRINSICS 1632 return Mask512<T>{_kxor_mask8(a.raw, b.raw)}; 1633 #else 1634 return Mask512<T>{static_cast<uint8_t>(a.raw ^ b.raw)}; 1635 #endif 1636 } 1637 1638 } // namespace detail 1639 1640 template <typename T> 1641 HWY_API Mask512<T> Not(const Mask512<T> m) { 1642 return detail::Not(hwy::SizeTag<sizeof(T)>(), m); 1643 } 1644 1645 template <typename T> 1646 HWY_API Mask512<T> And(const Mask512<T> a, Mask512<T> b) { 1647 return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); 1648 } 1649 1650 template <typename T> 1651 HWY_API Mask512<T> AndNot(const Mask512<T> a, Mask512<T> b) { 1652 return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); 1653 } 1654 1655 template <typename T> 1656 HWY_API Mask512<T> Or(const Mask512<T> a, Mask512<T> b) { 1657 return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); 1658 } 1659 1660 template <typename T> 1661 HWY_API Mask512<T> Xor(const Mask512<T> a, Mask512<T> b) { 1662 return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); 1663 } 1664 1665 // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) 1666 1667 HWY_API Vec512<int8_t> BroadcastSignBit(const Vec512<int8_t> v) { 1668 return VecFromMask(v < Zero(Full512<int8_t>())); 1669 } 1670 1671 HWY_API Vec512<int16_t> BroadcastSignBit(const Vec512<int16_t> v) { 1672 return ShiftRight<15>(v); 1673 } 1674 1675 HWY_API Vec512<int32_t> BroadcastSignBit(const Vec512<int32_t> v) { 1676 return ShiftRight<31>(v); 1677 } 1678 1679 HWY_API Vec512<int64_t> BroadcastSignBit(const Vec512<int64_t> v) { 1680 return Vec512<int64_t>{_mm512_srai_epi64(v.raw, 63)}; 1681 } 1682 1683 // ================================================== MEMORY 1684 1685 // ------------------------------ Load 1686 1687 template <typename T> 1688 HWY_API Vec512<T> Load(Full512<T> /* tag */, const T* HWY_RESTRICT aligned) { 1689 return Vec512<T>{ 1690 _mm512_load_si512(reinterpret_cast<const __m512i*>(aligned))}; 1691 } 1692 HWY_API Vec512<float> Load(Full512<float> /* tag */, 1693 const float* HWY_RESTRICT aligned) { 1694 return Vec512<float>{_mm512_load_ps(aligned)}; 1695 } 1696 HWY_API Vec512<double> Load(Full512<double> /* tag */, 1697 const double* HWY_RESTRICT aligned) { 1698 return Vec512<double>{_mm512_load_pd(aligned)}; 1699 } 1700 1701 template <typename T> 1702 HWY_API Vec512<T> LoadU(Full512<T> /* tag */, const T* HWY_RESTRICT p) { 1703 return Vec512<T>{_mm512_loadu_si512(reinterpret_cast<const __m512i*>(p))}; 1704 } 1705 HWY_API Vec512<float> LoadU(Full512<float> /* tag */, 1706 const float* HWY_RESTRICT p) { 1707 return Vec512<float>{_mm512_loadu_ps(p)}; 1708 } 1709 HWY_API Vec512<double> LoadU(Full512<double> /* tag */, 1710 const double* HWY_RESTRICT p) { 1711 return Vec512<double>{_mm512_loadu_pd(p)}; 1712 } 1713 1714 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the 1715 // 3-cycle cost of moving data between 128-bit halves and avoids port 5. 1716 template <typename T> 1717 HWY_API Vec512<T> LoadDup128(Full512<T> /* tag */, 1718 const T* const HWY_RESTRICT p) { 1719 // Clang 3.9 generates VINSERTF128 which is slower, but inline assembly leads 1720 // to "invalid output size for constraint" without -mavx512: 1721 // https://gcc.godbolt.org/z/-Jt_-F 1722 #if HWY_LOADDUP_ASM 1723 __m512i out; 1724 asm("vbroadcasti128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); 1725 return Vec512<T>{out}; 1726 #else 1727 const auto x4 = LoadU(Full128<T>(), p); 1728 return Vec512<T>{_mm512_broadcast_i32x4(x4.raw)}; 1729 #endif 1730 } 1731 HWY_API Vec512<float> LoadDup128(Full512<float> /* tag */, 1732 const float* const HWY_RESTRICT p) { 1733 #if HWY_LOADDUP_ASM 1734 __m512 out; 1735 asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); 1736 return Vec512<float>{out}; 1737 #else 1738 const __m128 x4 = _mm_loadu_ps(p); 1739 return Vec512<float>{_mm512_broadcast_f32x4(x4)}; 1740 #endif 1741 } 1742 1743 HWY_API Vec512<double> LoadDup128(Full512<double> /* tag */, 1744 const double* const HWY_RESTRICT p) { 1745 #if HWY_LOADDUP_ASM 1746 __m512d out; 1747 asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); 1748 return Vec512<double>{out}; 1749 #else 1750 const __m128d x2 = _mm_loadu_pd(p); 1751 return Vec512<double>{_mm512_broadcast_f64x2(x2)}; 1752 #endif 1753 } 1754 1755 // ------------------------------ Store 1756 1757 template <typename T> 1758 HWY_API void Store(const Vec512<T> v, Full512<T> /* tag */, 1759 T* HWY_RESTRICT aligned) { 1760 _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); 1761 } 1762 HWY_API void Store(const Vec512<float> v, Full512<float> /* tag */, 1763 float* HWY_RESTRICT aligned) { 1764 _mm512_store_ps(aligned, v.raw); 1765 } 1766 HWY_API void Store(const Vec512<double> v, Full512<double> /* tag */, 1767 double* HWY_RESTRICT aligned) { 1768 _mm512_store_pd(aligned, v.raw); 1769 } 1770 1771 template <typename T> 1772 HWY_API void StoreU(const Vec512<T> v, Full512<T> /* tag */, 1773 T* HWY_RESTRICT p) { 1774 _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); 1775 } 1776 HWY_API void StoreU(const Vec512<float> v, Full512<float> /* tag */, 1777 float* HWY_RESTRICT p) { 1778 _mm512_storeu_ps(p, v.raw); 1779 } 1780 HWY_API void StoreU(const Vec512<double> v, Full512<double>, 1781 double* HWY_RESTRICT p) { 1782 _mm512_storeu_pd(p, v.raw); 1783 } 1784 1785 // ------------------------------ Non-temporal stores 1786 1787 template <typename T> 1788 HWY_API void Stream(const Vec512<T> v, Full512<T> /* tag */, 1789 T* HWY_RESTRICT aligned) { 1790 _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); 1791 } 1792 HWY_API void Stream(const Vec512<float> v, Full512<float> /* tag */, 1793 float* HWY_RESTRICT aligned) { 1794 _mm512_stream_ps(aligned, v.raw); 1795 } 1796 HWY_API void Stream(const Vec512<double> v, Full512<double>, 1797 double* HWY_RESTRICT aligned) { 1798 _mm512_stream_pd(aligned, v.raw); 1799 } 1800 1801 // ------------------------------ Scatter 1802 1803 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 1804 HWY_DIAGNOSTICS(push) 1805 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 1806 1807 namespace detail { 1808 1809 template <typename T> 1810 HWY_API void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512<T> v, 1811 Full512<T> /* tag */, T* HWY_RESTRICT base, 1812 const Vec512<int32_t> offset) { 1813 _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); 1814 } 1815 template <typename T> 1816 HWY_API void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512<T> v, 1817 Full512<T> /* tag */, T* HWY_RESTRICT base, 1818 const Vec512<int32_t> index) { 1819 _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); 1820 } 1821 1822 template <typename T> 1823 HWY_API void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512<T> v, 1824 Full512<T> /* tag */, T* HWY_RESTRICT base, 1825 const Vec512<int64_t> offset) { 1826 _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); 1827 } 1828 template <typename T> 1829 HWY_API void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512<T> v, 1830 Full512<T> /* tag */, T* HWY_RESTRICT base, 1831 const Vec512<int64_t> index) { 1832 _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); 1833 } 1834 1835 } // namespace detail 1836 1837 template <typename T, typename Offset> 1838 HWY_API void ScatterOffset(Vec512<T> v, Full512<T> d, T* HWY_RESTRICT base, 1839 const Vec512<Offset> offset) { 1840 static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); 1841 return detail::ScatterOffset(hwy::SizeTag<sizeof(T)>(), v, d, base, offset); 1842 } 1843 template <typename T, typename Index> 1844 HWY_API void ScatterIndex(Vec512<T> v, Full512<T> d, T* HWY_RESTRICT base, 1845 const Vec512<Index> index) { 1846 static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); 1847 return detail::ScatterIndex(hwy::SizeTag<sizeof(T)>(), v, d, base, index); 1848 } 1849 1850 template <> 1851 HWY_INLINE void ScatterOffset<float>(Vec512<float> v, Full512<float> /* tag */, 1852 float* HWY_RESTRICT base, 1853 const Vec512<int32_t> offset) { 1854 _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); 1855 } 1856 template <> 1857 HWY_INLINE void ScatterIndex<float>(Vec512<float> v, Full512<float> /* tag */, 1858 float* HWY_RESTRICT base, 1859 const Vec512<int32_t> index) { 1860 _mm512_i32scatter_ps(base, index.raw, v.raw, 4); 1861 } 1862 1863 template <> 1864 HWY_INLINE void ScatterOffset<double>(Vec512<double> v, 1865 Full512<double> /* tag */, 1866 double* HWY_RESTRICT base, 1867 const Vec512<int64_t> offset) { 1868 _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); 1869 } 1870 template <> 1871 HWY_INLINE void ScatterIndex<double>(Vec512<double> v, 1872 Full512<double> /* tag */, 1873 double* HWY_RESTRICT base, 1874 const Vec512<int64_t> index) { 1875 _mm512_i64scatter_pd(base, index.raw, v.raw, 8); 1876 } 1877 1878 // ------------------------------ Gather 1879 1880 namespace detail { 1881 1882 template <typename T> 1883 HWY_API Vec512<T> GatherOffset(hwy::SizeTag<4> /* tag */, Full512<T> /* tag */, 1884 const T* HWY_RESTRICT base, 1885 const Vec512<int32_t> offset) { 1886 return Vec512<T>{_mm512_i32gather_epi32(offset.raw, base, 1)}; 1887 } 1888 template <typename T> 1889 HWY_API Vec512<T> GatherIndex(hwy::SizeTag<4> /* tag */, Full512<T> /* tag */, 1890 const T* HWY_RESTRICT base, 1891 const Vec512<int32_t> index) { 1892 return Vec512<T>{_mm512_i32gather_epi32(index.raw, base, 4)}; 1893 } 1894 1895 template <typename T> 1896 HWY_API Vec512<T> GatherOffset(hwy::SizeTag<8> /* tag */, Full512<T> /* tag */, 1897 const T* HWY_RESTRICT base, 1898 const Vec512<int64_t> offset) { 1899 return Vec512<T>{_mm512_i64gather_epi64(offset.raw, base, 1)}; 1900 } 1901 template <typename T> 1902 HWY_API Vec512<T> GatherIndex(hwy::SizeTag<8> /* tag */, Full512<T> /* tag */, 1903 const T* HWY_RESTRICT base, 1904 const Vec512<int64_t> index) { 1905 return Vec512<T>{_mm512_i64gather_epi64(index.raw, base, 8)}; 1906 } 1907 1908 } // namespace detail 1909 1910 template <typename T, typename Offset> 1911 HWY_API Vec512<T> GatherOffset(Full512<T> d, const T* HWY_RESTRICT base, 1912 const Vec512<Offset> offset) { 1913 static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); 1914 return detail::GatherOffset(hwy::SizeTag<sizeof(T)>(), d, base, offset); 1915 } 1916 template <typename T, typename Index> 1917 HWY_API Vec512<T> GatherIndex(Full512<T> d, const T* HWY_RESTRICT base, 1918 const Vec512<Index> index) { 1919 static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); 1920 return detail::GatherIndex(hwy::SizeTag<sizeof(T)>(), d, base, index); 1921 } 1922 1923 template <> 1924 HWY_INLINE Vec512<float> GatherOffset<float>(Full512<float> /* tag */, 1925 const float* HWY_RESTRICT base, 1926 const Vec512<int32_t> offset) { 1927 return Vec512<float>{_mm512_i32gather_ps(offset.raw, base, 1)}; 1928 } 1929 template <> 1930 HWY_INLINE Vec512<float> GatherIndex<float>(Full512<float> /* tag */, 1931 const float* HWY_RESTRICT base, 1932 const Vec512<int32_t> index) { 1933 return Vec512<float>{_mm512_i32gather_ps(index.raw, base, 4)}; 1934 } 1935 1936 template <> 1937 HWY_INLINE Vec512<double> GatherOffset<double>(Full512<double> /* tag */, 1938 const double* HWY_RESTRICT base, 1939 const Vec512<int64_t> offset) { 1940 return Vec512<double>{_mm512_i64gather_pd(offset.raw, base, 1)}; 1941 } 1942 template <> 1943 HWY_INLINE Vec512<double> GatherIndex<double>(Full512<double> /* tag */, 1944 const double* HWY_RESTRICT base, 1945 const Vec512<int64_t> index) { 1946 return Vec512<double>{_mm512_i64gather_pd(index.raw, base, 8)}; 1947 } 1948 1949 HWY_DIAGNOSTICS(pop) 1950 1951 // ================================================== SWIZZLE 1952 1953 template <typename T> 1954 HWY_API T GetLane(const Vec512<T> v) { 1955 return GetLane(LowerHalf(v)); 1956 } 1957 1958 // ------------------------------ Extract half 1959 1960 template <typename T> 1961 HWY_API Vec256<T> LowerHalf(Vec512<T> v) { 1962 return Vec256<T>{_mm512_castsi512_si256(v.raw)}; 1963 } 1964 template <> 1965 HWY_INLINE Vec256<float> LowerHalf(Vec512<float> v) { 1966 return Vec256<float>{_mm512_castps512_ps256(v.raw)}; 1967 } 1968 template <> 1969 HWY_INLINE Vec256<double> LowerHalf(Vec512<double> v) { 1970 return Vec256<double>{_mm512_castpd512_pd256(v.raw)}; 1971 } 1972 1973 template <typename T> 1974 HWY_API Vec256<T> UpperHalf(Vec512<T> v) { 1975 return Vec256<T>{_mm512_extracti32x8_epi32(v.raw, 1)}; 1976 } 1977 template <> 1978 HWY_INLINE Vec256<float> UpperHalf(Vec512<float> v) { 1979 return Vec256<float>{_mm512_extractf32x8_ps(v.raw, 1)}; 1980 } 1981 template <> 1982 HWY_INLINE Vec256<double> UpperHalf(Vec512<double> v) { 1983 return Vec256<double>{_mm512_extractf64x4_pd(v.raw, 1)}; 1984 } 1985 1986 // ------------------------------ ZeroExtendVector 1987 1988 // Unfortunately the initial _mm512_castsi256_si512 intrinsic leaves the upper 1989 // bits undefined. Although it makes sense for them to be zero (EVEX encoded 1990 // instructions have that effect), a compiler could decide to optimize out code 1991 // that relies on this. 1992 // 1993 // The newer _mm512_zextsi256_si512 intrinsic fixes this by specifying the 1994 // zeroing, but it is not available on GCC until 10.1. For older GCC, we can 1995 // still obtain the desired code thanks to pattern recognition; note that the 1996 // expensive insert instruction is not actually generated, see 1997 // https://gcc.godbolt.org/z/1MKGaP. 1998 1999 template <typename T> 2000 HWY_API Vec512<T> ZeroExtendVector(Vec256<T> lo) { 2001 #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) 2002 return Vec512<T>{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; 2003 #else 2004 return Vec512<T>{_mm512_zextsi256_si512(lo.raw)}; 2005 #endif 2006 } 2007 template <> 2008 HWY_INLINE Vec512<float> ZeroExtendVector(Vec256<float> lo) { 2009 #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) 2010 return Vec512<float>{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; 2011 #else 2012 return Vec512<float>{_mm512_zextps256_ps512(lo.raw)}; 2013 #endif 2014 } 2015 template <> 2016 HWY_INLINE Vec512<double> ZeroExtendVector(Vec256<double> lo) { 2017 #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) 2018 return Vec512<double>{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; 2019 #else 2020 return Vec512<double>{_mm512_zextpd256_pd512(lo.raw)}; 2021 #endif 2022 } 2023 2024 // ------------------------------ Combine 2025 2026 template <typename T> 2027 HWY_API Vec512<T> Combine(Vec256<T> hi, Vec256<T> lo) { 2028 const auto lo512 = ZeroExtendVector(lo); 2029 return Vec512<T>{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; 2030 } 2031 template <> 2032 HWY_INLINE Vec512<float> Combine(Vec256<float> hi, Vec256<float> lo) { 2033 const auto lo512 = ZeroExtendVector(lo); 2034 return Vec512<float>{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; 2035 } 2036 template <> 2037 HWY_INLINE Vec512<double> Combine(Vec256<double> hi, Vec256<double> lo) { 2038 const auto lo512 = ZeroExtendVector(lo); 2039 return Vec512<double>{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; 2040 } 2041 2042 // ------------------------------ Shift vector by constant #bytes 2043 2044 // 0x01..0F, kBytes = 1 => 0x02..0F00 2045 template <int kBytes, typename T> 2046 HWY_API Vec512<T> ShiftLeftBytes(const Vec512<T> v) { 2047 static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); 2048 return Vec512<T>{_mm512_bslli_epi128(v.raw, kBytes)}; 2049 } 2050 2051 template <int kLanes, typename T> 2052 HWY_API Vec512<T> ShiftLeftLanes(const Vec512<T> v) { 2053 const Full512<uint8_t> d8; 2054 const Full512<T> d; 2055 return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); 2056 } 2057 2058 // 0x01..0F, kBytes = 1 => 0x0001..0E 2059 template <int kBytes, typename T> 2060 HWY_API Vec512<T> ShiftRightBytes(const Vec512<T> v) { 2061 static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); 2062 return Vec512<T>{_mm512_bsrli_epi128(v.raw, kBytes)}; 2063 } 2064 2065 template <int kLanes, typename T> 2066 HWY_API Vec512<T> ShiftRightLanes(const Vec512<T> v) { 2067 const Full512<uint8_t> d8; 2068 const Full512<T> d; 2069 return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(BitCast(d8, v))); 2070 } 2071 2072 // ------------------------------ Extract from 2x 128-bit at constant offset 2073 2074 // Extracts 128 bits from <hi, lo> by skipping the least-significant kBytes. 2075 template <int kBytes, typename T> 2076 HWY_API Vec512<T> CombineShiftRightBytes(const Vec512<T> hi, 2077 const Vec512<T> lo) { 2078 const Full512<uint8_t> d8; 2079 const Vec512<uint8_t> extracted_bytes{ 2080 _mm512_alignr_epi8(BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}; 2081 return BitCast(Full512<T>(), extracted_bytes); 2082 } 2083 2084 // ------------------------------ Broadcast/splat any lane 2085 2086 // Unsigned 2087 template <int kLane> 2088 HWY_API Vec512<uint16_t> Broadcast(const Vec512<uint16_t> v) { 2089 static_assert(0 <= kLane && kLane < 8, "Invalid lane"); 2090 if (kLane < 4) { 2091 const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); 2092 return Vec512<uint16_t>{_mm512_unpacklo_epi64(lo, lo)}; 2093 } else { 2094 const __m512i hi = 2095 _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); 2096 return Vec512<uint16_t>{_mm512_unpackhi_epi64(hi, hi)}; 2097 } 2098 } 2099 template <int kLane> 2100 HWY_API Vec512<uint32_t> Broadcast(const Vec512<uint32_t> v) { 2101 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 2102 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); 2103 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, perm)}; 2104 } 2105 template <int kLane> 2106 HWY_API Vec512<uint64_t> Broadcast(const Vec512<uint64_t> v) { 2107 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 2108 constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; 2109 return Vec512<uint64_t>{_mm512_shuffle_epi32(v.raw, perm)}; 2110 } 2111 2112 // Signed 2113 template <int kLane> 2114 HWY_API Vec512<int16_t> Broadcast(const Vec512<int16_t> v) { 2115 static_assert(0 <= kLane && kLane < 8, "Invalid lane"); 2116 if (kLane < 4) { 2117 const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); 2118 return Vec512<int16_t>{_mm512_unpacklo_epi64(lo, lo)}; 2119 } else { 2120 const __m512i hi = 2121 _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); 2122 return Vec512<int16_t>{_mm512_unpackhi_epi64(hi, hi)}; 2123 } 2124 } 2125 template <int kLane> 2126 HWY_API Vec512<int32_t> Broadcast(const Vec512<int32_t> v) { 2127 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 2128 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); 2129 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, perm)}; 2130 } 2131 template <int kLane> 2132 HWY_API Vec512<int64_t> Broadcast(const Vec512<int64_t> v) { 2133 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 2134 constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; 2135 return Vec512<int64_t>{_mm512_shuffle_epi32(v.raw, perm)}; 2136 } 2137 2138 // Float 2139 template <int kLane> 2140 HWY_API Vec512<float> Broadcast(const Vec512<float> v) { 2141 static_assert(0 <= kLane && kLane < 4, "Invalid lane"); 2142 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); 2143 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, perm)}; 2144 } 2145 template <int kLane> 2146 HWY_API Vec512<double> Broadcast(const Vec512<double> v) { 2147 static_assert(0 <= kLane && kLane < 2, "Invalid lane"); 2148 constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); 2149 return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, perm)}; 2150 } 2151 2152 // ------------------------------ Hard-coded shuffles 2153 2154 // Notation: let Vec512<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is 2155 // least-significant). Shuffle0321 rotates four-lane blocks one lane to the 2156 // right (the previous least-significant lane is now most-significant => 2157 // 47650321). These could also be implemented via CombineShiftRightBytes but 2158 // the shuffle_abcd notation is more convenient. 2159 2160 // Swap 32-bit halves in 64-bit halves. 2161 HWY_API Vec512<uint32_t> Shuffle2301(const Vec512<uint32_t> v) { 2162 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; 2163 } 2164 HWY_API Vec512<int32_t> Shuffle2301(const Vec512<int32_t> v) { 2165 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; 2166 } 2167 HWY_API Vec512<float> Shuffle2301(const Vec512<float> v) { 2168 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; 2169 } 2170 2171 // Swap 64-bit halves 2172 HWY_API Vec512<uint32_t> Shuffle1032(const Vec512<uint32_t> v) { 2173 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 2174 } 2175 HWY_API Vec512<int32_t> Shuffle1032(const Vec512<int32_t> v) { 2176 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 2177 } 2178 HWY_API Vec512<float> Shuffle1032(const Vec512<float> v) { 2179 // Shorter encoding than _mm512_permute_ps. 2180 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; 2181 } 2182 HWY_API Vec512<uint64_t> Shuffle01(const Vec512<uint64_t> v) { 2183 return Vec512<uint64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 2184 } 2185 HWY_API Vec512<int64_t> Shuffle01(const Vec512<int64_t> v) { 2186 return Vec512<int64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; 2187 } 2188 HWY_API Vec512<double> Shuffle01(const Vec512<double> v) { 2189 // Shorter encoding than _mm512_permute_pd. 2190 return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; 2191 } 2192 2193 // Rotate right 32 bits 2194 HWY_API Vec512<uint32_t> Shuffle0321(const Vec512<uint32_t> v) { 2195 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; 2196 } 2197 HWY_API Vec512<int32_t> Shuffle0321(const Vec512<int32_t> v) { 2198 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; 2199 } 2200 HWY_API Vec512<float> Shuffle0321(const Vec512<float> v) { 2201 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; 2202 } 2203 // Rotate left 32 bits 2204 HWY_API Vec512<uint32_t> Shuffle2103(const Vec512<uint32_t> v) { 2205 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; 2206 } 2207 HWY_API Vec512<int32_t> Shuffle2103(const Vec512<int32_t> v) { 2208 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; 2209 } 2210 HWY_API Vec512<float> Shuffle2103(const Vec512<float> v) { 2211 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; 2212 } 2213 2214 // Reverse 2215 HWY_API Vec512<uint32_t> Shuffle0123(const Vec512<uint32_t> v) { 2216 return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; 2217 } 2218 HWY_API Vec512<int32_t> Shuffle0123(const Vec512<int32_t> v) { 2219 return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; 2220 } 2221 HWY_API Vec512<float> Shuffle0123(const Vec512<float> v) { 2222 return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; 2223 } 2224 2225 // ------------------------------ TableLookupLanes 2226 2227 // Returned by SetTableIndices for use by TableLookupLanes. 2228 template <typename T> 2229 struct Indices512 { 2230 __m512i raw; 2231 }; 2232 2233 template <typename T> 2234 HWY_API Indices512<T> SetTableIndices(const Full512<T>, const int32_t* idx) { 2235 #if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) 2236 const size_t N = 64 / sizeof(T); 2237 for (size_t i = 0; i < N; ++i) { 2238 HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast<int32_t>(N)); 2239 } 2240 #endif 2241 return Indices512<T>{LoadU(Full512<int32_t>(), idx).raw}; 2242 } 2243 2244 HWY_API Vec512<uint32_t> TableLookupLanes(const Vec512<uint32_t> v, 2245 const Indices512<uint32_t> idx) { 2246 return Vec512<uint32_t>{_mm512_permutexvar_epi32(idx.raw, v.raw)}; 2247 } 2248 HWY_API Vec512<int32_t> TableLookupLanes(const Vec512<int32_t> v, 2249 const Indices512<int32_t> idx) { 2250 return Vec512<int32_t>{_mm512_permutexvar_epi32(idx.raw, v.raw)}; 2251 } 2252 HWY_API Vec512<float> TableLookupLanes(const Vec512<float> v, 2253 const Indices512<float> idx) { 2254 return Vec512<float>{_mm512_permutexvar_ps(idx.raw, v.raw)}; 2255 } 2256 2257 // ------------------------------ Interleave lanes 2258 2259 // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides 2260 // the least-significant lane) and "b". To concatenate two half-width integers 2261 // into one, use ZipLower/Upper instead (also works with scalar). 2262 2263 HWY_API Vec512<uint8_t> InterleaveLower(const Vec512<uint8_t> a, 2264 const Vec512<uint8_t> b) { 2265 return Vec512<uint8_t>{_mm512_unpacklo_epi8(a.raw, b.raw)}; 2266 } 2267 HWY_API Vec512<uint16_t> InterleaveLower(const Vec512<uint16_t> a, 2268 const Vec512<uint16_t> b) { 2269 return Vec512<uint16_t>{_mm512_unpacklo_epi16(a.raw, b.raw)}; 2270 } 2271 HWY_API Vec512<uint32_t> InterleaveLower(const Vec512<uint32_t> a, 2272 const Vec512<uint32_t> b) { 2273 return Vec512<uint32_t>{_mm512_unpacklo_epi32(a.raw, b.raw)}; 2274 } 2275 HWY_API Vec512<uint64_t> InterleaveLower(const Vec512<uint64_t> a, 2276 const Vec512<uint64_t> b) { 2277 return Vec512<uint64_t>{_mm512_unpacklo_epi64(a.raw, b.raw)}; 2278 } 2279 2280 HWY_API Vec512<int8_t> InterleaveLower(const Vec512<int8_t> a, 2281 const Vec512<int8_t> b) { 2282 return Vec512<int8_t>{_mm512_unpacklo_epi8(a.raw, b.raw)}; 2283 } 2284 HWY_API Vec512<int16_t> InterleaveLower(const Vec512<int16_t> a, 2285 const Vec512<int16_t> b) { 2286 return Vec512<int16_t>{_mm512_unpacklo_epi16(a.raw, b.raw)}; 2287 } 2288 HWY_API Vec512<int32_t> InterleaveLower(const Vec512<int32_t> a, 2289 const Vec512<int32_t> b) { 2290 return Vec512<int32_t>{_mm512_unpacklo_epi32(a.raw, b.raw)}; 2291 } 2292 HWY_API Vec512<int64_t> InterleaveLower(const Vec512<int64_t> a, 2293 const Vec512<int64_t> b) { 2294 return Vec512<int64_t>{_mm512_unpacklo_epi64(a.raw, b.raw)}; 2295 } 2296 2297 HWY_API Vec512<float> InterleaveLower(const Vec512<float> a, 2298 const Vec512<float> b) { 2299 return Vec512<float>{_mm512_unpacklo_ps(a.raw, b.raw)}; 2300 } 2301 HWY_API Vec512<double> InterleaveLower(const Vec512<double> a, 2302 const Vec512<double> b) { 2303 return Vec512<double>{_mm512_unpacklo_pd(a.raw, b.raw)}; 2304 } 2305 2306 HWY_API Vec512<uint8_t> InterleaveUpper(const Vec512<uint8_t> a, 2307 const Vec512<uint8_t> b) { 2308 return Vec512<uint8_t>{_mm512_unpackhi_epi8(a.raw, b.raw)}; 2309 } 2310 HWY_API Vec512<uint16_t> InterleaveUpper(const Vec512<uint16_t> a, 2311 const Vec512<uint16_t> b) { 2312 return Vec512<uint16_t>{_mm512_unpackhi_epi16(a.raw, b.raw)}; 2313 } 2314 HWY_API Vec512<uint32_t> InterleaveUpper(const Vec512<uint32_t> a, 2315 const Vec512<uint32_t> b) { 2316 return Vec512<uint32_t>{_mm512_unpackhi_epi32(a.raw, b.raw)}; 2317 } 2318 HWY_API Vec512<uint64_t> InterleaveUpper(const Vec512<uint64_t> a, 2319 const Vec512<uint64_t> b) { 2320 return Vec512<uint64_t>{_mm512_unpackhi_epi64(a.raw, b.raw)}; 2321 } 2322 2323 HWY_API Vec512<int8_t> InterleaveUpper(const Vec512<int8_t> a, 2324 const Vec512<int8_t> b) { 2325 return Vec512<int8_t>{_mm512_unpackhi_epi8(a.raw, b.raw)}; 2326 } 2327 HWY_API Vec512<int16_t> InterleaveUpper(const Vec512<int16_t> a, 2328 const Vec512<int16_t> b) { 2329 return Vec512<int16_t>{_mm512_unpackhi_epi16(a.raw, b.raw)}; 2330 } 2331 HWY_API Vec512<int32_t> InterleaveUpper(const Vec512<int32_t> a, 2332 const Vec512<int32_t> b) { 2333 return Vec512<int32_t>{_mm512_unpackhi_epi32(a.raw, b.raw)}; 2334 } 2335 HWY_API Vec512<int64_t> InterleaveUpper(const Vec512<int64_t> a, 2336 const Vec512<int64_t> b) { 2337 return Vec512<int64_t>{_mm512_unpackhi_epi64(a.raw, b.raw)}; 2338 } 2339 2340 HWY_API Vec512<float> InterleaveUpper(const Vec512<float> a, 2341 const Vec512<float> b) { 2342 return Vec512<float>{_mm512_unpackhi_ps(a.raw, b.raw)}; 2343 } 2344 HWY_API Vec512<double> InterleaveUpper(const Vec512<double> a, 2345 const Vec512<double> b) { 2346 return Vec512<double>{_mm512_unpackhi_pd(a.raw, b.raw)}; 2347 } 2348 2349 // ------------------------------ Zip lanes 2350 2351 // Same as interleave_*, except that the return lanes are double-width integers; 2352 // this is necessary because the single-lane scalar cannot return two values. 2353 2354 HWY_API Vec512<uint16_t> ZipLower(const Vec512<uint8_t> a, 2355 const Vec512<uint8_t> b) { 2356 return Vec512<uint16_t>{_mm512_unpacklo_epi8(a.raw, b.raw)}; 2357 } 2358 HWY_API Vec512<uint32_t> ZipLower(const Vec512<uint16_t> a, 2359 const Vec512<uint16_t> b) { 2360 return Vec512<uint32_t>{_mm512_unpacklo_epi16(a.raw, b.raw)}; 2361 } 2362 HWY_API Vec512<uint64_t> ZipLower(const Vec512<uint32_t> a, 2363 const Vec512<uint32_t> b) { 2364 return Vec512<uint64_t>{_mm512_unpacklo_epi32(a.raw, b.raw)}; 2365 } 2366 2367 HWY_API Vec512<int16_t> ZipLower(const Vec512<int8_t> a, 2368 const Vec512<int8_t> b) { 2369 return Vec512<int16_t>{_mm512_unpacklo_epi8(a.raw, b.raw)}; 2370 } 2371 HWY_API Vec512<int32_t> ZipLower(const Vec512<int16_t> a, 2372 const Vec512<int16_t> b) { 2373 return Vec512<int32_t>{_mm512_unpacklo_epi16(a.raw, b.raw)}; 2374 } 2375 HWY_API Vec512<int64_t> ZipLower(const Vec512<int32_t> a, 2376 const Vec512<int32_t> b) { 2377 return Vec512<int64_t>{_mm512_unpacklo_epi32(a.raw, b.raw)}; 2378 } 2379 2380 HWY_API Vec512<uint16_t> ZipUpper(const Vec512<uint8_t> a, 2381 const Vec512<uint8_t> b) { 2382 return Vec512<uint16_t>{_mm512_unpackhi_epi8(a.raw, b.raw)}; 2383 } 2384 HWY_API Vec512<uint32_t> ZipUpper(const Vec512<uint16_t> a, 2385 const Vec512<uint16_t> b) { 2386 return Vec512<uint32_t>{_mm512_unpackhi_epi16(a.raw, b.raw)}; 2387 } 2388 HWY_API Vec512<uint64_t> ZipUpper(const Vec512<uint32_t> a, 2389 const Vec512<uint32_t> b) { 2390 return Vec512<uint64_t>{_mm512_unpackhi_epi32(a.raw, b.raw)}; 2391 } 2392 2393 HWY_API Vec512<int16_t> ZipUpper(const Vec512<int8_t> a, 2394 const Vec512<int8_t> b) { 2395 return Vec512<int16_t>{_mm512_unpackhi_epi8(a.raw, b.raw)}; 2396 } 2397 HWY_API Vec512<int32_t> ZipUpper(const Vec512<int16_t> a, 2398 const Vec512<int16_t> b) { 2399 return Vec512<int32_t>{_mm512_unpackhi_epi16(a.raw, b.raw)}; 2400 } 2401 HWY_API Vec512<int64_t> ZipUpper(const Vec512<int32_t> a, 2402 const Vec512<int32_t> b) { 2403 return Vec512<int64_t>{_mm512_unpackhi_epi32(a.raw, b.raw)}; 2404 } 2405 2406 // ------------------------------ Concat* halves 2407 2408 // hiH,hiL loH,loL |-> hiL,loL (= lower halves) 2409 template <typename T> 2410 HWY_API Vec512<T> ConcatLowerLower(const Vec512<T> hi, const Vec512<T> lo) { 2411 return Vec512<T>{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; 2412 } 2413 template <> 2414 HWY_INLINE Vec512<float> ConcatLowerLower(const Vec512<float> hi, 2415 const Vec512<float> lo) { 2416 return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; 2417 } 2418 template <> 2419 HWY_INLINE Vec512<double> ConcatLowerLower(const Vec512<double> hi, 2420 const Vec512<double> lo) { 2421 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; 2422 } 2423 2424 // hiH,hiL loH,loL |-> hiH,loH (= upper halves) 2425 template <typename T> 2426 HWY_API Vec512<T> ConcatUpperUpper(const Vec512<T> hi, const Vec512<T> lo) { 2427 return Vec512<T>{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; 2428 } 2429 template <> 2430 HWY_INLINE Vec512<float> ConcatUpperUpper(const Vec512<float> hi, 2431 const Vec512<float> lo) { 2432 return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; 2433 } 2434 template <> 2435 HWY_INLINE Vec512<double> ConcatUpperUpper(const Vec512<double> hi, 2436 const Vec512<double> lo) { 2437 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; 2438 } 2439 2440 // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) 2441 template <typename T> 2442 HWY_API Vec512<T> ConcatLowerUpper(const Vec512<T> hi, const Vec512<T> lo) { 2443 return Vec512<T>{_mm512_shuffle_i32x4(lo.raw, hi.raw, 0x4E)}; 2444 } 2445 template <> 2446 HWY_INLINE Vec512<float> ConcatLowerUpper(const Vec512<float> hi, 2447 const Vec512<float> lo) { 2448 return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, 0x4E)}; 2449 } 2450 template <> 2451 HWY_INLINE Vec512<double> ConcatLowerUpper(const Vec512<double> hi, 2452 const Vec512<double> lo) { 2453 return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, 0x4E)}; 2454 } 2455 2456 // hiH,hiL loH,loL |-> hiH,loL (= outer halves) 2457 template <typename T> 2458 HWY_API Vec512<T> ConcatUpperLower(const Vec512<T> hi, const Vec512<T> lo) { 2459 // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks 2460 // are efficiently loaded from 32-bit regs. 2461 const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); 2462 return Vec512<T>{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; 2463 } 2464 template <> 2465 HWY_INLINE Vec512<float> ConcatUpperLower(const Vec512<float> hi, 2466 const Vec512<float> lo) { 2467 const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); 2468 return Vec512<float>{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; 2469 } 2470 template <> 2471 HWY_INLINE Vec512<double> ConcatUpperLower(const Vec512<double> hi, 2472 const Vec512<double> lo) { 2473 const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); 2474 return Vec512<double>{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; 2475 } 2476 2477 // ------------------------------ Odd/even lanes 2478 2479 template <typename T> 2480 HWY_API Vec512<T> OddEven(const Vec512<T> a, const Vec512<T> b) { 2481 constexpr size_t s = sizeof(T); 2482 constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; 2483 return IfThenElse(Mask512<T>{0x5555555555555555ull >> shift}, b, a); 2484 } 2485 2486 // ------------------------------ Shuffle bytes with variable indices 2487 2488 // Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. 2489 // lane indices in [0, 16). 2490 template <typename T> 2491 HWY_API Vec512<T> TableLookupBytes(const Vec512<T> bytes, 2492 const Vec512<T> from) { 2493 return Vec512<T>{_mm512_shuffle_epi8(bytes.raw, from.raw)}; 2494 } 2495 2496 // ================================================== CONVERT 2497 2498 // ------------------------------ Promotions (part w/ narrow lanes -> full) 2499 2500 HWY_API Vec512<float> PromoteTo(Full512<float> /* tag */, 2501 const Vec256<float16_t> v) { 2502 return Vec512<float>{_mm512_cvtph_ps(v.raw)}; 2503 } 2504 2505 HWY_API Vec512<double> PromoteTo(Full512<double> /* tag */, Vec256<float> v) { 2506 return Vec512<double>{_mm512_cvtps_pd(v.raw)}; 2507 } 2508 2509 HWY_API Vec512<double> PromoteTo(Full512<double> /* tag */, Vec256<int32_t> v) { 2510 return Vec512<double>{_mm512_cvtepi32_pd(v.raw)}; 2511 } 2512 2513 // Unsigned: zero-extend. 2514 // Note: these have 3 cycle latency; if inputs are already split across the 2515 // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. 2516 HWY_API Vec512<uint16_t> PromoteTo(Full512<uint16_t> /* tag */, 2517 Vec256<uint8_t> v) { 2518 return Vec512<uint16_t>{_mm512_cvtepu8_epi16(v.raw)}; 2519 } 2520 HWY_API Vec512<uint32_t> PromoteTo(Full512<uint32_t> /* tag */, 2521 Vec128<uint8_t> v) { 2522 return Vec512<uint32_t>{_mm512_cvtepu8_epi32(v.raw)}; 2523 } 2524 HWY_API Vec512<int16_t> PromoteTo(Full512<int16_t> /* tag */, 2525 Vec256<uint8_t> v) { 2526 return Vec512<int16_t>{_mm512_cvtepu8_epi16(v.raw)}; 2527 } 2528 HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, 2529 Vec128<uint8_t> v) { 2530 return Vec512<int32_t>{_mm512_cvtepu8_epi32(v.raw)}; 2531 } 2532 HWY_API Vec512<uint32_t> PromoteTo(Full512<uint32_t> /* tag */, 2533 Vec256<uint16_t> v) { 2534 return Vec512<uint32_t>{_mm512_cvtepu16_epi32(v.raw)}; 2535 } 2536 HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, 2537 Vec256<uint16_t> v) { 2538 return Vec512<int32_t>{_mm512_cvtepu16_epi32(v.raw)}; 2539 } 2540 HWY_API Vec512<uint64_t> PromoteTo(Full512<uint64_t> /* tag */, 2541 Vec256<uint32_t> v) { 2542 return Vec512<uint64_t>{_mm512_cvtepu32_epi64(v.raw)}; 2543 } 2544 2545 // Signed: replicate sign bit. 2546 // Note: these have 3 cycle latency; if inputs are already split across the 2547 // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by 2548 // signed shift would be faster. 2549 HWY_API Vec512<int16_t> PromoteTo(Full512<int16_t> /* tag */, 2550 Vec256<int8_t> v) { 2551 return Vec512<int16_t>{_mm512_cvtepi8_epi16(v.raw)}; 2552 } 2553 HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, 2554 Vec128<int8_t> v) { 2555 return Vec512<int32_t>{_mm512_cvtepi8_epi32(v.raw)}; 2556 } 2557 HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, 2558 Vec256<int16_t> v) { 2559 return Vec512<int32_t>{_mm512_cvtepi16_epi32(v.raw)}; 2560 } 2561 HWY_API Vec512<int64_t> PromoteTo(Full512<int64_t> /* tag */, 2562 Vec256<int32_t> v) { 2563 return Vec512<int64_t>{_mm512_cvtepi32_epi64(v.raw)}; 2564 } 2565 2566 // ------------------------------ Demotions (full -> part w/ narrow lanes) 2567 2568 HWY_API Vec256<uint16_t> DemoteTo(Full256<uint16_t> /* tag */, 2569 const Vec512<int32_t> v) { 2570 const Vec512<uint16_t> u16{_mm512_packus_epi32(v.raw, v.raw)}; 2571 2572 // Compress even u64 lanes into 256 bit. 2573 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 2574 const auto idx64 = Load(Full512<uint64_t>(), kLanes); 2575 const Vec512<uint16_t> even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; 2576 return LowerHalf(even); 2577 } 2578 2579 HWY_API Vec256<int16_t> DemoteTo(Full256<int16_t> /* tag */, 2580 const Vec512<int32_t> v) { 2581 const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; 2582 2583 // Compress even u64 lanes into 256 bit. 2584 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 2585 const auto idx64 = Load(Full512<uint64_t>(), kLanes); 2586 const Vec512<int16_t> even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; 2587 return LowerHalf(even); 2588 } 2589 2590 HWY_API Vec128<uint8_t, 16> DemoteTo(Full128<uint8_t> /* tag */, 2591 const Vec512<int32_t> v) { 2592 const Vec512<uint16_t> u16{_mm512_packus_epi32(v.raw, v.raw)}; 2593 // packus treats the input as signed; we want unsigned. Clear the MSB to get 2594 // unsigned saturation to u8. 2595 const Vec512<int16_t> i16{ 2596 _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; 2597 const Vec512<uint8_t> u8{_mm512_packus_epi16(i16.raw, i16.raw)}; 2598 2599 alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; 2600 const auto idx32 = LoadDup128(Full512<uint32_t>(), kLanes); 2601 const Vec512<uint8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; 2602 return LowerHalf(LowerHalf(fixed)); 2603 } 2604 2605 HWY_API Vec256<uint8_t> DemoteTo(Full256<uint8_t> /* tag */, 2606 const Vec512<int16_t> v) { 2607 const Vec512<uint8_t> u8{_mm512_packus_epi16(v.raw, v.raw)}; 2608 2609 // Compress even u64 lanes into 256 bit. 2610 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 2611 const auto idx64 = Load(Full512<uint64_t>(), kLanes); 2612 const Vec512<uint8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; 2613 return LowerHalf(even); 2614 } 2615 2616 HWY_API Vec128<int8_t, 16> DemoteTo(Full128<int8_t> /* tag */, 2617 const Vec512<int32_t> v) { 2618 const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; 2619 const Vec512<int8_t> i8{_mm512_packs_epi16(i16.raw, i16.raw)}; 2620 2621 alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, 2622 0, 4, 8, 12, 0, 4, 8, 12}; 2623 const auto idx32 = LoadDup128(Full512<uint32_t>(), kLanes); 2624 const Vec512<int8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; 2625 return LowerHalf(LowerHalf(fixed)); 2626 } 2627 2628 HWY_API Vec256<int8_t> DemoteTo(Full256<int8_t> /* tag */, 2629 const Vec512<int16_t> v) { 2630 const Vec512<int8_t> u8{_mm512_packs_epi16(v.raw, v.raw)}; 2631 2632 // Compress even u64 lanes into 256 bit. 2633 alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; 2634 const auto idx64 = Load(Full512<uint64_t>(), kLanes); 2635 const Vec512<int8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; 2636 return LowerHalf(even); 2637 } 2638 2639 HWY_API Vec256<float16_t> DemoteTo(Full256<float16_t> /* tag */, 2640 const Vec512<float> v) { 2641 // Work around warnings in the intrinsic definitions (passing -1 as a mask). 2642 HWY_DIAGNOSTICS(push) 2643 HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") 2644 return Vec256<float16_t>{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; 2645 HWY_DIAGNOSTICS(pop) 2646 } 2647 2648 HWY_API Vec256<float> DemoteTo(Full256<float> /* tag */, 2649 const Vec512<double> v) { 2650 return Vec256<float>{_mm512_cvtpd_ps(v.raw)}; 2651 } 2652 2653 HWY_API Vec256<int32_t> DemoteTo(Full256<int32_t> /* tag */, 2654 const Vec512<double> v) { 2655 const auto clamped = detail::ClampF64ToI32Max(Full512<double>(), v); 2656 return Vec256<int32_t>{_mm512_cvttpd_epi32(clamped.raw)}; 2657 } 2658 2659 // For already range-limited input [0, 255]. 2660 HWY_API Vec128<uint8_t, 16> U8FromU32(const Vec512<uint32_t> v) { 2661 const Full512<uint32_t> d32; 2662 // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the 2663 // lowest 4 bytes. 2664 alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, 2665 ~0u}; 2666 const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); 2667 // Gather the lowest 4 bytes of 4 128-bit blocks. 2668 alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; 2669 const Vec512<uint8_t> bytes{ 2670 _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; 2671 return LowerHalf(LowerHalf(bytes)); 2672 } 2673 2674 // ------------------------------ Convert integer <=> floating point 2675 2676 HWY_API Vec512<float> ConvertTo(Full512<float> /* tag */, 2677 const Vec512<int32_t> v) { 2678 return Vec512<float>{_mm512_cvtepi32_ps(v.raw)}; 2679 } 2680 2681 HWY_API Vec512<double> ConvertTo(Full512<double> /* tag */, 2682 const Vec512<int64_t> v) { 2683 return Vec512<double>{_mm512_cvtepi64_pd(v.raw)}; 2684 } 2685 2686 // Truncates (rounds toward zero). 2687 HWY_API Vec512<int32_t> ConvertTo(Full512<int32_t> d, const Vec512<float> v) { 2688 return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); 2689 } 2690 HWY_API Vec512<int64_t> ConvertTo(Full512<int64_t> di, const Vec512<double> v) { 2691 return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); 2692 } 2693 2694 HWY_API Vec512<int32_t> NearestInt(const Vec512<float> v) { 2695 const Full512<int32_t> di; 2696 return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); 2697 } 2698 2699 // ================================================== MISC 2700 2701 // Returns a vector with lane i=[0, N) set to "first" + i. 2702 template <typename T, typename T2> 2703 Vec512<T> Iota(const Full512<T> d, const T2 first) { 2704 HWY_ALIGN T lanes[64 / sizeof(T)]; 2705 for (size_t i = 0; i < 64 / sizeof(T); ++i) { 2706 lanes[i] = static_cast<T>(first + static_cast<T2>(i)); 2707 } 2708 return Load(d, lanes); 2709 } 2710 2711 // ------------------------------ Mask 2712 2713 // Beware: the suffix indicates the number of mask bits, not lane size! 2714 2715 namespace detail { 2716 2717 template <typename T> 2718 HWY_API bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512<T> v) { 2719 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2720 return _kortestz_mask64_u8(v.raw, v.raw); 2721 #else 2722 return v.raw == 0; 2723 #endif 2724 } 2725 template <typename T> 2726 HWY_API bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512<T> v) { 2727 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2728 return _kortestz_mask32_u8(v.raw, v.raw); 2729 #else 2730 return v.raw == 0; 2731 #endif 2732 } 2733 template <typename T> 2734 HWY_API bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512<T> v) { 2735 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2736 return _kortestz_mask16_u8(v.raw, v.raw); 2737 #else 2738 return v.raw == 0; 2739 #endif 2740 } 2741 template <typename T> 2742 HWY_API bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512<T> v) { 2743 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2744 return _kortestz_mask8_u8(v.raw, v.raw); 2745 #else 2746 return v.raw == 0; 2747 #endif 2748 } 2749 2750 } // namespace detail 2751 2752 template <typename T> 2753 HWY_API bool AllFalse(const Mask512<T> v) { 2754 return detail::AllFalse(hwy::SizeTag<sizeof(T)>(), v); 2755 } 2756 2757 namespace detail { 2758 2759 template <typename T> 2760 HWY_API bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512<T> v) { 2761 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2762 return _kortestc_mask64_u8(v.raw, v.raw); 2763 #else 2764 return v.raw == 0xFFFFFFFFFFFFFFFFull; 2765 #endif 2766 } 2767 template <typename T> 2768 HWY_API bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512<T> v) { 2769 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2770 return _kortestc_mask32_u8(v.raw, v.raw); 2771 #else 2772 return v.raw == 0xFFFFFFFFull; 2773 #endif 2774 } 2775 template <typename T> 2776 HWY_API bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512<T> v) { 2777 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2778 return _kortestc_mask16_u8(v.raw, v.raw); 2779 #else 2780 return v.raw == 0xFFFFull; 2781 #endif 2782 } 2783 template <typename T> 2784 HWY_API bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512<T> v) { 2785 #if HWY_COMPILER_HAS_MASK_INTRINSICS 2786 return _kortestc_mask8_u8(v.raw, v.raw); 2787 #else 2788 return v.raw == 0xFFull; 2789 #endif 2790 } 2791 2792 } // namespace detail 2793 2794 template <typename T> 2795 HWY_API bool AllTrue(const Mask512<T> v) { 2796 return detail::AllTrue(hwy::SizeTag<sizeof(T)>(), v); 2797 } 2798 2799 template <typename T> 2800 HWY_INLINE size_t StoreMaskBits(const Mask512<T> mask, uint8_t* p) { 2801 const size_t kNumBytes = 8 / sizeof(T); 2802 CopyBytes<kNumBytes>(&mask.raw, p); 2803 return kNumBytes; 2804 } 2805 2806 template <typename T> 2807 HWY_API size_t CountTrue(const Mask512<T> mask) { 2808 return PopCount(mask.raw); 2809 } 2810 2811 // ------------------------------ Compress 2812 2813 HWY_API Vec512<uint32_t> Compress(Vec512<uint32_t> v, 2814 const Mask512<uint32_t> mask) { 2815 return Vec512<uint32_t>{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; 2816 } 2817 HWY_API Vec512<int32_t> Compress(Vec512<int32_t> v, 2818 const Mask512<int32_t> mask) { 2819 return Vec512<int32_t>{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; 2820 } 2821 2822 HWY_API Vec512<uint64_t> Compress(Vec512<uint64_t> v, 2823 const Mask512<uint64_t> mask) { 2824 return Vec512<uint64_t>{_mm512_maskz_compress_epi64(mask.raw, v.raw)}; 2825 } 2826 HWY_API Vec512<int64_t> Compress(Vec512<int64_t> v, 2827 const Mask512<int64_t> mask) { 2828 return Vec512<int64_t>{_mm512_maskz_compress_epi64(mask.raw, v.raw)}; 2829 } 2830 2831 HWY_API Vec512<float> Compress(Vec512<float> v, const Mask512<float> mask) { 2832 return Vec512<float>{_mm512_maskz_compress_ps(mask.raw, v.raw)}; 2833 } 2834 2835 HWY_API Vec512<double> Compress(Vec512<double> v, const Mask512<double> mask) { 2836 return Vec512<double>{_mm512_maskz_compress_pd(mask.raw, v.raw)}; 2837 } 2838 2839 namespace detail { 2840 2841 // Ignore IDE redefinition error for these two functions: if this header is 2842 // included, then the functions weren't actually defined in x86_256-inl.h. 2843 template <typename T> 2844 HWY_API Vec256<T> Compress(hwy::SizeTag<2> /*tag*/, Vec256<T> v, 2845 const uint64_t mask_bits) { 2846 using D = Full256<T>; 2847 const Rebind<uint16_t, D> du; 2848 const Rebind<int32_t, D> dw; // 512-bit, not 256! 2849 const auto vu16 = BitCast(du, v); // (required for float16_t inputs) 2850 const Mask512<int32_t> mask{static_cast<__mmask16>(mask_bits)}; 2851 return BitCast(D(), DemoteTo(du, Compress(PromoteTo(dw, vu16), mask))); 2852 } 2853 2854 } // namespace detail 2855 2856 template <typename T> 2857 HWY_API Vec256<T> Compress(Vec256<T> v, const Mask256<T> mask) { 2858 return detail::Compress(hwy::SizeTag<sizeof(T)>(), v, 2859 detail::BitsFromMask(mask)); 2860 } 2861 2862 // Expands to 32-bit, compresses, concatenate demoted halves. 2863 template <typename T, HWY_IF_LANE_SIZE(T, 2)> 2864 HWY_API Vec512<T> Compress(Vec512<T> v, const Mask512<T> mask) { 2865 using D = Full512<T>; 2866 const Rebind<uint16_t, D> du; 2867 const Repartition<int32_t, D> dw; 2868 const auto vu16 = BitCast(du, v); // (required for float16_t inputs) 2869 const auto promoted0 = PromoteTo(dw, LowerHalf(vu16)); 2870 const auto promoted1 = PromoteTo(dw, UpperHalf(vu16)); 2871 2872 const Mask512<int32_t> mask0{static_cast<__mmask16>(mask.raw & 0xFFFF)}; 2873 const Mask512<int32_t> mask1{static_cast<__mmask16>(mask.raw >> 16)}; 2874 const auto compressed0 = Compress(promoted0, mask0); 2875 const auto compressed1 = Compress(promoted1, mask1); 2876 2877 const Half<decltype(du)> dh; 2878 const auto demoted0 = ZeroExtendVector(DemoteTo(dh, compressed0)); 2879 const auto demoted1 = ZeroExtendVector(DemoteTo(dh, compressed1)); 2880 2881 // Concatenate into single vector by shifting upper with writemask. 2882 const size_t num0 = CountTrue(mask0); 2883 const __mmask32 m_upper = ~((1u << num0) - 1); 2884 alignas(64) uint16_t iota[64] = { 2885 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2886 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2887 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 2888 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; 2889 const auto idx = LoadU(du, iota + 32 - num0); 2890 return Vec512<T>{_mm512_mask_permutexvar_epi16(demoted0.raw, m_upper, idx.raw, 2891 demoted1.raw)}; 2892 } 2893 2894 // ------------------------------ CompressStore 2895 2896 template <typename T> 2897 HWY_API size_t CompressStore(Vec256<T> v, const Mask256<T> mask, Full256<T> d, 2898 T* HWY_RESTRICT aligned) { 2899 const uint64_t mask_bits = detail::BitsFromMask(mask); 2900 Store(detail::Compress(hwy::SizeTag<sizeof(T)>(), v, mask_bits), d, aligned); 2901 return PopCount(mask_bits); 2902 } 2903 2904 template <typename T, HWY_IF_LANE_SIZE(T, 2)> 2905 HWY_API size_t CompressStore(Vec512<T> v, const Mask512<T> mask, Full512<T> d, 2906 T* HWY_RESTRICT aligned) { 2907 // NOTE: it is tempting to split inputs into two halves for 16-bit lanes, but 2908 // using StoreU to concatenate the results would cause page faults if 2909 // `aligned` is the last valid vector. Instead rely on in-register splicing. 2910 Store(Compress(v, mask), d, aligned); 2911 return CountTrue(mask); 2912 } 2913 2914 HWY_API size_t CompressStore(Vec512<uint32_t> v, const Mask512<uint32_t> mask, 2915 Full512<uint32_t> /* tag */, 2916 uint32_t* HWY_RESTRICT aligned) { 2917 _mm512_mask_compressstoreu_epi32(aligned, mask.raw, v.raw); 2918 return CountTrue(mask); 2919 } 2920 HWY_API size_t CompressStore(Vec512<int32_t> v, const Mask512<int32_t> mask, 2921 Full512<int32_t> /* tag */, 2922 int32_t* HWY_RESTRICT aligned) { 2923 _mm512_mask_compressstoreu_epi32(aligned, mask.raw, v.raw); 2924 return CountTrue(mask); 2925 } 2926 2927 HWY_API size_t CompressStore(Vec512<uint64_t> v, const Mask512<uint64_t> mask, 2928 Full512<uint64_t> /* tag */, 2929 uint64_t* HWY_RESTRICT aligned) { 2930 _mm512_mask_compressstoreu_epi64(aligned, mask.raw, v.raw); 2931 return CountTrue(mask); 2932 } 2933 HWY_API size_t CompressStore(Vec512<int64_t> v, const Mask512<int64_t> mask, 2934 Full512<int64_t> /* tag */, 2935 int64_t* HWY_RESTRICT aligned) { 2936 _mm512_mask_compressstoreu_epi64(aligned, mask.raw, v.raw); 2937 return CountTrue(mask); 2938 } 2939 2940 HWY_API size_t CompressStore(Vec512<float> v, const Mask512<float> mask, 2941 Full512<float> /* tag */, 2942 float* HWY_RESTRICT aligned) { 2943 _mm512_mask_compressstoreu_ps(aligned, mask.raw, v.raw); 2944 return CountTrue(mask); 2945 } 2946 2947 HWY_API size_t CompressStore(Vec512<double> v, const Mask512<double> mask, 2948 Full512<double> /* tag */, 2949 double* HWY_RESTRICT aligned) { 2950 _mm512_mask_compressstoreu_pd(aligned, mask.raw, v.raw); 2951 return CountTrue(mask); 2952 } 2953 2954 // ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, 2955 // TableLookupBytes) 2956 2957 HWY_API void StoreInterleaved3(const Vec512<uint8_t> a, const Vec512<uint8_t> b, 2958 const Vec512<uint8_t> c, Full512<uint8_t> d, 2959 uint8_t* HWY_RESTRICT unaligned) { 2960 const auto k5 = Set(d, 5); 2961 const auto k6 = Set(d, 6); 2962 2963 // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. 2964 // 0x80 so lanes to be filled from other vectors are 0 for blending. 2965 alignas(16) static constexpr uint8_t tbl_r0[16] = { 2966 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // 2967 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; 2968 alignas(16) static constexpr uint8_t tbl_g0[16] = { 2969 0x80, 0, 0x80, 0x80, 1, 0x80, // 2970 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; 2971 const auto shuf_r0 = LoadDup128(d, tbl_r0); 2972 const auto shuf_g0 = LoadDup128(d, tbl_g0); // cannot reuse r0 due to 5 2973 const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); 2974 const auto r0 = TableLookupBytes(a, shuf_r0); // 5..4..3..2..1..0 2975 const auto g0 = TableLookupBytes(b, shuf_g0); // ..4..3..2..1..0. 2976 const auto b0 = TableLookupBytes(c, shuf_b0); // .4..3..2..1..0.. 2977 const auto i = (r0 | g0 | b0).raw; // low byte in each 128bit: 30 20 10 00 2978 2979 // Second vector: g10,r10, bgr[9:6], b5,g5 2980 const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. 2981 const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 2982 const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. 2983 const auto r1 = TableLookupBytes(a, shuf_r1); 2984 const auto g1 = TableLookupBytes(b, shuf_g1); 2985 const auto b1 = TableLookupBytes(c, shuf_b1); 2986 const auto j = (r1 | g1 | b1).raw; // low byte in each 128bit: 35 25 15 05 2987 2988 // Third vector: bgr[15:11], b10 2989 const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. 2990 const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. 2991 const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A 2992 const auto r2 = TableLookupBytes(a, shuf_r2); 2993 const auto g2 = TableLookupBytes(b, shuf_g2); 2994 const auto b2 = TableLookupBytes(c, shuf_b2); 2995 const auto k = (r2 | g2 | b2).raw; // low byte in each 128bit: 3A 2A 1A 0A 2996 2997 // To obtain 10 0A 05 00 in one vector, transpose "rows" into "columns". 2998 const auto k3_k0_i3_i0 = _mm512_shuffle_i64x2(i, k, _MM_SHUFFLE(3, 0, 3, 0)); 2999 const auto i1_i2_j0_j1 = _mm512_shuffle_i64x2(j, i, _MM_SHUFFLE(1, 2, 0, 1)); 3000 const auto j2_j3_k1_k2 = _mm512_shuffle_i64x2(k, j, _MM_SHUFFLE(2, 3, 1, 2)); 3001 3002 // Alternating order, most-significant 128 bits from the second arg. 3003 const __mmask8 m = 0xCC; 3004 const auto i1_k0_j0_i0 = _mm512_mask_blend_epi64(m, k3_k0_i3_i0, i1_i2_j0_j1); 3005 const auto j2_i2_k1_j1 = _mm512_mask_blend_epi64(m, i1_i2_j0_j1, j2_j3_k1_k2); 3006 const auto k3_j3_i3_k2 = _mm512_mask_blend_epi64(m, j2_j3_k1_k2, k3_k0_i3_i0); 3007 3008 StoreU(Vec512<uint8_t>{i1_k0_j0_i0}, d, unaligned + 0 * 64); // 10 0A 05 00 3009 StoreU(Vec512<uint8_t>{j2_i2_k1_j1}, d, unaligned + 1 * 64); // 25 20 1A 15 3010 StoreU(Vec512<uint8_t>{k3_j3_i3_k2}, d, unaligned + 2 * 64); // 3A 35 30 2A 3011 } 3012 3013 // ------------------------------ StoreInterleaved4 3014 3015 HWY_API void StoreInterleaved4(const Vec512<uint8_t> v0, 3016 const Vec512<uint8_t> v1, 3017 const Vec512<uint8_t> v2, 3018 const Vec512<uint8_t> v3, Full512<uint8_t> d, 3019 uint8_t* HWY_RESTRICT unaligned) { 3020 // let a,b,c,d denote v0..3. 3021 const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 3022 const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 3023 const auto ba8 = ZipUpper(v0, v1); 3024 const auto dc8 = ZipUpper(v2, v3); 3025 const auto i = ZipLower(ba0, dc0).raw; // 4x128bit: d..a3 d..a0 3026 const auto j = ZipUpper(ba0, dc0).raw; // 4x128bit: d..a7 d..a4 3027 const auto k = ZipLower(ba8, dc8).raw; // 4x128bit: d..aB d..a8 3028 const auto l = ZipUpper(ba8, dc8).raw; // 4x128bit: d..aF d..aC 3029 // 128-bit blocks were independent until now; transpose 4x4. 3030 const auto j1_j0_i1_i0 = _mm512_shuffle_i64x2(i, j, _MM_SHUFFLE(1, 0, 1, 0)); 3031 const auto l1_l0_k1_k0 = _mm512_shuffle_i64x2(k, l, _MM_SHUFFLE(1, 0, 1, 0)); 3032 const auto j3_j2_i3_i2 = _mm512_shuffle_i64x2(i, j, _MM_SHUFFLE(3, 2, 3, 2)); 3033 const auto l3_l2_k3_k2 = _mm512_shuffle_i64x2(k, l, _MM_SHUFFLE(3, 2, 3, 2)); 3034 constexpr int k20 = _MM_SHUFFLE(2, 0, 2, 0); 3035 constexpr int k31 = _MM_SHUFFLE(3, 1, 3, 1); 3036 const auto l0_k0_j0_i0 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k20); 3037 const auto l1_k1_j1_i1 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k31); 3038 const auto l2_k2_j2_i2 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k20); 3039 const auto l3_k3_j3_i3 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k31); 3040 StoreU(Vec512<uint8_t>{l0_k0_j0_i0}, d, unaligned + 0 * 64); 3041 StoreU(Vec512<uint8_t>{l1_k1_j1_i1}, d, unaligned + 1 * 64); 3042 StoreU(Vec512<uint8_t>{l2_k2_j2_i2}, d, unaligned + 2 * 64); 3043 StoreU(Vec512<uint8_t>{l3_k3_j3_i3}, d, unaligned + 3 * 64); 3044 } 3045 3046 // ------------------------------ Reductions 3047 3048 // Returns the sum in each lane. 3049 HWY_API Vec512<int32_t> SumOfLanes(const Vec512<int32_t> v) { 3050 return Set(Full512<int32_t>(), _mm512_reduce_add_epi32(v.raw)); 3051 } 3052 HWY_API Vec512<int64_t> SumOfLanes(const Vec512<int64_t> v) { 3053 return Set(Full512<int64_t>(), _mm512_reduce_add_epi64(v.raw)); 3054 } 3055 HWY_API Vec512<uint32_t> SumOfLanes(const Vec512<uint32_t> v) { 3056 return BitCast(Full512<uint32_t>(), 3057 SumOfLanes(BitCast(Full512<int32_t>(), v))); 3058 } 3059 HWY_API Vec512<uint64_t> SumOfLanes(const Vec512<uint64_t> v) { 3060 return BitCast(Full512<uint64_t>(), 3061 SumOfLanes(BitCast(Full512<int64_t>(), v))); 3062 } 3063 HWY_API Vec512<float> SumOfLanes(const Vec512<float> v) { 3064 return Set(Full512<float>(), _mm512_reduce_add_ps(v.raw)); 3065 } 3066 HWY_API Vec512<double> SumOfLanes(const Vec512<double> v) { 3067 return Set(Full512<double>(), _mm512_reduce_add_pd(v.raw)); 3068 } 3069 3070 // Returns the minimum in each lane. 3071 HWY_API Vec512<int32_t> MinOfLanes(const Vec512<int32_t> v) { 3072 return Set(Full512<int32_t>(), _mm512_reduce_min_epi32(v.raw)); 3073 } 3074 HWY_API Vec512<int64_t> MinOfLanes(const Vec512<int64_t> v) { 3075 return Set(Full512<int64_t>(), _mm512_reduce_min_epi64(v.raw)); 3076 } 3077 HWY_API Vec512<uint32_t> MinOfLanes(const Vec512<uint32_t> v) { 3078 return Set(Full512<uint32_t>(), _mm512_reduce_min_epu32(v.raw)); 3079 } 3080 HWY_API Vec512<uint64_t> MinOfLanes(const Vec512<uint64_t> v) { 3081 return Set(Full512<uint64_t>(), _mm512_reduce_min_epu64(v.raw)); 3082 } 3083 HWY_API Vec512<float> MinOfLanes(const Vec512<float> v) { 3084 return Set(Full512<float>(), _mm512_reduce_min_ps(v.raw)); 3085 } 3086 HWY_API Vec512<double> MinOfLanes(const Vec512<double> v) { 3087 return Set(Full512<double>(), _mm512_reduce_min_pd(v.raw)); 3088 } 3089 3090 // Returns the maximum in each lane. 3091 HWY_API Vec512<int32_t> MaxOfLanes(const Vec512<int32_t> v) { 3092 return Set(Full512<int32_t>(), _mm512_reduce_max_epi32(v.raw)); 3093 } 3094 HWY_API Vec512<int64_t> MaxOfLanes(const Vec512<int64_t> v) { 3095 return Set(Full512<int64_t>(), _mm512_reduce_max_epi64(v.raw)); 3096 } 3097 HWY_API Vec512<uint32_t> MaxOfLanes(const Vec512<uint32_t> v) { 3098 return Set(Full512<uint32_t>(), _mm512_reduce_max_epu32(v.raw)); 3099 } 3100 HWY_API Vec512<uint64_t> MaxOfLanes(const Vec512<uint64_t> v) { 3101 return Set(Full512<uint64_t>(), _mm512_reduce_max_epu64(v.raw)); 3102 } 3103 HWY_API Vec512<float> MaxOfLanes(const Vec512<float> v) { 3104 return Set(Full512<float>(), _mm512_reduce_max_ps(v.raw)); 3105 } 3106 HWY_API Vec512<double> MaxOfLanes(const Vec512<double> v) { 3107 return Set(Full512<double>(), _mm512_reduce_max_pd(v.raw)); 3108 } 3109 3110 // NOLINTNEXTLINE(google-readability-namespace-comments) 3111 } // namespace HWY_NAMESPACE 3112 } // namespace hwy 3113 HWY_AFTER_NAMESPACE(); 3114