1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/util/locks.h"
5 #include <atomic>
6 #include <thread>
7 #include "gtest/gtest.h"
8 
9 using namespace seal::util;
10 using namespace std;
11 
12 namespace sealtest
13 {
14     namespace util
15     {
16         class Reader
17         {
18         public:
Reader(ReaderWriterLocker & locker)19             Reader(ReaderWriterLocker &locker) : locker_(locker), locked_(false), trying_(false)
20             {}
21 
is_locked() const22             bool is_locked() const
23             {
24                 return locked_;
25             }
26 
is_trying_to_lock() const27             bool is_trying_to_lock() const
28             {
29                 return trying_;
30             }
31 
acquire_read()32             void acquire_read()
33             {
34                 trying_ = true;
35                 lock_ = locker_.acquire_read();
36                 locked_ = true;
37                 trying_ = false;
38             }
39 
release()40             void release()
41             {
42                 lock_.unlock();
43                 locked_ = false;
44             }
45 
wait_until_trying()46             void wait_until_trying()
47             {
48                 while (!trying_)
49                     ;
50             }
51 
wait_until_locked()52             void wait_until_locked()
53             {
54                 while (!locked_)
55                     ;
56             }
57 
58         private:
59             ReaderWriterLocker &locker_;
60 
61             ReaderLock lock_;
62 
63             volatile bool locked_;
64 
65             volatile bool trying_;
66         };
67 
68         class Writer
69         {
70         public:
Writer(ReaderWriterLocker & locker)71             Writer(ReaderWriterLocker &locker) : locker_(locker), locked_(false), trying_(false)
72             {}
73 
is_locked() const74             bool is_locked() const
75             {
76                 return locked_;
77             }
78 
is_trying_to_lock() const79             bool is_trying_to_lock() const
80             {
81                 return trying_;
82             }
83 
acquire_write()84             void acquire_write()
85             {
86                 trying_ = true;
87                 lock_ = locker_.acquire_write();
88                 locked_ = true;
89                 trying_ = false;
90             }
91 
release()92             void release()
93             {
94                 lock_.unlock();
95                 locked_ = false;
96             }
97 
wait_until_trying()98             void wait_until_trying()
99             {
100                 while (!trying_)
101                     ;
102             }
103 
wait_until_locked()104             void wait_until_locked()
105             {
106                 while (!locked_)
107                     ;
108             }
109 
wait_until_unlocked()110             void wait_until_unlocked()
111             {
112                 while (locked_)
113                     ;
114             }
115 
116         private:
117             ReaderWriterLocker &locker_;
118 
119             WriterLock lock_;
120 
121             volatile bool locked_;
122 
123             volatile bool trying_;
124         };
125 
TEST(ReaderWriterLockerTests,ReaderWriterLockNonBlocking)126         TEST(ReaderWriterLockerTests, ReaderWriterLockNonBlocking)
127         {
128             ReaderWriterLocker locker;
129 
130             WriterLock writeLock = locker.acquire_write();
131             ASSERT_TRUE(writeLock.owns_lock());
132             writeLock.unlock();
133             ASSERT_FALSE(writeLock.owns_lock());
134 
135             ReaderLock readLock = locker.acquire_read();
136             ASSERT_TRUE(readLock.owns_lock());
137             readLock.unlock();
138 
139             ReaderLock readLock2 = locker.acquire_read();
140             ASSERT_TRUE(readLock2.owns_lock());
141             ASSERT_FALSE(readLock.owns_lock());
142             readLock2.unlock();
143             ASSERT_FALSE(readLock2.owns_lock());
144 
145             readLock = locker.try_acquire_read();
146             ASSERT_TRUE(readLock.owns_lock());
147             writeLock = locker.try_acquire_write();
148             ASSERT_FALSE(writeLock.owns_lock());
149 
150             readLock2 = locker.try_acquire_read();
151             ASSERT_TRUE(readLock2.owns_lock());
152             writeLock = locker.try_acquire_write();
153             ASSERT_FALSE(writeLock.owns_lock());
154 
155             readLock.unlock();
156             writeLock = locker.try_acquire_write();
157             ASSERT_FALSE(writeLock.owns_lock());
158 
159             readLock2.unlock();
160             writeLock = locker.try_acquire_write();
161             ASSERT_TRUE(writeLock.owns_lock());
162 
163             WriterLock writeLock2 = locker.try_acquire_write();
164 
165             ASSERT_FALSE(writeLock2.owns_lock());
166             readLock2 = locker.try_acquire_read();
167             ASSERT_FALSE(readLock2.owns_lock());
168 
169             writeLock.unlock();
170 
171             writeLock2 = locker.try_acquire_write();
172             ASSERT_TRUE(writeLock2.owns_lock());
173             readLock2 = locker.try_acquire_read();
174             ASSERT_FALSE(readLock2.owns_lock());
175 
176             writeLock2.unlock();
177         }
178 
TEST(ReaderWriterLockerTests,ReaderWriterLockBlocking)179         TEST(ReaderWriterLockerTests, ReaderWriterLockBlocking)
180         {
181             ReaderWriterLocker locker;
182 
183             Reader *reader1 = new Reader(locker);
184             Reader *reader2 = new Reader(locker);
185             Writer *writer1 = new Writer(locker);
186             Writer *writer2 = new Writer(locker);
187 
188             ASSERT_FALSE(reader1->is_locked());
189             ASSERT_FALSE(reader2->is_locked());
190             ASSERT_FALSE(writer1->is_locked());
191             ASSERT_FALSE(writer2->is_locked());
192 
193             reader1->acquire_read();
194             ASSERT_TRUE(reader1->is_locked());
195             ASSERT_FALSE(reader2->is_locked());
196             reader2->acquire_read();
197             ASSERT_TRUE(reader1->is_locked());
198             ASSERT_TRUE(reader2->is_locked());
199 
200             atomic<bool> should_unlock1{ false };
201             atomic<bool> should_unlock2{ false };
202 
203             thread writer1_thread([&] {
204                 writer1->acquire_write();
205                 while (!should_unlock1)
206                 {
207                     this_thread::sleep_for(10ms);
208                 }
209                 writer1->release();
210             });
211 
212             writer1->wait_until_trying();
213             ASSERT_TRUE(writer1->is_trying_to_lock());
214             ASSERT_FALSE(writer1->is_locked());
215 
216             reader2->release();
217             ASSERT_TRUE(reader1->is_locked());
218             ASSERT_FALSE(reader2->is_locked());
219             ASSERT_TRUE(writer1->is_trying_to_lock());
220             ASSERT_FALSE(writer1->is_locked());
221 
222             thread writer2_thread([&] {
223                 writer2->acquire_write();
224                 while (!should_unlock2)
225                 {
226                     this_thread::sleep_for(10ms);
227                 }
228                 writer2->release();
229             });
230 
231             writer2->wait_until_trying();
232             ASSERT_TRUE(writer1->is_trying_to_lock());
233             ASSERT_FALSE(writer1->is_locked());
234             ASSERT_TRUE(writer2->is_trying_to_lock());
235             ASSERT_FALSE(writer2->is_locked());
236 
237             reader1->release();
238             ASSERT_FALSE(reader1->is_locked());
239 
240             while (writer1->is_trying_to_lock() && writer2->is_trying_to_lock())
241                 ;
242 
243             Writer *winner;
244             Writer *waiting;
245             atomic<bool> *should_unlock_winner;
246             atomic<bool> *should_unlock_waiting;
247 
248             if (writer1->is_locked())
249             {
250                 winner = writer1;
251                 waiting = writer2;
252                 should_unlock_winner = &should_unlock1;
253                 should_unlock_waiting = &should_unlock2;
254             }
255             else
256             {
257                 winner = writer2;
258                 waiting = writer1;
259                 should_unlock_winner = &should_unlock2;
260                 should_unlock_waiting = &should_unlock1;
261             }
262 
263             ASSERT_TRUE(winner->is_locked());
264             ASSERT_FALSE(waiting->is_locked());
265 
266             *should_unlock_winner = true;
267             winner->wait_until_unlocked();
268             ASSERT_FALSE(winner->is_locked());
269 
270             waiting->wait_until_locked();
271             ASSERT_TRUE(waiting->is_locked());
272 
273             thread reader1_thread(&Reader::acquire_read, reader1);
274             reader1->wait_until_trying();
275             ASSERT_TRUE(reader1->is_trying_to_lock());
276             ASSERT_FALSE(reader1->is_locked());
277 
278             thread reader2_thread(&Reader::acquire_read, reader2);
279             reader2->wait_until_trying();
280             ASSERT_TRUE(reader2->is_trying_to_lock());
281             ASSERT_FALSE(reader2->is_locked());
282 
283             *should_unlock_waiting = true;
284 
285             reader1->wait_until_locked();
286             reader2->wait_until_locked();
287             ASSERT_TRUE(reader1->is_locked());
288             ASSERT_TRUE(reader2->is_locked());
289 
290             reader1->release();
291             reader2->release();
292 
293             ASSERT_FALSE(reader1->is_locked());
294             ASSERT_FALSE(reader2->is_locked());
295             ASSERT_FALSE(writer1->is_locked());
296             ASSERT_FALSE(reader2->is_locked());
297             ASSERT_FALSE(reader1->is_trying_to_lock());
298             ASSERT_FALSE(reader2->is_trying_to_lock());
299             ASSERT_FALSE(writer1->is_trying_to_lock());
300             ASSERT_FALSE(reader2->is_trying_to_lock());
301 
302             writer1_thread.join();
303             writer2_thread.join();
304             reader1_thread.join();
305             reader2_thread.join();
306 
307             delete reader1;
308             delete reader2;
309             delete writer1;
310             delete writer2;
311         }
312     } // namespace util
313 } // namespace sealtest
314