1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #include "eltwise/eltwise-fma-mod-avx512.hpp"
5 
6 #include <immintrin.h>
7 
8 #include "hexl/eltwise/eltwise-fma-mod.hpp"
9 #include "hexl/number-theory/number-theory.hpp"
10 #include "hexl/util/check.hpp"
11 #include "util/avx512-util.hpp"
12 
13 namespace intel {
14 namespace hexl {
15 
16 #ifdef HEXL_HAS_AVX512IFMA
17 template void EltwiseFMAModAVX512<52, 1>(uint64_t* result, const uint64_t* arg1,
18                                          uint64_t arg2, const uint64_t* arg3,
19                                          uint64_t n, uint64_t modulus);
20 template void EltwiseFMAModAVX512<52, 2>(uint64_t* result, const uint64_t* arg1,
21                                          uint64_t arg2, const uint64_t* arg3,
22                                          uint64_t n, uint64_t modulus);
23 template void EltwiseFMAModAVX512<52, 4>(uint64_t* result, const uint64_t* arg1,
24                                          uint64_t arg2, const uint64_t* arg3,
25                                          uint64_t n, uint64_t modulus);
26 template void EltwiseFMAModAVX512<52, 8>(uint64_t* result, const uint64_t* arg1,
27                                          uint64_t arg2, const uint64_t* arg3,
28                                          uint64_t n, uint64_t modulus);
29 #endif
30 
31 #ifdef HEXL_HAS_AVX512DQ
32 template void EltwiseFMAModAVX512<64, 1>(uint64_t* result, const uint64_t* arg1,
33                                          uint64_t arg2, const uint64_t* arg3,
34                                          uint64_t n, uint64_t modulus);
35 template void EltwiseFMAModAVX512<64, 2>(uint64_t* result, const uint64_t* arg1,
36                                          uint64_t arg2, const uint64_t* arg3,
37                                          uint64_t n, uint64_t modulus);
38 template void EltwiseFMAModAVX512<64, 4>(uint64_t* result, const uint64_t* arg1,
39                                          uint64_t arg2, const uint64_t* arg3,
40                                          uint64_t n, uint64_t modulus);
41 template void EltwiseFMAModAVX512<64, 8>(uint64_t* result, const uint64_t* arg1,
42                                          uint64_t arg2, const uint64_t* arg3,
43                                          uint64_t n, uint64_t modulus);
44 
45 #endif
46 
47 #ifdef HEXL_HAS_AVX512DQ
48 
49 /// uses Shoup's modular multiplication. See Algorithm 4 of
50 /// https://arxiv.org/pdf/2012.01968.pdf
51 template <int BitShift, int InputModFactor>
EltwiseFMAModAVX512(uint64_t * result,const uint64_t * arg1,uint64_t arg2,const uint64_t * arg3,uint64_t n,uint64_t modulus)52 void EltwiseFMAModAVX512(uint64_t* result, const uint64_t* arg1, uint64_t arg2,
53                          const uint64_t* arg3, uint64_t n, uint64_t modulus) {
54   HEXL_CHECK(modulus < MaximumValue(BitShift),
55              "Modulus " << modulus << " exceeds bit shift bound "
56                         << MaximumValue(BitShift));
57   HEXL_CHECK(modulus != 0, "Require modulus != 0");
58 
59   HEXL_CHECK(arg1, "arg1 == nullptr");
60   HEXL_CHECK(result, "result == nullptr");
61 
62   HEXL_CHECK_BOUNDS(arg1, n, InputModFactor * modulus,
63                     "arg1 exceeds bound " << (InputModFactor * modulus));
64   HEXL_CHECK_BOUNDS(&arg2, 1, InputModFactor * modulus,
65                     "arg2 exceeds bound " << (InputModFactor * modulus));
66   HEXL_CHECK(BitShift == 52 || BitShift == 64,
67              "Invalid bitshift " << BitShift << "; need 52 or 64");
68 
69   uint64_t n_mod_8 = n % 8;
70   if (n_mod_8 != 0) {
71     EltwiseFMAModNative<InputModFactor>(result, arg1, arg2, arg3, n_mod_8,
72                                         modulus);
73     arg1 += n_mod_8;
74     if (arg3 != nullptr) {
75       arg3 += n_mod_8;
76     }
77     result += n_mod_8;
78     n -= n_mod_8;
79   }
80 
81   uint64_t twice_modulus = 2 * modulus;
82   uint64_t four_times_modulus = 4 * modulus;
83   arg2 = ReduceMod<InputModFactor>(arg2, modulus, &twice_modulus,
84                                    &four_times_modulus);
85   uint64_t arg2_barr = MultiplyFactor(arg2, BitShift, modulus).BarrettFactor();
86 
87   __m512i varg2_barr = _mm512_set1_epi64(static_cast<int64_t>(arg2_barr));
88 
89   __m512i vmodulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
90   __m512i vneg_modulus = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
91   __m512i v2_modulus = _mm512_set1_epi64(static_cast<int64_t>(2 * modulus));
92   __m512i v4_modulus = _mm512_set1_epi64(static_cast<int64_t>(4 * modulus));
93   const __m512i* vp_arg1 = reinterpret_cast<const __m512i*>(arg1);
94   __m512i varg2 = _mm512_set1_epi64(static_cast<int64_t>(arg2));
95   varg2 = _mm512_hexl_small_mod_epu64<InputModFactor>(varg2, vmodulus,
96                                                       &v2_modulus, &v4_modulus);
97 
98   __m512i* vp_result = reinterpret_cast<__m512i*>(result);
99 
100   if (arg3) {
101     const __m512i* vp_arg3 = reinterpret_cast<const __m512i*>(arg3);
102     HEXL_LOOP_UNROLL_8
103     for (size_t i = n / 8; i > 0; --i) {
104       __m512i varg1 = _mm512_loadu_si512(vp_arg1);
105       __m512i varg3 = _mm512_loadu_si512(vp_arg3);
106 
107       varg1 = _mm512_hexl_small_mod_epu64<InputModFactor>(
108           varg1, vmodulus, &v2_modulus, &v4_modulus);
109       varg3 = _mm512_hexl_small_mod_epu64<InputModFactor>(
110           varg3, vmodulus, &v2_modulus, &v4_modulus);
111 
112       __m512i va_times_b = _mm512_hexl_mullo_epi<BitShift>(varg1, varg2);
113       __m512i vq = _mm512_hexl_mulhi_epi<BitShift>(varg1, varg2_barr);
114 
115       // Compute vq in [0, 2 * p) where p is the modulus
116       // a * b - q * p
117       vq = _mm512_hexl_mullo_add_lo_epi<BitShift>(va_times_b, vq, vneg_modulus);
118 
119       // Add arg3, bringing vq to [0, 3 * p)
120       vq = _mm512_add_epi64(vq, varg3);
121       // Reduce to [0, p)
122       vq = _mm512_hexl_small_mod_epu64<4>(vq, vmodulus, &v2_modulus);
123 
124       _mm512_storeu_si512(vp_result, vq);
125 
126       ++vp_arg1;
127       ++vp_result;
128       ++vp_arg3;
129     }
130   } else {  // arg3 == nullptr
131     HEXL_LOOP_UNROLL_8
132     for (size_t i = n / 8; i > 0; --i) {
133       __m512i varg1 = _mm512_loadu_si512(vp_arg1);
134       varg1 = _mm512_hexl_small_mod_epu64<InputModFactor>(
135           varg1, vmodulus, &v2_modulus, &v4_modulus);
136 
137       __m512i va_times_b = _mm512_hexl_mullo_epi<BitShift>(varg1, varg2);
138       __m512i vq = _mm512_hexl_mulhi_epi<BitShift>(varg1, varg2_barr);
139 
140       // Compute vq in [0, 2 * p) where p is the modulus
141       // a * b - q * p
142       vq = _mm512_hexl_mullo_add_lo_epi<BitShift>(va_times_b, vq, vneg_modulus);
143       // Conditional Barrett subtraction
144       vq = _mm512_hexl_small_mod_epu64(vq, vmodulus);
145       _mm512_storeu_si512(vp_result, vq);
146 
147       ++vp_arg1;
148       ++vp_result;
149     }
150   }
151 }
152 
153 #endif
154 
155 }  // namespace hexl
156 }  // namespace intel
157