1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #if (SEAL_COMPILER == SEAL_COMPILER_GCC)
5 #pragma GCC diagnostic push
6 #pragma GCC diagnostic ignored "-Wconversion"
7 #elif (SEAL_COMPILER == SEAL_COMPILER_CLANG)
8 #pragma clang diagnostic push
9 #pragma clang diagnostic ignored "-Wconversion"
10 #endif
11 #include "benchmark/benchmark.h"
12 #if (SEAL_COMPILER == SEAL_COMPILER_GCC)
13 #pragma GCC diagnostic pop
14 #elif (SEAL_COMPILER == SEAL_COMPILER_CLANG)
15 #pragma clang diagnostic pop
16 #endif
17 
18 #include "seal/seal.h"
19 #include "seal/util/rlwe.h"
20 
21 namespace sealbench
22 {
23     /**
24     Class BMEnv contains a set of required precomputed/preconstructed objects to setup a benchmark case.
25     A global BMEnv object is only initialized when a benchmark case for a EncryptionParameters is requested.
26     Since benchmark cases for the same parameters are registered together, this avoids heavy precomputation.
27     */
28     class BMEnv
29     {
30     public:
31         BMEnv() = delete;
32 
33         // Allow insecure parameters for experimental purposes.
34         // DO NOT USE THIS AS AN EXAMPLE.
BMEnv(const seal::EncryptionParameters & parms)35         BMEnv(const seal::EncryptionParameters &parms)
36             : parms_(parms), context_(parms_, true, seal::sec_level_type::none)
37         {
38             keygen_ = std::make_shared<seal::KeyGenerator>(context_);
39             sk_ = keygen_->secret_key();
40             keygen_->create_public_key(pk_);
41             if (context_.using_keyswitching())
42             {
43                 keygen_->create_relin_keys(rlk_);
44                 galois_elts_all_ = context_.key_context_data()->galois_tool()->get_elts_from_steps({ 1 });
45                 galois_elts_all_.emplace_back(2 * static_cast<uint32_t>(parms_.poly_modulus_degree()) - 1);
46                 // galois_elts_all_ = context_.key_context_data()->galois_tool()->get_elts_all();
47                 keygen_->create_galois_keys(galois_elts_all_, glk_);
48             }
49 
50             encryptor_ = std::make_shared<seal::Encryptor>(context_, pk_, sk_);
51             decryptor_ = std::make_shared<seal::Decryptor>(context_, sk_);
52             if (parms_.scheme() == seal::scheme_type::bfv)
53             {
54                 batch_encoder_ = std::make_shared<seal::BatchEncoder>(context_);
55             }
56             else if (parms_.scheme() == seal::scheme_type::ckks)
57             {
58                 ckks_encoder_ = std::make_shared<seal::CKKSEncoder>(context_);
59             }
60             evaluator_ = std::make_shared<seal::Evaluator>(context_);
61 
62             pt_.resize(std::size_t(2));
63             for (std::size_t i = 0; i < 2; i++)
64             {
65                 pt_[i].resize(parms_.poly_modulus_degree());
66             }
67 
68             ct_.resize(std::size_t(3));
69             for (std::size_t i = 0; i < 3; i++)
70             {
71                 ct_[i].resize(context_, std::size_t(2));
72             }
73         }
74 
75         /**
76         Getter methods.
77         */
parms()78         SEAL_NODISCARD const seal::EncryptionParameters &parms() const
79         {
80             return parms_;
81         }
82 
context()83         SEAL_NODISCARD const seal::SEALContext &context() const
84         {
85             return context_;
86         }
87 
keygen()88         SEAL_NODISCARD std::shared_ptr<seal::KeyGenerator> keygen()
89         {
90             return keygen_;
91         }
92 
encryptor()93         SEAL_NODISCARD std::shared_ptr<seal::Encryptor> encryptor()
94         {
95             return encryptor_;
96         }
97 
decryptor()98         SEAL_NODISCARD std::shared_ptr<seal::Decryptor> decryptor()
99         {
100             return decryptor_;
101         }
102 
batch_encoder()103         SEAL_NODISCARD std::shared_ptr<seal::BatchEncoder> batch_encoder()
104         {
105             return batch_encoder_;
106         }
107 
ckks_encoder()108         SEAL_NODISCARD std::shared_ptr<seal::CKKSEncoder> ckks_encoder()
109         {
110             return ckks_encoder_;
111         }
112 
evaluator()113         SEAL_NODISCARD std::shared_ptr<seal::Evaluator> evaluator()
114         {
115             return evaluator_;
116         }
117 
sk()118         SEAL_NODISCARD seal::SecretKey &sk()
119         {
120             return sk_;
121         }
122 
sk()123         SEAL_NODISCARD const seal::SecretKey &sk() const
124         {
125             return sk_;
126         }
127 
pk()128         SEAL_NODISCARD seal::PublicKey &pk()
129         {
130             return pk_;
131         }
132 
pk()133         SEAL_NODISCARD const seal::PublicKey &pk() const
134         {
135             return pk_;
136         }
137 
rlk()138         SEAL_NODISCARD seal::RelinKeys &rlk()
139         {
140             return rlk_;
141         }
142 
rlk()143         SEAL_NODISCARD const seal::RelinKeys &rlk() const
144         {
145             return rlk_;
146         }
147 
glk()148         SEAL_NODISCARD seal::GaloisKeys &glk()
149         {
150             return glk_;
151         }
152 
glk()153         SEAL_NODISCARD const seal::GaloisKeys &glk() const
154         {
155             return glk_;
156         }
157 
galois_elts_all()158         SEAL_NODISCARD const std::vector<std::uint32_t> &galois_elts_all() const
159         {
160             return galois_elts_all_;
161         }
162 
msg_uint64()163         SEAL_NODISCARD std::vector<std::uint64_t> &msg_uint64()
164         {
165             return msg_uint64_;
166         }
167 
msg_double()168         SEAL_NODISCARD std::vector<double> &msg_double()
169         {
170             return msg_double_;
171         }
172 
pt()173         SEAL_NODISCARD std::vector<seal::Plaintext> &pt()
174         {
175             return pt_;
176         }
177 
ct()178         SEAL_NODISCARD std::vector<seal::Ciphertext> &ct()
179         {
180             return ct_;
181         }
182 
183         /**
184         In most cases, the scale is chosen half as large as the second last prime (or the last if there is only one).
185         This avoids "scale out of bound" error in ciphertext/plaintext multiplications.
186         */
safe_scale()187         SEAL_NODISCARD double safe_scale()
188         {
189             return pow(2.0, (context_.first_context_data()->parms().coeff_modulus().end() - 1)->bit_count() / 2 - 1);
190         }
191 
192         /**
193         Fill a buffer with a number of random values that are uniformly samples from 0 ~ modulus - 1.
194         */
randomize_array_mod(std::uint64_t * data,std::size_t count,const seal::Modulus & modulus)195         void randomize_array_mod(std::uint64_t *data, std::size_t count, const seal::Modulus &modulus)
196         {
197             // For the purpose of benchmark, avoid using seal::UniformRandomGenerator, as it degrades
198             // performance with HEXL on some systems, due to AVX512 transitions.
199             // See https://travisdowns.github.io/blog/2020/01/17/avxfreq1.html#voltage-only-transitions.
200             // This method is not used for random number generation in Microsoft SEAL.
201             std::random_device rd;
202             std::mt19937_64 generator(rd());
203             std::uniform_int_distribution<std::uint64_t> dist(0, modulus.value() - 1);
204             std::generate(data, data + count, [&]() { return dist(generator); });
205         }
206 
207         /**
208         Sample an RNS polynomial from uniform distribution.
209         */
randomize_poly_rns(std::uint64_t * data,const seal::EncryptionParameters & parms)210         void randomize_poly_rns(std::uint64_t *data, const seal::EncryptionParameters &parms)
211         {
212             std::size_t coeff_count = parms.poly_modulus_degree();
213             std::vector<seal::Modulus> coeff_modulus = parms.coeff_modulus();
214             for (auto &i : coeff_modulus)
215             {
216                 randomize_array_mod(data, coeff_count, i);
217                 data += coeff_count;
218             }
219         }
220 
221         /**
222         Create a uniform random ciphertext in BFV using the highest-level parameters.
223         */
randomize_ct_bfv(seal::Ciphertext & ct)224         void randomize_ct_bfv(seal::Ciphertext &ct)
225         {
226             if (ct.parms_id() != context_.first_parms_id())
227             {
228                 ct.resize(context_, std::size_t(2));
229             }
230             auto &parms = context_.first_context_data()->parms();
231             for (std::size_t i = 0; i < ct.size(); i++)
232             {
233                 randomize_poly_rns(ct.data(i), parms);
234             }
235             ct.is_ntt_form() = false;
236         }
237 
238         /**
239         Create a uniform random ciphertext in CKKS using the highest-level parameters.
240         */
randomize_ct_ckks(seal::Ciphertext & ct)241         void randomize_ct_ckks(seal::Ciphertext &ct)
242         {
243             if (ct.parms_id() != context_.first_parms_id())
244             {
245                 ct.resize(context_, std::size_t(2));
246             }
247             auto &parms = context_.first_context_data()->parms();
248             for (std::size_t i = 0; i < ct.size(); i++)
249             {
250                 randomize_poly_rns(ct.data(i), parms);
251             }
252             ct.is_ntt_form() = true;
253         }
254 
255         /**
256         Create a uniform random plaintext (single modulus) in BFV.
257         */
randomize_pt_bfv(seal::Plaintext & pt)258         void randomize_pt_bfv(seal::Plaintext &pt)
259         {
260             pt.resize(parms_.poly_modulus_degree());
261             pt.parms_id() = seal::parms_id_zero;
262             randomize_array_mod(pt.data(), parms_.poly_modulus_degree(), parms_.plain_modulus());
263         }
264 
265         /**
266         Create a uniform random plaintext (RNS poly) in CKKS.
267         */
randomize_pt_ckks(seal::Plaintext & pt)268         void randomize_pt_ckks(seal::Plaintext &pt)
269         {
270             auto &parms = context_.first_context_data()->parms();
271             if (pt.coeff_count() != parms.poly_modulus_degree() * parms.coeff_modulus().size())
272             {
273                 pt.parms_id() = seal::parms_id_zero;
274                 pt.resize(parms.poly_modulus_degree() * parms.coeff_modulus().size());
275             }
276             if (pt.parms_id() != context_.first_parms_id())
277             {
278                 pt.parms_id() = context_.first_parms_id();
279             }
280             randomize_poly_rns(pt.data(), parms);
281         }
282 
283         /**
284         Create a vector of slot_count uniform random integers modulo plain_modululs.
285         */
randomize_message_uint64(std::vector<std::uint64_t> & msg)286         void randomize_message_uint64(std::vector<std::uint64_t> &msg)
287         {
288             msg.resize(batch_encoder_->slot_count());
289             randomize_array_mod(msg.data(), batch_encoder_->slot_count(), parms_.plain_modulus());
290         }
291 
292         /**
293         Create a vector of slot_count uniform random double precision values in [0, 1).
294         */
randomize_message_double(std::vector<double> & msg)295         void randomize_message_double(std::vector<double> &msg)
296         {
297             msg.resize(ckks_encoder_->slot_count());
298             std::generate(msg.begin(), msg.end(), []() { return static_cast<double>(std::rand()) / RAND_MAX; });
299         }
300 
301     private:
302         seal::EncryptionParameters parms_;
303         seal::SEALContext context_;
304         std::shared_ptr<seal::KeyGenerator> keygen_{ nullptr };
305         std::shared_ptr<seal::Encryptor> encryptor_{ nullptr };
306         std::shared_ptr<seal::Decryptor> decryptor_{ nullptr };
307         std::shared_ptr<seal::BatchEncoder> batch_encoder_{ nullptr };
308         std::shared_ptr<seal::CKKSEncoder> ckks_encoder_{ nullptr };
309         std::shared_ptr<seal::Evaluator> evaluator_{ nullptr };
310 
311         /**
312         The following data members are created as input/output containers for benchmark cases.
313         This avoids repeated and unnecessary allocation/deallocation in benchmark runs.
314         */
315         seal::SecretKey sk_;
316         seal::PublicKey pk_;
317         seal::RelinKeys rlk_;
318         seal::GaloisKeys glk_;
319         std::vector<std::uint32_t> galois_elts_all_;
320         std::vector<std::uint64_t> msg_uint64_;
321         std::vector<double> msg_double_;
322         std::vector<seal::Plaintext> pt_;
323         std::vector<seal::Ciphertext> ct_;
324     }; // namespace BMEnv
325 
326     // NTT benchmark cases
327     void bm_util_ntt_forward(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
328     void bm_util_ntt_inverse(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
329     void bm_util_ntt_forward_low_level(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
330     void bm_util_ntt_inverse_low_level(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
331     void bm_util_ntt_forward_low_level_lazy(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
332     void bm_util_ntt_inverse_low_level_lazy(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
333 
334     // KeyGen benchmark cases
335     void bm_keygen_secret(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
336     void bm_keygen_public(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
337     void bm_keygen_relin(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
338     void bm_keygen_galois(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
339 
340     // BFV-specific benchmark cases
341     void bm_bfv_encrypt_secret(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
342     void bm_bfv_encrypt_public(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
343     void bm_bfv_decrypt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
344     void bm_bfv_encode_batch(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
345     void bm_bfv_decode_batch(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
346     void bm_bfv_add_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
347     void bm_bfv_add_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
348     void bm_bfv_negate(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
349     void bm_bfv_sub_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
350     void bm_bfv_sub_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
351     void bm_bfv_mul_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
352     void bm_bfv_mul_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
353     void bm_bfv_square(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
354     void bm_bfv_modswitch_inplace(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
355     void bm_bfv_relin_inplace(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
356     void bm_bfv_rotate_rows(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
357     void bm_bfv_rotate_cols(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
358 
359     // CKKS-specific benchmark cases
360     void bm_ckks_encrypt_secret(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
361     void bm_ckks_encrypt_public(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
362     void bm_ckks_decrypt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
363     void bm_ckks_encode_double(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
364     void bm_ckks_decode_double(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
365     void bm_ckks_add_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
366     void bm_ckks_add_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
367     void bm_ckks_negate(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
368     void bm_ckks_sub_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
369     void bm_ckks_sub_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
370     void bm_ckks_mul_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
371     void bm_ckks_mul_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
372     void bm_ckks_square(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
373     void bm_ckks_rescale_inplace(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
374     void bm_ckks_relin_inplace(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
375     void bm_ckks_rotate(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
376 } // namespace sealbench
377