1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #pragma once
5 
6 #include "seal/util/defines.h"
7 
8 #ifdef SEAL_USE_SHARED_MUTEX
9 #include <mutex>
10 #include <shared_mutex>
11 
12 namespace seal
13 {
14     namespace util
15     {
16         using ReaderLock = std::shared_lock<std::shared_mutex>;
17 
18         using WriterLock = std::unique_lock<std::shared_mutex>;
19 
20         class SEAL_NODISCARD ReaderWriterLocker
21         {
22         public:
23             ReaderWriterLocker() = default;
24 
acquire_read()25             SEAL_NODISCARD inline ReaderLock acquire_read()
26             {
27                 return ReaderLock(rw_lock_mutex_);
28             }
29 
acquire_write()30             SEAL_NODISCARD inline WriterLock acquire_write()
31             {
32                 return WriterLock(rw_lock_mutex_);
33             }
34 
try_acquire_read()35             SEAL_NODISCARD inline ReaderLock try_acquire_read() noexcept
36             {
37                 return ReaderLock(rw_lock_mutex_, std::try_to_lock);
38             }
39 
try_acquire_write()40             SEAL_NODISCARD inline WriterLock try_acquire_write() noexcept
41             {
42                 return WriterLock(rw_lock_mutex_, std::try_to_lock);
43             }
44 
45         private:
46             ReaderWriterLocker(const ReaderWriterLocker &copy) = delete;
47 
48             ReaderWriterLocker &operator=(const ReaderWriterLocker &assign) = delete;
49 
50             std::shared_mutex rw_lock_mutex_{};
51         };
52     } // namespace util
53 } // namespace seal
54 #else
55 #include <atomic>
56 #include <utility>
57 
58 namespace seal
59 {
60     namespace util
61     {
62         struct try_to_lock_t
63         {};
64 
65         constexpr try_to_lock_t try_to_lock{};
66 
67         class ReaderWriterLocker;
68 
69         class SEAL_NODISCARD ReaderLock
70         {
71         public:
ReaderLock()72             ReaderLock() noexcept : locker_(nullptr)
73             {}
74 
ReaderLock(ReaderLock && move)75             ReaderLock(ReaderLock &&move) noexcept : locker_(move.locker_)
76             {
77                 move.locker_ = nullptr;
78             }
79 
ReaderLock(ReaderWriterLocker & locker)80             ReaderLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr)
81             {
82                 acquire(locker);
83             }
84 
ReaderLock(ReaderWriterLocker & locker,try_to_lock_t)85             ReaderLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : locker_(nullptr)
86             {
87                 try_acquire(locker);
88             }
89 
~ReaderLock()90             ~ReaderLock() noexcept
91             {
92                 unlock();
93             }
94 
owns_lock()95             SEAL_NODISCARD inline bool owns_lock() const noexcept
96             {
97                 return locker_ != nullptr;
98             }
99 
100             void unlock() noexcept;
101 
swap_with(ReaderLock & lock)102             inline void swap_with(ReaderLock &lock) noexcept
103             {
104                 std::swap(locker_, lock.locker_);
105             }
106 
107             inline ReaderLock &operator=(ReaderLock &&lock) noexcept
108             {
109                 swap_with(lock);
110                 lock.unlock();
111                 return *this;
112             }
113 
114         private:
115             void acquire(ReaderWriterLocker &locker) noexcept;
116 
117             bool try_acquire(ReaderWriterLocker &locker) noexcept;
118 
119             ReaderWriterLocker *locker_;
120         };
121 
122         class SEAL_NODISCARD WriterLock
123         {
124         public:
WriterLock()125             WriterLock() noexcept : locker_(nullptr)
126             {}
127 
WriterLock(WriterLock && move)128             WriterLock(WriterLock &&move) noexcept : locker_(move.locker_)
129             {
130                 move.locker_ = nullptr;
131             }
132 
WriterLock(ReaderWriterLocker & locker)133             WriterLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr)
134             {
135                 acquire(locker);
136             }
137 
WriterLock(ReaderWriterLocker & locker,try_to_lock_t)138             WriterLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : locker_(nullptr)
139             {
140                 try_acquire(locker);
141             }
142 
~WriterLock()143             ~WriterLock() noexcept
144             {
145                 unlock();
146             }
147 
owns_lock()148             SEAL_NODISCARD inline bool owns_lock() const noexcept
149             {
150                 return locker_ != nullptr;
151             }
152 
153             void unlock() noexcept;
154 
swap_with(WriterLock & lock)155             inline void swap_with(WriterLock &lock) noexcept
156             {
157                 std::swap(locker_, lock.locker_);
158             }
159 
160             inline WriterLock &operator=(WriterLock &&lock) noexcept
161             {
162                 swap_with(lock);
163                 lock.unlock();
164                 return *this;
165             }
166 
167         private:
168             void acquire(ReaderWriterLocker &locker) noexcept;
169 
170             bool try_acquire(ReaderWriterLocker &locker) noexcept;
171 
172             ReaderWriterLocker *locker_;
173         };
174 
175         class SEAL_NODISCARD ReaderWriterLocker
176         {
177             friend class ReaderLock;
178 
179             friend class WriterLock;
180 
181         public:
ReaderWriterLocker()182             ReaderWriterLocker() noexcept : reader_locks_(0), writer_locked_(false)
183             {}
184 
acquire_read()185             SEAL_NODISCARD inline ReaderLock acquire_read() noexcept
186             {
187                 return ReaderLock(*this);
188             }
189 
acquire_write()190             SEAL_NODISCARD inline WriterLock acquire_write() noexcept
191             {
192                 return WriterLock(*this);
193             }
194 
try_acquire_read()195             SEAL_NODISCARD inline ReaderLock try_acquire_read() noexcept
196             {
197                 return ReaderLock(*this, try_to_lock);
198             }
199 
try_acquire_write()200             SEAL_NODISCARD inline WriterLock try_acquire_write() noexcept
201             {
202                 return WriterLock(*this, try_to_lock);
203             }
204 
205         private:
206             ReaderWriterLocker(const ReaderWriterLocker &copy) = delete;
207 
208             ReaderWriterLocker &operator=(const ReaderWriterLocker &assign) = delete;
209 
210             std::atomic<int> reader_locks_;
211 
212             std::atomic<bool> writer_locked_;
213         };
214 
unlock()215         inline void ReaderLock::unlock() noexcept
216         {
217             if (locker_ == nullptr)
218             {
219                 return;
220             }
221             locker_->reader_locks_.fetch_sub(1, std::memory_order_release);
222             locker_ = nullptr;
223         }
224 
acquire(ReaderWriterLocker & locker)225         inline void ReaderLock::acquire(ReaderWriterLocker &locker) noexcept
226         {
227             unlock();
228             do
229             {
230                 locker.reader_locks_.fetch_add(1, std::memory_order_acquire);
231                 locker_ = &locker;
232                 if (locker.writer_locked_.load(std::memory_order_acquire))
233                 {
234                     unlock();
235                     while (locker.writer_locked_.load(std::memory_order_acquire))
236                         ;
237                 }
238             } while (locker_ == nullptr);
239         }
240 
try_acquire(ReaderWriterLocker & locker)241         SEAL_NODISCARD inline bool ReaderLock::try_acquire(ReaderWriterLocker &locker) noexcept
242         {
243             unlock();
244             locker.reader_locks_.fetch_add(1, std::memory_order_acquire);
245             locker_ = &locker;
246             if (locker.writer_locked_.load(std::memory_order_acquire))
247             {
248                 unlock();
249                 return false;
250             }
251             return true;
252         }
253 
acquire(ReaderWriterLocker & locker)254         inline void WriterLock::acquire(ReaderWriterLocker &locker) noexcept
255         {
256             unlock();
257             bool expected = false;
258             while (!locker.writer_locked_.compare_exchange_strong(expected, true, std::memory_order_acquire))
259             {
260                 expected = false;
261             }
262             locker_ = &locker;
263             while (locker.reader_locks_.load(std::memory_order_acquire) != 0)
264                 ;
265         }
266 
try_acquire(ReaderWriterLocker & locker)267         SEAL_NODISCARD inline bool WriterLock::try_acquire(ReaderWriterLocker &locker) noexcept
268         {
269             unlock();
270             bool expected = false;
271             if (!locker.writer_locked_.compare_exchange_strong(expected, true, std::memory_order_acquire))
272             {
273                 return false;
274             }
275             locker_ = &locker;
276             if (locker.reader_locks_.load(std::memory_order_acquire) != 0)
277             {
278                 unlock();
279                 return false;
280             }
281             return true;
282         }
283 
unlock()284         inline void WriterLock::unlock() noexcept
285         {
286             if (locker_ == nullptr)
287             {
288                 return;
289             }
290             locker_->writer_locked_.store(false, std::memory_order_release);
291             locker_ = nullptr;
292         }
293     } // namespace util
294 } // namespace seal
295 #endif
296