1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #ifdef HEXL_USE_MSVC
7 
8 #define NOMINMAX  // Avoid errors with std::min/std::max
9 #undef min
10 #undef max
11 
12 #include <immintrin.h>
13 #include <intrin.h>
14 #include <stdint.h>
15 
16 #include <cmath>
17 
18 #include "hexl/util/check.hpp"
19 
20 #pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \
21                   _umul128)
22 
23 #undef TRUE
24 #undef FALSE
25 
26 namespace intel {
27 namespace hexl {
28 
BarrettReduce128(uint64_t input_hi,uint64_t input_lo,uint64_t modulus)29 inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo,
30                                  uint64_t modulus) {
31   HEXL_CHECK(modulus != 0, "modulus == 0")
32   uint64_t remainder;
33   _udiv128(input_hi, input_lo, modulus, &remainder);
34 
35   return remainder;
36 }
37 
38 // Multiplies x * y as 128-bit integer.
39 // @param prod_hi Stores high 64 bits of product
40 // @param prod_lo Stores low 64 bits of product
MultiplyUInt64(uint64_t x,uint64_t y,uint64_t * prod_hi,uint64_t * prod_lo)41 inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi,
42                            uint64_t* prod_lo) {
43   *prod_lo = _umul128(x, y, prod_hi);
44 }
45 
46 // Return the high 128 minus BitShift bits of the 128-bit product x * y
47 template <int BitShift>
MultiplyUInt64Hi(uint64_t x,uint64_t y)48 inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) {
49   HEXL_CHECK(BitShift == 52 || BitShift == 64,
50              "Invalid BitShift " << BitShift << "; expected 52 or 64");
51   uint64_t prod_hi;
52   uint64_t prod_lo = _umul128(x, y, &prod_hi);
53   uint64_t result_hi;
54   uint64_t result_lo;
55   RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift);
56   return result_lo;
57 }
58 
59 /// @brief Computes Left Shift op as 128-bit unsigned integer
60 /// @param[out] result_hi Stores high 64 bits of result
61 /// @param[out] result_lo Stores low 64 bits of result
62 /// @param[in] op_hi Stores high 64 bits of input
63 /// @param[in] op_lo Stores low 64 bits of input
LeftShift128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op_hi,const uint64_t op_lo,const uint64_t shift_value)64 inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo,
65                          const uint64_t op_hi, const uint64_t op_lo,
66                          const uint64_t shift_value) {
67   HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
68   HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
69   HEXL_CHECK(shift_value <= 128,
70              "shift_value cannot be greater than 128 " << shift_value);
71 
72   if (shift_value == 0) {
73     *result_hi = op_hi;
74     *result_lo = op_lo;
75   } else if (shift_value == 64) {
76     *result_hi = op_lo;
77     *result_lo = 0ULL;
78   } else if (shift_value == 128) {
79     *result_hi = 0ULL;
80     *result_lo = 0ULL;
81   } else if (shift_value >= 1 && shift_value <= 63) {
82     *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value));
83     *result_lo = op_lo << shift_value;
84   } else if (shift_value >= 65 && shift_value < 128) {
85     *result_hi = op_lo << (shift_value - 64);
86     *result_lo = 0ULL;
87   }
88 }
89 
90 /// @brief Computes Right Shift op as 128-bit unsigned integer
91 /// @param[out] result_hi Stores high 64 bits of result
92 /// @param[out] result_lo Stores low 64 bits of result
93 /// @param[in] op_hi Stores high 64 bits of input
94 /// @param[in] op_lo Stores low 64 bits of input
RightShift128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op_hi,const uint64_t op_lo,const uint64_t shift_value)95 inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo,
96                           const uint64_t op_hi, const uint64_t op_lo,
97                           const uint64_t shift_value) {
98   HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
99   HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
100   HEXL_CHECK(shift_value <= 128,
101              "shift_value cannot be greater than 128 " << shift_value);
102 
103   if (shift_value == 0) {
104     *result_hi = op_hi;
105     *result_lo = op_lo;
106   } else if (shift_value == 64) {
107     *result_hi = 0ULL;
108     *result_lo = op_hi;
109   } else if (shift_value == 128) {
110     *result_hi = 0ULL;
111     *result_lo = 0ULL;
112   } else if (shift_value >= 1 && shift_value <= 63) {
113     *result_hi = op_hi >> shift_value;
114     *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value);
115   } else if (shift_value >= 65 && shift_value < 128) {
116     *result_hi = 0ULL;
117     *result_lo = op_hi >> (shift_value - 64);
118   }
119 }
120 
121 /// @brief Adds op1 + op2 as 128-bit integer
122 /// @param[out] result_hi Stores high 64 bits of result
123 /// @param[out] result_lo Stores low 64 bits of result
AddWithCarry128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op1_hi,const uint64_t op1_lo,const uint64_t op2_hi,const uint64_t op2_lo)124 inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
125                             const uint64_t op1_hi, const uint64_t op1_lo,
126                             const uint64_t op2_hi, const uint64_t op2_lo) {
127   HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
128   HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
129 
130   // first 64bit block
131   *result_lo = op1_lo + op2_lo;
132   unsigned char carry = static_cast<unsigned char>(*result_lo < op1_lo);
133 
134   // second 64bit block
135   _addcarry_u64(carry, op1_hi, op2_hi, result_hi);
136 }
137 
138 /// @brief Subtracts op1 - op2 as 128-bit integer
139 /// @param[out] result_hi Stores high 64 bits of result
140 /// @param[out] result_lo Stores low 64 bits of result
SubWithCarry128(uint64_t * result_hi,uint64_t * result_lo,const uint64_t op1_hi,const uint64_t op1_lo,const uint64_t op2_hi,const uint64_t op2_lo)141 inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
142                             const uint64_t op1_hi, const uint64_t op1_lo,
143                             const uint64_t op2_hi, const uint64_t op2_lo) {
144   HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
145   HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
146 
147   unsigned char borrow;
148 
149   // first 64bit block
150   *result_lo = op1_lo - op2_lo;
151   borrow = static_cast<unsigned char>(op2_lo > op1_lo);
152 
153   // second 64bit block
154   _subborrow_u64(borrow, op1_hi, op2_hi, result_hi);
155 }
156 
157 /// @brief Computes and returns significant bit count
158 /// @param[in] value Input element at most 128 bits long
SignificantBitLength(const uint64_t * value)159 inline uint64_t SignificantBitLength(const uint64_t* value) {
160   HEXL_CHECK(value != nullptr, "Require value != nullptr");
161 
162   unsigned long count = 0;  // NOLINT(runtime/int)
163 
164   // second 64bit block
165   _BitScanReverse64(&count, *(value + 1));
166   if (count >= 0 && *(value + 1) > 0) {
167     return static_cast<uint64_t>(count) + 1 + 64;
168   }
169 
170   // first 64bit block
171   _BitScanReverse64(&count, *value);
172   if (count >= 0 && *(value) > 0) {
173     return static_cast<uint64_t>(count) + 1;
174   }
175   return 0;
176 }
177 
178 /// @brief Checks if input is negative number
179 /// @param[in] input Input element to check for sign
CheckSign(const uint64_t * input)180 inline bool CheckSign(const uint64_t* input) {
181   HEXL_CHECK(input != nullptr, "Require input != nullptr");
182 
183   uint64_t input_temp[2]{0, 0};
184   RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127);
185   return (input_temp[0] == 1);
186 }
187 
188 /// @brief Divides numerator by denominator
189 /// @param[out] quotient Stores quotient as two 64-bit blocks after division
190 /// @param[in] numerator
191 /// @param[in] denominator
DivideUInt128UInt64(uint64_t * quotient,const uint64_t * numerator,const uint64_t denominator)192 inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator,
193                                 const uint64_t denominator) {
194   HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr");
195   HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr");
196   HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator);
197 
198   // get bit count of divisor
199   uint64_t numerator_bits = SignificantBitLength(numerator);
200   const uint64_t numerator_bits_const = numerator_bits;
201   const uint64_t uint_128_bit = 128ULL;
202 
203   uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000};
204   uint64_t remainder[2]{0, 0};
205   uint64_t quotient_temp[2]{0, 0};
206   uint64_t denominator_temp[2]{denominator, 0};
207 
208   quotient[0] = numerator[0];
209   quotient[1] = numerator[1];
210 
211   // align numerator
212   LeftShift128(&quotient[1], &quotient[0], quotient[1], quotient[0],
213                (uint_128_bit - numerator_bits_const));
214 
215   while (numerator_bits) {
216     // if remainder is negative
217     if (CheckSign(remainder)) {
218       LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
219       RightShift128(&quotient_temp[1], &quotient_temp[0], quotient[1],
220                     quotient[0], (uint_128_bit - 1));
221       remainder[0] = remainder[0] | quotient_temp[0];
222       LeftShift128(&quotient[1], &quotient[0], quotient[1], quotient[0], 1);
223       // remainder=remainder+denominator_temp
224       AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
225                       denominator_temp[1], denominator_temp[0]);
226     } else {  // if remainder is positive
227       LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
228       RightShift128(&quotient_temp[1], &quotient_temp[0], quotient[1],
229                     quotient[0], (uint_128_bit - 1));
230       remainder[0] = remainder[0] | quotient_temp[0];
231       LeftShift128(&quotient[1], &quotient[0], quotient[1], quotient[0], 1);
232       // remainder=remainder-denominator_temp
233       SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
234                       denominator_temp[1], denominator_temp[0]);
235     }
236 
237     // if remainder is positive set MSB of quotient[0]=1
238     if (!CheckSign(remainder)) {
239       MASK[0] = 0x0000000000000001;
240       MASK[1] = 0x0000000000000000;
241       LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0],
242                    (uint_128_bit - numerator_bits_const));
243       quotient[0] = quotient[0] | MASK[0];
244       quotient[1] = quotient[1] | MASK[1];
245     }
246     quotient_temp[0] = 0;
247     quotient_temp[1] = 0;
248     numerator_bits--;
249   }
250 
251   if (CheckSign(remainder)) {
252     // remainder=remainder+denominator_temp
253     AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
254                     denominator_temp[1], denominator_temp[0]);
255   }
256   RightShift128(&quotient[1], &quotient[0], quotient[1], quotient[0],
257                 (uint_128_bit - numerator_bits_const));
258 }
259 
260 /// @brief Returns low of dividing numerator by denominator
261 /// @param[in] numerator_hi Stores high 64 bit of numerator
262 /// @param[in] numerator_lo Stores low 64 bit of numerator
263 /// @param[in] denominator Stores denominator
DivideUInt128UInt64Lo(const uint64_t numerator_hi,const uint64_t numerator_lo,const uint64_t denominator)264 inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi,
265                                       const uint64_t numerator_lo,
266                                       const uint64_t denominator) {
267   uint64_t numerator[2]{numerator_lo, numerator_hi};
268   uint64_t quotient[2]{0, 0};
269 
270   DivideUInt128UInt64(quotient, numerator, denominator);
271   return quotient[0];
272 }
273 
274 // Returns most-significant bit of the input
MSB(uint64_t input)275 inline uint64_t MSB(uint64_t input) {
276   unsigned long index{0};  // NOLINT(runtime/int)
277   _BitScanReverse64(&index, input);
278   return index;
279 }
280 
281 #define HEXL_LOOP_UNROLL_4 \
282   {}
283 #define HEXL_LOOP_UNROLL_8 \
284   {}
285 
286 #endif
287 
288 }  // namespace hexl
289 }  // namespace intel
290