1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/test.h"
18 #include "common/utils.h"
19 
20 #include "oneapi/tbb/parallel_for.h"
21 #include "oneapi/tbb/null_mutex.h"
22 #include "oneapi/tbb/null_rw_mutex.h"
23 
24 #include <type_traits>
25 
26 //! Generic test of a TBB mutex
27 /** Does not test features specific to reader-writer locks. */
28 template<typename M, typename Counter = utils::Counter<M>>
29 void GeneralTest(const char* mutex_name, bool check = true) { // check flag is needed to disable correctness check for null mutexes (for test reusage)
30     const int N = 100000;
31     const int GRAIN = 10000;
32     Counter counter;
33     counter.value = 0;
34 
35     // Stress test to force possible race condition of the counter
36     utils::NativeParallelFor(N, GRAIN, [&] (int i) {
37         if (i & 1) {
38             // Try implicit acquire and explicit release
39             typename M::scoped_lock lock(counter.mutex);
40             counter.value = counter.value + 1;
41             lock.release();
42         } else {
43             // Try explicit acquire and implicit release
44             typename M::scoped_lock lock;
45             lock.acquire(counter.mutex);
46             counter.value = counter.value + 1;
47         }
48     });
49     if (check) {
50         REQUIRE_MESSAGE(counter.value == N, "ERROR for " << mutex_name << ": race is detected");
51     }
52 }
53 
54 //! Test try_acquire functionality of a non-reenterable mutex
55 template<typename M>
TestTryAcquire(const char * mutex_name)56 void TestTryAcquire(const char* mutex_name) {
57     M tested_mutex;
58     typename M::scoped_lock lock_outer;
59     if (lock_outer.try_acquire(tested_mutex)) {
60         lock_outer.release();
61     } else {
62         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
63     }
64     {
65         typename M::scoped_lock lock_inner(tested_mutex);
66         CHECK_MESSAGE(!lock_outer.try_acquire(tested_mutex), "ERROR for " << mutex_name << ": try_acquire failed though it should not (1)");
67     }
68     if (lock_outer.try_acquire(tested_mutex)) {
69         lock_outer.release();
70     } else {
71         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
72     }
73 }
74 
75 template <>
76 void TestTryAcquire<oneapi::tbb::null_mutex>( const char* mutex_name ) {
77     oneapi::tbb::null_mutex tested_mutex;
78     typename oneapi::tbb::null_mutex::scoped_lock lock(tested_mutex);
79     CHECK_MESSAGE(lock.try_acquire(tested_mutex), "ERROR for " << mutex_name << ": try_acquire failed though it should not");
80     lock.release();
81     CHECK_MESSAGE(lock.try_acquire(tested_mutex), "ERROR for " << mutex_name << ": try_acquire failed though it should not");
82 }
83 
84 //! Test try_acquire functionality of a non-reenterable mutex
85 template<typename M>
TestTryAcquireReader(const char * mutex_name)86 void TestTryAcquireReader(const char* mutex_name) {
87     M tested_mutex;
88     typename M::scoped_lock lock_outer;
89     if (lock_outer.try_acquire(tested_mutex, false) ) {
90         lock_outer.release();
91     } else {
92         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
93     }
94     {
95         typename M::scoped_lock lock_inner(tested_mutex, false); // read lock
96         // try acquire on write
97         CHECK_MESSAGE(!lock_outer.try_acquire(tested_mutex, true), "ERROR for " << mutex_name << ": try_acquire on write succeed though it should not (1)");
98         lock_inner.release();                                    // unlock
99         lock_inner.acquire(tested_mutex, true);                  // write lock
100         // try acquire on read
101         CHECK_MESSAGE(!lock_outer.try_acquire(tested_mutex, false), "ERROR for " << mutex_name << ": try_acquire on read succeed though it should not (2)");
102     }
103     if (lock_outer.try_acquire(tested_mutex, false) ) {
104         lock_outer.release();
105     } else {
106         CHECK_MESSAGE(false, "ERROR for " << mutex_name << ": try_acquire failed though it should not");
107     }
108 }
109 
110 template <>
111 void TestTryAcquireReader<oneapi::tbb::null_rw_mutex>( const char* mutex_name ) {
112     oneapi::tbb::null_rw_mutex tested_mutex;
113     typename oneapi::tbb::null_rw_mutex::scoped_lock lock(tested_mutex, false);
114     CHECK_MESSAGE(lock.try_acquire(tested_mutex, false), "Error for " << mutex_name << ": try_acquire on read failed though it should not");
115     CHECK_MESSAGE(lock.try_acquire(tested_mutex, true), "Error for " << mutex_name << ": try_acquire on write failed though it should not");
116     lock.release();
117     CHECK_MESSAGE(lock.try_acquire(tested_mutex, false), "Error for " << mutex_name << ": try_acquire on read failed though it should not");
118     CHECK_MESSAGE(lock.try_acquire(tested_mutex, true), "Error for " << mutex_name << ": try_acquire on write failed though it should not");
119 }
120 
121 template<typename M, size_t N>
122 struct ArrayCounter {
123     using mutex_type = M;
124     M mutex;
125     long value[N];
126 
ArrayCounterArrayCounter127     ArrayCounter() : value{0} {}
128 
incrementArrayCounter129     void increment() {
130         for (size_t k = 0; k < N; ++k) {
131             ++value[k];
132         }
133     }
134 
value_isArrayCounter135     bool value_is(long expected_value) const {
136         for (size_t k = 0; k < N; ++k) {
137             if (value[k] != expected_value) {
138                 return false;
139             }
140         }
141         return true;
142     }
143 };
144 
145 template<typename M, typename Counter>
TestReaderWriterLock_Impl(Counter & counter,typename M::scoped_lock & lock,const std::size_t i,const bool write)146 void TestReaderWriterLock_Impl(Counter& counter, typename M::scoped_lock& lock, const std::size_t i, const bool write) {
147     bool okay = true;
148     if (write) {
149         long counter_value = counter.value[0];
150         counter.increment();
151         // Downgrade to reader
152         if (i % 16 == 7) {
153             if (!lock.downgrade_to_reader()) {
154                 // Get the previous value as downgrade with the same lock acquired was failed
155                 counter_value = counter.value[0] - 1;
156             }
157             okay = counter.value_is(counter_value + 1);
158         }
159     } else {
160         okay = counter.value_is(counter.value[0]);
161         // Upgrade to writer
162         if (i % 8 == 3) {
163             long counter_value = counter.value[0];
164             if (!lock.upgrade_to_writer()) {
165                 // Failed to upgrade, reacquiring happened, need to update the value
166                 counter_value = counter.value[0];
167             }
168             counter.increment();
169             okay = counter.value_is(counter_value + 1);
170         }
171     }
172     CHECK_MESSAGE(okay, "Error in read write mutex operations");
173 }
174 
175 //! Shared mutex type test
176 template<typename M>
TestReaderWriterLock(const char * mutex_name)177 void TestReaderWriterLock(const char* mutex_name) {
178     ArrayCounter<M, 8> counter;
179     const int N = 10000;
180 #if TBB_TEST_LOW_WORKLOAD
181     const int GRAIN = 500;
182 #else
183     const int GRAIN = 100;
184 #endif /* TBB_TEST_LOW_WORKLOAD */
185 
186     // Stress test similar to the general, but with upgrade/downgrade cases
187     utils::NativeParallelFor(N, GRAIN, [&](int i) {
188         //! Every 8th access is a write access
189         const bool write = (i % 8) == 7;
190         if (i & 1) {
191             // Try implicit acquire and explicit release
192             typename M::scoped_lock lock(counter.mutex, write);
193             TestReaderWriterLock_Impl<M, ArrayCounter<M, 8>>(counter, lock, i, write);
194             lock.release();
195         } else {
196             // Try explicit acquire and implicit release
197             typename M::scoped_lock lock;
198             lock.acquire(counter.mutex, write);
199             TestReaderWriterLock_Impl<M, ArrayCounter<M, 8>>(counter, lock, i, write);
200         }
201     });
202     // There is either a writer or a reader upgraded to a writer for each 4th iteration
203     REQUIRE_MESSAGE(counter.value_is(N / 4), "ERROR for " << mutex_name << ": race is detected");
204 }
205 
206 template<typename M>
TestRWStateMultipleChange(const char * mutex_name)207 void TestRWStateMultipleChange(const char* mutex_name) {
208     static_assert(M::is_rw_mutex, "Incorrect mutex type");
209 
210     const int N = 1000;
211     const int GRAIN = 100;
212     M mutex;
213     utils::NativeParallelFor(N, GRAIN, [&] (int) {
214         typename M::scoped_lock l(mutex, /*write=*/false);
215         for (int i = 0; i != GRAIN; ++i) {
216             CHECK_MESSAGE(l.downgrade_to_reader(), mutex_name << " downgrade must succeed for read lock");
217         }
218         l.upgrade_to_writer();
219         for (int i = 0; i != GRAIN; ++i) {
220             CHECK_MESSAGE(l.upgrade_to_writer(), mutex_name << " upgrade must succeed for write lock");
221         }
222     });
223 }
224 
225 //! Adaptor for using ISO C++0x style mutex as a TBB-style mutex.
226 template<typename M>
227 class TBB_MutexFromISO_Mutex {
228     M my_iso_mutex;
229 public:
230     typedef TBB_MutexFromISO_Mutex mutex_type;
231 
232     class scoped_lock;
233     friend class scoped_lock;
234 
235     class scoped_lock {
236         mutex_type* my_mutex;
237         bool m_is_writer;
238     public:
scoped_lock()239         scoped_lock() : my_mutex(NULL), m_is_writer(false) {}
scoped_lock(mutex_type & m)240         scoped_lock(mutex_type& m) : my_mutex(NULL), m_is_writer(false) {
241             acquire(m);
242         }
scoped_lock(mutex_type & m,bool is_writer)243         scoped_lock(mutex_type& m, bool is_writer) : my_mutex(NULL) {
244             acquire(m,is_writer);
245         }
acquire(mutex_type & m)246         void acquire(mutex_type& m) {
247             m_is_writer = true;
248             m.my_iso_mutex.lock();
249             my_mutex = &m;
250         }
try_acquire(mutex_type & m)251         bool try_acquire(mutex_type& m) {
252             m_is_writer = true;
253             if (m.my_iso_mutex.try_lock()) {
254                 my_mutex = &m;
255                 return true;
256             } else {
257                 return false;
258             }
259         }
260 
261         template<typename Q = M>
release()262         typename std::enable_if<!Q::is_rw_mutex>::type release() {
263             my_mutex->my_iso_mutex.unlock();
264             my_mutex = NULL;
265         }
266 
267         template<typename Q = M>
release()268         typename std::enable_if<Q::is_rw_mutex>::type  release() {
269             if (m_is_writer)
270                 my_mutex->my_iso_mutex.unlock();
271             else
272                 my_mutex->my_iso_mutex.unlock_shared();
273             my_mutex = NULL;
274         }
275 
276         // Methods for reader-writer mutex
277         // These methods can be instantiated only if M supports lock_shared() and try_lock_shared().
278 
acquire(mutex_type & m,bool is_writer)279         void acquire(mutex_type& m, bool is_writer) {
280             m_is_writer = is_writer;
281             if (is_writer) m.my_iso_mutex.lock();
282             else m.my_iso_mutex.lock_shared();
283             my_mutex = &m;
284         }
try_acquire(mutex_type & m,bool is_writer)285         bool try_acquire(mutex_type& m, bool is_writer) {
286             m_is_writer = is_writer;
287             if (is_writer ? m.my_iso_mutex.try_lock() : m.my_iso_mutex.try_lock_shared()) {
288                 my_mutex = &m;
289                 return true;
290             } else {
291                 return false;
292             }
293         }
upgrade_to_writer()294         bool upgrade_to_writer() {
295             if (m_is_writer)
296                 return true;
297             m_is_writer = true;
298             my_mutex->my_iso_mutex.unlock_shared();
299             my_mutex->my_iso_mutex.lock();
300             return false;
301         }
downgrade_to_reader()302         bool downgrade_to_reader() {
303             if (!m_is_writer)
304                 return true;
305             m_is_writer = false;
306             my_mutex->my_iso_mutex.unlock();
307             my_mutex->my_iso_mutex.lock_shared();
308             return false;
309         }
~scoped_lock()310         ~scoped_lock() {
311             if (my_mutex)
312                 release();
313         }
314     };
315 
316     static constexpr bool is_recursive_mutex = M::is_recursive_mutex;
317     static constexpr bool is_rw_mutex = M::is_rw_mutex;
318 };
319 
320 template<typename C>
321 struct NullRecursive: utils::NoAssign {
recurse_tillNullRecursive322     void recurse_till(std::size_t i, std::size_t till) const {
323         if(i == till) {
324             counter.value = counter.value + 1;
325             return;
326         }
327         if(i & 1) {
328             typename C::mutex_type::scoped_lock lock2(counter.mutex);
329             recurse_till(i + 1, till);
330             lock2.release();
331         } else {
332             typename C::mutex_type::scoped_lock lock2;
333             lock2.acquire(counter.mutex);
334             recurse_till(i + 1, till);
335         }
336     }
337 
operatorNullRecursive338     void operator()(oneapi::tbb::blocked_range<std::size_t>& range) const {
339         typename C::mutex_type::scoped_lock lock(counter.mutex);
340         recurse_till(range.begin(), range.end());
341     }
NullRecursiveNullRecursive342     NullRecursive(C& counter_) : counter(counter_) {
343         REQUIRE_MESSAGE(is_recursive_mutex, "Null mutex should be a recursive mutex.");
344     }
345     C& counter;
346     bool is_recursive_mutex = C::mutex_type::is_recursive_mutex;
347 };
348 
349 template<typename M>
350 struct NullUpgradeDowngrade: utils::NoAssign {
operatorNullUpgradeDowngrade351     void operator()(oneapi::tbb::blocked_range<std::size_t>& range) const {
352         typename M::scoped_lock lock2;
353         for(std::size_t i = range.begin(); i != range.end(); ++i) {
354             if(i & 1) {
355                 typename M::scoped_lock lock1(my_mutex, true);
356                 if(lock1.downgrade_to_reader() == false) {
357                     REQUIRE_MESSAGE(false, "ERROR for " << mutex_name << ": downgrade should always succeed");
358                 }
359             } else {
360                 lock2.acquire(my_mutex, false);
361                 if(lock2.upgrade_to_writer() == false) {
362                     REQUIRE_MESSAGE(false, "ERROR for " << mutex_name << ": upgrade should always succeed");
363                 }
364                 lock2.release();
365             }
366         }
367     }
368 
NullUpgradeDowngradeNullUpgradeDowngrade369     NullUpgradeDowngrade(M& m_, const char* n_) : my_mutex(m_), mutex_name(n_) {}
370     M& my_mutex;
371     const char* mutex_name;
372 };
373 
374 template<typename M>
TestNullMutex(const char * mutex_name)375 void TestNullMutex(const char* mutex_name) {
376     INFO(mutex_name);
377     utils::AtomicCounter<M> counter;
378     counter.value = 0;
379     const std::size_t n = 100;
380     oneapi::tbb::parallel_for(oneapi::tbb::blocked_range<std::size_t>(0, n, 10), NullRecursive<utils::AtomicCounter<M>>(counter));
381     M m;
382     m.lock();
383     REQUIRE(m.try_lock());
384     m.unlock();
385 }
386 
387 template<typename M>
TestNullRWMutex(const char * mutex_name)388 void TestNullRWMutex(const char* mutex_name) {
389     const std::size_t n = 100;
390     M m;
391     oneapi::tbb::parallel_for(oneapi::tbb::blocked_range<std::size_t>(0, n, 10), NullUpgradeDowngrade<M>(m, mutex_name));
392     m.lock();
393     REQUIRE(m.try_lock());
394     m.lock_shared();
395     REQUIRE(m.try_lock_shared());
396     m.unlock_shared();
397     m.unlock();
398 }
399