1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // UNSUPPORTED: c++98, c++03
10 // UNSUPPORTED: libcxxabi-no-threads, libcxxabi-no-exceptions
11 
12 #define TESTING_CXA_GUARD
13 #include "../src/cxa_guard_impl.h"
14 #include <unordered_map>
15 #include <thread>
16 #include <atomic>
17 #include <array>
18 #include <cassert>
19 #include <memory>
20 #include <vector>
21 
22 
23 using namespace __cxxabiv1;
24 
25 // Misc test configuration. It's used to tune the flakyness of the test.
26 // ThreadsPerTest - The number of threads used
27 constexpr int ThreadsPerTest = 10;
28 // The number of instances of a test to run concurrently.
29 constexpr int ConcurrentRunsPerTest = 10;
30 // The number of times to rerun each test.
31 constexpr int TestSamples = 50;
32 
33 
34 
BusyWait()35 void BusyWait() {
36     std::this_thread::yield();
37 }
38 
YieldAfterBarrier()39 void YieldAfterBarrier() {
40   std::this_thread::sleep_for(std::chrono::nanoseconds(10));
41   std::this_thread::yield();
42 }
43 
44 struct Barrier {
BarrierBarrier45   explicit Barrier(int n) : m_threads(n), m_remaining(n) { }
46   Barrier(Barrier const&) = delete;
47   Barrier& operator=(Barrier const&) = delete;
48 
arrive_and_waitBarrier49   void arrive_and_wait() const {
50     --m_remaining;
51     while (m_remaining.load()) {
52       BusyWait();
53     }
54   }
55 
arrive_and_dropBarrier56   void arrive_and_drop()  const {
57     --m_remaining;
58   }
59 
wait_for_threadsBarrier60   void wait_for_threads(int n) const {
61     while ((m_threads - m_remaining.load()) < n) {
62       std::this_thread::yield();
63     }
64   }
65 
66 private:
67   const int m_threads;
68   mutable std::atomic<int> m_remaining;
69 };
70 
71 
72 enum class InitResult {
73   COMPLETE,
74   PERFORMED,
75   WAITED,
76   ABORTED
77 };
78 constexpr InitResult COMPLETE = InitResult::COMPLETE;
79 constexpr InitResult PERFORMED = InitResult::PERFORMED;
80 constexpr InitResult WAITED = InitResult::WAITED;
81 constexpr InitResult ABORTED = InitResult::ABORTED;
82 
83 
84 template <class Impl, class GuardType, class Init>
check_guard(GuardType * g,Init init)85 InitResult check_guard(GuardType *g, Init init) {
86   uint8_t *first_byte = reinterpret_cast<uint8_t*>(g);
87   if (std::__libcpp_atomic_load(first_byte, std::_AO_Acquire) == 0) {
88     Impl impl(g);
89     if (impl.cxa_guard_acquire() == INIT_IS_PENDING) {
90 #ifndef LIBCXXABI_HAS_NO_EXCEPTIONS
91       try {
92 #endif
93         init();
94         impl.cxa_guard_release();
95         return PERFORMED;
96 #ifndef LIBCXXABI_HAS_NO_EXCEPTIONS
97       } catch (...) {
98         impl.cxa_guard_abort();
99         return ABORTED;
100       }
101 #endif
102     }
103     return WAITED;
104   }
105   return COMPLETE;
106 }
107 
108 
109 template <class GuardType, class Impl>
110 struct FunctionLocalStatic {
FunctionLocalStaticFunctionLocalStatic111   FunctionLocalStatic() {}
112   FunctionLocalStatic(FunctionLocalStatic const&) = delete;
113 
114   template <class InitFunc>
accessFunctionLocalStatic115   InitResult access(InitFunc&& init) {
116     auto res = check_guard<Impl>(&guard_object, init);
117     ++result_counts[static_cast<int>(res)];
118     return res;
119   }
120 
121   template <class InitFn>
122   struct AccessCallback {
operator ()FunctionLocalStatic::AccessCallback123     void operator()() const { this_obj->access(init); }
124 
125     FunctionLocalStatic *this_obj;
126     InitFn init;
127   };
128 
129   template <class InitFn, class Callback = AccessCallback< InitFn >  >
access_callbackFunctionLocalStatic130   Callback access_callback(InitFn init) {
131     return Callback{this, init};
132   }
133 
get_countFunctionLocalStatic134   int get_count(InitResult I) const {
135     return result_counts[static_cast<int>(I)].load();
136   }
137 
num_completedFunctionLocalStatic138   int num_completed() const {
139     return get_count(COMPLETE) + get_count(PERFORMED) + get_count(WAITED);
140   }
141 
num_waitingFunctionLocalStatic142   int num_waiting() const {
143     return waiting_threads.load();
144   }
145 
146 private:
147   GuardType guard_object = {};
148   std::atomic<int> waiting_threads{0};
149   std::array<std::atomic<int>, 4> result_counts{};
150   static_assert(static_cast<int>(ABORTED) == 3, "only 4 result kinds expected");
151 };
152 
153 struct ThreadGroup {
154   ThreadGroup() = default;
155   ThreadGroup(ThreadGroup const&) = delete;
156 
157   template <class ...Args>
CreateThreadGroup158   void Create(Args&& ...args) {
159     threads.emplace_back(std::forward<Args>(args)...);
160   }
161 
162   template <class Callback>
CreateThreadsWithBarrierThreadGroup163   void CreateThreadsWithBarrier(int N, Callback cb) {
164     auto start = std::make_shared<Barrier>(N + 1);
165     for (int I=0; I < N; ++I) {
166       Create([start, cb]() {
167         start->arrive_and_wait();
168         cb();
169       });
170     }
171     start->arrive_and_wait();
172   }
173 
JoinAllThreadGroup174   void JoinAll() {
175     for (auto& t : threads) {
176       t.join();
177     }
178   }
179 
180 private:
181   std::vector<std::thread> threads;
182 };
183 
184 
185 template <class GuardType, class Impl>
test_free_for_all(int num_waiters)186 void test_free_for_all(int num_waiters) {
187   FunctionLocalStatic<GuardType, Impl> test_obj;
188 
189   ThreadGroup threads;
190 
191   bool already_init = false;
192   threads.CreateThreadsWithBarrier(num_waiters,
193     test_obj.access_callback([&]() {
194       assert(!already_init);
195       already_init = true;
196     })
197   );
198 
199   // wait for the other threads to finish initialization.
200   threads.JoinAll();
201 
202   assert(test_obj.get_count(PERFORMED) == 1);
203   assert(test_obj.get_count(COMPLETE) + test_obj.get_count(WAITED) == num_waiters - 1);
204 }
205 
206 template <class GuardType, class Impl>
test_waiting_for_init(int num_waiters)207 void test_waiting_for_init(int num_waiters) {
208     FunctionLocalStatic<GuardType, Impl> test_obj;
209 
210     ThreadGroup threads;
211 
212     Barrier start_init(2);
213     threads.Create(test_obj.access_callback(
214       [&]() {
215         start_init.arrive_and_wait();
216         // Take our sweet time completing the initialization...
217         //
218         // There's a race condition between the other threads reaching the
219         // start_init barrier, and them actually hitting the cxa guard.
220         // But we're trying to test the waiting logic, we want as many
221         // threads to enter the waiting loop as possible.
222         YieldAfterBarrier();
223       }
224     ));
225     start_init.wait_for_threads(1);
226 
227     threads.CreateThreadsWithBarrier(num_waiters,
228         test_obj.access_callback([]() { assert(false); })
229     );
230     // unblock the initializing thread
231     start_init.arrive_and_drop();
232 
233     // wait for the other threads to finish initialization.
234     threads.JoinAll();
235 
236     assert(test_obj.get_count(PERFORMED) == 1);
237     assert(test_obj.get_count(ABORTED) == 0);
238     assert(test_obj.get_count(COMPLETE) + test_obj.get_count(WAITED) == num_waiters);
239 }
240 
241 
242 template <class GuardType, class Impl>
test_aborted_init(int num_waiters)243 void test_aborted_init(int num_waiters) {
244   FunctionLocalStatic<GuardType, Impl> test_obj;
245 
246   Barrier start_init(2);
247   ThreadGroup threads;
248   threads.Create(test_obj.access_callback(
249     [&]() {
250       start_init.arrive_and_wait();
251       YieldAfterBarrier();
252       throw 42;
253     })
254   );
255   start_init.wait_for_threads(1);
256 
257   bool already_init = false;
258   threads.CreateThreadsWithBarrier(num_waiters,
259       test_obj.access_callback([&]() {
260         assert(!already_init);
261         already_init = true;
262       })
263     );
264   // unblock the initializing thread
265   start_init.arrive_and_drop();
266 
267   // wait for the other threads to finish initialization.
268   threads.JoinAll();
269 
270   assert(test_obj.get_count(ABORTED) == 1);
271   assert(test_obj.get_count(PERFORMED) == 1);
272   assert(test_obj.get_count(WAITED) + test_obj.get_count(COMPLETE) == num_waiters - 1);
273 }
274 
275 
276 template <class GuardType, class Impl>
test_completed_init(int num_waiters)277 void test_completed_init(int num_waiters) {
278 
279   FunctionLocalStatic<GuardType, Impl> test_obj;
280 
281   test_obj.access([]() {}); // initialize the object
282   assert(test_obj.num_waiting() == 0);
283   assert(test_obj.num_completed() == 1);
284   assert(test_obj.get_count(PERFORMED) == 1);
285 
286   ThreadGroup threads;
287   threads.CreateThreadsWithBarrier(num_waiters,
288       test_obj.access_callback([]() { assert(false); })
289   );
290   // wait for the other threads to finish initialization.
291   threads.JoinAll();
292 
293   assert(test_obj.get_count(ABORTED) == 0);
294   assert(test_obj.get_count(PERFORMED) == 1);
295   assert(test_obj.get_count(WAITED) == 0);
296   assert(test_obj.get_count(COMPLETE) == num_waiters);
297 }
298 
299 template <class Impl>
test_impl()300 void test_impl() {
301   using TestFn = void(*)(int);
302   TestFn TestList[] = {
303     test_free_for_all<uint32_t, Impl>,
304     test_free_for_all<uint32_t, Impl>,
305     test_waiting_for_init<uint32_t, Impl>,
306     test_waiting_for_init<uint64_t, Impl>,
307     test_aborted_init<uint32_t, Impl>,
308     test_aborted_init<uint64_t, Impl>,
309     test_completed_init<uint32_t, Impl>,
310     test_completed_init<uint64_t, Impl>
311   };
312 
313   for (auto test_func : TestList) {
314       ThreadGroup test_threads;
315       test_threads.CreateThreadsWithBarrier(ConcurrentRunsPerTest, [=]() {
316         for (int I = 0; I < TestSamples; ++I) {
317           test_func(ThreadsPerTest);
318         }
319       });
320       test_threads.JoinAll();
321     }
322   }
323 
test_all_impls()324 void test_all_impls() {
325   using MutexImpl = SelectImplementation<Implementation::GlobalLock>::type;
326 
327   // Attempt to test the Futex based implementation if it's supported on the
328   // target platform.
329   using RealFutexImpl = SelectImplementation<Implementation::Futex>::type;
330   using FutexImpl = typename std::conditional<
331       PlatformSupportsFutex(),
332       RealFutexImpl,
333       MutexImpl
334   >::type;
335 
336   test_impl<MutexImpl>();
337   if (PlatformSupportsFutex())
338     test_impl<FutexImpl>();
339 }
340 
341 // A dummy
342 template <bool Dummy = true>
test_futex_syscall()343 void test_futex_syscall() {
344   if (!PlatformSupportsFutex())
345     return;
346   int lock1 = 0;
347   int lock2 = 0;
348   int lock3 = 0;
349   std::thread waiter1([&]() {
350     int expect = 0;
351     PlatformFutexWait(&lock1, expect);
352     assert(lock1 == 1);
353   });
354   std::thread waiter2([&]() {
355     int expect = 0;
356     PlatformFutexWait(&lock2, expect);
357     assert(lock2 == 2);
358   });
359   std::thread waiter3([&]() {
360     int expect = 42; // not the value
361     PlatformFutexWait(&lock3, expect); // doesn't block
362   });
363   std::thread waker([&]() {
364     lock1 = 1;
365     PlatformFutexWake(&lock1);
366     lock2 = 2;
367     PlatformFutexWake(&lock2);
368   });
369   waiter1.join();
370   waiter2.join();
371   waiter3.join();
372   waker.join();
373 }
374 
main()375 int main() {
376   // Test each multi-threaded implementation with real threads.
377   test_all_impls();
378   // Test the basic sanity of the futex syscall wrappers.
379   test_futex_syscall();
380 }
381