1 // Copyright (c) Microsoft Corporation. All rights reserved. 2 // Licensed under the MIT license. 3 4 #include "seal/util/galois.h" 5 #include "seal/util/numth.h" 6 #include "seal/util/uintcore.h" 7 8 using namespace std; 9 10 namespace seal 11 { 12 namespace util 13 { 14 // Required for C++14 compliance: static constexpr member variables are not necessarily inlined so need to 15 // ensure symbol is created. 16 constexpr uint32_t GaloisTool::generator_; 17 generate_table_ntt(uint32_t galois_elt,Pointer<uint32_t> & result) const18 void GaloisTool::generate_table_ntt(uint32_t galois_elt, Pointer<uint32_t> &result) const 19 { 20 #ifdef SEAL_DEBUG 21 if (!(galois_elt & 1) || (galois_elt >= 2 * (uint64_t(1) << coeff_count_power_))) 22 { 23 throw invalid_argument("Galois element is not valid"); 24 } 25 #endif 26 ReaderLock reader_lock(permutation_tables_locker_.acquire_read()); 27 if (result) 28 { 29 return; 30 } 31 reader_lock.unlock(); 32 33 auto temp(allocate<uint32_t>(coeff_count_, pool_)); 34 auto temp_ptr = temp.get(); 35 36 uint32_t coeff_count_minus_one = safe_cast<uint32_t>(coeff_count_) - 1; 37 for (size_t i = coeff_count_; i < coeff_count_ << 1; i++) 38 { 39 uint32_t reversed = reverse_bits<uint32_t>(safe_cast<uint32_t>(i), coeff_count_power_ + 1); 40 uint64_t index_raw = (static_cast<uint64_t>(galois_elt) * static_cast<uint64_t>(reversed)) >> 1; 41 index_raw &= static_cast<uint64_t>(coeff_count_minus_one); 42 *temp_ptr++ = reverse_bits<uint32_t>(static_cast<uint32_t>(index_raw), coeff_count_power_); 43 } 44 45 WriterLock writer_lock(permutation_tables_locker_.acquire_write()); 46 if (result) 47 { 48 return; 49 } 50 result.acquire(move(temp)); 51 } 52 get_elt_from_step(int step) const53 uint32_t GaloisTool::get_elt_from_step(int step) const 54 { 55 uint32_t n = safe_cast<uint32_t>(coeff_count_); 56 uint32_t m32 = mul_safe(n, uint32_t(2)); 57 uint64_t m = static_cast<uint64_t>(m32); 58 59 if (step == 0) 60 { 61 return static_cast<uint32_t>(m - 1); 62 } 63 else 64 { 65 // Extract sign of steps. When steps is positive, the rotation 66 // is to the left; when steps is negative, it is to the right. 67 bool sign = step < 0; 68 uint32_t pos_step = safe_cast<uint32_t>(abs(step)); 69 70 if (pos_step >= (n >> 1)) 71 { 72 throw invalid_argument("step count too large"); 73 } 74 75 pos_step &= m32 - 1; 76 if (sign) 77 { 78 step = safe_cast<int>(n >> 1) - safe_cast<int>(pos_step); 79 } 80 else 81 { 82 step = safe_cast<int>(pos_step); 83 } 84 85 // Construct Galois element for row rotation 86 uint64_t gen = static_cast<uint64_t>(generator_); 87 uint64_t galois_elt = 1; 88 while (step--) 89 { 90 galois_elt *= gen; 91 galois_elt &= m - 1; 92 } 93 return static_cast<uint32_t>(galois_elt); 94 } 95 } 96 get_elts_from_steps(const vector<int> & steps) const97 vector<uint32_t> GaloisTool::get_elts_from_steps(const vector<int> &steps) const 98 { 99 vector<uint32_t> galois_elts; 100 transform(steps.begin(), steps.end(), back_inserter(galois_elts), [&](auto s) { 101 return this->get_elt_from_step(s); 102 }); 103 return galois_elts; 104 } 105 get_elts_all() const106 vector<uint32_t> GaloisTool::get_elts_all() const noexcept 107 { 108 uint32_t m = safe_cast<uint32_t>(static_cast<uint64_t>(coeff_count_) << 1); 109 vector<uint32_t> galois_elts{}; 110 111 // Generate Galois keys for m - 1 (X -> X^{m-1}) 112 galois_elts.push_back(m - 1); 113 114 // Generate Galois key for power of generator_ mod m (X -> X^{3^k}) and 115 // for negative power of generator_ mod m (X -> X^{-3^k}) 116 uint64_t pos_power = generator_; 117 uint64_t neg_power = 0; 118 try_invert_uint_mod(generator_, m, neg_power); 119 for (int i = 0; i < coeff_count_power_ - 1; i++) 120 { 121 galois_elts.push_back(static_cast<uint32_t>(pos_power)); 122 pos_power *= pos_power; 123 pos_power &= (m - 1); 124 125 galois_elts.push_back(static_cast<uint32_t>(neg_power)); 126 neg_power *= neg_power; 127 neg_power &= (m - 1); 128 } 129 130 return galois_elts; 131 } 132 initialize(int coeff_count_power)133 void GaloisTool::initialize(int coeff_count_power) 134 { 135 if ((coeff_count_power < get_power_of_two(SEAL_POLY_MOD_DEGREE_MIN)) || 136 coeff_count_power > get_power_of_two(SEAL_POLY_MOD_DEGREE_MAX)) 137 { 138 throw invalid_argument("coeff_count_power out of range"); 139 } 140 141 coeff_count_power_ = coeff_count_power; 142 coeff_count_ = size_t(1) << coeff_count_power_; 143 144 // Capacity for coeff_count_ number of tables 145 permutation_tables_ = allocate<Pointer<uint32_t>>(coeff_count_, pool_); 146 } 147 apply_galois(ConstCoeffIter operand,uint32_t galois_elt,const Modulus & modulus,CoeffIter result) const148 void GaloisTool::apply_galois( 149 ConstCoeffIter operand, uint32_t galois_elt, const Modulus &modulus, CoeffIter result) const 150 { 151 #ifdef SEAL_DEBUG 152 if (!operand) 153 { 154 throw invalid_argument("operand"); 155 } 156 if (!result) 157 { 158 throw invalid_argument("result"); 159 } 160 if (operand == result) 161 { 162 throw invalid_argument("result cannot point to the same value as operand"); 163 } 164 // Verify coprime conditions. 165 if (!(galois_elt & 1) || (galois_elt >= 2 * (uint64_t(1) << coeff_count_power_))) 166 { 167 throw invalid_argument("Galois element is not valid"); 168 } 169 if (modulus.is_zero()) 170 { 171 throw invalid_argument("modulus"); 172 } 173 #endif 174 const uint64_t modulus_value = modulus.value(); 175 const uint64_t coeff_count_minus_one = coeff_count_ - 1; 176 uint64_t index_raw = 0; 177 for (uint64_t i = 0; i <= coeff_count_minus_one; i++, ++operand, index_raw += galois_elt) 178 { 179 uint64_t index = index_raw & coeff_count_minus_one; 180 uint64_t result_value = *operand; 181 if ((index_raw >> coeff_count_power_) & 1) 182 { 183 // Explicit inline 184 // result[index] = negate_uint_mod(result[index], modulus); 185 int64_t non_zero = (result_value != 0); 186 result_value = (modulus_value - result_value) & static_cast<uint64_t>(-non_zero); 187 } 188 result[index] = result_value; 189 } 190 } 191 apply_galois_ntt(ConstCoeffIter operand,uint32_t galois_elt,CoeffIter result) const192 void GaloisTool::apply_galois_ntt(ConstCoeffIter operand, uint32_t galois_elt, CoeffIter result) const 193 { 194 #ifdef SEAL_DEBUG 195 if (!operand) 196 { 197 throw invalid_argument("operand"); 198 } 199 if (!result) 200 { 201 throw invalid_argument("result"); 202 } 203 if (operand == result) 204 { 205 throw invalid_argument("result cannot point to the same value as operand"); 206 } 207 // Verify coprime conditions. 208 if (!(galois_elt & 1) || (galois_elt >= 2 * (uint64_t(1) << coeff_count_power_))) 209 { 210 throw invalid_argument("Galois element is not valid"); 211 } 212 #endif 213 generate_table_ntt(galois_elt, permutation_tables_[GetIndexFromElt(galois_elt)]); 214 auto table = iter(permutation_tables_[GetIndexFromElt(galois_elt)]); 215 216 // Perform permutation. 217 SEAL_ITERATE(iter(table, result), coeff_count_, [&](auto I) { get<1>(I) = operand[get<0>(I)]; }); 218 } 219 } // namespace util 220 } // namespace seal 221