1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/util/galois.h"
5 #include "seal/util/numth.h"
6 #include "seal/util/uintcore.h"
7 
8 using namespace std;
9 
10 namespace seal
11 {
12     namespace util
13     {
14         // Required for C++14 compliance: static constexpr member variables are not necessarily inlined so need to
15         // ensure symbol is created.
16         constexpr uint32_t GaloisTool::generator_;
17 
generate_table_ntt(uint32_t galois_elt,Pointer<uint32_t> & result) const18         void GaloisTool::generate_table_ntt(uint32_t galois_elt, Pointer<uint32_t> &result) const
19         {
20 #ifdef SEAL_DEBUG
21             if (!(galois_elt & 1) || (galois_elt >= 2 * (uint64_t(1) << coeff_count_power_)))
22             {
23                 throw invalid_argument("Galois element is not valid");
24             }
25 #endif
26             ReaderLock reader_lock(permutation_tables_locker_.acquire_read());
27             if (result)
28             {
29                 return;
30             }
31             reader_lock.unlock();
32 
33             auto temp(allocate<uint32_t>(coeff_count_, pool_));
34             auto temp_ptr = temp.get();
35 
36             uint32_t coeff_count_minus_one = safe_cast<uint32_t>(coeff_count_) - 1;
37             for (size_t i = coeff_count_; i < coeff_count_ << 1; i++)
38             {
39                 uint32_t reversed = reverse_bits<uint32_t>(safe_cast<uint32_t>(i), coeff_count_power_ + 1);
40                 uint64_t index_raw = (static_cast<uint64_t>(galois_elt) * static_cast<uint64_t>(reversed)) >> 1;
41                 index_raw &= static_cast<uint64_t>(coeff_count_minus_one);
42                 *temp_ptr++ = reverse_bits<uint32_t>(static_cast<uint32_t>(index_raw), coeff_count_power_);
43             }
44 
45             WriterLock writer_lock(permutation_tables_locker_.acquire_write());
46             if (result)
47             {
48                 return;
49             }
50             result.acquire(move(temp));
51         }
52 
get_elt_from_step(int step) const53         uint32_t GaloisTool::get_elt_from_step(int step) const
54         {
55             uint32_t n = safe_cast<uint32_t>(coeff_count_);
56             uint32_t m32 = mul_safe(n, uint32_t(2));
57             uint64_t m = static_cast<uint64_t>(m32);
58 
59             if (step == 0)
60             {
61                 return static_cast<uint32_t>(m - 1);
62             }
63             else
64             {
65                 // Extract sign of steps. When steps is positive, the rotation
66                 // is to the left; when steps is negative, it is to the right.
67                 bool sign = step < 0;
68                 uint32_t pos_step = safe_cast<uint32_t>(abs(step));
69 
70                 if (pos_step >= (n >> 1))
71                 {
72                     throw invalid_argument("step count too large");
73                 }
74 
75                 pos_step &= m32 - 1;
76                 if (sign)
77                 {
78                     step = safe_cast<int>(n >> 1) - safe_cast<int>(pos_step);
79                 }
80                 else
81                 {
82                     step = safe_cast<int>(pos_step);
83                 }
84 
85                 // Construct Galois element for row rotation
86                 uint64_t gen = static_cast<uint64_t>(generator_);
87                 uint64_t galois_elt = 1;
88                 while (step--)
89                 {
90                     galois_elt *= gen;
91                     galois_elt &= m - 1;
92                 }
93                 return static_cast<uint32_t>(galois_elt);
94             }
95         }
96 
get_elts_from_steps(const vector<int> & steps) const97         vector<uint32_t> GaloisTool::get_elts_from_steps(const vector<int> &steps) const
98         {
99             vector<uint32_t> galois_elts;
100             transform(steps.begin(), steps.end(), back_inserter(galois_elts), [&](auto s) {
101                 return this->get_elt_from_step(s);
102             });
103             return galois_elts;
104         }
105 
get_elts_all() const106         vector<uint32_t> GaloisTool::get_elts_all() const noexcept
107         {
108             uint32_t m = safe_cast<uint32_t>(static_cast<uint64_t>(coeff_count_) << 1);
109             vector<uint32_t> galois_elts{};
110 
111             // Generate Galois keys for m - 1 (X -> X^{m-1})
112             galois_elts.push_back(m - 1);
113 
114             // Generate Galois key for power of generator_ mod m (X -> X^{3^k}) and
115             // for negative power of generator_ mod m (X -> X^{-3^k})
116             uint64_t pos_power = generator_;
117             uint64_t neg_power = 0;
118             try_invert_uint_mod(generator_, m, neg_power);
119             for (int i = 0; i < coeff_count_power_ - 1; i++)
120             {
121                 galois_elts.push_back(static_cast<uint32_t>(pos_power));
122                 pos_power *= pos_power;
123                 pos_power &= (m - 1);
124 
125                 galois_elts.push_back(static_cast<uint32_t>(neg_power));
126                 neg_power *= neg_power;
127                 neg_power &= (m - 1);
128             }
129 
130             return galois_elts;
131         }
132 
initialize(int coeff_count_power)133         void GaloisTool::initialize(int coeff_count_power)
134         {
135             if ((coeff_count_power < get_power_of_two(SEAL_POLY_MOD_DEGREE_MIN)) ||
136                 coeff_count_power > get_power_of_two(SEAL_POLY_MOD_DEGREE_MAX))
137             {
138                 throw invalid_argument("coeff_count_power out of range");
139             }
140 
141             coeff_count_power_ = coeff_count_power;
142             coeff_count_ = size_t(1) << coeff_count_power_;
143 
144             // Capacity for coeff_count_ number of tables
145             permutation_tables_ = allocate<Pointer<uint32_t>>(coeff_count_, pool_);
146         }
147 
apply_galois(ConstCoeffIter operand,uint32_t galois_elt,const Modulus & modulus,CoeffIter result) const148         void GaloisTool::apply_galois(
149             ConstCoeffIter operand, uint32_t galois_elt, const Modulus &modulus, CoeffIter result) const
150         {
151 #ifdef SEAL_DEBUG
152             if (!operand)
153             {
154                 throw invalid_argument("operand");
155             }
156             if (!result)
157             {
158                 throw invalid_argument("result");
159             }
160             if (operand == result)
161             {
162                 throw invalid_argument("result cannot point to the same value as operand");
163             }
164             // Verify coprime conditions.
165             if (!(galois_elt & 1) || (galois_elt >= 2 * (uint64_t(1) << coeff_count_power_)))
166             {
167                 throw invalid_argument("Galois element is not valid");
168             }
169             if (modulus.is_zero())
170             {
171                 throw invalid_argument("modulus");
172             }
173 #endif
174             const uint64_t modulus_value = modulus.value();
175             const uint64_t coeff_count_minus_one = coeff_count_ - 1;
176             uint64_t index_raw = 0;
177             for (uint64_t i = 0; i <= coeff_count_minus_one; i++, ++operand, index_raw += galois_elt)
178             {
179                 uint64_t index = index_raw & coeff_count_minus_one;
180                 uint64_t result_value = *operand;
181                 if ((index_raw >> coeff_count_power_) & 1)
182                 {
183                     // Explicit inline
184                     // result[index] = negate_uint_mod(result[index], modulus);
185                     int64_t non_zero = (result_value != 0);
186                     result_value = (modulus_value - result_value) & static_cast<uint64_t>(-non_zero);
187                 }
188                 result[index] = result_value;
189             }
190         }
191 
apply_galois_ntt(ConstCoeffIter operand,uint32_t galois_elt,CoeffIter result) const192         void GaloisTool::apply_galois_ntt(ConstCoeffIter operand, uint32_t galois_elt, CoeffIter result) const
193         {
194 #ifdef SEAL_DEBUG
195             if (!operand)
196             {
197                 throw invalid_argument("operand");
198             }
199             if (!result)
200             {
201                 throw invalid_argument("result");
202             }
203             if (operand == result)
204             {
205                 throw invalid_argument("result cannot point to the same value as operand");
206             }
207             // Verify coprime conditions.
208             if (!(galois_elt & 1) || (galois_elt >= 2 * (uint64_t(1) << coeff_count_power_)))
209             {
210                 throw invalid_argument("Galois element is not valid");
211             }
212 #endif
213             generate_table_ntt(galois_elt, permutation_tables_[GetIndexFromElt(galois_elt)]);
214             auto table = iter(permutation_tables_[GetIndexFromElt(galois_elt)]);
215 
216             // Perform permutation.
217             SEAL_ITERATE(iter(table, result), coeff_count_, [&](auto I) { get<1>(I) = operand[get<0>(I)]; });
218         }
219     } // namespace util
220 } // namespace seal
221