1 // Copyright (c) Microsoft Corporation. All rights reserved. 2 // Licensed under the MIT license. 3 4 #include "seal/keygenerator.h" 5 #include "seal/randomgen.h" 6 #include <algorithm> 7 #include <array> 8 #include <cstdint> 9 #include <memory> 10 #include <numeric> 11 #include <set> 12 #include <sstream> 13 #include <thread> 14 #include "gtest/gtest.h" 15 16 using namespace seal; 17 using namespace std; 18 19 namespace sealtest 20 { 21 namespace 22 { 23 class SequentialRandomGenerator : public UniformRandomGenerator 24 { 25 public: SequentialRandomGenerator(const prng_seed_type & seed)26 SequentialRandomGenerator(const prng_seed_type &seed) : UniformRandomGenerator(seed) 27 {} 28 SequentialRandomGenerator()29 SequentialRandomGenerator() : UniformRandomGenerator({}) 30 {} 31 32 ~SequentialRandomGenerator() override = default; 33 34 protected: refill_buffer()35 void refill_buffer() override 36 { 37 iota(reinterpret_cast<uint8_t *>(buffer_begin_), reinterpret_cast<uint8_t *>(buffer_end_), value); 38 39 value = static_cast<uint8_t>(static_cast<size_t>(value) + buffer_size_); 40 } 41 type() const42 SEAL_NODISCARD prng_type type() const noexcept override 43 { 44 return prng_type::unknown; 45 } 46 47 private: 48 uint8_t value = 0; 49 }; 50 51 class SequentialRandomGeneratorFactory : public UniformRandomGeneratorFactory 52 { 53 private: create_impl(SEAL_MAYBE_UNUSED prng_seed_type seed)54 SEAL_NODISCARD auto create_impl(SEAL_MAYBE_UNUSED prng_seed_type seed) 55 -> shared_ptr<UniformRandomGenerator> override 56 { 57 return make_shared<SequentialRandomGenerator>(); 58 } 59 }; 60 } // namespace 61 TEST(RandomGenerator,UniformRandomCreateDefault)62 TEST(RandomGenerator, UniformRandomCreateDefault) 63 { 64 shared_ptr<UniformRandomGenerator> generator(UniformRandomGeneratorFactory::DefaultFactory()->create()); 65 ASSERT_TRUE(UniformRandomGeneratorFactory::DefaultFactory()->use_random_seed()); 66 67 bool lower_half = false; 68 bool upper_half = false; 69 bool even = false; 70 bool odd = false; 71 for (int i = 0; i < 20; ++i) 72 { 73 uint32_t value = generator->generate(); 74 if (value < UINT32_MAX / 2) 75 { 76 lower_half = true; 77 } 78 else 79 { 80 upper_half = true; 81 } 82 if ((value % 2) == 0) 83 { 84 even = true; 85 } 86 else 87 { 88 odd = true; 89 } 90 } 91 ASSERT_TRUE(lower_half); 92 ASSERT_TRUE(upper_half); 93 ASSERT_TRUE(even); 94 ASSERT_TRUE(odd); 95 } 96 TEST(RandomGenerator,RandomGeneratorFactorySeed)97 TEST(RandomGenerator, RandomGeneratorFactorySeed) 98 { 99 shared_ptr<UniformRandomGeneratorFactory> factory(make_shared<Blake2xbPRNGFactory>()); 100 ASSERT_TRUE(factory->use_random_seed()); 101 102 factory = make_shared<Blake2xbPRNGFactory>(prng_seed_type{}); 103 ASSERT_FALSE(factory->use_random_seed()); 104 ASSERT_EQ(prng_seed_type{}, factory->default_seed()); 105 106 factory = make_shared<Blake2xbPRNGFactory>(prng_seed_type{ 1, 2, 3, 4, 5, 6, 7, 8 }); 107 ASSERT_FALSE(factory->use_random_seed()); 108 ASSERT_EQ(prng_seed_type({ 1, 2, 3, 4, 5, 6, 7, 8 }), factory->default_seed()); 109 110 factory = make_shared<Blake2xbPRNGFactory>(); 111 ASSERT_TRUE(factory->use_random_seed()); 112 } 113 TEST(RandomGenerator,SequentialRandomGenerator)114 TEST(RandomGenerator, SequentialRandomGenerator) 115 { 116 unique_ptr<UniformRandomGenerator> sgen = make_unique<SequentialRandomGenerator>(); 117 array<uint8_t, 4096> value_list; 118 iota(value_list.begin(), value_list.end(), 0); 119 120 array<uint8_t, 4096> compare_list; 121 sgen->generate(4096, reinterpret_cast<seal_byte *>(compare_list.data())); 122 123 ASSERT_TRUE(equal(value_list.cbegin(), value_list.cend(), compare_list.cbegin())); 124 } 125 TEST(RandomGenerator,RandomUInt64)126 TEST(RandomGenerator, RandomUInt64) 127 { 128 set<uint64_t> values; 129 size_t count = 100; 130 for (size_t i = 0; i < count; i++) 131 { 132 values.emplace(random_uint64()); 133 } 134 ASSERT_EQ(count, values.size()); 135 } 136 TEST(RandomGenerator,SeededRNG)137 TEST(RandomGenerator, SeededRNG) 138 { 139 auto generator1(UniformRandomGeneratorFactory::DefaultFactory()->create({})); 140 141 array<uint32_t, 20> values1; 142 generator1->generate(sizeof(values1), reinterpret_cast<seal_byte *>(values1.data())); 143 144 auto generator2(UniformRandomGeneratorFactory::DefaultFactory()->create({ 1 })); 145 array<uint32_t, 20> values2; 146 generator2->generate(sizeof(values2), reinterpret_cast<seal_byte *>(values2.data())); 147 148 auto generator3(UniformRandomGeneratorFactory::DefaultFactory()->create({ 1 })); 149 array<uint32_t, 20> values3; 150 generator3->generate(sizeof(values3), reinterpret_cast<seal_byte *>(values3.data())); 151 152 for (size_t i = 0; i < values1.size(); i++) 153 { 154 ASSERT_NE(values1[i], values2[i]); 155 ASSERT_EQ(values2[i], values3[i]); 156 } 157 158 uint32_t val1, val2, val3; 159 val1 = generator1->generate(); 160 val2 = generator2->generate(); 161 val3 = generator3->generate(); 162 ASSERT_NE(val1, val2); 163 ASSERT_EQ(val2, val3); 164 } 165 TEST(RandomGenerator,RandomSeededRNG)166 TEST(RandomGenerator, RandomSeededRNG) 167 { 168 auto generator1(UniformRandomGeneratorFactory::DefaultFactory()->create()); 169 array<uint32_t, 20> values1; 170 generator1->generate(sizeof(values1), reinterpret_cast<seal_byte *>(values1.data())); 171 172 auto generator2(UniformRandomGeneratorFactory::DefaultFactory()->create()); 173 array<uint32_t, 20> values2; 174 generator2->generate(sizeof(values2), reinterpret_cast<seal_byte *>(values2.data())); 175 176 auto seed3 = generator2->seed(); 177 auto generator3(UniformRandomGeneratorFactory::DefaultFactory()->create(seed3)); 178 array<uint32_t, 20> values3; 179 generator3->generate(sizeof(values3), reinterpret_cast<seal_byte *>(values3.data())); 180 181 for (size_t i = 0; i < values1.size(); i++) 182 { 183 ASSERT_NE(values1[i], values2[i]); 184 ASSERT_EQ(values2[i], values3[i]); 185 } 186 187 uint32_t val1, val2, val3; 188 val1 = generator1->generate(); 189 val2 = generator2->generate(); 190 val3 = generator3->generate(); 191 ASSERT_NE(val1, val2); 192 ASSERT_EQ(val2, val3); 193 } 194 TEST(RandomGenerator,MultiThreaded)195 TEST(RandomGenerator, MultiThreaded) 196 { 197 constexpr size_t thread_count = 2; 198 constexpr size_t numbers_per_thread = 50; 199 array<uint64_t, thread_count * numbers_per_thread> results; 200 201 auto generator(UniformRandomGeneratorFactory::DefaultFactory()->create()); 202 203 vector<thread> th_vec; 204 for (size_t i = 0; i < thread_count; i++) 205 { 206 auto th_func = [&, generator, i]() { 207 generator->generate( 208 sizeof(uint64_t) * numbers_per_thread, 209 reinterpret_cast<seal_byte *>(results.data() + numbers_per_thread * i)); 210 }; 211 th_vec.emplace_back(th_func); 212 } 213 214 for (auto &th : th_vec) 215 { 216 th.join(); 217 } 218 219 auto seed = generator->seed(); 220 auto generator2(UniformRandomGeneratorFactory::DefaultFactory()->create(seed)); 221 for (size_t i = 0; i < thread_count * numbers_per_thread; i++) 222 { 223 uint64_t value = 0; 224 generator2->generate(sizeof(value), reinterpret_cast<seal_byte *>(&value)); 225 ASSERT_TRUE(find(results.begin(), results.end(), value) != results.end()); 226 } 227 } 228 TEST(RandomGenerator,UniformRandomGeneratorInfo)229 TEST(RandomGenerator, UniformRandomGeneratorInfo) 230 { 231 UniformRandomGeneratorInfo info; 232 ASSERT_EQ(prng_type::unknown, info.type()); 233 ASSERT_TRUE(info.has_valid_prng_type()); 234 prng_seed_type seed_arr = {}; 235 ASSERT_EQ(seed_arr, info.seed()); 236 237 seed_arr = { 1, 2, 3, 4, 5, 6, 7, 8 }; 238 { 239 shared_ptr<UniformRandomGenerator> rg(make_unique<Blake2xbPRNG>(seed_arr)); 240 info = rg->info(); 241 242 ASSERT_EQ(prng_type::blake2xb, info.type()); 243 ASSERT_TRUE(info.has_valid_prng_type()); 244 ASSERT_EQ(seed_arr, info.seed()); 245 246 auto rg2 = info.make_prng(); 247 ASSERT_TRUE(rg2); 248 for (int i = 0; i < 100; i++) 249 { 250 ASSERT_EQ(rg->generate(), rg2->generate()); 251 } 252 } 253 { 254 shared_ptr<UniformRandomGenerator> rg(make_unique<Shake256PRNG>(seed_arr)); 255 info = rg->info(); 256 257 ASSERT_EQ(prng_type::shake256, info.type()); 258 ASSERT_TRUE(info.has_valid_prng_type()); 259 ASSERT_EQ(seed_arr, info.seed()); 260 261 auto rg2 = info.make_prng(); 262 ASSERT_TRUE(rg2); 263 for (int i = 0; i < 100; i++) 264 { 265 ASSERT_EQ(rg->generate(), rg2->generate()); 266 } 267 } 268 { 269 shared_ptr<UniformRandomGenerator> rg(make_unique<SequentialRandomGenerator>(seed_arr)); 270 info = rg->info(); 271 272 ASSERT_EQ(prng_type::unknown, info.type()); 273 ASSERT_TRUE(info.has_valid_prng_type()); 274 ASSERT_EQ(seed_arr, info.seed()); 275 276 auto rg2 = info.make_prng(); 277 ASSERT_FALSE(rg2); 278 } 279 } 280 TEST(RandomGenerator,UniformRandomGeneratorInfoSaveLoad)281 TEST(RandomGenerator, UniformRandomGeneratorInfoSaveLoad) 282 { 283 UniformRandomGeneratorInfo info, info2; 284 stringstream ss; 285 auto size = info.save(ss, compr_mode_type::none); 286 ASSERT_EQ(size, info.save_size(compr_mode_type::none)); 287 info2.load(ss); 288 ASSERT_TRUE(info == info2); 289 290 prng_seed_type seed_arr = { 1, 2, 3, 4, 5, 6, 7, 8 }; 291 { 292 shared_ptr<UniformRandomGenerator> rg(make_unique<Blake2xbPRNG>(seed_arr)); 293 info = rg->info(); 294 info.save(ss); 295 info2.load(ss); 296 ASSERT_TRUE(info == info2); 297 } 298 { 299 shared_ptr<UniformRandomGenerator> rg(make_unique<Shake256PRNG>(seed_arr)); 300 info = rg->info(); 301 info.save(ss); 302 info2.load(ss); 303 ASSERT_TRUE(info == info2); 304 } 305 } 306 } // namespace sealtest 307