1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/keygenerator.h"
5 #include "seal/randomgen.h"
6 #include <algorithm>
7 #include <array>
8 #include <cstdint>
9 #include <memory>
10 #include <numeric>
11 #include <set>
12 #include <sstream>
13 #include <thread>
14 #include "gtest/gtest.h"
15 
16 using namespace seal;
17 using namespace std;
18 
19 namespace sealtest
20 {
21     namespace
22     {
23         class SequentialRandomGenerator : public UniformRandomGenerator
24         {
25         public:
SequentialRandomGenerator(const prng_seed_type & seed)26             SequentialRandomGenerator(const prng_seed_type &seed) : UniformRandomGenerator(seed)
27             {}
28 
SequentialRandomGenerator()29             SequentialRandomGenerator() : UniformRandomGenerator({})
30             {}
31 
32             ~SequentialRandomGenerator() override = default;
33 
34         protected:
refill_buffer()35             void refill_buffer() override
36             {
37                 iota(reinterpret_cast<uint8_t *>(buffer_begin_), reinterpret_cast<uint8_t *>(buffer_end_), value);
38 
39                 value = static_cast<uint8_t>(static_cast<size_t>(value) + buffer_size_);
40             }
41 
type() const42             SEAL_NODISCARD prng_type type() const noexcept override
43             {
44                 return prng_type::unknown;
45             }
46 
47         private:
48             uint8_t value = 0;
49         };
50 
51         class SequentialRandomGeneratorFactory : public UniformRandomGeneratorFactory
52         {
53         private:
create_impl(SEAL_MAYBE_UNUSED prng_seed_type seed)54             SEAL_NODISCARD auto create_impl(SEAL_MAYBE_UNUSED prng_seed_type seed)
55                 -> shared_ptr<UniformRandomGenerator> override
56             {
57                 return make_shared<SequentialRandomGenerator>();
58             }
59         };
60     } // namespace
61 
TEST(RandomGenerator,UniformRandomCreateDefault)62     TEST(RandomGenerator, UniformRandomCreateDefault)
63     {
64         shared_ptr<UniformRandomGenerator> generator(UniformRandomGeneratorFactory::DefaultFactory()->create());
65         ASSERT_TRUE(UniformRandomGeneratorFactory::DefaultFactory()->use_random_seed());
66 
67         bool lower_half = false;
68         bool upper_half = false;
69         bool even = false;
70         bool odd = false;
71         for (int i = 0; i < 20; ++i)
72         {
73             uint32_t value = generator->generate();
74             if (value < UINT32_MAX / 2)
75             {
76                 lower_half = true;
77             }
78             else
79             {
80                 upper_half = true;
81             }
82             if ((value % 2) == 0)
83             {
84                 even = true;
85             }
86             else
87             {
88                 odd = true;
89             }
90         }
91         ASSERT_TRUE(lower_half);
92         ASSERT_TRUE(upper_half);
93         ASSERT_TRUE(even);
94         ASSERT_TRUE(odd);
95     }
96 
TEST(RandomGenerator,RandomGeneratorFactorySeed)97     TEST(RandomGenerator, RandomGeneratorFactorySeed)
98     {
99         shared_ptr<UniformRandomGeneratorFactory> factory(make_shared<Blake2xbPRNGFactory>());
100         ASSERT_TRUE(factory->use_random_seed());
101 
102         factory = make_shared<Blake2xbPRNGFactory>(prng_seed_type{});
103         ASSERT_FALSE(factory->use_random_seed());
104         ASSERT_EQ(prng_seed_type{}, factory->default_seed());
105 
106         factory = make_shared<Blake2xbPRNGFactory>(prng_seed_type{ 1, 2, 3, 4, 5, 6, 7, 8 });
107         ASSERT_FALSE(factory->use_random_seed());
108         ASSERT_EQ(prng_seed_type({ 1, 2, 3, 4, 5, 6, 7, 8 }), factory->default_seed());
109 
110         factory = make_shared<Blake2xbPRNGFactory>();
111         ASSERT_TRUE(factory->use_random_seed());
112     }
113 
TEST(RandomGenerator,SequentialRandomGenerator)114     TEST(RandomGenerator, SequentialRandomGenerator)
115     {
116         unique_ptr<UniformRandomGenerator> sgen = make_unique<SequentialRandomGenerator>();
117         array<uint8_t, 4096> value_list;
118         iota(value_list.begin(), value_list.end(), 0);
119 
120         array<uint8_t, 4096> compare_list;
121         sgen->generate(4096, reinterpret_cast<seal_byte *>(compare_list.data()));
122 
123         ASSERT_TRUE(equal(value_list.cbegin(), value_list.cend(), compare_list.cbegin()));
124     }
125 
TEST(RandomGenerator,RandomUInt64)126     TEST(RandomGenerator, RandomUInt64)
127     {
128         set<uint64_t> values;
129         size_t count = 100;
130         for (size_t i = 0; i < count; i++)
131         {
132             values.emplace(random_uint64());
133         }
134         ASSERT_EQ(count, values.size());
135     }
136 
TEST(RandomGenerator,SeededRNG)137     TEST(RandomGenerator, SeededRNG)
138     {
139         auto generator1(UniformRandomGeneratorFactory::DefaultFactory()->create({}));
140 
141         array<uint32_t, 20> values1;
142         generator1->generate(sizeof(values1), reinterpret_cast<seal_byte *>(values1.data()));
143 
144         auto generator2(UniformRandomGeneratorFactory::DefaultFactory()->create({ 1 }));
145         array<uint32_t, 20> values2;
146         generator2->generate(sizeof(values2), reinterpret_cast<seal_byte *>(values2.data()));
147 
148         auto generator3(UniformRandomGeneratorFactory::DefaultFactory()->create({ 1 }));
149         array<uint32_t, 20> values3;
150         generator3->generate(sizeof(values3), reinterpret_cast<seal_byte *>(values3.data()));
151 
152         for (size_t i = 0; i < values1.size(); i++)
153         {
154             ASSERT_NE(values1[i], values2[i]);
155             ASSERT_EQ(values2[i], values3[i]);
156         }
157 
158         uint32_t val1, val2, val3;
159         val1 = generator1->generate();
160         val2 = generator2->generate();
161         val3 = generator3->generate();
162         ASSERT_NE(val1, val2);
163         ASSERT_EQ(val2, val3);
164     }
165 
TEST(RandomGenerator,RandomSeededRNG)166     TEST(RandomGenerator, RandomSeededRNG)
167     {
168         auto generator1(UniformRandomGeneratorFactory::DefaultFactory()->create());
169         array<uint32_t, 20> values1;
170         generator1->generate(sizeof(values1), reinterpret_cast<seal_byte *>(values1.data()));
171 
172         auto generator2(UniformRandomGeneratorFactory::DefaultFactory()->create());
173         array<uint32_t, 20> values2;
174         generator2->generate(sizeof(values2), reinterpret_cast<seal_byte *>(values2.data()));
175 
176         auto seed3 = generator2->seed();
177         auto generator3(UniformRandomGeneratorFactory::DefaultFactory()->create(seed3));
178         array<uint32_t, 20> values3;
179         generator3->generate(sizeof(values3), reinterpret_cast<seal_byte *>(values3.data()));
180 
181         for (size_t i = 0; i < values1.size(); i++)
182         {
183             ASSERT_NE(values1[i], values2[i]);
184             ASSERT_EQ(values2[i], values3[i]);
185         }
186 
187         uint32_t val1, val2, val3;
188         val1 = generator1->generate();
189         val2 = generator2->generate();
190         val3 = generator3->generate();
191         ASSERT_NE(val1, val2);
192         ASSERT_EQ(val2, val3);
193     }
194 
TEST(RandomGenerator,MultiThreaded)195     TEST(RandomGenerator, MultiThreaded)
196     {
197         constexpr size_t thread_count = 2;
198         constexpr size_t numbers_per_thread = 50;
199         array<uint64_t, thread_count * numbers_per_thread> results;
200 
201         auto generator(UniformRandomGeneratorFactory::DefaultFactory()->create());
202 
203         vector<thread> th_vec;
204         for (size_t i = 0; i < thread_count; i++)
205         {
206             auto th_func = [&, generator, i]() {
207                 generator->generate(
208                     sizeof(uint64_t) * numbers_per_thread,
209                     reinterpret_cast<seal_byte *>(results.data() + numbers_per_thread * i));
210             };
211             th_vec.emplace_back(th_func);
212         }
213 
214         for (auto &th : th_vec)
215         {
216             th.join();
217         }
218 
219         auto seed = generator->seed();
220         auto generator2(UniformRandomGeneratorFactory::DefaultFactory()->create(seed));
221         for (size_t i = 0; i < thread_count * numbers_per_thread; i++)
222         {
223             uint64_t value = 0;
224             generator2->generate(sizeof(value), reinterpret_cast<seal_byte *>(&value));
225             ASSERT_TRUE(find(results.begin(), results.end(), value) != results.end());
226         }
227     }
228 
TEST(RandomGenerator,UniformRandomGeneratorInfo)229     TEST(RandomGenerator, UniformRandomGeneratorInfo)
230     {
231         UniformRandomGeneratorInfo info;
232         ASSERT_EQ(prng_type::unknown, info.type());
233         ASSERT_TRUE(info.has_valid_prng_type());
234         prng_seed_type seed_arr = {};
235         ASSERT_EQ(seed_arr, info.seed());
236 
237         seed_arr = { 1, 2, 3, 4, 5, 6, 7, 8 };
238         {
239             shared_ptr<UniformRandomGenerator> rg(make_unique<Blake2xbPRNG>(seed_arr));
240             info = rg->info();
241 
242             ASSERT_EQ(prng_type::blake2xb, info.type());
243             ASSERT_TRUE(info.has_valid_prng_type());
244             ASSERT_EQ(seed_arr, info.seed());
245 
246             auto rg2 = info.make_prng();
247             ASSERT_TRUE(rg2);
248             for (int i = 0; i < 100; i++)
249             {
250                 ASSERT_EQ(rg->generate(), rg2->generate());
251             }
252         }
253         {
254             shared_ptr<UniformRandomGenerator> rg(make_unique<Shake256PRNG>(seed_arr));
255             info = rg->info();
256 
257             ASSERT_EQ(prng_type::shake256, info.type());
258             ASSERT_TRUE(info.has_valid_prng_type());
259             ASSERT_EQ(seed_arr, info.seed());
260 
261             auto rg2 = info.make_prng();
262             ASSERT_TRUE(rg2);
263             for (int i = 0; i < 100; i++)
264             {
265                 ASSERT_EQ(rg->generate(), rg2->generate());
266             }
267         }
268         {
269             shared_ptr<UniformRandomGenerator> rg(make_unique<SequentialRandomGenerator>(seed_arr));
270             info = rg->info();
271 
272             ASSERT_EQ(prng_type::unknown, info.type());
273             ASSERT_TRUE(info.has_valid_prng_type());
274             ASSERT_EQ(seed_arr, info.seed());
275 
276             auto rg2 = info.make_prng();
277             ASSERT_FALSE(rg2);
278         }
279     }
280 
TEST(RandomGenerator,UniformRandomGeneratorInfoSaveLoad)281     TEST(RandomGenerator, UniformRandomGeneratorInfoSaveLoad)
282     {
283         UniformRandomGeneratorInfo info, info2;
284         stringstream ss;
285         auto size = info.save(ss, compr_mode_type::none);
286         ASSERT_EQ(size, info.save_size(compr_mode_type::none));
287         info2.load(ss);
288         ASSERT_TRUE(info == info2);
289 
290         prng_seed_type seed_arr = { 1, 2, 3, 4, 5, 6, 7, 8 };
291         {
292             shared_ptr<UniformRandomGenerator> rg(make_unique<Blake2xbPRNG>(seed_arr));
293             info = rg->info();
294             info.save(ss);
295             info2.load(ss);
296             ASSERT_TRUE(info == info2);
297         }
298         {
299             shared_ptr<UniformRandomGenerator> rg(make_unique<Shake256PRNG>(seed_arr));
300             info = rg->info();
301             info.save(ss);
302             info2.load(ss);
303             ASSERT_TRUE(info == info2);
304         }
305     }
306 } // namespace sealtest
307