1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2016 Benoit Steiner (benoit.steiner.goog@gmail.com) 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_PACKET_MATH_AVX512_H 11 #define EIGEN_PACKET_MATH_AVX512_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 18 #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 19 #endif 20 21 #ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 22 #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 23 #endif 24 25 #ifdef EIGEN_VECTORIZE_FMA 26 #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD 27 #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD 28 #endif 29 #endif 30 31 typedef __m512 Packet16f; 32 typedef __m512i Packet16i; 33 typedef __m512d Packet8d; 34 35 template <> 36 struct is_arithmetic<__m512> { 37 enum { value = true }; 38 }; 39 template <> 40 struct is_arithmetic<__m512i> { 41 enum { value = true }; 42 }; 43 template <> 44 struct is_arithmetic<__m512d> { 45 enum { value = true }; 46 }; 47 48 template<> struct packet_traits<float> : default_packet_traits 49 { 50 typedef Packet16f type; 51 typedef Packet8f half; 52 enum { 53 Vectorizable = 1, 54 AlignedOnScalar = 1, 55 size = 16, 56 HasHalfPacket = 1, 57 HasBlend = 0, 58 #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) 59 #ifdef EIGEN_VECTORIZE_AVX512DQ 60 HasLog = 1, 61 #endif 62 HasExp = 1, 63 HasSqrt = EIGEN_FAST_MATH, 64 HasRsqrt = EIGEN_FAST_MATH, 65 #endif 66 HasDiv = 1 67 }; 68 }; 69 template<> struct packet_traits<double> : default_packet_traits 70 { 71 typedef Packet8d type; 72 typedef Packet4d half; 73 enum { 74 Vectorizable = 1, 75 AlignedOnScalar = 1, 76 size = 8, 77 HasHalfPacket = 1, 78 #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) 79 HasSqrt = EIGEN_FAST_MATH, 80 HasRsqrt = EIGEN_FAST_MATH, 81 #endif 82 HasDiv = 1 83 }; 84 }; 85 86 /* TODO Implement AVX512 for integers 87 template<> struct packet_traits<int> : default_packet_traits 88 { 89 typedef Packet16i type; 90 enum { 91 Vectorizable = 1, 92 AlignedOnScalar = 1, 93 size=8 94 }; 95 }; 96 */ 97 98 template <> 99 struct unpacket_traits<Packet16f> { 100 typedef float type; 101 typedef Packet8f half; 102 typedef Packet16i integer_packet; 103 enum { size = 16, alignment=Aligned64 }; 104 }; 105 template <> 106 struct unpacket_traits<Packet8d> { 107 typedef double type; 108 typedef Packet4d half; 109 enum { size = 8, alignment=Aligned64 }; 110 }; 111 template <> 112 struct unpacket_traits<Packet16i> { 113 typedef int type; 114 typedef Packet8i half; 115 enum { size = 16, alignment=Aligned64 }; 116 }; 117 118 template <> 119 EIGEN_STRONG_INLINE Packet16f pset1<Packet16f>(const float& from) { 120 return _mm512_set1_ps(from); 121 } 122 template <> 123 EIGEN_STRONG_INLINE Packet8d pset1<Packet8d>(const double& from) { 124 return _mm512_set1_pd(from); 125 } 126 template <> 127 EIGEN_STRONG_INLINE Packet16i pset1<Packet16i>(const int& from) { 128 return _mm512_set1_epi32(from); 129 } 130 131 template <> 132 EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) { 133 return _mm512_broadcastss_ps(_mm_load_ps1(from)); 134 } 135 template <> 136 EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) { 137 return _mm512_set1_pd(*from); 138 } 139 140 template <> 141 EIGEN_STRONG_INLINE Packet16f plset<Packet16f>(const float& a) { 142 return _mm512_add_ps( 143 _mm512_set1_ps(a), 144 _mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 145 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)); 146 } 147 template <> 148 EIGEN_STRONG_INLINE Packet8d plset<Packet8d>(const double& a) { 149 return _mm512_add_pd(_mm512_set1_pd(a), 150 _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)); 151 } 152 153 template <> 154 EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a, 155 const Packet16f& b) { 156 return _mm512_add_ps(a, b); 157 } 158 template <> 159 EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a, 160 const Packet8d& b) { 161 return _mm512_add_pd(a, b); 162 } 163 template <> 164 EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a, 165 const Packet16i& b) { 166 return _mm512_add_epi32(a, b); 167 } 168 169 template <> 170 EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a, 171 const Packet16f& b) { 172 return _mm512_sub_ps(a, b); 173 } 174 template <> 175 EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a, 176 const Packet8d& b) { 177 return _mm512_sub_pd(a, b); 178 } 179 template <> 180 EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a, 181 const Packet16i& b) { 182 return _mm512_sub_epi32(a, b); 183 } 184 185 template <> 186 EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) { 187 return _mm512_sub_ps(_mm512_set1_ps(0.0), a); 188 } 189 template <> 190 EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) { 191 return _mm512_sub_pd(_mm512_set1_pd(0.0), a); 192 } 193 194 template <> 195 EIGEN_STRONG_INLINE Packet16f pconj(const Packet16f& a) { 196 return a; 197 } 198 template <> 199 EIGEN_STRONG_INLINE Packet8d pconj(const Packet8d& a) { 200 return a; 201 } 202 template <> 203 EIGEN_STRONG_INLINE Packet16i pconj(const Packet16i& a) { 204 return a; 205 } 206 207 template <> 208 EIGEN_STRONG_INLINE Packet16f pmul<Packet16f>(const Packet16f& a, 209 const Packet16f& b) { 210 return _mm512_mul_ps(a, b); 211 } 212 template <> 213 EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a, 214 const Packet8d& b) { 215 return _mm512_mul_pd(a, b); 216 } 217 template <> 218 EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a, 219 const Packet16i& b) { 220 return _mm512_mul_epi32(a, b); 221 } 222 223 template <> 224 EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a, 225 const Packet16f& b) { 226 return _mm512_div_ps(a, b); 227 } 228 template <> 229 EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a, 230 const Packet8d& b) { 231 return _mm512_div_pd(a, b); 232 } 233 234 #ifdef EIGEN_VECTORIZE_FMA 235 template <> 236 EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b, 237 const Packet16f& c) { 238 return _mm512_fmadd_ps(a, b, c); 239 } 240 template <> 241 EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, 242 const Packet8d& c) { 243 return _mm512_fmadd_pd(a, b, c); 244 } 245 #endif 246 247 template <> 248 EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a, 249 const Packet16f& b) { 250 // Arguments are reversed to match NaN propagation behavior of std::min. 251 return _mm512_min_ps(b, a); 252 } 253 template <> 254 EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a, 255 const Packet8d& b) { 256 // Arguments are reversed to match NaN propagation behavior of std::min. 257 return _mm512_min_pd(b, a); 258 } 259 260 template <> 261 EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a, 262 const Packet16f& b) { 263 // Arguments are reversed to match NaN propagation behavior of std::max. 264 return _mm512_max_ps(b, a); 265 } 266 template <> 267 EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a, 268 const Packet8d& b) { 269 // Arguments are reversed to match NaN propagation behavior of std::max. 270 return _mm512_max_pd(b, a); 271 } 272 273 #ifdef EIGEN_VECTORIZE_AVX512DQ 274 template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); } 275 template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); } 276 EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); } 277 #else 278 // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 279 template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { 280 return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I_)); 281 } 282 283 // AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512 284 template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { 285 return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I_)); 286 } 287 288 EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { 289 return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), 290 _mm256_castps_si256(b),1)); 291 } 292 #endif 293 294 // Helper function for bit packing snippet of low precision comparison. 295 // It packs the flags from 32x16 to 16x16. 296 EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { 297 // Split data into small pieces and handle with AVX instructions 298 // to guarantee internal order of vector. 299 // Operation: 300 // dst[15:0] := Saturate16(rf[31:0]) 301 // dst[31:16] := Saturate16(rf[63:32]) 302 // ... 303 // dst[255:240] := Saturate16(rf[255:224]) 304 __m256i lo = _mm256_castps_si256(extract256<0>(rf)); 305 __m256i hi = _mm256_castps_si256(extract256<1>(rf)); 306 __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), 307 _mm256_extractf128_si256(lo, 1)); 308 __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), 309 _mm256_extractf128_si256(hi, 1)); 310 return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); 311 } 312 313 template <> 314 EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a, 315 const Packet16i& b) { 316 return _mm512_and_si512(a,b); 317 } 318 319 template <> 320 EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a, 321 const Packet16f& b) { 322 #ifdef EIGEN_VECTORIZE_AVX512DQ 323 return _mm512_and_ps(a, b); 324 #else 325 return _mm512_castsi512_ps(pand(_mm512_castps_si512(a),_mm512_castps_si512(b))); 326 #endif 327 } 328 template <> 329 EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a, 330 const Packet8d& b) { 331 #ifdef EIGEN_VECTORIZE_AVX512DQ 332 return _mm512_and_pd(a, b); 333 #else 334 Packet8d res = _mm512_undefined_pd(); 335 Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0); 336 Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0); 337 res = _mm512_insertf64x4(res, _mm256_and_pd(lane0_a, lane0_b), 0); 338 339 Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1); 340 Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1); 341 return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1); 342 #endif 343 } 344 345 template <> 346 EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) { 347 return _mm512_or_si512(a, b); 348 } 349 350 template <> 351 EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) { 352 #ifdef EIGEN_VECTORIZE_AVX512DQ 353 return _mm512_or_ps(a, b); 354 #else 355 return _mm512_castsi512_ps(por(_mm512_castps_si512(a),_mm512_castps_si512(b))); 356 #endif 357 } 358 359 template <> 360 EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a, 361 const Packet8d& b) { 362 #ifdef EIGEN_VECTORIZE_AVX512DQ 363 return _mm512_or_pd(a, b); 364 #else 365 return _mm512_castsi512_pd(por(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); 366 #endif 367 } 368 369 template <> 370 EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) { 371 return _mm512_xor_si512(a, b); 372 } 373 374 template <> 375 EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) { 376 #ifdef EIGEN_VECTORIZE_AVX512DQ 377 return _mm512_xor_ps(a, b); 378 #else 379 return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a),_mm512_castps_si512(b))); 380 #endif 381 } 382 383 template <> 384 EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) { 385 #ifdef EIGEN_VECTORIZE_AVX512DQ 386 return _mm512_xor_pd(a, b); 387 #else 388 return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); 389 #endif 390 } 391 392 template <> 393 EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) { 394 return _mm512_andnot_si512(b, a); 395 } 396 397 template <> 398 EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) { 399 #ifdef EIGEN_VECTORIZE_AVX512DQ 400 return _mm512_andnot_ps(b, a); 401 #else 402 return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a),_mm512_castps_si512(b))); 403 #endif 404 } 405 template <> 406 EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& b) { 407 #ifdef EIGEN_VECTORIZE_AVX512DQ 408 return _mm512_andnot_pd(b, a); 409 #else 410 return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); 411 #endif 412 } 413 414 template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) { 415 return _mm512_srai_epi32(a, N); 416 } 417 418 template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) { 419 return _mm512_srli_epi32(a, N); 420 } 421 422 template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) { 423 return _mm512_slli_epi32(a, N); 424 } 425 426 template <> 427 EIGEN_STRONG_INLINE Packet16f pload<Packet16f>(const float* from) { 428 EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ps(from); 429 } 430 template <> 431 EIGEN_STRONG_INLINE Packet8d pload<Packet8d>(const double* from) { 432 EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_pd(from); 433 } 434 template <> 435 EIGEN_STRONG_INLINE Packet16i pload<Packet16i>(const int* from) { 436 EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( 437 reinterpret_cast<const __m512i*>(from)); 438 } 439 440 template <> 441 EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from) { 442 EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ps(from); 443 } 444 template <> 445 EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from) { 446 EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_pd(from); 447 } 448 template <> 449 EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) { 450 EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( 451 reinterpret_cast<const __m512i*>(from)); 452 } 453 454 // Loads 8 floats from memory a returns the packet 455 // {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7} 456 template <> 457 EIGEN_STRONG_INLINE Packet16f ploaddup<Packet16f>(const float* from) { 458 // an unaligned load is required here as there is no requirement 459 // on the alignment of input pointer 'from' 460 __m256i low_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); 461 __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half)); 462 __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); 463 return pairs; 464 } 465 466 #ifdef EIGEN_VECTORIZE_AVX512DQ 467 // FIXME: this does not look optimal, better load a Packet4d and shuffle... 468 // Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, 469 // a3} 470 template <> 471 EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) { 472 __m512d x = _mm512_setzero_pd(); 473 x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0); 474 x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1); 475 x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2); 476 x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3); 477 return x; 478 } 479 #else 480 template <> 481 EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) { 482 __m512d x = _mm512_setzero_pd(); 483 x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0)); 484 x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1)); 485 x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2)); 486 x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3)); 487 return x; 488 } 489 #endif 490 491 // Loads 4 floats from memory a returns the packet 492 // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} 493 template <> 494 EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) { 495 Packet16f tmp = _mm512_castps128_ps512(ploadu<Packet4f>(from)); 496 const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0); 497 return _mm512_permutexvar_ps(scatter_mask, tmp); 498 } 499 500 // Loads 2 doubles from memory a returns the packet 501 // {a0, a0 a0, a0, a1, a1, a1, a1} 502 template <> 503 EIGEN_STRONG_INLINE Packet8d ploadquad<Packet8d>(const double* from) { 504 __m256d lane0 = _mm256_set1_pd(*from); 505 __m256d lane1 = _mm256_set1_pd(*(from+1)); 506 __m512d tmp = _mm512_undefined_pd(); 507 tmp = _mm512_insertf64x4(tmp, lane0, 0); 508 return _mm512_insertf64x4(tmp, lane1, 1); 509 } 510 511 template <> 512 EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet16f& from) { 513 EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ps(to, from); 514 } 515 template <> 516 EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet8d& from) { 517 EIGEN_DEBUG_ALIGNED_STORE _mm512_store_pd(to, from); 518 } 519 template <> 520 EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet16i& from) { 521 EIGEN_DEBUG_ALIGNED_STORE _mm512_storeu_si512(reinterpret_cast<__m512i*>(to), 522 from); 523 } 524 525 template <> 526 EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from) { 527 EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ps(to, from); 528 } 529 template <> 530 EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from) { 531 EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_pd(to, from); 532 } 533 template <> 534 EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet16i& from) { 535 EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( 536 reinterpret_cast<__m512i*>(to), from); 537 } 538 539 template <> 540 EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from, 541 Index stride) { 542 Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride)); 543 Packet16i stride_multiplier = 544 _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); 545 Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); 546 547 return _mm512_i32gather_ps(indices, from, 4); 548 } 549 template <> 550 EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from, 551 Index stride) { 552 Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride)); 553 Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); 554 Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); 555 556 return _mm512_i32gather_pd(indices, from, 8); 557 } 558 559 template <> 560 EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, 561 const Packet16f& from, 562 Index stride) { 563 Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride)); 564 Packet16i stride_multiplier = 565 _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); 566 Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); 567 _mm512_i32scatter_ps(to, indices, from, 4); 568 } 569 template <> 570 EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to, 571 const Packet8d& from, 572 Index stride) { 573 Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride)); 574 Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); 575 Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); 576 _mm512_i32scatter_pd(to, indices, from, 8); 577 } 578 579 template <> 580 EIGEN_STRONG_INLINE void pstore1<Packet16f>(float* to, const float& a) { 581 Packet16f pa = pset1<Packet16f>(a); 582 pstore(to, pa); 583 } 584 template <> 585 EIGEN_STRONG_INLINE void pstore1<Packet8d>(double* to, const double& a) { 586 Packet8d pa = pset1<Packet8d>(a); 587 pstore(to, pa); 588 } 589 template <> 590 EIGEN_STRONG_INLINE void pstore1<Packet16i>(int* to, const int& a) { 591 Packet16i pa = pset1<Packet16i>(a); 592 pstore(to, pa); 593 } 594 595 template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } 596 template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } 597 template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } 598 599 template <> 600 EIGEN_STRONG_INLINE float pfirst<Packet16f>(const Packet16f& a) { 601 return _mm_cvtss_f32(_mm512_extractf32x4_ps(a, 0)); 602 } 603 template <> 604 EIGEN_STRONG_INLINE double pfirst<Packet8d>(const Packet8d& a) { 605 return _mm_cvtsd_f64(_mm256_extractf128_pd(_mm512_extractf64x4_pd(a, 0), 0)); 606 } 607 template <> 608 EIGEN_STRONG_INLINE int pfirst<Packet16i>(const Packet16i& a) { 609 return _mm_extract_epi32(_mm512_extracti32x4_epi32(a, 0), 0); 610 } 611 612 template<> EIGEN_STRONG_INLINE Packet16f preverse(const Packet16f& a) 613 { 614 return _mm512_permutexvar_ps(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a); 615 } 616 617 template<> EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a) 618 { 619 return _mm512_permutexvar_pd(_mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), a); 620 } 621 622 template<> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) 623 { 624 // _mm512_abs_ps intrinsic not found, so hack around it 625 return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff))); 626 } 627 template <> 628 EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) { 629 // _mm512_abs_ps intrinsic not found, so hack around it 630 return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a), 631 _mm512_set1_epi64(0x7fffffffffffffff))); 632 } 633 634 #ifdef EIGEN_VECTORIZE_AVX512DQ 635 // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 636 #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ 637 __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \ 638 __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1) 639 #else 640 #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ 641 __m256 OUTPUT##_0 = _mm256_insertf128_ps( \ 642 _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \ 643 _mm512_extractf32x4_ps(INPUT, 1), 1); \ 644 __m256 OUTPUT##_1 = _mm256_insertf128_ps( \ 645 _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \ 646 _mm512_extractf32x4_ps(INPUT, 3), 1); 647 #endif 648 649 #ifdef EIGEN_VECTORIZE_AVX512DQ 650 #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ 651 OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1); 652 #else 653 #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ 654 OUTPUT = _mm512_undefined_ps(); \ 655 OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \ 656 OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \ 657 OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \ 658 OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3); 659 #endif 660 661 template <> 662 EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) { 663 #ifdef EIGEN_VECTORIZE_AVX512DQ 664 __m256 lane0 = _mm512_extractf32x8_ps(a, 0); 665 __m256 lane1 = _mm512_extractf32x8_ps(a, 1); 666 Packet8f x = _mm256_add_ps(lane0, lane1); 667 return predux<Packet8f>(x); 668 #else 669 __m128 lane0 = _mm512_extractf32x4_ps(a, 0); 670 __m128 lane1 = _mm512_extractf32x4_ps(a, 1); 671 __m128 lane2 = _mm512_extractf32x4_ps(a, 2); 672 __m128 lane3 = _mm512_extractf32x4_ps(a, 3); 673 __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3)); 674 sum = _mm_hadd_ps(sum, sum); 675 sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1)); 676 return _mm_cvtss_f32(sum); 677 #endif 678 } 679 template <> 680 EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) { 681 __m256d lane0 = _mm512_extractf64x4_pd(a, 0); 682 __m256d lane1 = _mm512_extractf64x4_pd(a, 1); 683 __m256d sum = _mm256_add_pd(lane0, lane1); 684 __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); 685 return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0))); 686 } 687 688 template <> 689 EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) { 690 #ifdef EIGEN_VECTORIZE_AVX512DQ 691 Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); 692 Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); 693 return padd(lane0, lane1); 694 #else 695 Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); 696 Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); 697 Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); 698 Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); 699 Packet4f sum0 = padd(lane0, lane2); 700 Packet4f sum1 = padd(lane1, lane3); 701 return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1); 702 #endif 703 } 704 template <> 705 EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) { 706 Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); 707 Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); 708 Packet4d res = padd(lane0, lane1); 709 return res; 710 } 711 712 template <> 713 EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) { 714 //#ifdef EIGEN_VECTORIZE_AVX512DQ 715 #if 0 716 Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); 717 Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); 718 Packet8f res = pmul(lane0, lane1); 719 res = pmul(res, _mm256_permute2f128_ps(res, res, 1)); 720 res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); 721 return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); 722 #else 723 __m128 lane0 = _mm512_extractf32x4_ps(a, 0); 724 __m128 lane1 = _mm512_extractf32x4_ps(a, 1); 725 __m128 lane2 = _mm512_extractf32x4_ps(a, 2); 726 __m128 lane3 = _mm512_extractf32x4_ps(a, 3); 727 __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3)); 728 res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); 729 return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); 730 #endif 731 } 732 template <> 733 EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) { 734 __m256d lane0 = _mm512_extractf64x4_pd(a, 0); 735 __m256d lane1 = _mm512_extractf64x4_pd(a, 1); 736 __m256d res = pmul(lane0, lane1); 737 res = pmul(res, _mm256_permute2f128_pd(res, res, 1)); 738 return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1))); 739 } 740 741 template <> 742 EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) { 743 __m128 lane0 = _mm512_extractf32x4_ps(a, 0); 744 __m128 lane1 = _mm512_extractf32x4_ps(a, 1); 745 __m128 lane2 = _mm512_extractf32x4_ps(a, 2); 746 __m128 lane3 = _mm512_extractf32x4_ps(a, 3); 747 __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3)); 748 res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); 749 return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); 750 } 751 template <> 752 EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) { 753 __m256d lane0 = _mm512_extractf64x4_pd(a, 0); 754 __m256d lane1 = _mm512_extractf64x4_pd(a, 1); 755 __m256d res = _mm256_min_pd(lane0, lane1); 756 res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1)); 757 return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1))); 758 } 759 760 template <> 761 EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) { 762 __m128 lane0 = _mm512_extractf32x4_ps(a, 0); 763 __m128 lane1 = _mm512_extractf32x4_ps(a, 1); 764 __m128 lane2 = _mm512_extractf32x4_ps(a, 2); 765 __m128 lane3 = _mm512_extractf32x4_ps(a, 3); 766 __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3)); 767 res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); 768 return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); 769 } 770 771 template <> 772 EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) { 773 __m256d lane0 = _mm512_extractf64x4_pd(a, 0); 774 __m256d lane1 = _mm512_extractf64x4_pd(a, 1); 775 __m256d res = _mm256_max_pd(lane0, lane1); 776 res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1)); 777 return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1))); 778 } 779 780 template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f* vecs) 781 { 782 EIGEN_EXTRACT_8f_FROM_16f(vecs[0], vecs0); 783 EIGEN_EXTRACT_8f_FROM_16f(vecs[1], vecs1); 784 EIGEN_EXTRACT_8f_FROM_16f(vecs[2], vecs2); 785 EIGEN_EXTRACT_8f_FROM_16f(vecs[3], vecs3); 786 EIGEN_EXTRACT_8f_FROM_16f(vecs[4], vecs4); 787 EIGEN_EXTRACT_8f_FROM_16f(vecs[5], vecs5); 788 EIGEN_EXTRACT_8f_FROM_16f(vecs[6], vecs6); 789 EIGEN_EXTRACT_8f_FROM_16f(vecs[7], vecs7); 790 EIGEN_EXTRACT_8f_FROM_16f(vecs[8], vecs8); 791 EIGEN_EXTRACT_8f_FROM_16f(vecs[9], vecs9); 792 EIGEN_EXTRACT_8f_FROM_16f(vecs[10], vecs10); 793 EIGEN_EXTRACT_8f_FROM_16f(vecs[11], vecs11); 794 EIGEN_EXTRACT_8f_FROM_16f(vecs[12], vecs12); 795 EIGEN_EXTRACT_8f_FROM_16f(vecs[13], vecs13); 796 EIGEN_EXTRACT_8f_FROM_16f(vecs[14], vecs14); 797 EIGEN_EXTRACT_8f_FROM_16f(vecs[15], vecs15); 798 799 __m256 hsum1 = _mm256_hadd_ps(vecs0_0, vecs1_0); 800 __m256 hsum2 = _mm256_hadd_ps(vecs2_0, vecs3_0); 801 __m256 hsum3 = _mm256_hadd_ps(vecs4_0, vecs5_0); 802 __m256 hsum4 = _mm256_hadd_ps(vecs6_0, vecs7_0); 803 804 __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1); 805 __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2); 806 __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3); 807 __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4); 808 809 __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); 810 __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); 811 __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); 812 __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); 813 814 __m256 sum1 = _mm256_add_ps(perm1, hsum5); 815 __m256 sum2 = _mm256_add_ps(perm2, hsum6); 816 __m256 sum3 = _mm256_add_ps(perm3, hsum7); 817 __m256 sum4 = _mm256_add_ps(perm4, hsum8); 818 819 __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); 820 __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); 821 822 __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0); 823 824 hsum1 = _mm256_hadd_ps(vecs0_1, vecs1_1); 825 hsum2 = _mm256_hadd_ps(vecs2_1, vecs3_1); 826 hsum3 = _mm256_hadd_ps(vecs4_1, vecs5_1); 827 hsum4 = _mm256_hadd_ps(vecs6_1, vecs7_1); 828 829 hsum5 = _mm256_hadd_ps(hsum1, hsum1); 830 hsum6 = _mm256_hadd_ps(hsum2, hsum2); 831 hsum7 = _mm256_hadd_ps(hsum3, hsum3); 832 hsum8 = _mm256_hadd_ps(hsum4, hsum4); 833 834 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); 835 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); 836 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); 837 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); 838 839 sum1 = _mm256_add_ps(perm1, hsum5); 840 sum2 = _mm256_add_ps(perm2, hsum6); 841 sum3 = _mm256_add_ps(perm3, hsum7); 842 sum4 = _mm256_add_ps(perm4, hsum8); 843 844 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); 845 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); 846 847 final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0)); 848 849 hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0); 850 hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0); 851 hsum3 = _mm256_hadd_ps(vecs12_0, vecs13_0); 852 hsum4 = _mm256_hadd_ps(vecs14_0, vecs15_0); 853 854 hsum5 = _mm256_hadd_ps(hsum1, hsum1); 855 hsum6 = _mm256_hadd_ps(hsum2, hsum2); 856 hsum7 = _mm256_hadd_ps(hsum3, hsum3); 857 hsum8 = _mm256_hadd_ps(hsum4, hsum4); 858 859 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); 860 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); 861 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); 862 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); 863 864 sum1 = _mm256_add_ps(perm1, hsum5); 865 sum2 = _mm256_add_ps(perm2, hsum6); 866 sum3 = _mm256_add_ps(perm3, hsum7); 867 sum4 = _mm256_add_ps(perm4, hsum8); 868 869 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); 870 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); 871 872 __m256 final_1 = _mm256_blend_ps(blend1, blend2, 0xf0); 873 874 hsum1 = _mm256_hadd_ps(vecs8_1, vecs9_1); 875 hsum2 = _mm256_hadd_ps(vecs10_1, vecs11_1); 876 hsum3 = _mm256_hadd_ps(vecs12_1, vecs13_1); 877 hsum4 = _mm256_hadd_ps(vecs14_1, vecs15_1); 878 879 hsum5 = _mm256_hadd_ps(hsum1, hsum1); 880 hsum6 = _mm256_hadd_ps(hsum2, hsum2); 881 hsum7 = _mm256_hadd_ps(hsum3, hsum3); 882 hsum8 = _mm256_hadd_ps(hsum4, hsum4); 883 884 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23); 885 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23); 886 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23); 887 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23); 888 889 sum1 = _mm256_add_ps(perm1, hsum5); 890 sum2 = _mm256_add_ps(perm2, hsum6); 891 sum3 = _mm256_add_ps(perm3, hsum7); 892 sum4 = _mm256_add_ps(perm4, hsum8); 893 894 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); 895 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); 896 897 final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0)); 898 899 __m512 final_output; 900 901 EIGEN_INSERT_8f_INTO_16f(final_output, final, final_1); 902 return final_output; 903 } 904 905 template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs) 906 { 907 Packet4d vecs0_0 = _mm512_extractf64x4_pd(vecs[0], 0); 908 Packet4d vecs0_1 = _mm512_extractf64x4_pd(vecs[0], 1); 909 910 Packet4d vecs1_0 = _mm512_extractf64x4_pd(vecs[1], 0); 911 Packet4d vecs1_1 = _mm512_extractf64x4_pd(vecs[1], 1); 912 913 Packet4d vecs2_0 = _mm512_extractf64x4_pd(vecs[2], 0); 914 Packet4d vecs2_1 = _mm512_extractf64x4_pd(vecs[2], 1); 915 916 Packet4d vecs3_0 = _mm512_extractf64x4_pd(vecs[3], 0); 917 Packet4d vecs3_1 = _mm512_extractf64x4_pd(vecs[3], 1); 918 919 Packet4d vecs4_0 = _mm512_extractf64x4_pd(vecs[4], 0); 920 Packet4d vecs4_1 = _mm512_extractf64x4_pd(vecs[4], 1); 921 922 Packet4d vecs5_0 = _mm512_extractf64x4_pd(vecs[5], 0); 923 Packet4d vecs5_1 = _mm512_extractf64x4_pd(vecs[5], 1); 924 925 Packet4d vecs6_0 = _mm512_extractf64x4_pd(vecs[6], 0); 926 Packet4d vecs6_1 = _mm512_extractf64x4_pd(vecs[6], 1); 927 928 Packet4d vecs7_0 = _mm512_extractf64x4_pd(vecs[7], 0); 929 Packet4d vecs7_1 = _mm512_extractf64x4_pd(vecs[7], 1); 930 931 Packet4d tmp0, tmp1; 932 933 tmp0 = _mm256_hadd_pd(vecs0_0, vecs1_0); 934 tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); 935 936 tmp1 = _mm256_hadd_pd(vecs2_0, vecs3_0); 937 tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); 938 939 __m256d final_0 = _mm256_blend_pd(tmp0, tmp1, 0xC); 940 941 tmp0 = _mm256_hadd_pd(vecs0_1, vecs1_1); 942 tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); 943 944 tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1); 945 tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); 946 947 final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC)); 948 949 tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0); 950 tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); 951 952 tmp1 = _mm256_hadd_pd(vecs6_0, vecs7_0); 953 tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); 954 955 __m256d final_1 = _mm256_blend_pd(tmp0, tmp1, 0xC); 956 957 tmp0 = _mm256_hadd_pd(vecs4_1, vecs5_1); 958 tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); 959 960 tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1); 961 tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); 962 963 final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC)); 964 965 __m512d final_output = _mm512_insertf64x4(final_output, final_0, 0); 966 967 return _mm512_insertf64x4(final_output, final_1, 1); 968 } 969 970 971 972 #define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \ 973 EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]); 974 975 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) { 976 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); 977 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); 978 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); 979 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); 980 __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); 981 __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); 982 __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); 983 __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); 984 __m512 T8 = _mm512_unpacklo_ps(kernel.packet[8], kernel.packet[9]); 985 __m512 T9 = _mm512_unpackhi_ps(kernel.packet[8], kernel.packet[9]); 986 __m512 T10 = _mm512_unpacklo_ps(kernel.packet[10], kernel.packet[11]); 987 __m512 T11 = _mm512_unpackhi_ps(kernel.packet[10], kernel.packet[11]); 988 __m512 T12 = _mm512_unpacklo_ps(kernel.packet[12], kernel.packet[13]); 989 __m512 T13 = _mm512_unpackhi_ps(kernel.packet[12], kernel.packet[13]); 990 __m512 T14 = _mm512_unpacklo_ps(kernel.packet[14], kernel.packet[15]); 991 __m512 T15 = _mm512_unpackhi_ps(kernel.packet[14], kernel.packet[15]); 992 __m512 S0 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(1, 0, 1, 0)); 993 __m512 S1 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(3, 2, 3, 2)); 994 __m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0)); 995 __m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2)); 996 __m512 S4 = _mm512_shuffle_ps(T4, T6, _MM_SHUFFLE(1, 0, 1, 0)); 997 __m512 S5 = _mm512_shuffle_ps(T4, T6, _MM_SHUFFLE(3, 2, 3, 2)); 998 __m512 S6 = _mm512_shuffle_ps(T5, T7, _MM_SHUFFLE(1, 0, 1, 0)); 999 __m512 S7 = _mm512_shuffle_ps(T5, T7, _MM_SHUFFLE(3, 2, 3, 2)); 1000 __m512 S8 = _mm512_shuffle_ps(T8, T10, _MM_SHUFFLE(1, 0, 1, 0)); 1001 __m512 S9 = _mm512_shuffle_ps(T8, T10, _MM_SHUFFLE(3, 2, 3, 2)); 1002 __m512 S10 = _mm512_shuffle_ps(T9, T11, _MM_SHUFFLE(1, 0, 1, 0)); 1003 __m512 S11 = _mm512_shuffle_ps(T9, T11, _MM_SHUFFLE(3, 2, 3, 2)); 1004 __m512 S12 = _mm512_shuffle_ps(T12, T14, _MM_SHUFFLE(1, 0, 1, 0)); 1005 __m512 S13 = _mm512_shuffle_ps(T12, T14, _MM_SHUFFLE(3, 2, 3, 2)); 1006 __m512 S14 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(1, 0, 1, 0)); 1007 __m512 S15 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(3, 2, 3, 2)); 1008 1009 EIGEN_EXTRACT_8f_FROM_16f(S0, S0); 1010 EIGEN_EXTRACT_8f_FROM_16f(S1, S1); 1011 EIGEN_EXTRACT_8f_FROM_16f(S2, S2); 1012 EIGEN_EXTRACT_8f_FROM_16f(S3, S3); 1013 EIGEN_EXTRACT_8f_FROM_16f(S4, S4); 1014 EIGEN_EXTRACT_8f_FROM_16f(S5, S5); 1015 EIGEN_EXTRACT_8f_FROM_16f(S6, S6); 1016 EIGEN_EXTRACT_8f_FROM_16f(S7, S7); 1017 EIGEN_EXTRACT_8f_FROM_16f(S8, S8); 1018 EIGEN_EXTRACT_8f_FROM_16f(S9, S9); 1019 EIGEN_EXTRACT_8f_FROM_16f(S10, S10); 1020 EIGEN_EXTRACT_8f_FROM_16f(S11, S11); 1021 EIGEN_EXTRACT_8f_FROM_16f(S12, S12); 1022 EIGEN_EXTRACT_8f_FROM_16f(S13, S13); 1023 EIGEN_EXTRACT_8f_FROM_16f(S14, S14); 1024 EIGEN_EXTRACT_8f_FROM_16f(S15, S15); 1025 1026 PacketBlock<Packet8f, 32> tmp; 1027 1028 tmp.packet[0] = _mm256_permute2f128_ps(S0_0, S4_0, 0x20); 1029 tmp.packet[1] = _mm256_permute2f128_ps(S1_0, S5_0, 0x20); 1030 tmp.packet[2] = _mm256_permute2f128_ps(S2_0, S6_0, 0x20); 1031 tmp.packet[3] = _mm256_permute2f128_ps(S3_0, S7_0, 0x20); 1032 tmp.packet[4] = _mm256_permute2f128_ps(S0_0, S4_0, 0x31); 1033 tmp.packet[5] = _mm256_permute2f128_ps(S1_0, S5_0, 0x31); 1034 tmp.packet[6] = _mm256_permute2f128_ps(S2_0, S6_0, 0x31); 1035 tmp.packet[7] = _mm256_permute2f128_ps(S3_0, S7_0, 0x31); 1036 1037 tmp.packet[8] = _mm256_permute2f128_ps(S0_1, S4_1, 0x20); 1038 tmp.packet[9] = _mm256_permute2f128_ps(S1_1, S5_1, 0x20); 1039 tmp.packet[10] = _mm256_permute2f128_ps(S2_1, S6_1, 0x20); 1040 tmp.packet[11] = _mm256_permute2f128_ps(S3_1, S7_1, 0x20); 1041 tmp.packet[12] = _mm256_permute2f128_ps(S0_1, S4_1, 0x31); 1042 tmp.packet[13] = _mm256_permute2f128_ps(S1_1, S5_1, 0x31); 1043 tmp.packet[14] = _mm256_permute2f128_ps(S2_1, S6_1, 0x31); 1044 tmp.packet[15] = _mm256_permute2f128_ps(S3_1, S7_1, 0x31); 1045 1046 // Second set of _m256 outputs 1047 tmp.packet[16] = _mm256_permute2f128_ps(S8_0, S12_0, 0x20); 1048 tmp.packet[17] = _mm256_permute2f128_ps(S9_0, S13_0, 0x20); 1049 tmp.packet[18] = _mm256_permute2f128_ps(S10_0, S14_0, 0x20); 1050 tmp.packet[19] = _mm256_permute2f128_ps(S11_0, S15_0, 0x20); 1051 tmp.packet[20] = _mm256_permute2f128_ps(S8_0, S12_0, 0x31); 1052 tmp.packet[21] = _mm256_permute2f128_ps(S9_0, S13_0, 0x31); 1053 tmp.packet[22] = _mm256_permute2f128_ps(S10_0, S14_0, 0x31); 1054 tmp.packet[23] = _mm256_permute2f128_ps(S11_0, S15_0, 0x31); 1055 1056 tmp.packet[24] = _mm256_permute2f128_ps(S8_1, S12_1, 0x20); 1057 tmp.packet[25] = _mm256_permute2f128_ps(S9_1, S13_1, 0x20); 1058 tmp.packet[26] = _mm256_permute2f128_ps(S10_1, S14_1, 0x20); 1059 tmp.packet[27] = _mm256_permute2f128_ps(S11_1, S15_1, 0x20); 1060 tmp.packet[28] = _mm256_permute2f128_ps(S8_1, S12_1, 0x31); 1061 tmp.packet[29] = _mm256_permute2f128_ps(S9_1, S13_1, 0x31); 1062 tmp.packet[30] = _mm256_permute2f128_ps(S10_1, S14_1, 0x31); 1063 tmp.packet[31] = _mm256_permute2f128_ps(S11_1, S15_1, 0x31); 1064 1065 // Pack them into the output 1066 PACK_OUTPUT(kernel.packet, tmp.packet, 0, 16); 1067 PACK_OUTPUT(kernel.packet, tmp.packet, 1, 16); 1068 PACK_OUTPUT(kernel.packet, tmp.packet, 2, 16); 1069 PACK_OUTPUT(kernel.packet, tmp.packet, 3, 16); 1070 1071 PACK_OUTPUT(kernel.packet, tmp.packet, 4, 16); 1072 PACK_OUTPUT(kernel.packet, tmp.packet, 5, 16); 1073 PACK_OUTPUT(kernel.packet, tmp.packet, 6, 16); 1074 PACK_OUTPUT(kernel.packet, tmp.packet, 7, 16); 1075 1076 PACK_OUTPUT(kernel.packet, tmp.packet, 8, 16); 1077 PACK_OUTPUT(kernel.packet, tmp.packet, 9, 16); 1078 PACK_OUTPUT(kernel.packet, tmp.packet, 10, 16); 1079 PACK_OUTPUT(kernel.packet, tmp.packet, 11, 16); 1080 1081 PACK_OUTPUT(kernel.packet, tmp.packet, 12, 16); 1082 PACK_OUTPUT(kernel.packet, tmp.packet, 13, 16); 1083 PACK_OUTPUT(kernel.packet, tmp.packet, 14, 16); 1084 PACK_OUTPUT(kernel.packet, tmp.packet, 15, 16); 1085 } 1086 #define PACK_OUTPUT_2(OUTPUT, INPUT, INDEX, STRIDE) \ 1087 EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \ 1088 INPUT[2 * INDEX + STRIDE]); 1089 1090 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) { 1091 __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); 1092 __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); 1093 __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); 1094 __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); 1095 1096 __m512 S0 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(1, 0, 1, 0)); 1097 __m512 S1 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(3, 2, 3, 2)); 1098 __m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0)); 1099 __m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2)); 1100 1101 EIGEN_EXTRACT_8f_FROM_16f(S0, S0); 1102 EIGEN_EXTRACT_8f_FROM_16f(S1, S1); 1103 EIGEN_EXTRACT_8f_FROM_16f(S2, S2); 1104 EIGEN_EXTRACT_8f_FROM_16f(S3, S3); 1105 1106 PacketBlock<Packet8f, 8> tmp; 1107 1108 tmp.packet[0] = _mm256_permute2f128_ps(S0_0, S1_0, 0x20); 1109 tmp.packet[1] = _mm256_permute2f128_ps(S2_0, S3_0, 0x20); 1110 tmp.packet[2] = _mm256_permute2f128_ps(S0_0, S1_0, 0x31); 1111 tmp.packet[3] = _mm256_permute2f128_ps(S2_0, S3_0, 0x31); 1112 1113 tmp.packet[4] = _mm256_permute2f128_ps(S0_1, S1_1, 0x20); 1114 tmp.packet[5] = _mm256_permute2f128_ps(S2_1, S3_1, 0x20); 1115 tmp.packet[6] = _mm256_permute2f128_ps(S0_1, S1_1, 0x31); 1116 tmp.packet[7] = _mm256_permute2f128_ps(S2_1, S3_1, 0x31); 1117 1118 PACK_OUTPUT_2(kernel.packet, tmp.packet, 0, 1); 1119 PACK_OUTPUT_2(kernel.packet, tmp.packet, 1, 1); 1120 PACK_OUTPUT_2(kernel.packet, tmp.packet, 2, 1); 1121 PACK_OUTPUT_2(kernel.packet, tmp.packet, 3, 1); 1122 } 1123 1124 #define PACK_OUTPUT_SQ_D(OUTPUT, INPUT, INDEX, STRIDE) \ 1125 OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX], 0); \ 1126 OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX + STRIDE], 1); 1127 1128 #define PACK_OUTPUT_D(OUTPUT, INPUT, INDEX, STRIDE) \ 1129 OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \ 1130 OUTPUT[INDEX] = \ 1131 _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1); 1132 1133 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) { 1134 __m512d T0 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0); 1135 __m512d T1 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0xff); 1136 __m512d T2 = _mm512_shuffle_pd(kernel.packet[2], kernel.packet[3], 0); 1137 __m512d T3 = _mm512_shuffle_pd(kernel.packet[2], kernel.packet[3], 0xff); 1138 1139 PacketBlock<Packet4d, 8> tmp; 1140 1141 tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), 1142 _mm512_extractf64x4_pd(T2, 0), 0x20); 1143 tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), 1144 _mm512_extractf64x4_pd(T3, 0), 0x20); 1145 tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), 1146 _mm512_extractf64x4_pd(T2, 0), 0x31); 1147 tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), 1148 _mm512_extractf64x4_pd(T3, 0), 0x31); 1149 1150 tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), 1151 _mm512_extractf64x4_pd(T2, 1), 0x20); 1152 tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), 1153 _mm512_extractf64x4_pd(T3, 1), 0x20); 1154 tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), 1155 _mm512_extractf64x4_pd(T2, 1), 0x31); 1156 tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), 1157 _mm512_extractf64x4_pd(T3, 1), 0x31); 1158 1159 PACK_OUTPUT_D(kernel.packet, tmp.packet, 0, 1); 1160 PACK_OUTPUT_D(kernel.packet, tmp.packet, 1, 1); 1161 PACK_OUTPUT_D(kernel.packet, tmp.packet, 2, 1); 1162 PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1); 1163 } 1164 1165 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) { 1166 __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]); 1167 __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]); 1168 __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]); 1169 __m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]); 1170 __m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]); 1171 __m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]); 1172 __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]); 1173 __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]); 1174 1175 PacketBlock<Packet4d, 16> tmp; 1176 1177 tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), 1178 _mm512_extractf64x4_pd(T2, 0), 0x20); 1179 tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), 1180 _mm512_extractf64x4_pd(T3, 0), 0x20); 1181 tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), 1182 _mm512_extractf64x4_pd(T2, 0), 0x31); 1183 tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), 1184 _mm512_extractf64x4_pd(T3, 0), 0x31); 1185 1186 tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), 1187 _mm512_extractf64x4_pd(T2, 1), 0x20); 1188 tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), 1189 _mm512_extractf64x4_pd(T3, 1), 0x20); 1190 tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), 1191 _mm512_extractf64x4_pd(T2, 1), 0x31); 1192 tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), 1193 _mm512_extractf64x4_pd(T3, 1), 0x31); 1194 1195 tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0), 1196 _mm512_extractf64x4_pd(T6, 0), 0x20); 1197 tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0), 1198 _mm512_extractf64x4_pd(T7, 0), 0x20); 1199 tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0), 1200 _mm512_extractf64x4_pd(T6, 0), 0x31); 1201 tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0), 1202 _mm512_extractf64x4_pd(T7, 0), 0x31); 1203 1204 tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1), 1205 _mm512_extractf64x4_pd(T6, 1), 0x20); 1206 tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1), 1207 _mm512_extractf64x4_pd(T7, 1), 0x20); 1208 tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1), 1209 _mm512_extractf64x4_pd(T6, 1), 0x31); 1210 tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1), 1211 _mm512_extractf64x4_pd(T7, 1), 0x31); 1212 1213 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8); 1214 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8); 1215 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8); 1216 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8); 1217 1218 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8); 1219 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8); 1220 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8); 1221 PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8); 1222 } 1223 template <> 1224 EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/, 1225 const Packet16f& /*thenPacket*/, 1226 const Packet16f& /*elsePacket*/) { 1227 assert(false && "To be implemented"); 1228 return Packet16f(); 1229 } 1230 template <> 1231 EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, 1232 const Packet8d& thenPacket, 1233 const Packet8d& elsePacket) { 1234 __mmask8 m = (ifPacket.select[0] ) 1235 | (ifPacket.select[1]<<1) 1236 | (ifPacket.select[2]<<2) 1237 | (ifPacket.select[3]<<3) 1238 | (ifPacket.select[4]<<4) 1239 | (ifPacket.select[5]<<5) 1240 | (ifPacket.select[6]<<6) 1241 | (ifPacket.select[7]<<7); 1242 return _mm512_mask_blend_pd(m, elsePacket, thenPacket); 1243 } 1244 1245 template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) { 1246 return _mm512_cvttps_epi32(a); 1247 } 1248 1249 template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) { 1250 return _mm512_cvtepi32_ps(a); 1251 } 1252 1253 template <int Offset> 1254 struct palign_impl<Offset, Packet16f> { 1255 static EIGEN_STRONG_INLINE void run(Packet16f& first, 1256 const Packet16f& second) { 1257 if (Offset != 0) { 1258 __m512i first_idx = _mm512_set_epi32( 1259 Offset + 15, Offset + 14, Offset + 13, Offset + 12, Offset + 11, 1260 Offset + 10, Offset + 9, Offset + 8, Offset + 7, Offset + 6, 1261 Offset + 5, Offset + 4, Offset + 3, Offset + 2, Offset + 1, Offset); 1262 1263 __m512i second_idx = 1264 _mm512_set_epi32(Offset - 1, Offset - 2, Offset - 3, Offset - 4, 1265 Offset - 5, Offset - 6, Offset - 7, Offset - 8, 1266 Offset - 9, Offset - 10, Offset - 11, Offset - 12, 1267 Offset - 13, Offset - 14, Offset - 15, Offset - 16); 1268 1269 unsigned short mask = 0xFFFF; 1270 mask <<= (16 - Offset); 1271 1272 first = _mm512_permutexvar_ps(first_idx, first); 1273 Packet16f tmp = _mm512_permutexvar_ps(second_idx, second); 1274 first = _mm512_mask_blend_ps(mask, first, tmp); 1275 } 1276 } 1277 }; 1278 template <int Offset> 1279 struct palign_impl<Offset, Packet8d> { 1280 static EIGEN_STRONG_INLINE void run(Packet8d& first, const Packet8d& second) { 1281 if (Offset != 0) { 1282 __m512i first_idx = _mm512_set_epi32( 1283 0, Offset + 7, 0, Offset + 6, 0, Offset + 5, 0, Offset + 4, 0, 1284 Offset + 3, 0, Offset + 2, 0, Offset + 1, 0, Offset); 1285 1286 __m512i second_idx = _mm512_set_epi32( 1287 0, Offset - 1, 0, Offset - 2, 0, Offset - 3, 0, Offset - 4, 0, 1288 Offset - 5, 0, Offset - 6, 0, Offset - 7, 0, Offset - 8); 1289 1290 unsigned char mask = 0xFF; 1291 mask <<= (8 - Offset); 1292 1293 first = _mm512_permutexvar_pd(first_idx, first); 1294 Packet8d tmp = _mm512_permutexvar_pd(second_idx, second); 1295 first = _mm512_mask_blend_pd(mask, first, tmp); 1296 } 1297 } 1298 }; 1299 1300 1301 } // end namespace internal 1302 1303 } // end namespace Eigen 1304 1305 #endif // EIGEN_PACKET_MATH_AVX512_H 1306