1 // Copyright (c) Microsoft Corporation. All rights reserved. 2 // Licensed under the MIT license. 3 4 #pragma once 5 6 #include "seal/util/defines.h" 7 8 #ifdef SEAL_USE_SHARED_MUTEX 9 #include <mutex> 10 #include <shared_mutex> 11 12 namespace seal 13 { 14 namespace util 15 { 16 using ReaderLock = std::shared_lock<std::shared_mutex>; 17 18 using WriterLock = std::unique_lock<std::shared_mutex>; 19 20 class SEAL_NODISCARD ReaderWriterLocker 21 { 22 public: 23 ReaderWriterLocker() = default; 24 acquire_read()25 SEAL_NODISCARD inline ReaderLock acquire_read() 26 { 27 return ReaderLock(rw_lock_mutex_); 28 } 29 acquire_write()30 SEAL_NODISCARD inline WriterLock acquire_write() 31 { 32 return WriterLock(rw_lock_mutex_); 33 } 34 try_acquire_read()35 SEAL_NODISCARD inline ReaderLock try_acquire_read() noexcept 36 { 37 return ReaderLock(rw_lock_mutex_, std::try_to_lock); 38 } 39 try_acquire_write()40 SEAL_NODISCARD inline WriterLock try_acquire_write() noexcept 41 { 42 return WriterLock(rw_lock_mutex_, std::try_to_lock); 43 } 44 45 private: 46 ReaderWriterLocker(const ReaderWriterLocker ©) = delete; 47 48 ReaderWriterLocker &operator=(const ReaderWriterLocker &assign) = delete; 49 50 std::shared_mutex rw_lock_mutex_{}; 51 }; 52 } // namespace util 53 } // namespace seal 54 #else 55 #include <atomic> 56 #include <utility> 57 58 namespace seal 59 { 60 namespace util 61 { 62 struct try_to_lock_t 63 {}; 64 65 constexpr try_to_lock_t try_to_lock{}; 66 67 class ReaderWriterLocker; 68 69 class SEAL_NODISCARD ReaderLock 70 { 71 public: ReaderLock()72 ReaderLock() noexcept : locker_(nullptr) 73 {} 74 ReaderLock(ReaderLock && move)75 ReaderLock(ReaderLock &&move) noexcept : locker_(move.locker_) 76 { 77 move.locker_ = nullptr; 78 } 79 ReaderLock(ReaderWriterLocker & locker)80 ReaderLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr) 81 { 82 acquire(locker); 83 } 84 ReaderLock(ReaderWriterLocker & locker,try_to_lock_t)85 ReaderLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : locker_(nullptr) 86 { 87 try_acquire(locker); 88 } 89 ~ReaderLock()90 ~ReaderLock() noexcept 91 { 92 unlock(); 93 } 94 owns_lock()95 SEAL_NODISCARD inline bool owns_lock() const noexcept 96 { 97 return locker_ != nullptr; 98 } 99 100 void unlock() noexcept; 101 swap_with(ReaderLock & lock)102 inline void swap_with(ReaderLock &lock) noexcept 103 { 104 std::swap(locker_, lock.locker_); 105 } 106 107 inline ReaderLock &operator=(ReaderLock &&lock) noexcept 108 { 109 swap_with(lock); 110 lock.unlock(); 111 return *this; 112 } 113 114 private: 115 void acquire(ReaderWriterLocker &locker) noexcept; 116 117 bool try_acquire(ReaderWriterLocker &locker) noexcept; 118 119 ReaderWriterLocker *locker_; 120 }; 121 122 class SEAL_NODISCARD WriterLock 123 { 124 public: WriterLock()125 WriterLock() noexcept : locker_(nullptr) 126 {} 127 WriterLock(WriterLock && move)128 WriterLock(WriterLock &&move) noexcept : locker_(move.locker_) 129 { 130 move.locker_ = nullptr; 131 } 132 WriterLock(ReaderWriterLocker & locker)133 WriterLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr) 134 { 135 acquire(locker); 136 } 137 WriterLock(ReaderWriterLocker & locker,try_to_lock_t)138 WriterLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : locker_(nullptr) 139 { 140 try_acquire(locker); 141 } 142 ~WriterLock()143 ~WriterLock() noexcept 144 { 145 unlock(); 146 } 147 owns_lock()148 SEAL_NODISCARD inline bool owns_lock() const noexcept 149 { 150 return locker_ != nullptr; 151 } 152 153 void unlock() noexcept; 154 swap_with(WriterLock & lock)155 inline void swap_with(WriterLock &lock) noexcept 156 { 157 std::swap(locker_, lock.locker_); 158 } 159 160 inline WriterLock &operator=(WriterLock &&lock) noexcept 161 { 162 swap_with(lock); 163 lock.unlock(); 164 return *this; 165 } 166 167 private: 168 void acquire(ReaderWriterLocker &locker) noexcept; 169 170 bool try_acquire(ReaderWriterLocker &locker) noexcept; 171 172 ReaderWriterLocker *locker_; 173 }; 174 175 class SEAL_NODISCARD ReaderWriterLocker 176 { 177 friend class ReaderLock; 178 179 friend class WriterLock; 180 181 public: ReaderWriterLocker()182 ReaderWriterLocker() noexcept : reader_locks_(0), writer_locked_(false) 183 {} 184 acquire_read()185 SEAL_NODISCARD inline ReaderLock acquire_read() noexcept 186 { 187 return ReaderLock(*this); 188 } 189 acquire_write()190 SEAL_NODISCARD inline WriterLock acquire_write() noexcept 191 { 192 return WriterLock(*this); 193 } 194 try_acquire_read()195 SEAL_NODISCARD inline ReaderLock try_acquire_read() noexcept 196 { 197 return ReaderLock(*this, try_to_lock); 198 } 199 try_acquire_write()200 SEAL_NODISCARD inline WriterLock try_acquire_write() noexcept 201 { 202 return WriterLock(*this, try_to_lock); 203 } 204 205 private: 206 ReaderWriterLocker(const ReaderWriterLocker ©) = delete; 207 208 ReaderWriterLocker &operator=(const ReaderWriterLocker &assign) = delete; 209 210 std::atomic<int> reader_locks_; 211 212 std::atomic<bool> writer_locked_; 213 }; 214 unlock()215 inline void ReaderLock::unlock() noexcept 216 { 217 if (locker_ == nullptr) 218 { 219 return; 220 } 221 locker_->reader_locks_.fetch_sub(1, std::memory_order_release); 222 locker_ = nullptr; 223 } 224 acquire(ReaderWriterLocker & locker)225 inline void ReaderLock::acquire(ReaderWriterLocker &locker) noexcept 226 { 227 unlock(); 228 do 229 { 230 locker.reader_locks_.fetch_add(1, std::memory_order_acquire); 231 locker_ = &locker; 232 if (locker.writer_locked_.load(std::memory_order_acquire)) 233 { 234 unlock(); 235 while (locker.writer_locked_.load(std::memory_order_acquire)) 236 ; 237 } 238 } while (locker_ == nullptr); 239 } 240 try_acquire(ReaderWriterLocker & locker)241 SEAL_NODISCARD inline bool ReaderLock::try_acquire(ReaderWriterLocker &locker) noexcept 242 { 243 unlock(); 244 locker.reader_locks_.fetch_add(1, std::memory_order_acquire); 245 locker_ = &locker; 246 if (locker.writer_locked_.load(std::memory_order_acquire)) 247 { 248 unlock(); 249 return false; 250 } 251 return true; 252 } 253 acquire(ReaderWriterLocker & locker)254 inline void WriterLock::acquire(ReaderWriterLocker &locker) noexcept 255 { 256 unlock(); 257 bool expected = false; 258 while (!locker.writer_locked_.compare_exchange_strong(expected, true, std::memory_order_acquire)) 259 { 260 expected = false; 261 } 262 locker_ = &locker; 263 while (locker.reader_locks_.load(std::memory_order_acquire) != 0) 264 ; 265 } 266 try_acquire(ReaderWriterLocker & locker)267 SEAL_NODISCARD inline bool WriterLock::try_acquire(ReaderWriterLocker &locker) noexcept 268 { 269 unlock(); 270 bool expected = false; 271 if (!locker.writer_locked_.compare_exchange_strong(expected, true, std::memory_order_acquire)) 272 { 273 return false; 274 } 275 locker_ = &locker; 276 if (locker.reader_locks_.load(std::memory_order_acquire) != 0) 277 { 278 unlock(); 279 return false; 280 } 281 return true; 282 } 283 unlock()284 inline void WriterLock::unlock() noexcept 285 { 286 if (locker_ == nullptr) 287 { 288 return; 289 } 290 locker_->writer_locked_.store(false, std::memory_order_release); 291 locker_ = nullptr; 292 } 293 } // namespace util 294 } // namespace seal 295 #endif 296