1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/encryptor.h"
5 #include "seal/modulus.h"
6 #include "seal/randomtostd.h"
7 #include "seal/util/common.h"
8 #include "seal/util/iterator.h"
9 #include "seal/util/polyarithsmallmod.h"
10 #include "seal/util/rlwe.h"
11 #include "seal/util/scalingvariant.h"
12 #include <algorithm>
13 #include <stdexcept>
14 
15 using namespace std;
16 using namespace seal::util;
17 
18 namespace seal
19 {
Encryptor(const SEALContext & context,const PublicKey & public_key)20     Encryptor::Encryptor(const SEALContext &context, const PublicKey &public_key) : context_(context)
21     {
22         // Verify parameters
23         if (!context_.parameters_set())
24         {
25             throw invalid_argument("encryption parameters are not set correctly");
26         }
27 
28         set_public_key(public_key);
29 
30         auto &parms = context_.key_context_data()->parms();
31         auto &coeff_modulus = parms.coeff_modulus();
32         size_t coeff_count = parms.poly_modulus_degree();
33         size_t coeff_modulus_size = coeff_modulus.size();
34 
35         // Quick sanity check
36         if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2)))
37         {
38             throw logic_error("invalid parameters");
39         }
40     }
41 
Encryptor(const SEALContext & context,const SecretKey & secret_key)42     Encryptor::Encryptor(const SEALContext &context, const SecretKey &secret_key) : context_(context)
43     {
44         // Verify parameters
45         if (!context_.parameters_set())
46         {
47             throw invalid_argument("encryption parameters are not set correctly");
48         }
49 
50         set_secret_key(secret_key);
51 
52         auto &parms = context_.key_context_data()->parms();
53         auto &coeff_modulus = parms.coeff_modulus();
54         size_t coeff_count = parms.poly_modulus_degree();
55         size_t coeff_modulus_size = coeff_modulus.size();
56 
57         // Quick sanity check
58         if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2)))
59         {
60             throw logic_error("invalid parameters");
61         }
62     }
63 
Encryptor(const SEALContext & context,const PublicKey & public_key,const SecretKey & secret_key)64     Encryptor::Encryptor(const SEALContext &context, const PublicKey &public_key, const SecretKey &secret_key)
65         : context_(context)
66     {
67         // Verify parameters
68         if (!context_.parameters_set())
69         {
70             throw invalid_argument("encryption parameters are not set correctly");
71         }
72 
73         set_public_key(public_key);
74         set_secret_key(secret_key);
75 
76         auto &parms = context_.key_context_data()->parms();
77         auto &coeff_modulus = parms.coeff_modulus();
78         size_t coeff_count = parms.poly_modulus_degree();
79         size_t coeff_modulus_size = coeff_modulus.size();
80 
81         // Quick sanity check
82         if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2)))
83         {
84             throw logic_error("invalid parameters");
85         }
86     }
87 
encrypt_zero_internal(parms_id_type parms_id,bool is_asymmetric,bool save_seed,Ciphertext & destination,MemoryPoolHandle pool) const88     void Encryptor::encrypt_zero_internal(
89         parms_id_type parms_id, bool is_asymmetric, bool save_seed, Ciphertext &destination,
90         MemoryPoolHandle pool) const
91     {
92         // Verify parameters.
93         if (!pool)
94         {
95             throw invalid_argument("pool is uninitialized");
96         }
97 
98         auto context_data_ptr = context_.get_context_data(parms_id);
99         if (!context_data_ptr)
100         {
101             throw invalid_argument("parms_id is not valid for encryption parameters");
102         }
103 
104         auto &context_data = *context_.get_context_data(parms_id);
105         auto &parms = context_data.parms();
106         size_t coeff_modulus_size = parms.coeff_modulus().size();
107         size_t coeff_count = parms.poly_modulus_degree();
108         bool is_ntt_form = false;
109 
110         if (parms.scheme() == scheme_type::ckks)
111         {
112             is_ntt_form = true;
113         }
114         else if (parms.scheme() != scheme_type::bfv)
115         {
116             throw invalid_argument("unsupported scheme");
117         }
118 
119         // Resize destination and save results
120         destination.resize(context_, parms_id, 2);
121 
122         // If asymmetric key encryption
123         if (is_asymmetric)
124         {
125             auto prev_context_data_ptr = context_data.prev_context_data();
126             if (prev_context_data_ptr)
127             {
128                 // Requires modulus switching
129                 auto &prev_context_data = *prev_context_data_ptr;
130                 auto &prev_parms_id = prev_context_data.parms_id();
131                 auto rns_tool = prev_context_data.rns_tool();
132 
133                 // Zero encryption without modulus switching
134                 Ciphertext temp(pool);
135                 util::encrypt_zero_asymmetric(public_key_, context_, prev_parms_id, is_ntt_form, temp);
136 
137                 // Modulus switching
138                 SEAL_ITERATE(iter(temp, destination), temp.size(), [&](auto I) {
139                     if (is_ntt_form)
140                     {
141                         rns_tool->divide_and_round_q_last_ntt_inplace(
142                             get<0>(I), prev_context_data.small_ntt_tables(), pool);
143                     }
144                     else
145                     {
146                         rns_tool->divide_and_round_q_last_inplace(get<0>(I), pool);
147                     }
148                     set_poly(get<0>(I), coeff_count, coeff_modulus_size, get<1>(I));
149                 });
150 
151                 destination.is_ntt_form() = is_ntt_form;
152                 destination.scale() = temp.scale();
153                 destination.parms_id() = parms_id;
154             }
155             else
156             {
157                 // Does not require modulus switching
158                 util::encrypt_zero_asymmetric(public_key_, context_, parms_id, is_ntt_form, destination);
159             }
160         }
161         else
162         {
163             // Does not require modulus switching
164             util::encrypt_zero_symmetric(secret_key_, context_, parms_id, is_ntt_form, save_seed, destination);
165         }
166     }
167 
encrypt_internal(const Plaintext & plain,bool is_asymmetric,bool save_seed,Ciphertext & destination,MemoryPoolHandle pool) const168     void Encryptor::encrypt_internal(
169         const Plaintext &plain, bool is_asymmetric, bool save_seed, Ciphertext &destination,
170         MemoryPoolHandle pool) const
171     {
172         // Minimal verification that the keys are set
173         if (is_asymmetric)
174         {
175             if (!is_metadata_valid_for(public_key_, context_))
176             {
177                 throw logic_error("public key is not set");
178             }
179         }
180         else
181         {
182             if (!is_metadata_valid_for(secret_key_, context_))
183             {
184                 throw logic_error("secret key is not set");
185             }
186         }
187 
188         // Verify that plain is valid.
189         if (!is_metadata_valid_for(plain, context_) || !is_buffer_valid(plain))
190         {
191             throw invalid_argument("plain is not valid for encryption parameters");
192         }
193 
194         auto scheme = context_.key_context_data()->parms().scheme();
195         if (scheme == scheme_type::bfv)
196         {
197             if (plain.is_ntt_form())
198             {
199                 throw invalid_argument("plain cannot be in NTT form");
200             }
201 
202             encrypt_zero_internal(context_.first_parms_id(), is_asymmetric, save_seed, destination, pool);
203 
204             // Multiply plain by scalar coeff_div_plaintext and reposition if in upper-half.
205             // Result gets added into the c_0 term of ciphertext (c_0,c_1).
206             multiply_add_plain_with_scaling_variant(plain, *context_.first_context_data(), *iter(destination));
207         }
208         else if (scheme == scheme_type::ckks)
209         {
210             if (!plain.is_ntt_form())
211             {
212                 throw invalid_argument("plain must be in NTT form");
213             }
214 
215             auto context_data_ptr = context_.get_context_data(plain.parms_id());
216             if (!context_data_ptr)
217             {
218                 throw invalid_argument("plain is not valid for encryption parameters");
219             }
220             encrypt_zero_internal(plain.parms_id(), is_asymmetric, save_seed, destination, pool);
221 
222             auto &parms = context_.get_context_data(plain.parms_id())->parms();
223             auto &coeff_modulus = parms.coeff_modulus();
224             size_t coeff_modulus_size = coeff_modulus.size();
225             size_t coeff_count = parms.poly_modulus_degree();
226 
227             // The plaintext gets added into the c_0 term of ciphertext (c_0,c_1).
228             ConstRNSIter plain_iter(plain.data(), coeff_count);
229             RNSIter destination_iter = *iter(destination);
230             add_poly_coeffmod(destination_iter, plain_iter, coeff_modulus_size, coeff_modulus, destination_iter);
231 
232             destination.scale() = plain.scale();
233         }
234         else
235         {
236             throw invalid_argument("unsupported scheme");
237         }
238     }
239 } // namespace seal
240