1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #pragma once
5
6 #include <immintrin.h>
7 #include <stdint.h>
8
9 #include "eltwise/eltwise-cmp-sub-mod-internal.hpp"
10 #include "hexl/number-theory/number-theory.hpp"
11 #include "hexl/util/check.hpp"
12 #include "util/avx512-util.hpp"
13
14 namespace intel {
15 namespace hexl {
16
17 #ifdef HEXL_HAS_AVX512DQ
18 template <int BitShift>
EltwiseCmpSubModAVX512(uint64_t * result,const uint64_t * operand1,uint64_t n,uint64_t modulus,CMPINT cmp,uint64_t bound,uint64_t diff)19 void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
20 uint64_t n, uint64_t modulus, CMPINT cmp,
21 uint64_t bound, uint64_t diff) {
22 HEXL_CHECK(result != nullptr, "Require result != nullptr");
23 HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr");
24 HEXL_CHECK(n != 0, "Require n != 0")
25 HEXL_CHECK(modulus > 1, "Require modulus > 1");
26 HEXL_CHECK(diff != 0, "Require diff != 0");
27 HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);
28
29 uint64_t n_mod_8 = n % 8;
30 if (n_mod_8 != 0) {
31 EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound,
32 diff);
33 operand1 += n_mod_8;
34 result += n_mod_8;
35 n -= n_mod_8;
36 }
37 HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);
38
39 const __m512i* v_op_ptr = reinterpret_cast<const __m512i*>(operand1);
40 __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result);
41 __m512i v_bound = _mm512_set1_epi64(static_cast<int64_t>(bound));
42 __m512i v_diff = _mm512_set1_epi64(static_cast<int64_t>(diff));
43 __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
44
45 uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor();
46 __m512i v_mu = _mm512_set1_epi64(static_cast<int64_t>(mu));
47
48 // Multi-word Barrett reduction precomputation
49 constexpr int64_t beta = -2;
50 const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2
51 uint64_t prod_right_shift = ceil_log_mod + beta;
52 __m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
53
54 uint64_t alpha = BitShift - 2;
55 uint64_t mu_64 =
56 MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift,
57 modulus)
58 .BarrettFactor();
59
60 if (BitShift == 64) {
61 // Single-worded Barrett reduction.
62 mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor();
63 }
64
65 __m512i v_mu_64 = _mm512_set1_epi64(static_cast<int64_t>(mu_64));
66
67 for (size_t i = n / 8; i > 0; --i) {
68 __m512i v_op = _mm512_loadu_si512(v_op_ptr);
69 __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));
70
71 v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
72 v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod);
73
74 __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
75 v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
76 v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0);
77
78 v_op = _mm512_add_epi64(v_op, v_to_add);
79 _mm512_storeu_si512(v_result_ptr, v_op);
80 ++v_op_ptr;
81 ++v_result_ptr;
82 }
83 }
84 #endif
85
86 } // namespace hexl
87 } // namespace intel
88