1 // Copyright (c) Microsoft Corporation. All rights reserved. 2 // Licensed under the MIT license. 3 4 #include "seal/randomgen.h" 5 #include "seal/util/blake2.h" 6 #include "seal/util/common.h" 7 #include "seal/util/fips202.h" 8 #include <algorithm> 9 #include <iostream> 10 #include <random> 11 #if (SEAL_SYSTEM == SEAL_SYSTEM_WINDOWS) 12 #include <Windows.h> 13 #include <bcrypt.h> 14 #pragma comment(lib, "bcrypt") 15 #endif 16 17 using namespace std; 18 using namespace seal::util; 19 20 #if (SEAL_SYSTEM == SEAL_SYSTEM_WINDOWS) 21 22 constexpr auto RTL_GENRANDOM = "SystemFunction036"; 23 24 // Preserve error codes to diagnose in case of failure 25 NTSTATUS last_bcrypt_error = 0; 26 DWORD last_genrandom_error = 0; 27 28 #endif 29 30 namespace seal 31 { random_bytes(seal_byte * buf,size_t count)32 void random_bytes(seal_byte *buf, size_t count) 33 { 34 #if SEAL_SYSTEM == SEAL_SYSTEM_UNIX_LIKE 35 random_device rd("/dev/urandom"); 36 while (count >= 4) 37 { 38 *reinterpret_cast<uint32_t *>(buf) = rd(); 39 buf += 4; 40 count -= 4; 41 } 42 if (count) 43 { 44 uint32_t last = rd(); 45 memcpy(buf, &last, count); 46 } 47 #elif SEAL_SYSTEM == SEAL_SYSTEM_WINDOWS 48 NTSTATUS status = BCryptGenRandom( 49 NULL, reinterpret_cast<unsigned char *>(buf), safe_cast<ULONG>(count), BCRYPT_USE_SYSTEM_PREFERRED_RNG); 50 51 if (BCRYPT_SUCCESS(status)) 52 { 53 return; 54 } 55 56 last_bcrypt_error = status; 57 58 HMODULE hAdvApi = LoadLibraryA("ADVAPI32.DLL"); 59 if (!hAdvApi) 60 { 61 last_genrandom_error = GetLastError(); 62 throw runtime_error("Failed to load ADVAPI32.DLL"); 63 } 64 65 BOOLEAN(APIENTRY * RtlGenRandom) 66 (void *, ULONG) = (BOOLEAN(APIENTRY *)(void *, ULONG))GetProcAddress(hAdvApi, RTL_GENRANDOM); 67 68 BOOLEAN genrand_result = FALSE; 69 if (RtlGenRandom) 70 { 71 genrand_result = RtlGenRandom(buf, bytes_per_uint64); 72 } 73 74 DWORD dwError = GetLastError(); 75 FreeLibrary(hAdvApi); 76 77 if (!genrand_result) 78 { 79 last_genrandom_error = dwError; 80 throw runtime_error("Failed to call RtlGenRandom"); 81 } 82 #elif SEAL_SYSTEM == SEAL_SYSTEM_OTHER 83 #warning "SECURITY WARNING: System detection failed; falling back to a potentially insecure randomness source!" 84 random_device rd; 85 while (count >= 4) 86 { 87 *reinterpret_cast<uint32_t *>(buf) = rd(); 88 buf += 4; 89 count -= 4; 90 } 91 if (count) 92 { 93 uint32_t last = rd(); 94 memcpy(buf, &last, count); 95 } 96 #endif 97 } 98 save_members(ostream & stream) const99 void UniformRandomGeneratorInfo::save_members(ostream &stream) const 100 { 101 // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit 102 auto old_except_mask = stream.exceptions(); 103 try 104 { 105 stream.exceptions(ios_base::badbit | ios_base::failbit); 106 107 stream.write(reinterpret_cast<const char *>(&type_), sizeof(prng_type)); 108 stream.write(reinterpret_cast<const char *>(seed_.data()), prng_seed_byte_count); 109 } 110 catch (const ios_base::failure &) 111 { 112 stream.exceptions(old_except_mask); 113 throw runtime_error("I/O error"); 114 } 115 catch (...) 116 { 117 stream.exceptions(old_except_mask); 118 throw; 119 } 120 stream.exceptions(old_except_mask); 121 } 122 load_members(istream & stream,SEAL_MAYBE_UNUSED SEALVersion version)123 void UniformRandomGeneratorInfo::load_members(istream &stream, SEAL_MAYBE_UNUSED SEALVersion version) 124 { 125 // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit 126 auto old_except_mask = stream.exceptions(); 127 try 128 { 129 stream.exceptions(ios_base::badbit | ios_base::failbit); 130 131 UniformRandomGeneratorInfo info; 132 133 // Read the PRNG type 134 stream.read(reinterpret_cast<char *>(&info.type_), sizeof(prng_type)); 135 if (!info.has_valid_prng_type()) 136 { 137 throw logic_error("prng_type is invalid"); 138 } 139 140 // Read the seed data 141 stream.read(reinterpret_cast<char *>(info.seed_.data()), prng_seed_byte_count); 142 143 swap(*this, info); 144 145 stream.exceptions(old_except_mask); 146 } 147 catch (const ios_base::failure &) 148 { 149 stream.exceptions(old_except_mask); 150 throw runtime_error("I/O error"); 151 } 152 catch (...) 153 { 154 stream.exceptions(old_except_mask); 155 throw; 156 } 157 stream.exceptions(old_except_mask); 158 } 159 make_prng() const160 shared_ptr<UniformRandomGenerator> UniformRandomGeneratorInfo::make_prng() const 161 { 162 switch (type_) 163 { 164 case prng_type::blake2xb: 165 return make_shared<Blake2xbPRNG>(seed_); 166 167 case prng_type::shake256: 168 return make_shared<Shake256PRNG>(seed_); 169 170 case prng_type::unknown: 171 return nullptr; 172 } 173 return nullptr; 174 } 175 generate(size_t byte_count,seal_byte * destination)176 void UniformRandomGenerator::generate(size_t byte_count, seal_byte *destination) 177 { 178 lock_guard<mutex> lock(mutex_); 179 while (byte_count) 180 { 181 size_t current_bytes = min(byte_count, static_cast<size_t>(distance(buffer_head_, buffer_end_))); 182 copy_n(buffer_head_, current_bytes, destination); 183 buffer_head_ += current_bytes; 184 destination += current_bytes; 185 byte_count -= current_bytes; 186 187 if (buffer_head_ == buffer_end_) 188 { 189 refill_buffer(); 190 buffer_head_ = buffer_begin_; 191 } 192 } 193 } 194 DefaultFactory()195 auto UniformRandomGeneratorFactory::DefaultFactory() -> shared_ptr<UniformRandomGeneratorFactory> 196 { 197 static shared_ptr<UniformRandomGeneratorFactory> default_factory{ new SEAL_DEFAULT_PRNG_FACTORY() }; 198 return default_factory; 199 } 200 refill_buffer()201 void Blake2xbPRNG::refill_buffer() 202 { 203 // Fill the randomness buffer 204 if (blake2xb( 205 buffer_begin_, buffer_size_, &counter_, sizeof(counter_), seed_.cbegin(), 206 seed_.size() * sizeof(decltype(seed_)::type)) != 0) 207 { 208 throw runtime_error("blake2xb failed"); 209 } 210 counter_++; 211 } 212 refill_buffer()213 void Shake256PRNG::refill_buffer() 214 { 215 // Fill the randomness buffer 216 array<uint64_t, prng_seed_uint64_count + 1> seed_ext; 217 copy_n(seed_.cbegin(), prng_seed_uint64_count, seed_ext.begin()); 218 seed_ext[prng_seed_uint64_count] = counter_; 219 shake256( 220 reinterpret_cast<uint8_t *>(buffer_begin_), buffer_size_, 221 reinterpret_cast<const uint8_t *>(seed_ext.data()), seed_ext.size() * bytes_per_uint64); 222 seal_memzero(seed_ext.data(), seed_ext.size() * bytes_per_uint64); 223 counter_++; 224 } 225 } // namespace seal 226