1 #ifndef NTL_PD__H 2 #define NTL_PD__H 3 4 #include <NTL/tools.h> 5 #include <immintrin.h> 6 7 NTL_OPEN_NNS 8 9 10 template<int N> 11 struct PD { 12 private: 13 PD(); 14 }; 15 16 17 // FIXME: should distinguish more carefully: 18 // AVX512DQ for long/double conversions 19 // AVX512VL for certain ops applied to shorter types: 20 // long/double conversions and mask ops 21 // may need to translate long/double conversions for non-AVXDQ512 22 23 24 25 //=================== PD<8> implementation =============== 26 27 #ifdef NTL_HAVE_AVX512F 28 29 template<> 30 struct PD<8> { 31 __m512d data; 32 33 enum { size = 8}; 34 35 PD() { } 36 PD(double x) : data(_mm512_set1_pd(x)) { } 37 PD(__m512d _data) : data(_data) { } 38 39 PD(double d0, double d1, double d2, double d3, 40 double d4, double d5, double d6, double d7) 41 : data(_mm512_set_pd(d7, d6, d5, d4, d3, d2, d1, d0)) { } 42 43 static PD load(const double *p) { return _mm512_load_pd(p); } 44 45 // load from unaligned address 46 static PD loadu(const double *p) { return _mm512_loadu_pd(p); } 47 }; 48 49 inline void 50 load(PD<8>& x, const double *p) 51 { x = PD<8>::load(p); } 52 53 // load from unaligned address 54 inline void 55 loadu(PD<8>& x, const double *p) 56 { x = PD<8>::loadu(p); } 57 58 inline void 59 store(double *p, PD<8> a) 60 { _mm512_store_pd(p, a.data); } 61 62 // store to unaligned address 63 inline void 64 storeu(double *p, PD<8> a) 65 { _mm512_storeu_pd(p, a.data); } 66 67 // load and convert 68 inline void 69 load(PD<8>& x, const long *p) 70 { __m512i a = _mm512_load_epi64(p); x = _mm512_cvtepi64_pd(a); } 71 72 // load unaligned and convert 73 inline void 74 loadu(PD<8>& x, const long *p) 75 { __m512i a = _mm512_loadu_si512(p); x = _mm512_cvtepi64_pd(a); } 76 77 // convert and store 78 inline void 79 store(long *p, PD<8> a) 80 { __m512i b = _mm512_cvtpd_epi64(a.data); _mm512_store_epi64(p, b); } 81 82 // convert and store unaligned 83 inline void 84 storeu(long *p, PD<8> a) 85 { __m512i b = _mm512_cvtpd_epi64(a.data); _mm512_storeu_si512(p, b); } 86 87 88 // swap even/odd slots 89 // e.g., 01234567 -> 10325476 90 inline PD<8> 91 swap2(PD<8> a) 92 { return _mm512_permute_pd(a.data, 0x55); } 93 94 // swap even/odd slot-pairs 95 // e.g., 01234567 -> 23016745 96 inline PD<8> 97 swap4(PD<8> a) 98 { return _mm512_permutex_pd(a.data, 0x4e); } 99 100 // 01234567 -> 00224466 101 inline PD<8> 102 dup2even(PD<8> a) 103 { return _mm512_permute_pd(a.data, 0); } 104 105 // 01234567 -> 11335577 106 inline PD<8> 107 dup2odd(PD<8> a) 108 { return _mm512_permute_pd(a.data, 0xff); } 109 110 // 01234567 -> 01014545 111 inline PD<8> 112 dup4even(PD<8> a) 113 { return _mm512_permutex_pd(a.data, 0x44); } 114 115 // 01234567 -> 23236767 116 inline PD<8> 117 dup4odd(PD<8> a) 118 { return _mm512_permutex_pd(a.data, 0xee); } 119 120 // blend even/odd slots 121 // 01234567, 89abcdef -> 092b4d6f 122 inline PD<8> 123 blend2(PD<8> a, PD<8> b) 124 { return _mm512_mask_blend_pd(0xaa, a.data, b.data); } 125 // FIXME: why isn't there an intrinsic that doesn't require a mask register? 126 127 // blend even/odd slot-pairs 128 // 01234567, 89abcdef -> 01ab45ef 129 inline PD<8> 130 blend4(PD<8> a, PD<8> b) 131 { return _mm512_mask_blend_pd(0xcc, a.data, b.data); } 132 // FIXME: why isn't there an intrinsic that doesn't require a mask register? 133 134 // res[i] = a[i] < b[i] ? a[i] : a[i]-b[i] 135 inline PD<8> 136 correct_excess(PD<8> a, PD<8> b) 137 { 138 __mmask8 k = _mm512_cmp_pd_mask(a.data, b.data, _CMP_GE_OQ); 139 return _mm512_mask_sub_pd(a.data, k, a.data, b.data); 140 } 141 142 // res[i] = a[i] >= 0 ? a[i] : a[i]+b[i] 143 inline PD<8> 144 correct_deficit(PD<8> a, PD<8> b) 145 { 146 __mmask8 k = _mm512_cmp_pd_mask(a.data, _mm512_setzero_pd(), _CMP_LT_OQ); 147 return _mm512_mask_add_pd(a.data, k, a.data, b.data); 148 } 149 150 inline void 151 clear(PD<8>& x) 152 { x.data = _mm512_setzero_pd(); } 153 154 inline PD<8> 155 operator+(PD<8> a, PD<8> b) 156 { return _mm512_add_pd(a.data, b.data); } 157 158 inline PD<8> 159 operator-(PD<8> a, PD<8> b) 160 { return _mm512_sub_pd(a.data, b.data); } 161 162 inline PD<8> 163 operator*(PD<8> a, PD<8> b) 164 { return _mm512_mul_pd(a.data, b.data); } 165 166 inline PD<8> 167 operator/(PD<8> a, PD<8> b) 168 { return _mm512_div_pd(a.data, b.data); } 169 170 inline PD<8>& 171 operator+=(PD<8>& a, PD<8> b) 172 { a = a + b; return a; } 173 174 inline PD<8>& 175 operator-=(PD<8>& a, PD<8> b) 176 { a = a - b; return a; } 177 178 inline PD<8>& 179 operator*=(PD<8>& a, PD<8> b) 180 { a = a * b; return a; } 181 182 inline PD<8>& 183 operator/=(PD<8>& a, PD<8> b) 184 { a = a / b; return a; } 185 186 // a*b+c (fused) 187 inline PD<8> 188 fused_muladd(PD<8> a, PD<8> b, PD<8> c) 189 { return _mm512_fmadd_pd(a.data, b.data, c.data); } 190 191 // a*b-c (fused) 192 inline PD<8> 193 fused_mulsub(PD<8> a, PD<8> b, PD<8> c) 194 { return _mm512_fmsub_pd(a.data, b.data, c.data); } 195 196 // -a*b+c (fused) 197 inline PD<8> 198 fused_negmuladd(PD<8> a, PD<8> b, PD<8> c) 199 { return _mm512_fnmadd_pd(a.data, b.data, c.data); } 200 201 #endif 202 203 //=================== PD<4> implementation =============== 204 205 #if (defined(NTL_HAVE_AVX2) && defined(NTL_HAVE_FMA)) 206 207 template<> 208 struct PD<4> { 209 __m256d data; 210 211 enum { size = 4}; 212 213 PD() { } 214 PD(double x) : data(_mm256_set1_pd(x)) { } 215 PD(__m256d _data) : data(_data) { } 216 PD(double d0, double d1, double d2, double d3) 217 : data(_mm256_set_pd(d3, d2, d1, d0)) { } 218 219 static PD load(const double *p) { return _mm256_load_pd(p); } 220 221 // load from unaligned address 222 static PD loadu(const double *p) { return _mm256_loadu_pd(p); } 223 }; 224 225 inline void 226 load(PD<4>& x, const double *p) 227 { x = PD<4>::load(p); } 228 229 // load from unaligned address 230 inline void 231 loadu(PD<4>& x, const double *p) 232 { x = PD<4>::loadu(p); } 233 234 inline void 235 store(double *p, PD<4> a) 236 { _mm256_store_pd(p, a.data); } 237 238 // store to unaligned address 239 inline void 240 storeu(double *p, PD<4> a) 241 { _mm256_storeu_pd(p, a.data); } 242 243 244 245 246 247 // The following assume all numbers are integers 248 // in the range [0, 2^52). The idea is taken from here: 249 // https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx 250 251 252 // Some of the Intel intrinsics for loading and storing packed 253 // integers from memory require casting between long* and __m256i*. 254 // Strictly speaking, this can break strict aliasing rules, but 255 // this is hopefully not a problem. 256 // See discussion here: 257 // https://stackoverflow.com/questions/24787268/how-to-implement-mm-storeu-epi64-without-aliasing-problems 258 259 260 // load and convert 261 inline void 262 load(PD<4>& x, const long *p) 263 { 264 #ifdef NTL_HAVE_AVX512F 265 __m256i a = _mm256_load_si256((const __m256i*)p); 266 x = _mm256_cvtepi64_pd(a); 267 #else 268 __m256i a = _mm256_load_si256((const __m256i*)p); 269 a = _mm256_or_si256(a, _mm256_castpd_si256(_mm256_set1_pd(1L << 52))); 270 x = _mm256_sub_pd(_mm256_castsi256_pd(a), _mm256_set1_pd(1L << 52)); 271 #endif 272 } 273 274 // load unaligned and convert 275 inline void 276 loadu(PD<4>& x, const long *p) 277 { 278 #ifdef NTL_HAVE_AVX512F 279 __m256i a = _mm256_loadu_si256((const __m256i*)p); x = _mm256_cvtepi64_pd(a); 280 #else 281 __m256i a = _mm256_loadu_si256((const __m256i*)p); 282 a = _mm256_or_si256(a, _mm256_castpd_si256(_mm256_set1_pd(1L << 52))); 283 x = _mm256_sub_pd(_mm256_castsi256_pd(a), _mm256_set1_pd(1L << 52)); 284 #endif 285 } 286 287 // convert and store 288 inline void 289 store(long *p, PD<4> a) 290 { 291 #ifdef NTL_HAVE_AVX512F 292 __m256i b = _mm256_cvtpd_epi64(a.data); 293 #ifdef __clang__ 294 _mm256_store_si256((__m256i*)p, b); 295 #else 296 // clang doesn't define this...why?? 297 _mm256_store_epi64(p, b); 298 #endif 299 #else 300 __m256d x = a.data; 301 x = _mm256_add_pd(x, _mm256_set1_pd(1L << 52)); 302 __m256i b = _mm256_xor_si256( 303 _mm256_castpd_si256(x), 304 _mm256_castpd_si256(_mm256_set1_pd(1L << 52))); 305 _mm256_store_si256((__m256i*)p, b); 306 #endif 307 } 308 309 // convert and store unaligned 310 inline void 311 storeu(long *p, PD<4> a) 312 { 313 #ifdef NTL_HAVE_AVX512F 314 __m256i b = _mm256_cvtpd_epi64(a.data); 315 _mm256_storeu_si256((__m256i*)p, b); 316 #else 317 __m256d x = a.data; 318 x = _mm256_add_pd(x, _mm256_set1_pd(1L << 52)); 319 __m256i b = _mm256_xor_si256( 320 _mm256_castpd_si256(x), 321 _mm256_castpd_si256(_mm256_set1_pd(1L << 52))); 322 _mm256_storeu_si256((__m256i*)p, b); 323 #endif 324 } 325 326 327 // swap even/odd slots 328 // e.g., 0123 -> 1032 329 inline PD<4> 330 swap2(PD<4> a) 331 { return _mm256_permute_pd(a.data, 0x5); } 332 333 // 0123 -> 0022 334 inline PD<4> 335 dup2even(PD<4> a) 336 { return _mm256_permute_pd(a.data, 0); } 337 338 // 0123 -> 1133 339 inline PD<4> 340 dup2odd(PD<4> a) 341 { return _mm256_permute_pd(a.data, 0xf); } 342 343 // blend even/odd slots 344 // 0123, 4567 -> 0527 345 inline PD<4> 346 blend2(PD<4> a, PD<4> b) 347 { return _mm256_blend_pd(a.data, b.data, 0xa); } 348 349 // res[i] = a[i] < b[i] ? a[i] : a[i]-b[i] 350 inline PD<4> 351 correct_excess(PD<4> a, PD<4> b) 352 { 353 #ifdef NTL_HAVE_AVX512F 354 __mmask8 k = _mm256_cmp_pd_mask(a.data, b.data, _CMP_GE_OQ); 355 return _mm256_mask_sub_pd(a.data, k, a.data, b.data); 356 #else 357 __m256d mask = _mm256_cmp_pd(a.data, b.data, _CMP_GE_OQ); 358 __m256d corrected = _mm256_sub_pd(a.data, b.data); 359 return _mm256_blendv_pd(a.data, corrected, mask); 360 #endif 361 } 362 363 // res[i] = a[i] >= 0 ? a[i] : a[i]+b[i] 364 inline PD<4> 365 correct_deficit(PD<4> a, PD<4> b) 366 { 367 #ifdef NTL_HAVE_AVX512F 368 __mmask8 k = _mm256_cmp_pd_mask(a.data, _mm256_setzero_pd(), _CMP_LT_OQ); 369 return _mm256_mask_add_pd(a.data, k, a.data, b.data); 370 #else 371 __m256d mask = _mm256_cmp_pd(a.data, _mm256_setzero_pd(), _CMP_LT_OQ); 372 __m256d corrected = _mm256_add_pd(a.data, b.data); 373 return _mm256_blendv_pd(a.data, corrected, mask); 374 #endif 375 } 376 377 inline void 378 clear(PD<4>& x) 379 { x.data = _mm256_setzero_pd(); } 380 381 inline PD<4> 382 operator+(PD<4> a, PD<4> b) 383 { return _mm256_add_pd(a.data, b.data); } 384 385 inline PD<4> 386 operator-(PD<4> a, PD<4> b) 387 { return _mm256_sub_pd(a.data, b.data); } 388 389 inline PD<4> 390 operator*(PD<4> a, PD<4> b) 391 { return _mm256_mul_pd(a.data, b.data); } 392 393 inline PD<4> 394 operator/(PD<4> a, PD<4> b) 395 { return _mm256_div_pd(a.data, b.data); } 396 397 inline PD<4>& 398 operator+=(PD<4>& a, PD<4> b) 399 { a = a + b; return a; } 400 401 inline PD<4>& 402 operator-=(PD<4>& a, PD<4> b) 403 { a = a - b; return a; } 404 405 inline PD<4>& 406 operator*=(PD<4>& a, PD<4> b) 407 { a = a * b; return a; } 408 409 inline PD<4>& 410 operator/=(PD<4>& a, PD<4> b) 411 { a = a / b; return a; } 412 413 // a*b+c (fused) 414 inline PD<4> 415 fused_muladd(PD<4> a, PD<4> b, PD<4> c) 416 { return _mm256_fmadd_pd(a.data, b.data, c.data); } 417 418 // a*b-c (fused) 419 inline PD<4> 420 fused_mulsub(PD<4> a, PD<4> b, PD<4> c) 421 { return _mm256_fmsub_pd(a.data, b.data, c.data); } 422 423 // -a*b+c (fused) 424 inline PD<4> 425 fused_negmuladd(PD<4> a, PD<4> b, PD<4> c) 426 { return _mm256_fnmadd_pd(a.data, b.data, c.data); } 427 428 429 //=================== PD<2> implementation =============== 430 431 432 template<> 433 struct PD<2> { 434 __m128d data; 435 436 enum { size = 2}; 437 438 PD() { } 439 PD(double x) : data(_mm_set1_pd(x)) { } 440 PD(__m128d _data) : data(_data) { } 441 PD(double d0, double d1) 442 : data(_mm_set_pd(d1, d0)) { } 443 444 static PD load(const double *p) { return _mm_load_pd(p); } 445 446 // load from unaligned address 447 static PD loadu(const double *p) { return _mm_loadu_pd(p); } 448 }; 449 450 inline void 451 load(PD<2>& x, const double *p) 452 { x = PD<2>::load(p); } 453 454 // load from unaligned address 455 inline void 456 loadu(PD<2>& x, const double *p) 457 { x = PD<2>::loadu(p); } 458 459 inline void 460 store(double *p, PD<2> a) 461 { _mm_store_pd(p, a.data); } 462 463 // store to unaligned address 464 inline void 465 storeu(double *p, PD<2> a) 466 { _mm_storeu_pd(p, a.data); } 467 468 469 470 471 472 // The following assume all numbers are integers 473 // in the range [0, 2^52). The idea is taken from here: 474 // https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx 475 476 // load and convert 477 inline void 478 load(PD<2>& x, const long *p) 479 { 480 #ifdef NTL_HAVE_AVX512F 481 __m128i a = _mm_load_si128((const __m128i*)p); 482 x = _mm_cvtepi64_pd(a); 483 #else 484 __m128i a = _mm_load_si128((const __m128i*)p); 485 a = _mm_or_si128(a, _mm_castpd_si128(_mm_set1_pd(1L << 52))); 486 x = _mm_sub_pd(_mm_castsi128_pd(a), _mm_set1_pd(1L << 52)); 487 #endif 488 } 489 490 // load unaligned and convert 491 inline void 492 loadu(PD<2>& x, const long *p) 493 { 494 #ifdef NTL_HAVE_AVX512F 495 __m128i a = _mm_loadu_si128((const __m128i*)p); x = _mm_cvtepi64_pd(a); 496 #else 497 __m128i a = _mm_loadu_si128((const __m128i*)p); 498 a = _mm_or_si128(a, _mm_castpd_si128(_mm_set1_pd(1L << 52))); 499 x = _mm_sub_pd(_mm_castsi128_pd(a), _mm_set1_pd(1L << 52)); 500 #endif 501 } 502 503 // convert and store 504 inline void 505 store(long *p, PD<2> a) 506 { 507 #ifdef NTL_HAVE_AVX512F 508 __m128i b = _mm_cvtpd_epi64(a.data); 509 #ifdef __clang__ 510 _mm_store_si128((__m128i*)p, b); 511 #else 512 // clang doesn't define this...why?? 513 _mm_store_epi64(p, b); 514 #endif 515 #else 516 __m128d x = a.data; 517 x = _mm_add_pd(x, _mm_set1_pd(1L << 52)); 518 __m128i b = _mm_xor_si128( 519 _mm_castpd_si128(x), 520 _mm_castpd_si128(_mm_set1_pd(1L << 52))); 521 _mm_store_si128((__m128i*)p, b); 522 #endif 523 } 524 525 // convert and store unaligned 526 inline void 527 storeu(long *p, PD<2> a) 528 { 529 #ifdef NTL_HAVE_AVX512F 530 __m128i b = _mm_cvtpd_epi64(a.data); 531 _mm_storeu_si128((__m128i*)p, b); 532 #else 533 __m128d x = a.data; 534 x = _mm_add_pd(x, _mm_set1_pd(1L << 52)); 535 __m128i b = _mm_xor_si128( 536 _mm_castpd_si128(x), 537 _mm_castpd_si128(_mm_set1_pd(1L << 52))); 538 _mm_storeu_si128((__m128i*)p, b); 539 #endif 540 } 541 542 543 // res[i] = a[i] < b[i] ? a[i] : a[i]-b[i] 544 inline PD<2> 545 correct_excess(PD<2> a, PD<2> b) 546 { 547 #ifdef NTL_HAVE_AVX512F 548 __mmask8 k = _mm_cmp_pd_mask(a.data, b.data, _CMP_GE_OQ); 549 return _mm_mask_sub_pd(a.data, k, a.data, b.data); 550 #else 551 __m128d mask = _mm_cmp_pd(a.data, b.data, _CMP_GE_OQ); 552 __m128d corrected = _mm_sub_pd(a.data, b.data); 553 return _mm_blendv_pd(a.data, corrected, mask); 554 #endif 555 } 556 557 // res[i] = a[i] >= 0 ? a[i] : a[i]+b[i] 558 inline PD<2> 559 correct_deficit(PD<2> a, PD<2> b) 560 { 561 #ifdef NTL_HAVE_AVX512F 562 __mmask8 k = _mm_cmp_pd_mask(a.data, _mm_setzero_pd(), _CMP_LT_OQ); 563 return _mm_mask_add_pd(a.data, k, a.data, b.data); 564 #else 565 __m128d mask = _mm_cmp_pd(a.data, _mm_setzero_pd(), _CMP_LT_OQ); 566 __m128d corrected = _mm_add_pd(a.data, b.data); 567 return _mm_blendv_pd(a.data, corrected, mask); 568 #endif 569 } 570 571 inline void 572 clear(PD<2>& x) 573 { x.data = _mm_setzero_pd(); } 574 575 inline PD<2> 576 operator+(PD<2> a, PD<2> b) 577 { return _mm_add_pd(a.data, b.data); } 578 579 inline PD<2> 580 operator-(PD<2> a, PD<2> b) 581 { return _mm_sub_pd(a.data, b.data); } 582 583 inline PD<2> 584 operator*(PD<2> a, PD<2> b) 585 { return _mm_mul_pd(a.data, b.data); } 586 587 inline PD<2> 588 operator/(PD<2> a, PD<2> b) 589 { return _mm_div_pd(a.data, b.data); } 590 591 inline PD<2>& 592 operator+=(PD<2>& a, PD<2> b) 593 { a = a + b; return a; } 594 595 inline PD<2>& 596 operator-=(PD<2>& a, PD<2> b) 597 { a = a - b; return a; } 598 599 inline PD<2>& 600 operator*=(PD<2>& a, PD<2> b) 601 { a = a * b; return a; } 602 603 inline PD<2>& 604 operator/=(PD<2>& a, PD<2> b) 605 { a = a / b; return a; } 606 607 // a*b+c (fused) 608 inline PD<2> 609 fused_muladd(PD<2> a, PD<2> b, PD<2> c) 610 { return _mm_fmadd_pd(a.data, b.data, c.data); } 611 612 // a*b-c (fused) 613 inline PD<2> 614 fused_mulsub(PD<2> a, PD<2> b, PD<2> c) 615 { return _mm_fmsub_pd(a.data, b.data, c.data); } 616 617 // -a*b+c (fused) 618 inline PD<2> 619 fused_negmuladd(PD<2> a, PD<2> b, PD<2> c) 620 { return _mm_fnmadd_pd(a.data, b.data, c.data); } 621 622 623 624 625 //================== PD<8>/PD<4> conversions ================ 626 627 #ifdef NTL_HAVE_AVX512F 628 629 // 0123, 4567 -> 01234567 630 inline PD<8> 631 join(PD<4> a, PD<4> b) 632 { 633 __m512d c = _mm512_castpd256_pd512(a.data); 634 return _mm512_insertf64x4(c, b.data, 1); 635 } 636 637 // 01234567 -> 0123 638 inline PD<4> 639 get_lo(PD<8> a) 640 { return _mm512_extractf64x4_pd(a.data, 0); } 641 642 // 01234567 -> 4567 643 inline PD<4> 644 get_hi(PD<8> a) 645 { return _mm512_extractf64x4_pd(a.data, 1); } 646 647 #endif 648 649 //================== PD<4>/PD<2> conversions ================ 650 651 // 01, 23 -> 0123 652 inline PD<4> 653 join(PD<2> a, PD<2> b) 654 #if 0 655 // some versions of gcc are buggy and don't define this function 656 { return _mm256_set_m128d(b.data, a.data); } 657 #else 658 { return _mm256_insertf128_pd(_mm256_castpd128_pd256(a.data), b.data, 1); } 659 #endif 660 661 662 // 0123 -> 01 663 inline PD<2> 664 get_lo(PD<4> a) 665 { return _mm256_extractf128_pd(a.data, 0); } 666 667 // 0123 -> 23 668 inline PD<2> 669 get_hi(PD<4> a) 670 { return _mm256_extractf128_pd(a.data, 1); } 671 672 673 #endif 674 675 676 NTL_CLOSE_NNS 677 678 679 #endif 680