1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #include <immintrin.h>
7 #include <stdint.h>
8 
9 #include "eltwise/eltwise-cmp-sub-mod-internal.hpp"
10 #include "hexl/number-theory/number-theory.hpp"
11 #include "hexl/util/check.hpp"
12 #include "util/avx512-util.hpp"
13 
14 namespace intel {
15 namespace hexl {
16 
17 #ifdef HEXL_HAS_AVX512DQ
18 template <int BitShift>
EltwiseCmpSubModAVX512(uint64_t * result,const uint64_t * operand1,uint64_t n,uint64_t modulus,CMPINT cmp,uint64_t bound,uint64_t diff)19 void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
20                             uint64_t n, uint64_t modulus, CMPINT cmp,
21                             uint64_t bound, uint64_t diff) {
22   HEXL_CHECK(result != nullptr, "Require result != nullptr");
23   HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr");
24   HEXL_CHECK(n != 0, "Require n != 0")
25   HEXL_CHECK(modulus > 1, "Require modulus > 1");
26   HEXL_CHECK(diff != 0, "Require diff != 0");
27   HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);
28 
29   uint64_t n_mod_8 = n % 8;
30   if (n_mod_8 != 0) {
31     EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound,
32                            diff);
33     operand1 += n_mod_8;
34     result += n_mod_8;
35     n -= n_mod_8;
36   }
37   HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);
38 
39   const __m512i* v_op_ptr = reinterpret_cast<const __m512i*>(operand1);
40   __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result);
41   __m512i v_bound = _mm512_set1_epi64(static_cast<int64_t>(bound));
42   __m512i v_diff = _mm512_set1_epi64(static_cast<int64_t>(diff));
43   __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
44 
45   uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor();
46   __m512i v_mu = _mm512_set1_epi64(static_cast<int64_t>(mu));
47 
48   // Multi-word Barrett reduction precomputation
49   constexpr int64_t beta = -2;
50   const uint64_t ceil_log_mod = Log2(modulus) + 1;  // "n" from Algorithm 2
51   uint64_t prod_right_shift = ceil_log_mod + beta;
52   __m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
53 
54   uint64_t alpha = BitShift - 2;
55   uint64_t mu_64 =
56       MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift,
57                      modulus)
58           .BarrettFactor();
59 
60   if (BitShift == 64) {
61     // Single-worded Barrett reduction.
62     mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor();
63   }
64 
65   __m512i v_mu_64 = _mm512_set1_epi64(static_cast<int64_t>(mu_64));
66 
67   for (size_t i = n / 8; i > 0; --i) {
68     __m512i v_op = _mm512_loadu_si512(v_op_ptr);
69     __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));
70 
71     v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
72         v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod);
73 
74     __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
75     v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
76     v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0);
77 
78     v_op = _mm512_add_epi64(v_op, v_to_add);
79     _mm512_storeu_si512(v_result_ptr, v_op);
80     ++v_op_ptr;
81     ++v_result_ptr;
82   }
83 }
84 #endif
85 
86 }  // namespace hexl
87 }  // namespace intel
88