1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #pragma once
5
6 #ifdef HEXL_USE_MSVC
7
8 #define NOMINMAX // Avoid errors with std::min/std::max
9 #undef min
10 #undef max
11
12 #include <immintrin.h>
13 #include <intrin.h>
14 #include <stdint.h>
15
16 #include <cmath>
17
18 #include "hexl/util/check.hpp"
19
20 #pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \
21 _umul128)
22
23 #undef TRUE
24 #undef FALSE
25
26 namespace intel {
27 namespace hexl {
28
BarrettReduce128(uint64_t input_hi,uint64_t input_lo,uint64_t modulus)29 inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo,
30 uint64_t modulus) {
31 HEXL_CHECK(modulus != 0, "modulus == 0")
32 uint64_t remainder;
33 _udiv128(input_hi, input_lo, modulus, &remainder);
34
35 return remainder;
36 }
37
38 // Multiplies x * y as 128-bit integer.
39 // @param prod_hi Stores high 64 bits of product
40 // @param prod_lo Stores low 64 bits of product
MultiplyUInt64(uint64_t x,uint64_t y,uint64_t * prod_hi,uint64_t * prod_lo)41 inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi,
42 uint64_t* prod_lo) {
43 *prod_lo = _umul128(x, y, prod_hi);
44 }
45
46 // Return the high 128 minus BitShift bits of the 128-bit product x * y
47 template <int BitShift>
MultiplyUInt64Hi(uint64_t x,uint64_t y)48 inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) {
49 HEXL_CHECK(BitShift == 52 || BitShift == 64,
50 "Invalid BitShift " << BitShift << "; expected 52 or 64");
51 uint64_t prod_hi;
52 uint64_t prod_lo = _umul128(x, y, &prod_hi);
53 uint64_t result_hi;
54 uint64_t result_lo;
55 RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift);
56 return result_lo;
57 }
58
59 /// @brief Computes Left Shift op as 128-bit unsigned integer
60 /// @param[out] result_hi Stores high 64 bits of result
61 /// @param[out] result_lo Stores low 64 bits of result
62 /// @param[in] op_hi Stores high 64 bits of input
63 /// @param[in] op_lo Stores low 64 bits of input
LeftShift128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op_hi,const uint64_t op_lo,const uint64_t shift_value)64 inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo,
65 const uint64_t op_hi, const uint64_t op_lo,
66 const uint64_t shift_value) {
67 HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
68 HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
69 HEXL_CHECK(shift_value <= 128,
70 "shift_value cannot be greater than 128 " << shift_value);
71
72 if (shift_value == 0) {
73 *result_hi = op_hi;
74 *result_lo = op_lo;
75 } else if (shift_value == 64) {
76 *result_hi = op_lo;
77 *result_lo = 0ULL;
78 } else if (shift_value == 128) {
79 *result_hi = 0ULL;
80 *result_lo = 0ULL;
81 } else if (shift_value >= 1 && shift_value <= 63) {
82 *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value));
83 *result_lo = op_lo << shift_value;
84 } else if (shift_value >= 65 && shift_value < 128) {
85 *result_hi = op_lo << (shift_value - 64);
86 *result_lo = 0ULL;
87 }
88 }
89
90 /// @brief Computes Right Shift op as 128-bit unsigned integer
91 /// @param[out] result_hi Stores high 64 bits of result
92 /// @param[out] result_lo Stores low 64 bits of result
93 /// @param[in] op_hi Stores high 64 bits of input
94 /// @param[in] op_lo Stores low 64 bits of input
RightShift128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op_hi,const uint64_t op_lo,const uint64_t shift_value)95 inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo,
96 const uint64_t op_hi, const uint64_t op_lo,
97 const uint64_t shift_value) {
98 HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
99 HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
100 HEXL_CHECK(shift_value <= 128,
101 "shift_value cannot be greater than 128 " << shift_value);
102
103 if (shift_value == 0) {
104 *result_hi = op_hi;
105 *result_lo = op_lo;
106 } else if (shift_value == 64) {
107 *result_hi = 0ULL;
108 *result_lo = op_hi;
109 } else if (shift_value == 128) {
110 *result_hi = 0ULL;
111 *result_lo = 0ULL;
112 } else if (shift_value >= 1 && shift_value <= 63) {
113 *result_hi = op_hi >> shift_value;
114 *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value);
115 } else if (shift_value >= 65 && shift_value < 128) {
116 *result_hi = 0ULL;
117 *result_lo = op_hi >> (shift_value - 64);
118 }
119 }
120
121 /// @brief Adds op1 + op2 as 128-bit integer
122 /// @param[out] result_hi Stores high 64 bits of result
123 /// @param[out] result_lo Stores low 64 bits of result
AddWithCarry128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op1_hi,const uint64_t op1_lo,const uint64_t op2_hi,const uint64_t op2_lo)124 inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
125 const uint64_t op1_hi, const uint64_t op1_lo,
126 const uint64_t op2_hi, const uint64_t op2_lo) {
127 HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
128 HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
129
130 // first 64bit block
131 *result_lo = op1_lo + op2_lo;
132 unsigned char carry = static_cast<unsigned char>(*result_lo < op1_lo);
133
134 // second 64bit block
135 _addcarry_u64(carry, op1_hi, op2_hi, result_hi);
136 }
137
138 /// @brief Subtracts op1 - op2 as 128-bit integer
139 /// @param[out] result_hi Stores high 64 bits of result
140 /// @param[out] result_lo Stores low 64 bits of result
SubWithCarry128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op1_hi,const uint64_t op1_lo,const uint64_t op2_hi,const uint64_t op2_lo)141 inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
142 const uint64_t op1_hi, const uint64_t op1_lo,
143 const uint64_t op2_hi, const uint64_t op2_lo) {
144 HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
145 HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
146
147 unsigned char borrow;
148
149 // first 64bit block
150 *result_lo = op1_lo - op2_lo;
151 borrow = static_cast<unsigned char>(op2_lo > op1_lo);
152
153 // second 64bit block
154 _subborrow_u64(borrow, op1_hi, op2_hi, result_hi);
155 }
156
157 /// @brief Computes and returns significant bit count
158 /// @param[in] value Input element at most 128 bits long
SignificantBitLength(const uint64_t * value)159 inline uint64_t SignificantBitLength(const uint64_t* value) {
160 HEXL_CHECK(value != nullptr, "Require value != nullptr");
161
162 unsigned long count = 0; // NOLINT(runtime/int)
163
164 // second 64bit block
165 _BitScanReverse64(&count, *(value + 1));
166 if (count >= 0 && *(value + 1) > 0) {
167 return static_cast<uint64_t>(count) + 1 + 64;
168 }
169
170 // first 64bit block
171 _BitScanReverse64(&count, *value);
172 if (count >= 0 && *(value) > 0) {
173 return static_cast<uint64_t>(count) + 1;
174 }
175 return 0;
176 }
177
178 /// @brief Checks if input is negative number
179 /// @param[in] input Input element to check for sign
CheckSign(const uint64_t * input)180 inline bool CheckSign(const uint64_t* input) {
181 HEXL_CHECK(input != nullptr, "Require input != nullptr");
182
183 uint64_t input_temp[2]{0, 0};
184 RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127);
185 return (input_temp[0] == 1);
186 }
187
188 /// @brief Divides numerator by denominator
189 /// @param[out] quotient Stores quotient as two 64-bit blocks after division
190 /// @param[in] numerator
191 /// @param[in] denominator
DivideUInt128UInt64(uint64_t * quotient,const uint64_t * numerator,const uint64_t denominator)192 inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator,
193 const uint64_t denominator) {
194 HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr");
195 HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr");
196 HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator);
197
198 // get bit count of divisor
199 uint64_t numerator_bits = SignificantBitLength(numerator);
200 const uint64_t numerator_bits_const = numerator_bits;
201 const uint64_t uint_128_bit = 128ULL;
202
203 uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000};
204 uint64_t remainder[2]{0, 0};
205 uint64_t quotient_temp[2]{0, 0};
206 uint64_t denominator_temp[2]{denominator, 0};
207
208 quotient[0] = numerator[0];
209 quotient[1] = numerator[1];
210
211 // align numerator
212 LeftShift128("ient[1], "ient[0], quotient[1], quotient[0],
213 (uint_128_bit - numerator_bits_const));
214
215 while (numerator_bits) {
216 // if remainder is negative
217 if (CheckSign(remainder)) {
218 LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
219 RightShift128("ient_temp[1], "ient_temp[0], quotient[1],
220 quotient[0], (uint_128_bit - 1));
221 remainder[0] = remainder[0] | quotient_temp[0];
222 LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1);
223 // remainder=remainder+denominator_temp
224 AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
225 denominator_temp[1], denominator_temp[0]);
226 } else { // if remainder is positive
227 LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
228 RightShift128("ient_temp[1], "ient_temp[0], quotient[1],
229 quotient[0], (uint_128_bit - 1));
230 remainder[0] = remainder[0] | quotient_temp[0];
231 LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1);
232 // remainder=remainder-denominator_temp
233 SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
234 denominator_temp[1], denominator_temp[0]);
235 }
236
237 // if remainder is positive set MSB of quotient[0]=1
238 if (!CheckSign(remainder)) {
239 MASK[0] = 0x0000000000000001;
240 MASK[1] = 0x0000000000000000;
241 LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0],
242 (uint_128_bit - numerator_bits_const));
243 quotient[0] = quotient[0] | MASK[0];
244 quotient[1] = quotient[1] | MASK[1];
245 }
246 quotient_temp[0] = 0;
247 quotient_temp[1] = 0;
248 numerator_bits--;
249 }
250
251 if (CheckSign(remainder)) {
252 // remainder=remainder+denominator_temp
253 AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
254 denominator_temp[1], denominator_temp[0]);
255 }
256 RightShift128("ient[1], "ient[0], quotient[1], quotient[0],
257 (uint_128_bit - numerator_bits_const));
258 }
259
260 /// @brief Returns low of dividing numerator by denominator
261 /// @param[in] numerator_hi Stores high 64 bit of numerator
262 /// @param[in] numerator_lo Stores low 64 bit of numerator
263 /// @param[in] denominator Stores denominator
DivideUInt128UInt64Lo(const uint64_t numerator_hi,const uint64_t numerator_lo,const uint64_t denominator)264 inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi,
265 const uint64_t numerator_lo,
266 const uint64_t denominator) {
267 uint64_t numerator[2]{numerator_lo, numerator_hi};
268 uint64_t quotient[2]{0, 0};
269
270 DivideUInt128UInt64(quotient, numerator, denominator);
271 return quotient[0];
272 }
273
274 // Returns most-significant bit of the input
MSB(uint64_t input)275 inline uint64_t MSB(uint64_t input) {
276 unsigned long index{0}; // NOLINT(runtime/int)
277 _BitScanReverse64(&index, input);
278 return index;
279 }
280
281 #define HEXL_LOOP_UNROLL_4 \
282 {}
283 #define HEXL_LOOP_UNROLL_8 \
284 {}
285
286 #endif
287
288 } // namespace hexl
289 } // namespace intel
290