1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/randomgen.h"
5 #include "seal/util/blake2.h"
6 #include "seal/util/common.h"
7 #include "seal/util/fips202.h"
8 #include <algorithm>
9 #include <iostream>
10 #include <random>
11 #if (SEAL_SYSTEM == SEAL_SYSTEM_WINDOWS)
12 #include <Windows.h>
13 #include <bcrypt.h>
14 #pragma comment(lib, "bcrypt")
15 #endif
16 
17 using namespace std;
18 using namespace seal::util;
19 
20 #if (SEAL_SYSTEM == SEAL_SYSTEM_WINDOWS)
21 
22 constexpr auto RTL_GENRANDOM = "SystemFunction036";
23 
24 // Preserve error codes to diagnose in case of failure
25 NTSTATUS last_bcrypt_error = 0;
26 DWORD last_genrandom_error = 0;
27 
28 #endif
29 
30 namespace seal
31 {
random_bytes(seal_byte * buf,size_t count)32     void random_bytes(seal_byte *buf, size_t count)
33     {
34 #if SEAL_SYSTEM == SEAL_SYSTEM_UNIX_LIKE
35         random_device rd("/dev/urandom");
36         while (count >= 4)
37         {
38             *reinterpret_cast<uint32_t *>(buf) = rd();
39             buf += 4;
40             count -= 4;
41         }
42         if (count)
43         {
44             uint32_t last = rd();
45             memcpy(buf, &last, count);
46         }
47 #elif SEAL_SYSTEM == SEAL_SYSTEM_WINDOWS
48         NTSTATUS status = BCryptGenRandom(
49             NULL, reinterpret_cast<unsigned char *>(buf), safe_cast<ULONG>(count), BCRYPT_USE_SYSTEM_PREFERRED_RNG);
50 
51         if (BCRYPT_SUCCESS(status))
52         {
53             return;
54         }
55 
56         last_bcrypt_error = status;
57 
58         HMODULE hAdvApi = LoadLibraryA("ADVAPI32.DLL");
59         if (!hAdvApi)
60         {
61             last_genrandom_error = GetLastError();
62             throw runtime_error("Failed to load ADVAPI32.DLL");
63         }
64 
65         BOOLEAN(APIENTRY * RtlGenRandom)
66         (void *, ULONG) = (BOOLEAN(APIENTRY *)(void *, ULONG))GetProcAddress(hAdvApi, RTL_GENRANDOM);
67 
68         BOOLEAN genrand_result = FALSE;
69         if (RtlGenRandom)
70         {
71             genrand_result = RtlGenRandom(buf, bytes_per_uint64);
72         }
73 
74         DWORD dwError = GetLastError();
75         FreeLibrary(hAdvApi);
76 
77         if (!genrand_result)
78         {
79             last_genrandom_error = dwError;
80             throw runtime_error("Failed to call RtlGenRandom");
81         }
82 #elif SEAL_SYSTEM == SEAL_SYSTEM_OTHER
83 #warning "SECURITY WARNING: System detection failed; falling back to a potentially insecure randomness source!"
84         random_device rd;
85         while (count >= 4)
86         {
87             *reinterpret_cast<uint32_t *>(buf) = rd();
88             buf += 4;
89             count -= 4;
90         }
91         if (count)
92         {
93             uint32_t last = rd();
94             memcpy(buf, &last, count);
95         }
96 #endif
97     }
98 
save_members(ostream & stream) const99     void UniformRandomGeneratorInfo::save_members(ostream &stream) const
100     {
101         // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit
102         auto old_except_mask = stream.exceptions();
103         try
104         {
105             stream.exceptions(ios_base::badbit | ios_base::failbit);
106 
107             stream.write(reinterpret_cast<const char *>(&type_), sizeof(prng_type));
108             stream.write(reinterpret_cast<const char *>(seed_.data()), prng_seed_byte_count);
109         }
110         catch (const ios_base::failure &)
111         {
112             stream.exceptions(old_except_mask);
113             throw runtime_error("I/O error");
114         }
115         catch (...)
116         {
117             stream.exceptions(old_except_mask);
118             throw;
119         }
120         stream.exceptions(old_except_mask);
121     }
122 
load_members(istream & stream,SEAL_MAYBE_UNUSED SEALVersion version)123     void UniformRandomGeneratorInfo::load_members(istream &stream, SEAL_MAYBE_UNUSED SEALVersion version)
124     {
125         // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit
126         auto old_except_mask = stream.exceptions();
127         try
128         {
129             stream.exceptions(ios_base::badbit | ios_base::failbit);
130 
131             UniformRandomGeneratorInfo info;
132 
133             // Read the PRNG type
134             stream.read(reinterpret_cast<char *>(&info.type_), sizeof(prng_type));
135             if (!info.has_valid_prng_type())
136             {
137                 throw logic_error("prng_type is invalid");
138             }
139 
140             // Read the seed data
141             stream.read(reinterpret_cast<char *>(info.seed_.data()), prng_seed_byte_count);
142 
143             swap(*this, info);
144 
145             stream.exceptions(old_except_mask);
146         }
147         catch (const ios_base::failure &)
148         {
149             stream.exceptions(old_except_mask);
150             throw runtime_error("I/O error");
151         }
152         catch (...)
153         {
154             stream.exceptions(old_except_mask);
155             throw;
156         }
157         stream.exceptions(old_except_mask);
158     }
159 
make_prng() const160     shared_ptr<UniformRandomGenerator> UniformRandomGeneratorInfo::make_prng() const
161     {
162         switch (type_)
163         {
164         case prng_type::blake2xb:
165             return make_shared<Blake2xbPRNG>(seed_);
166 
167         case prng_type::shake256:
168             return make_shared<Shake256PRNG>(seed_);
169 
170         case prng_type::unknown:
171             return nullptr;
172         }
173         return nullptr;
174     }
175 
generate(size_t byte_count,seal_byte * destination)176     void UniformRandomGenerator::generate(size_t byte_count, seal_byte *destination)
177     {
178         lock_guard<mutex> lock(mutex_);
179         while (byte_count)
180         {
181             size_t current_bytes = min(byte_count, static_cast<size_t>(distance(buffer_head_, buffer_end_)));
182             copy_n(buffer_head_, current_bytes, destination);
183             buffer_head_ += current_bytes;
184             destination += current_bytes;
185             byte_count -= current_bytes;
186 
187             if (buffer_head_ == buffer_end_)
188             {
189                 refill_buffer();
190                 buffer_head_ = buffer_begin_;
191             }
192         }
193     }
194 
DefaultFactory()195     auto UniformRandomGeneratorFactory::DefaultFactory() -> shared_ptr<UniformRandomGeneratorFactory>
196     {
197         static shared_ptr<UniformRandomGeneratorFactory> default_factory{ new SEAL_DEFAULT_PRNG_FACTORY() };
198         return default_factory;
199     }
200 
refill_buffer()201     void Blake2xbPRNG::refill_buffer()
202     {
203         // Fill the randomness buffer
204         if (blake2xb(
205                 buffer_begin_, buffer_size_, &counter_, sizeof(counter_), seed_.cbegin(),
206                 seed_.size() * sizeof(decltype(seed_)::type)) != 0)
207         {
208             throw runtime_error("blake2xb failed");
209         }
210         counter_++;
211     }
212 
refill_buffer()213     void Shake256PRNG::refill_buffer()
214     {
215         // Fill the randomness buffer
216         array<uint64_t, prng_seed_uint64_count + 1> seed_ext;
217         copy_n(seed_.cbegin(), prng_seed_uint64_count, seed_ext.begin());
218         seed_ext[prng_seed_uint64_count] = counter_;
219         shake256(
220             reinterpret_cast<uint8_t *>(buffer_begin_), buffer_size_,
221             reinterpret_cast<const uint8_t *>(seed_ext.data()), seed_ext.size() * bytes_per_uint64);
222         seal_memzero(seed_ext.data(), seed_ext.size() * bytes_per_uint64);
223         counter_++;
224     }
225 } // namespace seal
226