1 //  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
2 //  This source code is licensed under both the GPLv2 (found in the
3 //  COPYING file in the root directory) and Apache 2.0 License
4 //  (found in the LICENSE.Apache file in the root directory).
5 
6 #include <thread>
7 #include <atomic>
8 #include <string>
9 
10 #include "port/port.h"
11 #include "rocksdb/env.h"
12 #include "test_util/sync_point.h"
13 #include "test_util/testharness.h"
14 #include "test_util/testutil.h"
15 #include "util/autovector.h"
16 #include "util/thread_local.h"
17 
18 namespace ROCKSDB_NAMESPACE {
19 
20 class ThreadLocalTest : public testing::Test {
21  public:
ThreadLocalTest()22   ThreadLocalTest() : env_(Env::Default()) {}
23 
24   Env* env_;
25 };
26 
27 namespace {
28 
29 struct Params {
ParamsROCKSDB_NAMESPACE::__anon965857180111::Params30   Params(port::Mutex* m, port::CondVar* c, int* u, int n,
31          UnrefHandler handler = nullptr)
32       : mu(m),
33         cv(c),
34         unref(u),
35         total(n),
36         started(0),
37         completed(0),
38         doWrite(false),
39         tls1(handler),
40         tls2(nullptr) {}
41 
42   port::Mutex* mu;
43   port::CondVar* cv;
44   int* unref;
45   int total;
46   int started;
47   int completed;
48   bool doWrite;
49   ThreadLocalPtr tls1;
50   ThreadLocalPtr* tls2;
51 };
52 
53 class IDChecker : public ThreadLocalPtr {
54 public:
PeekId()55   static uint32_t PeekId() {
56     return TEST_PeekId();
57   }
58 };
59 
60 }  // anonymous namespace
61 
62 // Suppress false positive clang analyzer warnings.
63 #ifndef __clang_analyzer__
TEST_F(ThreadLocalTest,UniqueIdTest)64 TEST_F(ThreadLocalTest, UniqueIdTest) {
65   port::Mutex mu;
66   port::CondVar cv(&mu);
67 
68   uint32_t base_id = IDChecker::PeekId();
69   // New ThreadLocal instance bumps id by 1
70   {
71     // Id used 0
72     Params p1(&mu, &cv, nullptr, 1u);
73     ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
74     // Id used 1
75     Params p2(&mu, &cv, nullptr, 1u);
76     ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
77     // Id used 2
78     Params p3(&mu, &cv, nullptr, 1u);
79     ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
80     // Id used 3
81     Params p4(&mu, &cv, nullptr, 1u);
82     ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
83   }
84   // id 3, 2, 1, 0 are in the free queue in order
85   ASSERT_EQ(IDChecker::PeekId(), base_id + 0u);
86 
87   // pick up 0
88   Params p1(&mu, &cv, nullptr, 1u);
89   ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
90   // pick up 1
91   Params* p2 = new Params(&mu, &cv, nullptr, 1u);
92   ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
93   // pick up 2
94   Params p3(&mu, &cv, nullptr, 1u);
95   ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
96   // return up 1
97   delete p2;
98   ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
99   // Now we have 3, 1 in queue
100   // pick up 1
101   Params p4(&mu, &cv, nullptr, 1u);
102   ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
103   // pick up 3
104   Params p5(&mu, &cv, nullptr, 1u);
105   // next new id
106   ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
107   // After exit, id sequence in queue:
108   // 3, 1, 2, 0
109 }
110 #endif  // __clang_analyzer__
111 
TEST_F(ThreadLocalTest,SequentialReadWriteTest)112 TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
113   // global id list carries over 3, 1, 2, 0
114   uint32_t base_id = IDChecker::PeekId();
115 
116   port::Mutex mu;
117   port::CondVar cv(&mu);
118   Params p(&mu, &cv, nullptr, 1);
119   ThreadLocalPtr tls2;
120   p.tls2 = &tls2;
121 
122   auto func = [](void* ptr) {
123     auto& params = *static_cast<Params*>(ptr);
124 
125     ASSERT_TRUE(params.tls1.Get() == nullptr);
126     params.tls1.Reset(reinterpret_cast<int*>(1));
127     ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
128     params.tls1.Reset(reinterpret_cast<int*>(2));
129     ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2));
130 
131     ASSERT_TRUE(params.tls2->Get() == nullptr);
132     params.tls2->Reset(reinterpret_cast<int*>(1));
133     ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1));
134     params.tls2->Reset(reinterpret_cast<int*>(2));
135     ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2));
136 
137     params.mu->Lock();
138     ++(params.completed);
139     params.cv->SignalAll();
140     params.mu->Unlock();
141   };
142 
143   for (int iter = 0; iter < 1024; ++iter) {
144     ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
145     // Another new thread, read/write should not see value from previous thread
146     env_->StartThread(func, static_cast<void*>(&p));
147     mu.Lock();
148     while (p.completed != iter + 1) {
149       cv.Wait();
150     }
151     mu.Unlock();
152     ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
153   }
154 }
155 
TEST_F(ThreadLocalTest,ConcurrentReadWriteTest)156 TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
157   // global id list carries over 3, 1, 2, 0
158   uint32_t base_id = IDChecker::PeekId();
159 
160   ThreadLocalPtr tls2;
161   port::Mutex mu1;
162   port::CondVar cv1(&mu1);
163   Params p1(&mu1, &cv1, nullptr, 16);
164   p1.tls2 = &tls2;
165 
166   port::Mutex mu2;
167   port::CondVar cv2(&mu2);
168   Params p2(&mu2, &cv2, nullptr, 16);
169   p2.doWrite = true;
170   p2.tls2 = &tls2;
171 
172   auto func = [](void* ptr) {
173     auto& p = *static_cast<Params*>(ptr);
174 
175     p.mu->Lock();
176     // Size_T switches size along with the ptr size
177     // we want to cast to.
178     size_t own = ++(p.started);
179     p.cv->SignalAll();
180     while (p.started != p.total) {
181       p.cv->Wait();
182     }
183     p.mu->Unlock();
184 
185     // Let write threads write a different value from the read threads
186     if (p.doWrite) {
187       own += 8192;
188     }
189 
190     ASSERT_TRUE(p.tls1.Get() == nullptr);
191     ASSERT_TRUE(p.tls2->Get() == nullptr);
192 
193     auto* env = Env::Default();
194     auto start = env->NowMicros();
195 
196     p.tls1.Reset(reinterpret_cast<size_t*>(own));
197     p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
198     // Loop for 1 second
199     while (env->NowMicros() - start < 1000 * 1000) {
200       for (int iter = 0; iter < 100000; ++iter) {
201         ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own));
202         ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1));
203         if (p.doWrite) {
204           p.tls1.Reset(reinterpret_cast<size_t*>(own));
205           p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
206         }
207       }
208     }
209 
210     p.mu->Lock();
211     ++(p.completed);
212     p.cv->SignalAll();
213     p.mu->Unlock();
214   };
215 
216   // Initiate 2 instnaces: one keeps writing and one keeps reading.
217   // The read instance should not see data from the write instance.
218   // Each thread local copy of the value are also different from each
219   // other.
220   for (int th = 0; th < p1.total; ++th) {
221     env_->StartThread(func, static_cast<void*>(&p1));
222   }
223   for (int th = 0; th < p2.total; ++th) {
224     env_->StartThread(func, static_cast<void*>(&p2));
225   }
226 
227   mu1.Lock();
228   while (p1.completed != p1.total) {
229     cv1.Wait();
230   }
231   mu1.Unlock();
232 
233   mu2.Lock();
234   while (p2.completed != p2.total) {
235     cv2.Wait();
236   }
237   mu2.Unlock();
238 
239   ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
240 }
241 
TEST_F(ThreadLocalTest,Unref)242 TEST_F(ThreadLocalTest, Unref) {
243   auto unref = [](void* ptr) {
244     auto& p = *static_cast<Params*>(ptr);
245     p.mu->Lock();
246     ++(*p.unref);
247     p.mu->Unlock();
248   };
249 
250   // Case 0: no unref triggered if ThreadLocalPtr is never accessed
251   auto func0 = [](void* ptr) {
252     auto& p = *static_cast<Params*>(ptr);
253 
254     p.mu->Lock();
255     ++(p.started);
256     p.cv->SignalAll();
257     while (p.started != p.total) {
258       p.cv->Wait();
259     }
260     p.mu->Unlock();
261   };
262 
263   for (int th = 1; th <= 128; th += th) {
264     port::Mutex mu;
265     port::CondVar cv(&mu);
266     int unref_count = 0;
267     Params p(&mu, &cv, &unref_count, th, unref);
268 
269     for (int i = 0; i < p.total; ++i) {
270       env_->StartThread(func0, static_cast<void*>(&p));
271     }
272     env_->WaitForJoin();
273     ASSERT_EQ(unref_count, 0);
274   }
275 
276   // Case 1: unref triggered by thread exit
277   auto func1 = [](void* ptr) {
278     auto& p = *static_cast<Params*>(ptr);
279 
280     p.mu->Lock();
281     ++(p.started);
282     p.cv->SignalAll();
283     while (p.started != p.total) {
284       p.cv->Wait();
285     }
286     p.mu->Unlock();
287 
288     ASSERT_TRUE(p.tls1.Get() == nullptr);
289     ASSERT_TRUE(p.tls2->Get() == nullptr);
290 
291     p.tls1.Reset(ptr);
292     p.tls2->Reset(ptr);
293 
294     p.tls1.Reset(ptr);
295     p.tls2->Reset(ptr);
296   };
297 
298   for (int th = 1; th <= 128; th += th) {
299     port::Mutex mu;
300     port::CondVar cv(&mu);
301     int unref_count = 0;
302     ThreadLocalPtr tls2(unref);
303     Params p(&mu, &cv, &unref_count, th, unref);
304     p.tls2 = &tls2;
305 
306     for (int i = 0; i < p.total; ++i) {
307       env_->StartThread(func1, static_cast<void*>(&p));
308     }
309 
310     env_->WaitForJoin();
311 
312     // N threads x 2 ThreadLocal instance cleanup on thread exit
313     ASSERT_EQ(unref_count, 2 * p.total);
314   }
315 
316   // Case 2: unref triggered by ThreadLocal instance destruction
317   auto func2 = [](void* ptr) {
318     auto& p = *static_cast<Params*>(ptr);
319 
320     p.mu->Lock();
321     ++(p.started);
322     p.cv->SignalAll();
323     while (p.started != p.total) {
324       p.cv->Wait();
325     }
326     p.mu->Unlock();
327 
328     ASSERT_TRUE(p.tls1.Get() == nullptr);
329     ASSERT_TRUE(p.tls2->Get() == nullptr);
330 
331     p.tls1.Reset(ptr);
332     p.tls2->Reset(ptr);
333 
334     p.tls1.Reset(ptr);
335     p.tls2->Reset(ptr);
336 
337     p.mu->Lock();
338     ++(p.completed);
339     p.cv->SignalAll();
340 
341     // Waiting for instruction to exit thread
342     while (p.completed != 0) {
343       p.cv->Wait();
344     }
345     p.mu->Unlock();
346   };
347 
348   for (int th = 1; th <= 128; th += th) {
349     port::Mutex mu;
350     port::CondVar cv(&mu);
351     int unref_count = 0;
352     Params p(&mu, &cv, &unref_count, th, unref);
353     p.tls2 = new ThreadLocalPtr(unref);
354 
355     for (int i = 0; i < p.total; ++i) {
356       env_->StartThread(func2, static_cast<void*>(&p));
357     }
358 
359     // Wait for all threads to finish using Params
360     mu.Lock();
361     while (p.completed != p.total) {
362       cv.Wait();
363     }
364     mu.Unlock();
365 
366     // Now destroy one ThreadLocal instance
367     delete p.tls2;
368     p.tls2 = nullptr;
369     // instance destroy for N threads
370     ASSERT_EQ(unref_count, p.total);
371 
372     // Signal to exit
373     mu.Lock();
374     p.completed = 0;
375     cv.SignalAll();
376     mu.Unlock();
377     env_->WaitForJoin();
378     // additional N threads exit unref for the left instance
379     ASSERT_EQ(unref_count, 2 * p.total);
380   }
381 }
382 
TEST_F(ThreadLocalTest,Swap)383 TEST_F(ThreadLocalTest, Swap) {
384   ThreadLocalPtr tls;
385   tls.Reset(reinterpret_cast<void*>(1));
386   ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
387   ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
388   ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
389   ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
390 }
391 
TEST_F(ThreadLocalTest,Scrape)392 TEST_F(ThreadLocalTest, Scrape) {
393   auto unref = [](void* ptr) {
394     auto& p = *static_cast<Params*>(ptr);
395     p.mu->Lock();
396     ++(*p.unref);
397     p.mu->Unlock();
398   };
399 
400   auto func = [](void* ptr) {
401     auto& p = *static_cast<Params*>(ptr);
402 
403     ASSERT_TRUE(p.tls1.Get() == nullptr);
404     ASSERT_TRUE(p.tls2->Get() == nullptr);
405 
406     p.tls1.Reset(ptr);
407     p.tls2->Reset(ptr);
408 
409     p.tls1.Reset(ptr);
410     p.tls2->Reset(ptr);
411 
412     p.mu->Lock();
413     ++(p.completed);
414     p.cv->SignalAll();
415 
416     // Waiting for instruction to exit thread
417     while (p.completed != 0) {
418       p.cv->Wait();
419     }
420     p.mu->Unlock();
421   };
422 
423   for (int th = 1; th <= 128; th += th) {
424     port::Mutex mu;
425     port::CondVar cv(&mu);
426     int unref_count = 0;
427     Params p(&mu, &cv, &unref_count, th, unref);
428     p.tls2 = new ThreadLocalPtr(unref);
429 
430     for (int i = 0; i < p.total; ++i) {
431       env_->StartThread(func, static_cast<void*>(&p));
432     }
433 
434     // Wait for all threads to finish using Params
435     mu.Lock();
436     while (p.completed != p.total) {
437       cv.Wait();
438     }
439     mu.Unlock();
440 
441     ASSERT_EQ(unref_count, 0);
442 
443     // Scrape all thread local data. No unref at thread
444     // exit or ThreadLocalPtr destruction
445     autovector<void*> ptrs;
446     p.tls1.Scrape(&ptrs, nullptr);
447     p.tls2->Scrape(&ptrs, nullptr);
448     delete p.tls2;
449     // Signal to exit
450     mu.Lock();
451     p.completed = 0;
452     cv.SignalAll();
453     mu.Unlock();
454     env_->WaitForJoin();
455 
456     ASSERT_EQ(unref_count, 0);
457   }
458 }
459 
TEST_F(ThreadLocalTest,Fold)460 TEST_F(ThreadLocalTest, Fold) {
461   auto unref = [](void* ptr) {
462     delete static_cast<std::atomic<int64_t>*>(ptr);
463   };
464   static const int kNumThreads = 16;
465   static const int kItersPerThread = 10;
466   port::Mutex mu;
467   port::CondVar cv(&mu);
468   Params params(&mu, &cv, nullptr, kNumThreads, unref);
469   auto func = [](void* ptr) {
470     auto& p = *static_cast<Params*>(ptr);
471     ASSERT_TRUE(p.tls1.Get() == nullptr);
472     p.tls1.Reset(new std::atomic<int64_t>(0));
473 
474     for (int i = 0; i < kItersPerThread; ++i) {
475       static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
476     }
477 
478     p.mu->Lock();
479     ++(p.completed);
480     p.cv->SignalAll();
481 
482     // Waiting for instruction to exit thread
483     while (p.completed != 0) {
484       p.cv->Wait();
485     }
486     p.mu->Unlock();
487   };
488 
489   for (int th = 0; th < params.total; ++th) {
490     env_->StartThread(func, static_cast<void*>(&params));
491   }
492 
493   // Wait for all threads to finish using Params
494   mu.Lock();
495   while (params.completed != params.total) {
496     cv.Wait();
497   }
498   mu.Unlock();
499 
500   // Verify Fold() behavior
501   int64_t sum = 0;
502   params.tls1.Fold(
503       [](void* ptr, void* res) {
504         auto sum_ptr = static_cast<int64_t*>(res);
505         *sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
506       },
507       &sum);
508   ASSERT_EQ(sum, kNumThreads * kItersPerThread);
509 
510   // Signal to exit
511   mu.Lock();
512   params.completed = 0;
513   cv.SignalAll();
514   mu.Unlock();
515   env_->WaitForJoin();
516 }
517 
TEST_F(ThreadLocalTest,CompareAndSwap)518 TEST_F(ThreadLocalTest, CompareAndSwap) {
519   ThreadLocalPtr tls;
520   ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
521   void* expected = reinterpret_cast<void*>(1);
522   // Swap in 2
523   ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
524   expected = reinterpret_cast<void*>(100);
525   // Fail Swap, still 2
526   ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
527   ASSERT_EQ(expected, reinterpret_cast<void*>(2));
528   // Swap in 3
529   expected = reinterpret_cast<void*>(2);
530   ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected));
531   ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3));
532 }
533 
534 namespace {
535 
AccessThreadLocal(void *)536 void* AccessThreadLocal(void* /*arg*/) {
537   TEST_SYNC_POINT("AccessThreadLocal:Start");
538   ThreadLocalPtr tlp;
539   tlp.Reset(new std::string("hello RocksDB"));
540   TEST_SYNC_POINT("AccessThreadLocal:End");
541   return nullptr;
542 }
543 
544 }  // namespace
545 
546 // The following test is disabled as it requires manual steps to run it
547 // correctly.
548 //
549 // Currently we have no way to acess SyncPoint w/o ASAN error when the
550 // child thread dies after the main thread dies.  So if you manually enable
551 // this test and only see an ASAN error on SyncPoint, it means you pass the
552 // test.
TEST_F(ThreadLocalTest,DISABLED_MainThreadDiesFirst)553 TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) {
554   ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
555       {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
556        {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
557 
558   // Triggers the initialization of singletons.
559   Env::Default();
560 
561 #ifndef ROCKSDB_LITE
562   try {
563 #endif  // ROCKSDB_LITE
564     ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr);
565     th.detach();
566     TEST_SYNC_POINT("MainThreadDiesFirst:End");
567 #ifndef ROCKSDB_LITE
568   } catch (const std::system_error& ex) {
569     std::cerr << "Start thread: " << ex.code() << std::endl;
570     FAIL();
571   }
572 #endif  // ROCKSDB_LITE
573 }
574 
575 }  // namespace ROCKSDB_NAMESPACE
576 
main(int argc,char ** argv)577 int main(int argc, char** argv) {
578   ::testing::InitGoogleTest(&argc, argv);
579   return RUN_ALL_TESTS();
580 }
581