1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #include "eltwise/eltwise-fma-mod-avx512.hpp"
5
6 #include <immintrin.h>
7
8 #include "hexl/eltwise/eltwise-fma-mod.hpp"
9 #include "hexl/number-theory/number-theory.hpp"
10 #include "hexl/util/check.hpp"
11 #include "util/avx512-util.hpp"
12
13 namespace intel {
14 namespace hexl {
15
16 #ifdef HEXL_HAS_AVX512IFMA
17 template void EltwiseFMAModAVX512<52, 1>(uint64_t* result, const uint64_t* arg1,
18 uint64_t arg2, const uint64_t* arg3,
19 uint64_t n, uint64_t modulus);
20 template void EltwiseFMAModAVX512<52, 2>(uint64_t* result, const uint64_t* arg1,
21 uint64_t arg2, const uint64_t* arg3,
22 uint64_t n, uint64_t modulus);
23 template void EltwiseFMAModAVX512<52, 4>(uint64_t* result, const uint64_t* arg1,
24 uint64_t arg2, const uint64_t* arg3,
25 uint64_t n, uint64_t modulus);
26 template void EltwiseFMAModAVX512<52, 8>(uint64_t* result, const uint64_t* arg1,
27 uint64_t arg2, const uint64_t* arg3,
28 uint64_t n, uint64_t modulus);
29 #endif
30
31 #ifdef HEXL_HAS_AVX512DQ
32 template void EltwiseFMAModAVX512<64, 1>(uint64_t* result, const uint64_t* arg1,
33 uint64_t arg2, const uint64_t* arg3,
34 uint64_t n, uint64_t modulus);
35 template void EltwiseFMAModAVX512<64, 2>(uint64_t* result, const uint64_t* arg1,
36 uint64_t arg2, const uint64_t* arg3,
37 uint64_t n, uint64_t modulus);
38 template void EltwiseFMAModAVX512<64, 4>(uint64_t* result, const uint64_t* arg1,
39 uint64_t arg2, const uint64_t* arg3,
40 uint64_t n, uint64_t modulus);
41 template void EltwiseFMAModAVX512<64, 8>(uint64_t* result, const uint64_t* arg1,
42 uint64_t arg2, const uint64_t* arg3,
43 uint64_t n, uint64_t modulus);
44
45 #endif
46
47 #ifdef HEXL_HAS_AVX512DQ
48
49 /// uses Shoup's modular multiplication. See Algorithm 4 of
50 /// https://arxiv.org/pdf/2012.01968.pdf
51 template <int BitShift, int InputModFactor>
EltwiseFMAModAVX512(uint64_t * result,const uint64_t * arg1,uint64_t arg2,const uint64_t * arg3,uint64_t n,uint64_t modulus)52 void EltwiseFMAModAVX512(uint64_t* result, const uint64_t* arg1, uint64_t arg2,
53 const uint64_t* arg3, uint64_t n, uint64_t modulus) {
54 HEXL_CHECK(modulus < MaximumValue(BitShift),
55 "Modulus " << modulus << " exceeds bit shift bound "
56 << MaximumValue(BitShift));
57 HEXL_CHECK(modulus != 0, "Require modulus != 0");
58
59 HEXL_CHECK(arg1, "arg1 == nullptr");
60 HEXL_CHECK(result, "result == nullptr");
61
62 HEXL_CHECK_BOUNDS(arg1, n, InputModFactor * modulus,
63 "arg1 exceeds bound " << (InputModFactor * modulus));
64 HEXL_CHECK_BOUNDS(&arg2, 1, InputModFactor * modulus,
65 "arg2 exceeds bound " << (InputModFactor * modulus));
66 HEXL_CHECK(BitShift == 52 || BitShift == 64,
67 "Invalid bitshift " << BitShift << "; need 52 or 64");
68
69 uint64_t n_mod_8 = n % 8;
70 if (n_mod_8 != 0) {
71 EltwiseFMAModNative<InputModFactor>(result, arg1, arg2, arg3, n_mod_8,
72 modulus);
73 arg1 += n_mod_8;
74 if (arg3 != nullptr) {
75 arg3 += n_mod_8;
76 }
77 result += n_mod_8;
78 n -= n_mod_8;
79 }
80
81 uint64_t twice_modulus = 2 * modulus;
82 uint64_t four_times_modulus = 4 * modulus;
83 arg2 = ReduceMod<InputModFactor>(arg2, modulus, &twice_modulus,
84 &four_times_modulus);
85 uint64_t arg2_barr = MultiplyFactor(arg2, BitShift, modulus).BarrettFactor();
86
87 __m512i varg2_barr = _mm512_set1_epi64(static_cast<int64_t>(arg2_barr));
88
89 __m512i vmodulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
90 __m512i vneg_modulus = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
91 __m512i v2_modulus = _mm512_set1_epi64(static_cast<int64_t>(2 * modulus));
92 __m512i v4_modulus = _mm512_set1_epi64(static_cast<int64_t>(4 * modulus));
93 const __m512i* vp_arg1 = reinterpret_cast<const __m512i*>(arg1);
94 __m512i varg2 = _mm512_set1_epi64(static_cast<int64_t>(arg2));
95 varg2 = _mm512_hexl_small_mod_epu64<InputModFactor>(varg2, vmodulus,
96 &v2_modulus, &v4_modulus);
97
98 __m512i* vp_result = reinterpret_cast<__m512i*>(result);
99
100 if (arg3) {
101 const __m512i* vp_arg3 = reinterpret_cast<const __m512i*>(arg3);
102 HEXL_LOOP_UNROLL_8
103 for (size_t i = n / 8; i > 0; --i) {
104 __m512i varg1 = _mm512_loadu_si512(vp_arg1);
105 __m512i varg3 = _mm512_loadu_si512(vp_arg3);
106
107 varg1 = _mm512_hexl_small_mod_epu64<InputModFactor>(
108 varg1, vmodulus, &v2_modulus, &v4_modulus);
109 varg3 = _mm512_hexl_small_mod_epu64<InputModFactor>(
110 varg3, vmodulus, &v2_modulus, &v4_modulus);
111
112 __m512i va_times_b = _mm512_hexl_mullo_epi<BitShift>(varg1, varg2);
113 __m512i vq = _mm512_hexl_mulhi_epi<BitShift>(varg1, varg2_barr);
114
115 // Compute vq in [0, 2 * p) where p is the modulus
116 // a * b - q * p
117 vq = _mm512_hexl_mullo_add_lo_epi<BitShift>(va_times_b, vq, vneg_modulus);
118
119 // Add arg3, bringing vq to [0, 3 * p)
120 vq = _mm512_add_epi64(vq, varg3);
121 // Reduce to [0, p)
122 vq = _mm512_hexl_small_mod_epu64<4>(vq, vmodulus, &v2_modulus);
123
124 _mm512_storeu_si512(vp_result, vq);
125
126 ++vp_arg1;
127 ++vp_result;
128 ++vp_arg3;
129 }
130 } else { // arg3 == nullptr
131 HEXL_LOOP_UNROLL_8
132 for (size_t i = n / 8; i > 0; --i) {
133 __m512i varg1 = _mm512_loadu_si512(vp_arg1);
134 varg1 = _mm512_hexl_small_mod_epu64<InputModFactor>(
135 varg1, vmodulus, &v2_modulus, &v4_modulus);
136
137 __m512i va_times_b = _mm512_hexl_mullo_epi<BitShift>(varg1, varg2);
138 __m512i vq = _mm512_hexl_mulhi_epi<BitShift>(varg1, varg2_barr);
139
140 // Compute vq in [0, 2 * p) where p is the modulus
141 // a * b - q * p
142 vq = _mm512_hexl_mullo_add_lo_epi<BitShift>(va_times_b, vq, vneg_modulus);
143 // Conditional Barrett subtraction
144 vq = _mm512_hexl_small_mod_epu64(vq, vmodulus);
145 _mm512_storeu_si512(vp_result, vq);
146
147 ++vp_arg1;
148 ++vp_result;
149 }
150 }
151 }
152
153 #endif
154
155 } // namespace hexl
156 } // namespace intel
157