1 // Copyright (c) Microsoft Corporation. All rights reserved. 2 // Licensed under the MIT license. 3 4 #include "seal/encryptor.h" 5 #include "seal/modulus.h" 6 #include "seal/randomtostd.h" 7 #include "seal/util/common.h" 8 #include "seal/util/iterator.h" 9 #include "seal/util/polyarithsmallmod.h" 10 #include "seal/util/rlwe.h" 11 #include "seal/util/scalingvariant.h" 12 #include <algorithm> 13 #include <stdexcept> 14 15 using namespace std; 16 using namespace seal::util; 17 18 namespace seal 19 { Encryptor(const SEALContext & context,const PublicKey & public_key)20 Encryptor::Encryptor(const SEALContext &context, const PublicKey &public_key) : context_(context) 21 { 22 // Verify parameters 23 if (!context_.parameters_set()) 24 { 25 throw invalid_argument("encryption parameters are not set correctly"); 26 } 27 28 set_public_key(public_key); 29 30 auto &parms = context_.key_context_data()->parms(); 31 auto &coeff_modulus = parms.coeff_modulus(); 32 size_t coeff_count = parms.poly_modulus_degree(); 33 size_t coeff_modulus_size = coeff_modulus.size(); 34 35 // Quick sanity check 36 if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2))) 37 { 38 throw logic_error("invalid parameters"); 39 } 40 } 41 Encryptor(const SEALContext & context,const SecretKey & secret_key)42 Encryptor::Encryptor(const SEALContext &context, const SecretKey &secret_key) : context_(context) 43 { 44 // Verify parameters 45 if (!context_.parameters_set()) 46 { 47 throw invalid_argument("encryption parameters are not set correctly"); 48 } 49 50 set_secret_key(secret_key); 51 52 auto &parms = context_.key_context_data()->parms(); 53 auto &coeff_modulus = parms.coeff_modulus(); 54 size_t coeff_count = parms.poly_modulus_degree(); 55 size_t coeff_modulus_size = coeff_modulus.size(); 56 57 // Quick sanity check 58 if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2))) 59 { 60 throw logic_error("invalid parameters"); 61 } 62 } 63 Encryptor(const SEALContext & context,const PublicKey & public_key,const SecretKey & secret_key)64 Encryptor::Encryptor(const SEALContext &context, const PublicKey &public_key, const SecretKey &secret_key) 65 : context_(context) 66 { 67 // Verify parameters 68 if (!context_.parameters_set()) 69 { 70 throw invalid_argument("encryption parameters are not set correctly"); 71 } 72 73 set_public_key(public_key); 74 set_secret_key(secret_key); 75 76 auto &parms = context_.key_context_data()->parms(); 77 auto &coeff_modulus = parms.coeff_modulus(); 78 size_t coeff_count = parms.poly_modulus_degree(); 79 size_t coeff_modulus_size = coeff_modulus.size(); 80 81 // Quick sanity check 82 if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2))) 83 { 84 throw logic_error("invalid parameters"); 85 } 86 } 87 encrypt_zero_internal(parms_id_type parms_id,bool is_asymmetric,bool save_seed,Ciphertext & destination,MemoryPoolHandle pool) const88 void Encryptor::encrypt_zero_internal( 89 parms_id_type parms_id, bool is_asymmetric, bool save_seed, Ciphertext &destination, 90 MemoryPoolHandle pool) const 91 { 92 // Verify parameters. 93 if (!pool) 94 { 95 throw invalid_argument("pool is uninitialized"); 96 } 97 98 auto context_data_ptr = context_.get_context_data(parms_id); 99 if (!context_data_ptr) 100 { 101 throw invalid_argument("parms_id is not valid for encryption parameters"); 102 } 103 104 auto &context_data = *context_.get_context_data(parms_id); 105 auto &parms = context_data.parms(); 106 size_t coeff_modulus_size = parms.coeff_modulus().size(); 107 size_t coeff_count = parms.poly_modulus_degree(); 108 bool is_ntt_form = false; 109 110 if (parms.scheme() == scheme_type::ckks) 111 { 112 is_ntt_form = true; 113 } 114 else if (parms.scheme() != scheme_type::bfv) 115 { 116 throw invalid_argument("unsupported scheme"); 117 } 118 119 // Resize destination and save results 120 destination.resize(context_, parms_id, 2); 121 122 // If asymmetric key encryption 123 if (is_asymmetric) 124 { 125 auto prev_context_data_ptr = context_data.prev_context_data(); 126 if (prev_context_data_ptr) 127 { 128 // Requires modulus switching 129 auto &prev_context_data = *prev_context_data_ptr; 130 auto &prev_parms_id = prev_context_data.parms_id(); 131 auto rns_tool = prev_context_data.rns_tool(); 132 133 // Zero encryption without modulus switching 134 Ciphertext temp(pool); 135 util::encrypt_zero_asymmetric(public_key_, context_, prev_parms_id, is_ntt_form, temp); 136 137 // Modulus switching 138 SEAL_ITERATE(iter(temp, destination), temp.size(), [&](auto I) { 139 if (is_ntt_form) 140 { 141 rns_tool->divide_and_round_q_last_ntt_inplace( 142 get<0>(I), prev_context_data.small_ntt_tables(), pool); 143 } 144 else 145 { 146 rns_tool->divide_and_round_q_last_inplace(get<0>(I), pool); 147 } 148 set_poly(get<0>(I), coeff_count, coeff_modulus_size, get<1>(I)); 149 }); 150 151 destination.is_ntt_form() = is_ntt_form; 152 destination.scale() = temp.scale(); 153 destination.parms_id() = parms_id; 154 } 155 else 156 { 157 // Does not require modulus switching 158 util::encrypt_zero_asymmetric(public_key_, context_, parms_id, is_ntt_form, destination); 159 } 160 } 161 else 162 { 163 // Does not require modulus switching 164 util::encrypt_zero_symmetric(secret_key_, context_, parms_id, is_ntt_form, save_seed, destination); 165 } 166 } 167 encrypt_internal(const Plaintext & plain,bool is_asymmetric,bool save_seed,Ciphertext & destination,MemoryPoolHandle pool) const168 void Encryptor::encrypt_internal( 169 const Plaintext &plain, bool is_asymmetric, bool save_seed, Ciphertext &destination, 170 MemoryPoolHandle pool) const 171 { 172 // Minimal verification that the keys are set 173 if (is_asymmetric) 174 { 175 if (!is_metadata_valid_for(public_key_, context_)) 176 { 177 throw logic_error("public key is not set"); 178 } 179 } 180 else 181 { 182 if (!is_metadata_valid_for(secret_key_, context_)) 183 { 184 throw logic_error("secret key is not set"); 185 } 186 } 187 188 // Verify that plain is valid. 189 if (!is_metadata_valid_for(plain, context_) || !is_buffer_valid(plain)) 190 { 191 throw invalid_argument("plain is not valid for encryption parameters"); 192 } 193 194 auto scheme = context_.key_context_data()->parms().scheme(); 195 if (scheme == scheme_type::bfv) 196 { 197 if (plain.is_ntt_form()) 198 { 199 throw invalid_argument("plain cannot be in NTT form"); 200 } 201 202 encrypt_zero_internal(context_.first_parms_id(), is_asymmetric, save_seed, destination, pool); 203 204 // Multiply plain by scalar coeff_div_plaintext and reposition if in upper-half. 205 // Result gets added into the c_0 term of ciphertext (c_0,c_1). 206 multiply_add_plain_with_scaling_variant(plain, *context_.first_context_data(), *iter(destination)); 207 } 208 else if (scheme == scheme_type::ckks) 209 { 210 if (!plain.is_ntt_form()) 211 { 212 throw invalid_argument("plain must be in NTT form"); 213 } 214 215 auto context_data_ptr = context_.get_context_data(plain.parms_id()); 216 if (!context_data_ptr) 217 { 218 throw invalid_argument("plain is not valid for encryption parameters"); 219 } 220 encrypt_zero_internal(plain.parms_id(), is_asymmetric, save_seed, destination, pool); 221 222 auto &parms = context_.get_context_data(plain.parms_id())->parms(); 223 auto &coeff_modulus = parms.coeff_modulus(); 224 size_t coeff_modulus_size = coeff_modulus.size(); 225 size_t coeff_count = parms.poly_modulus_degree(); 226 227 // The plaintext gets added into the c_0 term of ciphertext (c_0,c_1). 228 ConstRNSIter plain_iter(plain.data(), coeff_count); 229 RNSIter destination_iter = *iter(destination); 230 add_poly_coeffmod(destination_iter, plain_iter, coeff_modulus_size, coeff_modulus, destination_iter); 231 232 destination.scale() = plain.scale(); 233 } 234 else 235 { 236 throw invalid_argument("unsupported scheme"); 237 } 238 } 239 } // namespace seal 240