1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <chrono>
18 #include <condition_variable>
19 #include <ctime>
20 #include <deque>
21 #include <memory>
22 #include <mutex>
23 #include <random>
24 #include <thread>
25 
26 #include <folly/CPortability.h>
27 #include <folly/Synchronized.h>
28 #include <folly/executors/Codel.h>
29 #include <folly/lang/Keep.h>
30 #include <folly/portability/GTest.h>
31 #include <folly/portability/PThread.h>
32 #include <folly/portability/SysResource.h>
33 #include <folly/portability/SysTime.h>
34 #include <folly/synchronization/Baton.h>
35 #include <folly/synchronization/Latch.h>
36 #include <thrift/lib/cpp/concurrency/FunctionRunner.h>
37 #include <thrift/lib/cpp/concurrency/PosixThreadFactory.h>
38 #include <thrift/lib/cpp/concurrency/ThreadManager.h>
39 #include <thrift/lib/cpp/concurrency/Util.h>
40 
41 using namespace apache::thrift::concurrency;
42 
43 class ThreadManagerTest : public testing::Test {
44  public:
~ThreadManagerTest()45   ~ThreadManagerTest() override { ThreadManager::setGlobalObserver(nullptr); }
46 
47  private:
48   gflags::FlagSaver flagsaver_;
49 };
50 
51 static folly::WorkerProvider* kWorkerProviderGlobal = nullptr;
52 
53 namespace folly {
54 
55 #ifdef FOLLY_HAVE_WEAK_SYMBOLS
56 FOLLY_KEEP std::unique_ptr<folly::QueueObserverFactory>
make_queue_observer_factory(const std::string &,size_t,folly::WorkerProvider * workerProvider)57 make_queue_observer_factory(
58     const std::string&, size_t, folly::WorkerProvider* workerProvider) {
59   kWorkerProviderGlobal = workerProvider;
60   return {};
61 }
62 #endif
63 
64 } // namespace folly
65 
66 // Loops until x==y for up to timeout ms.
67 // The end result is the same as of {EXPECT,ASSERT}_EQ(x,y)
68 // (depending on OP) if x!=y after the timeout passes
69 #define X_EQUAL_SPECIFIC_TIMEOUT(OP, timeout, x, y)         \
70   do {                                                      \
71     using std::chrono::steady_clock;                        \
72     using std::chrono::milliseconds;                        \
73     auto end = steady_clock::now() + milliseconds(timeout); \
74     while ((x) != (y) && steady_clock::now() < end) {       \
75     }                                                       \
76     OP##_EQ(x, y);                                          \
77   } while (0)
78 
79 #define CHECK_EQUAL_SPECIFIC_TIMEOUT(timeout, x, y) \
80   X_EQUAL_SPECIFIC_TIMEOUT(EXPECT, timeout, x, y)
81 #define REQUIRE_EQUAL_SPECIFIC_TIMEOUT(timeout, x, y) \
82   X_EQUAL_SPECIFIC_TIMEOUT(ASSERT, timeout, x, y)
83 
84 // A default timeout of 1 sec should be long enough for other threads to
85 // stabilize the values of x and y, and short enough to catch real errors
86 // when x is not going to be equal to y anytime soon
87 #define CHECK_EQUAL_TIMEOUT(x, y) CHECK_EQUAL_SPECIFIC_TIMEOUT(1000, x, y)
88 #define REQUIRE_EQUAL_TIMEOUT(x, y) REQUIRE_EQUAL_SPECIFIC_TIMEOUT(1000, x, y)
89 
90 class LoadTask : public Runnable {
91  public:
LoadTask(std::mutex * mutex,std::condition_variable * cond,size_t * count,int64_t timeout)92   LoadTask(
93       std::mutex* mutex,
94       std::condition_variable* cond,
95       size_t* count,
96       int64_t timeout)
97       : mutex_(mutex),
98         cond_(cond),
99         count_(count),
100         timeout_(timeout),
101         startTime_(0),
102         endTime_(0) {}
103 
run()104   void run() override {
105     startTime_ = Util::currentTime();
106     usleep(timeout_ * Util::US_PER_MS);
107     endTime_ = Util::currentTime();
108 
109     {
110       std::unique_lock<std::mutex> l(*mutex_);
111 
112       (*count_)--;
113       if (*count_ == 0) {
114         cond_->notify_one();
115       }
116     }
117   }
118 
119   std::mutex* mutex_;
120   std::condition_variable* cond_;
121   size_t* count_;
122   int64_t timeout_;
123   int64_t startTime_;
124   int64_t endTime_;
125 };
126 
127 /**
128  * Dispatch count tasks, each of which blocks for timeout milliseconds then
129  * completes. Verify that all tasks completed and that thread manager cleans
130  * up properly on delete.
131  */
loadTest(size_t numTasks,int64_t timeout,size_t numWorkers)132 static void loadTest(size_t numTasks, int64_t timeout, size_t numWorkers) {
133   std::mutex mutex;
134   std::condition_variable cond;
135   size_t tasksLeft = numTasks;
136 
137   auto threadManager = ThreadManager::newSimpleThreadManager(numWorkers);
138   auto threadFactory = std::make_shared<PosixThreadFactory>();
139   threadManager->threadFactory(threadFactory);
140   threadManager->start();
141 
142   std::set<std::shared_ptr<LoadTask>> tasks;
143   for (size_t n = 0; n < numTasks; n++) {
144     tasks.insert(
145         std::make_shared<LoadTask>(&mutex, &cond, &tasksLeft, timeout));
146   }
147 
148   int64_t startTime = Util::currentTime();
149   for (const auto& task : tasks) {
150     threadManager->add(task);
151   }
152 
153   int64_t tasksStartedTime = Util::currentTime();
154 
155   {
156     std::unique_lock<std::mutex> l(mutex);
157     while (tasksLeft > 0) {
158       cond.wait(l);
159     }
160   }
161   int64_t endTime = Util::currentTime();
162 
163   int64_t firstTime = std::numeric_limits<int64_t>::max();
164   int64_t lastTime = 0;
165   double averageTime = 0;
166   int64_t minTime = std::numeric_limits<int64_t>::max();
167   int64_t maxTime = 0;
168 
169   for (const auto& task : tasks) {
170     EXPECT_GT(task->startTime_, 0);
171     EXPECT_GT(task->endTime_, 0);
172 
173     int64_t delta = task->endTime_ - task->startTime_;
174     assert(delta > 0);
175 
176     firstTime = std::min(firstTime, task->startTime_);
177     lastTime = std::max(lastTime, task->endTime_);
178     minTime = std::min(minTime, delta);
179     maxTime = std::max(maxTime, delta);
180 
181     averageTime += delta;
182   }
183   averageTime /= numTasks;
184 
185   LOG(INFO) << "first start: " << firstTime << "ms "
186             << "last end: " << lastTime << "ms "
187             << "min: " << minTime << "ms "
188             << "max: " << maxTime << "ms "
189             << "average: " << averageTime << "ms";
190 
191   double idealTime = ((numTasks + (numWorkers - 1)) / numWorkers) * timeout;
192   double actualTime = endTime - startTime;
193   double taskStartTime = tasksStartedTime - startTime;
194 
195   double overheadPct = (actualTime - idealTime) / idealTime;
196   if (overheadPct < 0) {
197     overheadPct *= -1.0;
198   }
199 
200   LOG(INFO) << "ideal time: " << idealTime << "ms "
201             << "actual time: " << actualTime << "ms "
202             << "task startup time: " << taskStartTime << "ms "
203             << "overhead: " << overheadPct * 100.0 << "%";
204 
205   // Fail if the test took 10% more time than the ideal time
206   EXPECT_LT(overheadPct, 0.10);
207 }
208 
TEST_F(ThreadManagerTest,LoadTest)209 TEST_F(ThreadManagerTest, LoadTest) {
210   size_t numTasks = 10000;
211   int64_t timeout = 50;
212   size_t numWorkers = 100;
213   loadTest(numTasks, timeout, numWorkers);
214 }
215 
216 class BlockTask : public Runnable {
217  public:
BlockTask(std::mutex * mutex,std::condition_variable * cond,std::mutex * bmutex,std::condition_variable * bcond,bool * blocked,size_t * count)218   BlockTask(
219       std::mutex* mutex,
220       std::condition_variable* cond,
221       std::mutex* bmutex,
222       std::condition_variable* bcond,
223       bool* blocked,
224       size_t* count)
225       : mutex_(mutex),
226         cond_(cond),
227         bmutex_(bmutex),
228         bcond_(bcond),
229         blocked_(blocked),
230         count_(count),
231         started_(false) {}
232 
run()233   void run() override {
234     started_ = true;
235     {
236       std::unique_lock<std::mutex> l(*bmutex_);
237       while (*blocked_) {
238         bcond_->wait(l);
239       }
240     }
241 
242     {
243       std::unique_lock<std::mutex> l(*mutex_);
244       (*count_)--;
245       if (*count_ == 0) {
246         cond_->notify_one();
247       }
248     }
249   }
250 
251   std::mutex* mutex_;
252   std::condition_variable* cond_;
253   std::mutex* bmutex_;
254   std::condition_variable* bcond_;
255   bool* blocked_;
256   size_t* count_;
257   bool started_;
258 };
259 
expireTestCallback(std::shared_ptr<Runnable>,std::mutex * mutex,std::condition_variable * cond,size_t * count)260 static void expireTestCallback(
261     std::shared_ptr<Runnable>,
262     std::mutex* mutex,
263     std::condition_variable* cond,
264     size_t* count) {
265   std::unique_lock<std::mutex> l(*mutex);
266   --(*count);
267   if (*count == 0) {
268     cond->notify_one();
269   }
270 }
271 
expireTest(size_t numWorkers,int64_t expirationTimeMs)272 static void expireTest(size_t numWorkers, int64_t expirationTimeMs) {
273   size_t maxPendingTasks = numWorkers;
274   size_t activeTasks = numWorkers + maxPendingTasks;
275   std::mutex mutex;
276   std::condition_variable cond;
277 
278   auto threadManager = ThreadManager::newSimpleThreadManager(numWorkers);
279   auto threadFactory = std::make_shared<PosixThreadFactory>();
280   threadManager->threadFactory(threadFactory);
281   threadManager->setExpireCallback(std::bind(
282       expireTestCallback, std::placeholders::_1, &mutex, &cond, &activeTasks));
283   threadManager->start();
284 
285   // Add numWorkers + maxPendingTasks to fill up the ThreadManager's task queue
286   std::vector<std::shared_ptr<BlockTask>> tasks;
287   tasks.reserve(activeTasks);
288 
289   std::mutex bmutex;
290   std::condition_variable bcond;
291   bool blocked = true;
292   for (size_t n = 0; n < numWorkers + maxPendingTasks; ++n) {
293     auto task = std::make_shared<BlockTask>(
294         &mutex, &cond, &bmutex, &bcond, &blocked, &activeTasks);
295     tasks.push_back(task);
296     threadManager->add(task, 0, expirationTimeMs);
297   }
298 
299   // Sleep for more than the expiration time
300   usleep(expirationTimeMs * Util::US_PER_MS * 1.10);
301 
302   // Unblock the tasks
303   {
304     std::unique_lock<std::mutex> l(bmutex);
305     blocked = false;
306     bcond.notify_all();
307   }
308   // Wait for all tasks to complete or expire
309   {
310     std::unique_lock<std::mutex> l(mutex);
311     while (activeTasks != 0) {
312       cond.wait(l);
313     }
314   }
315 
316   // The first numWorkers tasks should have completed,
317   // the remaining ones should have expired without running
318   for (size_t index = 0; index < tasks.size(); ++index) {
319     if (index < numWorkers) {
320       EXPECT_TRUE(tasks[index]->started_);
321     } else {
322       EXPECT_TRUE(!tasks[index]->started_);
323     }
324   }
325 }
326 
TEST_F(ThreadManagerTest,ExpireTest)327 TEST_F(ThreadManagerTest, ExpireTest) {
328   size_t numWorkers = 100;
329   int64_t expireTimeMs = 50;
330   expireTest(numWorkers, expireTimeMs);
331 }
332 
333 class AddRemoveTask : public Runnable,
334                       public std::enable_shared_from_this<AddRemoveTask> {
335  public:
AddRemoveTask(uint32_t timeoutUs,const std::shared_ptr<ThreadManager> & manager,std::mutex * mutex,std::condition_variable * cond,int64_t * count,int64_t * objectCount)336   AddRemoveTask(
337       uint32_t timeoutUs,
338       const std::shared_ptr<ThreadManager>& manager,
339       std::mutex* mutex,
340       std::condition_variable* cond,
341       int64_t* count,
342       int64_t* objectCount)
343       : timeoutUs_(timeoutUs),
344         manager_(manager),
345         mutex_(mutex),
346         cond_(cond),
347         count_(count),
348         objectCount_(objectCount) {
349     std::unique_lock<std::mutex> l(*mutex_);
350     ++*objectCount_;
351   }
352 
~AddRemoveTask()353   ~AddRemoveTask() override {
354     std::unique_lock<std::mutex> l(*mutex_);
355     --*objectCount_;
356   }
357 
run()358   void run() override {
359     usleep(timeoutUs_);
360 
361     {
362       std::unique_lock<std::mutex> l(*mutex_);
363 
364       if (*count_ <= 0) {
365         // The task count already dropped to 0.
366         // We add more tasks than count_, so some of them may still be running
367         // when count_ drops to 0.
368         return;
369       }
370 
371       --*count_;
372       if (*count_ == 0) {
373         cond_->notify_all();
374         return;
375       }
376     }
377 
378     // Add ourself to the task queue again
379     manager_->add(shared_from_this());
380   }
381 
382  private:
383   int32_t timeoutUs_;
384   std::shared_ptr<ThreadManager> manager_;
385   std::mutex* mutex_;
386   std::condition_variable* cond_;
387   int64_t* count_;
388   int64_t* objectCount_;
389 };
390 
391 class WorkerCountChanger : public Runnable {
392  public:
WorkerCountChanger(const std::shared_ptr<ThreadManager> & manager,std::mutex * mutex,int64_t * count,int64_t * addAndRemoveCount)393   WorkerCountChanger(
394       const std::shared_ptr<ThreadManager>& manager,
395       std::mutex* mutex,
396       int64_t* count,
397       int64_t* addAndRemoveCount)
398       : manager_(manager),
399         mutex_(mutex),
400         count_(count),
401         addAndRemoveCount_(addAndRemoveCount) {}
402 
run()403   void run() override {
404     // Continue adding and removing threads until the tasks are all done
405     while (true) {
406       {
407         std::unique_lock<std::mutex> l(*mutex_);
408         if (*count_ == 0) {
409           return;
410         }
411         ++*addAndRemoveCount_;
412       }
413       addAndRemove();
414     }
415   }
416 
addAndRemove()417   void addAndRemove() {
418     // Add a random number of workers
419     std::uniform_int_distribution<> workerDist(1, 10);
420     uint32_t workersToAdd = workerDist(rng_);
421     manager_->addWorker(workersToAdd);
422 
423     std::uniform_int_distribution<> taskDist(1, 50);
424     uint32_t tasksToAdd = taskDist(rng_);
425     (void)tasksToAdd;
426 
427     // Sleep for a random amount of time
428     std::uniform_int_distribution<> sleepDist(1000, 5000);
429     uint32_t sleepUs = sleepDist(rng_);
430     usleep(sleepUs);
431 
432     // Remove the same number of workers we added
433     manager_->removeWorker(workersToAdd);
434   }
435 
436  private:
437   std::mt19937 rng_;
438   std::shared_ptr<ThreadManager> manager_;
439   std::mutex* mutex_;
440   int64_t* count_;
441   int64_t* addAndRemoveCount_;
442 };
443 
444 // Run lots of tasks, while several threads are all changing
445 // the number of worker threads.
TEST_F(ThreadManagerTest,AddRemoveWorker)446 TEST_F(ThreadManagerTest, AddRemoveWorker) {
447   // Number of tasks to run
448   int64_t numTasks = 100000;
449   // Minimum number of workers to keep at any point in time
450   size_t minNumWorkers = 10;
451   // Number of threads that will be adding and removing workers
452   int64_t numAddRemoveWorkers = 30;
453   // Number of tasks to run in parallel
454   int64_t numParallelTasks = 200;
455 
456   auto threadManager = ThreadManager::newSimpleThreadManager(minNumWorkers);
457   auto threadFactory = std::make_shared<PosixThreadFactory>();
458   threadManager->threadFactory(threadFactory);
459   threadManager->start();
460 
461   std::mutex mutex;
462   std::condition_variable cond;
463   int64_t currentTaskObjects = 0;
464   int64_t count = numTasks;
465   int64_t addRemoveCount = 0;
466 
467   std::mt19937 rng;
468   std::uniform_int_distribution<> taskTimeoutDist(1, 3000);
469   for (int64_t n = 0; n < numParallelTasks; ++n) {
470     int64_t taskTimeoutUs = taskTimeoutDist(rng);
471     auto task = std::make_shared<AddRemoveTask>(
472         taskTimeoutUs,
473         threadManager,
474         &mutex,
475         &cond,
476         &count,
477         &currentTaskObjects);
478     threadManager->add(task);
479   }
480 
481   auto addRemoveFactory = std::make_shared<PosixThreadFactory>();
482   addRemoveFactory->setDetached(false);
483   std::deque<std::shared_ptr<Thread>> addRemoveThreads;
484   for (int64_t n = 0; n < numAddRemoveWorkers; ++n) {
485     auto worker = std::make_shared<WorkerCountChanger>(
486         threadManager, &mutex, &count, &addRemoveCount);
487     auto thread = addRemoveFactory->newThread(worker);
488     addRemoveThreads.push_back(thread);
489     thread->start();
490   }
491 
492   while (!addRemoveThreads.empty()) {
493     addRemoveThreads.front()->join();
494     addRemoveThreads.pop_front();
495   }
496 
497   LOG(INFO) << "add remove count: " << addRemoveCount;
498   EXPECT_GT(addRemoveCount, 0);
499 
500   // Stop the ThreadManager, and ensure that all Task objects have been
501   // destroyed.
502   threadManager->stop();
503   EXPECT_EQ(0, currentTaskObjects);
504 }
505 
TEST_F(ThreadManagerTest,NeverStartedTest)506 TEST_F(ThreadManagerTest, NeverStartedTest) {
507   // Test destroying a ThreadManager that was never started.
508   // This ensures that calling stop() on an unstarted ThreadManager works
509   // properly.
510   {
511     auto threadManager = ThreadManager::newSimpleThreadManager(10); //
512   }
513 
514   // Destroy a ThreadManager that has a ThreadFactory but was never started.
515   {
516     auto threadManager = ThreadManager::newSimpleThreadManager(10);
517     auto threadFactory = std::make_shared<PosixThreadFactory>();
518     threadManager->threadFactory(threadFactory);
519   }
520 }
521 
TEST_F(ThreadManagerTest,OnlyStartedTest)522 TEST_F(ThreadManagerTest, OnlyStartedTest) {
523   // Destroy a ThreadManager that has a ThreadFactory and was started.
524   for (int i = 0; i < 1000; ++i) {
525     auto threadManager = ThreadManager::newSimpleThreadManager(10);
526     auto threadFactory = std::make_shared<PosixThreadFactory>();
527     threadManager->threadFactory(threadFactory);
528     threadManager->start();
529   }
530 }
531 
TEST_F(ThreadManagerTest,RequestContext)532 TEST_F(ThreadManagerTest, RequestContext) {
533   class TestData : public folly::RequestData {
534    public:
535     explicit TestData(int data) : data(data) {}
536 
537     bool hasCallback() override { return false; }
538 
539     int data;
540   };
541 
542   // Create new request context for this scope.
543   folly::RequestContextScopeGuard rctx;
544   EXPECT_EQ(nullptr, folly::RequestContext::get()->getContextData("test"));
545   folly::RequestContext::get()->setContextData(
546       "test", std::make_unique<TestData>(42));
547   auto data = folly::RequestContext::get()->getContextData("test");
548   EXPECT_EQ(42, dynamic_cast<TestData*>(data)->data);
549 
550   struct VerifyRequestContext {
551     ~VerifyRequestContext() {
552       auto data2 = folly::RequestContext::get()->getContextData("test");
553       EXPECT_TRUE(data2 != nullptr);
554       if (data2 != nullptr) {
555         EXPECT_EQ(42, dynamic_cast<TestData*>(data2)->data);
556       }
557     }
558   };
559 
560   {
561     auto threadManager = ThreadManager::newSimpleThreadManager(10);
562     auto threadFactory = std::make_shared<PosixThreadFactory>();
563     threadManager->threadFactory(threadFactory);
564     threadManager->start();
565     threadManager->add([] { VerifyRequestContext(); });
566     threadManager->add([x = VerifyRequestContext()] {});
567     threadManager->join();
568   }
569 }
570 
TEST_F(ThreadManagerTest,Exceptions)571 TEST_F(ThreadManagerTest, Exceptions) {
572   class ThrowTask : public Runnable {
573    public:
574     void run() override {
575       throw std::runtime_error("This should not crash the program");
576     }
577   };
578   {
579     auto threadManager = ThreadManager::newSimpleThreadManager(10);
580     auto threadFactory = std::make_shared<PosixThreadFactory>();
581     threadManager->threadFactory(threadFactory);
582     threadManager->start();
583     threadManager->add(std::make_shared<ThrowTask>());
584     threadManager->join();
585   }
586 }
587 
588 class TestObserver : public ThreadManager::Observer {
589  public:
TestObserver(int64_t timeout,const std::string & expectedName)590   TestObserver(int64_t timeout, const std::string& expectedName)
591       : timesCalled(0), timeout(timeout), expectedName(expectedName) {}
592 
preRun(folly::RequestContext *)593   void preRun(folly::RequestContext*) override {}
postRun(folly::RequestContext *,const ThreadManager::RunStats & stats)594   void postRun(
595       folly::RequestContext*, const ThreadManager::RunStats& stats) override {
596     EXPECT_EQ(expectedName, stats.threadPoolName);
597 
598     // Note: Technically could fail if system clock changes.
599     EXPECT_GT((stats.workBegin - stats.queueBegin).count(), 0);
600     EXPECT_GT((stats.workEnd - stats.workBegin).count(), 0);
601     EXPECT_GT((stats.workEnd - stats.workBegin).count(), timeout - 1);
602     ++timesCalled;
603   }
604 
605   uint64_t timesCalled;
606   int64_t timeout;
607   std::string expectedName;
608 };
609 
610 class FailThread : public PthreadThread {
611  public:
FailThread(int policy,int priority,int stackSize,bool detached,std::shared_ptr<Runnable> runnable)612   FailThread(
613       int policy,
614       int priority,
615       int stackSize,
616       bool detached,
617       std::shared_ptr<Runnable> runnable)
618       : PthreadThread(policy, priority, stackSize, detached, runnable) {}
619 
start()620   void start() override { throw 2; }
621 };
622 
623 class FailThreadFactory : public PosixThreadFactory {
624  public:
625   class FakeImpl : public Impl {
626    public:
FakeImpl(POLICY policy,PosixThreadFactory::THREAD_PRIORITY priority,int stackSize,DetachState detached)627     FakeImpl(
628         POLICY policy,
629         PosixThreadFactory::THREAD_PRIORITY priority,
630         int stackSize,
631         DetachState detached)
632         : Impl(policy, priority, stackSize, detached) {}
633 
newThread(const std::shared_ptr<Runnable> & runnable,DetachState detachState) const634     std::shared_ptr<Thread> newThread(
635         const std::shared_ptr<Runnable>& runnable,
636         DetachState detachState) const override {
637       auto result = std::make_shared<FailThread>(
638           toPthreadPolicy(policy_),
639           toPthreadPriority(policy_, priority_),
640           stackSize_,
641           detachState == DETACHED,
642           runnable);
643       result->weakRef(result);
644       runnable->thread(result);
645       return result;
646     }
647   };
648 
FailThreadFactory(POLICY=kDefaultPolicy,THREAD_PRIORITY=kDefaultPriority,int=kDefaultStackSizeMB,bool detached=true)649   explicit FailThreadFactory(
650       POLICY /*policy*/ = kDefaultPolicy,
651       THREAD_PRIORITY /*priority*/ = kDefaultPriority,
652       int /*stackSize*/ = kDefaultStackSizeMB,
653       bool detached = true) {
654     impl_ = std::make_shared<FailThreadFactory::FakeImpl>(
655         kDefaultPolicy,
656         kDefaultPriority,
657         kDefaultStackSizeMB,
658         detached ? DETACHED : ATTACHED);
659   }
660 };
661 
662 class DummyFailureClass {
663  public:
DummyFailureClass()664   DummyFailureClass() {
665     threadManager_ = ThreadManager::newSimpleThreadManager(20);
666     threadManager_->setNamePrefix("foo");
667     auto threadFactory = std::make_shared<FailThreadFactory>();
668     threadManager_->threadFactory(threadFactory);
669     threadManager_->start();
670   }
671 
672  private:
673   std::shared_ptr<ThreadManager> threadManager_;
674 };
675 
TEST_F(ThreadManagerTest,ThreadStartFailureTest)676 TEST_F(ThreadManagerTest, ThreadStartFailureTest) {
677   for (int i = 0; i < 10; i++) {
678     EXPECT_THROW(DummyFailureClass(), int);
679   }
680 }
681 
TEST_F(ThreadManagerTest,ObserverTest)682 TEST_F(ThreadManagerTest, ObserverTest) {
683   auto observer = std::make_shared<TestObserver>(1000, "foo");
684   ThreadManager::setGlobalObserver(observer);
685 
686   std::mutex mutex;
687   std::condition_variable cond;
688   size_t tasks = 1;
689 
690   auto threadManager = ThreadManager::newSimpleThreadManager(10);
691   threadManager->setNamePrefix("foo");
692   threadManager->threadFactory(std::make_shared<PosixThreadFactory>());
693   threadManager->start();
694 
695   auto task = std::make_shared<LoadTask>(&mutex, &cond, &tasks, 1000);
696   threadManager->add(task);
697   threadManager->join();
698   EXPECT_EQ(1, observer->timesCalled);
699 }
700 
TEST_F(ThreadManagerTest,ObserverAssignedAfterStart)701 TEST_F(ThreadManagerTest, ObserverAssignedAfterStart) {
702   class MyTask : public Runnable {
703    public:
704     void run() override {}
705   };
706   class MyObserver : public ThreadManager::Observer {
707    public:
708     MyObserver(std::string name, std::shared_ptr<std::string> tgt)
709         : name_(std::move(name)), tgt_(std::move(tgt)) {}
710     void preRun(folly::RequestContext*) override {}
711     void postRun(
712         folly::RequestContext*, const ThreadManager::RunStats&) override {
713       *tgt_ = name_;
714     }
715 
716    private:
717     std::string name_;
718     std::shared_ptr<std::string> tgt_;
719   };
720 
721   // start a tm
722   auto tm = ThreadManager::newSimpleThreadManager(1);
723   tm->setNamePrefix("foo");
724   tm->threadFactory(std::make_shared<PosixThreadFactory>());
725   tm->start();
726   // set the observer w/ observable side-effect
727   auto tgt = std::make_shared<std::string>();
728   ThreadManager::setGlobalObserver(std::make_shared<MyObserver>("bar", tgt));
729   // add a task - observable side-effect should trigger
730   tm->add(std::make_shared<MyTask>());
731   tm->join();
732   // confirm the side-effect
733   EXPECT_EQ("bar", *tgt);
734 }
735 
TEST_F(ThreadManagerTest,PosixThreadFactoryPriority)736 TEST_F(ThreadManagerTest, PosixThreadFactoryPriority) {
737   auto getNiceValue = [](PosixThreadFactory::THREAD_PRIORITY prio) -> int {
738     PosixThreadFactory factory(PosixThreadFactory::OTHER, prio);
739     factory.setDetached(false);
740     int result = 0;
741     auto t = factory.newThread(
742         FunctionRunner::create([&] { result = getpriority(PRIO_PROCESS, 0); }));
743     t->start();
744     t->join();
745     return result;
746   };
747 
748   // NOTE: Test may not have permission to raise priority,
749   // so use prio <= NORMAL.
750   EXPECT_EQ(0, getNiceValue(PosixThreadFactory::NORMAL_PRI));
751   EXPECT_LT(0, getNiceValue(PosixThreadFactory::LOW_PRI));
752   auto th = std::thread([&] {
753     for (int i = 0; i < 20; ++i) {
754       if (setpriority(PRIO_PROCESS, 0, i) != 0) {
755         PLOG(WARNING) << "failed setpriority(" << i << ")";
756         continue;
757       }
758       EXPECT_EQ(i, getNiceValue(PosixThreadFactory::INHERITED_PRI));
759     }
760   });
761   th.join();
762 }
763 
TEST_F(ThreadManagerTest,PriorityThreadManagerWorkerCount)764 TEST_F(ThreadManagerTest, PriorityThreadManagerWorkerCount) {
765   auto threadManager = PriorityThreadManager::newPriorityThreadManager({{
766       1 /*HIGH_IMPORTANT*/,
767       2 /*HIGH*/,
768       3 /*IMPORTANT*/,
769       4 /*NORMAL*/,
770       5 /*BEST_EFFORT*/
771   }});
772   threadManager->start();
773 
774   EXPECT_EQ(1, threadManager->workerCount(PRIORITY::HIGH_IMPORTANT));
775   EXPECT_EQ(2, threadManager->workerCount(PRIORITY::HIGH));
776   EXPECT_EQ(3, threadManager->workerCount(PRIORITY::IMPORTANT));
777   EXPECT_EQ(4, threadManager->workerCount(PRIORITY::NORMAL));
778   EXPECT_EQ(5, threadManager->workerCount(PRIORITY::BEST_EFFORT));
779 
780   threadManager->addWorker(PRIORITY::HIGH_IMPORTANT, 1);
781   threadManager->addWorker(PRIORITY::HIGH, 1);
782   threadManager->addWorker(PRIORITY::IMPORTANT, 1);
783   threadManager->addWorker(PRIORITY::NORMAL, 1);
784   threadManager->addWorker(PRIORITY::BEST_EFFORT, 1);
785 
786   EXPECT_EQ(2, threadManager->workerCount(PRIORITY::HIGH_IMPORTANT));
787   EXPECT_EQ(3, threadManager->workerCount(PRIORITY::HIGH));
788   EXPECT_EQ(4, threadManager->workerCount(PRIORITY::IMPORTANT));
789   EXPECT_EQ(5, threadManager->workerCount(PRIORITY::NORMAL));
790   EXPECT_EQ(6, threadManager->workerCount(PRIORITY::BEST_EFFORT));
791 
792   threadManager->removeWorker(PRIORITY::HIGH_IMPORTANT, 1);
793   threadManager->removeWorker(PRIORITY::HIGH, 1);
794   threadManager->removeWorker(PRIORITY::IMPORTANT, 1);
795   threadManager->removeWorker(PRIORITY::NORMAL, 1);
796   threadManager->removeWorker(PRIORITY::BEST_EFFORT, 1);
797 
798   EXPECT_EQ(1, threadManager->workerCount(PRIORITY::HIGH_IMPORTANT));
799   EXPECT_EQ(2, threadManager->workerCount(PRIORITY::HIGH));
800   EXPECT_EQ(3, threadManager->workerCount(PRIORITY::IMPORTANT));
801   EXPECT_EQ(4, threadManager->workerCount(PRIORITY::NORMAL));
802   EXPECT_EQ(5, threadManager->workerCount(PRIORITY::BEST_EFFORT));
803 }
804 
TEST_F(ThreadManagerTest,PriorityQueueThreadManagerExecutor)805 TEST_F(ThreadManagerTest, PriorityQueueThreadManagerExecutor) {
806   auto threadManager = ThreadManager::newPriorityQueueThreadManager(1);
807   threadManager->start();
808   folly::Baton<> reqSyncBaton;
809   folly::Baton<> reqDoneBaton;
810   // block the TM
811   threadManager->add([&] { reqSyncBaton.wait(); });
812 
813   std::string foo = "";
814   threadManager->addWithPriority(
815       [&] {
816         foo += "a";
817         reqDoneBaton.post();
818       },
819       0);
820   // Should be added by default at highest priority
821   threadManager->add([&] { foo += "b"; });
822   threadManager->addWithPriority([&] { foo += "c"; }, 1);
823 
824   // unblock the TM
825   reqSyncBaton.post();
826 
827   // wait until the request that's supposed to finish last is done
828   reqDoneBaton.wait();
829 
830   EXPECT_EQ("bca", foo);
831 }
832 
833 std::array<std::function<std::shared_ptr<ThreadManager>()>, 3> factories = {
834     std::bind(
835         (std::shared_ptr<ThreadManager>(*)(
836             size_t))ThreadManager::newSimpleThreadManager,
837         1),
838     std::bind(ThreadManager::newPriorityQueueThreadManager, 1),
__anon11c8e6b90a02() 839     []() -> std::shared_ptr<apache::thrift::concurrency::ThreadManager> {
840       return PriorityThreadManager::newPriorityThreadManager({{
841           1 /*HIGH_IMPORTANT*/,
842           2 /*HIGH*/,
843           3 /*IMPORTANT*/,
844           4 /*NORMAL*/,
845           5 /*BEST_EFFORT*/
846       }});
847     }};
848 class JoinTest : public testing::TestWithParam<
849                      std::function<std::shared_ptr<ThreadManager>()>> {};
850 
TEST_P(JoinTest,Join)851 TEST_P(JoinTest, Join) {
852   auto threadManager = GetParam()();
853   auto threadFactory = std::make_shared<PosixThreadFactory>();
854   threadManager->threadFactory(threadFactory);
855   threadManager->start();
856   folly::Baton<> wait1, wait2, joinStarted, joined;
857   // block the TM
858   threadManager->add(
859       FunctionRunner::create([&] { wait1.wait(); }), 0, 0, false);
860   threadManager->add(FunctionRunner::create([&] { wait2.wait(); }), 0, 0, true);
861   std::thread t([&] {
862     joinStarted.post();
863     threadManager->join();
864     joined.post();
865   });
866 
867   joinStarted.wait();
868   EXPECT_FALSE(joined.try_wait_for(std::chrono::milliseconds(100)));
869   joined.reset();
870   wait1.post();
871   EXPECT_FALSE(joined.try_wait_for(std::chrono::milliseconds(100)));
872   joined.reset();
873   wait2.post();
874   EXPECT_TRUE(joined.try_wait_for(std::chrono::milliseconds(100)));
875   t.join();
876 }
877 
878 INSTANTIATE_TEST_CASE_P(
879     ThreadManagerTest, JoinTest, ::testing::ValuesIn(factories));
880 
881 class TMThreadIDCollectorTest : public ::testing::Test {
882  protected:
SetUp()883   void SetUp() override { kWorkerProviderGlobal = nullptr; }
TearDown()884   void TearDown() override { kWorkerProviderGlobal = nullptr; }
885   static constexpr size_t kNumThreads = 4;
886 };
887 
TEST_F(TMThreadIDCollectorTest,BasicTest)888 TEST_F(TMThreadIDCollectorTest, BasicTest) {
889   // This is a sanity check test. We start a ThreadManager, queue a task,
890   // and then invoke the collectThreadIds() API to capture the TID of the
891   // active thread.
892   auto tm = ThreadManager::newSimpleThreadManager(1);
893   tm->setNamePrefix("baz");
894   tm->threadFactory(std::make_shared<PosixThreadFactory>());
895   tm->start();
896 
897   std::atomic<pid_t> threadId = {};
898   folly::Baton<> bat;
899   tm->add([&]() {
900     threadId.exchange(folly::getOSThreadID());
901     bat.post();
902   });
903   {
904     bat.wait();
905     auto idsWithKA = kWorkerProviderGlobal->collectThreadIds();
906     auto& ids = idsWithKA.threadIds;
907     EXPECT_EQ(ids.size(), 1);
908     EXPECT_EQ(ids[0], threadId.load());
909   }
910 }
911 
TEST_F(TMThreadIDCollectorTest,CollectIDMultipleInvocationTest)912 TEST_F(TMThreadIDCollectorTest, CollectIDMultipleInvocationTest) {
913   // This test verifies that multiple invocations of collectthreadId()
914   // do not deadlock.
915   auto tm = ThreadManager::newSimpleThreadManager(kNumThreads);
916   tm->setNamePrefix("bar");
917   tm->threadFactory(std::make_shared<PosixThreadFactory>());
918   tm->start();
919 
920   folly::Synchronized<std::vector<pid_t>> threadIds;
921   std::array<folly::Baton<>, kNumThreads> bats;
922   folly::Baton<> tasksAddedBat;
923   for (size_t i = 0; i < kNumThreads; ++i) {
924     tm->add([i, &threadIds, &bats, &tasksAddedBat]() {
925       threadIds.wlock()->push_back(folly::getOSThreadID());
926       if (i == kNumThreads - 1) {
927         tasksAddedBat.post();
928       }
929       bats[i].wait();
930     });
931   }
932   {
933     tasksAddedBat.wait();
934     auto idsWithKA1 = kWorkerProviderGlobal->collectThreadIds();
935     auto idsWithKA2 = kWorkerProviderGlobal->collectThreadIds();
936     auto& ids1 = idsWithKA1.threadIds;
937     auto& ids2 = idsWithKA2.threadIds;
938     EXPECT_EQ(ids1.size(), 4);
939     EXPECT_EQ(ids1.size(), ids2.size());
940     EXPECT_EQ(ids1, ids2);
941   }
942   for (auto& bat : bats) {
943     bat.post();
944   }
945   tm->join();
946 }
947 
TEST_F(TMThreadIDCollectorTest,CollectIDBlocksThreadExitTest)948 TEST_F(TMThreadIDCollectorTest, CollectIDBlocksThreadExitTest) {
949   // This test verifies that collectThreadId() call prevents the
950   // active ThreadManager threads from exiting.
951   auto tm = ThreadManager::newSimpleThreadManager(kNumThreads);
952   tm->setNamePrefix("bar");
953   tm->threadFactory(std::make_shared<PosixThreadFactory>());
954   tm->start();
955 
956   std::array<folly::Baton<>, kNumThreads> bats;
957   folly::Baton<> tasksAddedBat;
958   for (size_t i = 0; i < kNumThreads; ++i) {
959     tm->add([i, &bats, &tasksAddedBat]() {
960       if (i == kNumThreads - 1) {
961         tasksAddedBat.post();
962       }
963       bats[i].wait();
964     });
965   }
966   folly::Baton<> waitForCollectBat;
967   folly::Baton<> threadCountBat;
968   auto bgCollector = std::thread([&]() {
969     tasksAddedBat.wait();
970     auto idsWithKA = kWorkerProviderGlobal->collectThreadIds();
971     waitForCollectBat.post();
972     auto posted = threadCountBat.try_wait_for(std::chrono::milliseconds(100));
973     // The thread count reduction should not have returned (thread exit
974     // is currently blocked).
975     EXPECT_FALSE(posted);
976     auto& ids = idsWithKA.threadIds;
977     EXPECT_EQ(ids.size(), kNumThreads);
978   });
979   waitForCollectBat.wait();
980   for (auto& bat : bats) {
981     bat.post();
982   }
983   tm->removeWorker(2);
984   threadCountBat.post();
985   bgCollector.join();
986   EXPECT_EQ(tm->workerCount(), 2);
987   tm->join();
988 }
989 
990 //
991 // =============================================================================
992 // This section validates the cpu time accounting logic in ThreadManager. It
993 // requires thread-specific clocks, so only Linux is supported at this time.
994 // =============================================================================
995 //
996 #ifdef __linux__
997 
998 // Like ASSERT_NEAR, but for chrono duration types
999 #define ASSERT_NEAR_NS(a, b, c)  \
1000   do {                           \
1001     ASSERT_NEAR(                 \
1002         nanoseconds(a).count(),  \
1003         nanoseconds(b).count(),  \
1004         nanoseconds(c).count()); \
1005   } while (0)
1006 
thread_clock_now()1007 static std::chrono::nanoseconds thread_clock_now() {
1008   timespec tp;
1009   clockid_t clockid;
1010   CHECK(!pthread_getcpuclockid(pthread_self(), &clockid));
1011   CHECK(!clock_gettime(clockid, &tp));
1012   return std::chrono::nanoseconds(tp.tv_nsec) + std::chrono::seconds(tp.tv_sec);
1013 }
1014 
1015 // Burn thread cpu cycles
burn(std::chrono::milliseconds ms)1016 static void burn(std::chrono::milliseconds ms) {
1017   auto expires = thread_clock_now() + ms;
1018   while (thread_clock_now() < expires) {
1019   }
1020 }
1021 
1022 // Loop without using much cpu time
idle(std::chrono::milliseconds ms)1023 static void idle(std::chrono::milliseconds ms) {
1024   using clock = std::chrono::high_resolution_clock;
1025   auto expires = clock::now() + ms;
1026   while (clock::now() < expires) {
1027     /* sleep override */
1028     std::this_thread::sleep_for(std::chrono::milliseconds(100));
1029   }
1030 }
1031 
TEST_F(ThreadManagerTest,UsedCpuTime_Simple)1032 TEST_F(ThreadManagerTest, UsedCpuTime_Simple) {
1033   using namespace std::chrono;
1034   auto tm = ThreadManager::newSimpleThreadManager(3);
1035   tm->start();
1036 
1037   const auto t0 = tm->getUsedCpuTime();
1038 
1039   // Schedule 3 threads: two doing busy work, 1 idling
1040   {
1041     folly::Latch latch(3);
1042     tm->add([&] { // + 500ms cpu time
1043       burn(500ms);
1044       latch.count_down();
1045     });
1046     tm->add([&] { // + 300ms cpu time
1047       burn(300ms);
1048       latch.count_down();
1049     });
1050     tm->add([&] { // + ~0ms cpu time
1051       idle(500ms);
1052       latch.count_down();
1053     });
1054     latch.wait();
1055     ASSERT_EQ(tm->workerCount(), 3);
1056     ASSERT_NEAR_NS(tm->getUsedCpuTime() - t0, 800ms, 20ms); // = 800ms ± 20ms
1057   }
1058 
1059   // Remove one thread, cpu time should not change
1060   {
1061     tm->removeWorker(1);
1062     ASSERT_EQ(tm->workerCount(), 2);
1063     ASSERT_NEAR_NS(tm->getUsedCpuTime() - t0, 800ms, 20ms); // = 800ms ± 20ms
1064   }
1065 
1066   // Do a bit more work, cpu time should add to previous value
1067   {
1068     folly::Latch latch(1);
1069     tm->add([&] { // + 200ms cpu time
1070       burn(200ms);
1071       latch.count_down();
1072     });
1073     latch.wait();
1074     ASSERT_NEAR_NS(tm->getUsedCpuTime() - t0, 1s, 20ms); // = 1s ± 20ms
1075   }
1076 
1077   // Remove all threads, cpu time should be preserved
1078   {
1079     tm->removeWorker(2);
1080     ASSERT_EQ(tm->workerCount(), 0);
1081     ASSERT_NEAR_NS(tm->getUsedCpuTime() - t0, 1s, 20ms); // = 1s ± 20ms
1082   }
1083 }
1084 
TEST_F(ThreadManagerTest,UsedCpuTime_Priority)1085 TEST_F(ThreadManagerTest, UsedCpuTime_Priority) {
1086   using namespace std::chrono;
1087   auto tm = PriorityThreadManager::newPriorityThreadManager({{
1088       1 /*HIGH_IMPORTANT*/,
1089       1 /*HIGH*/,
1090       1 /*IMPORTANT*/,
1091       1 /*NORMAL*/,
1092       1 /*BEST_EFFORT*/
1093   }});
1094   tm->start();
1095 
1096   const auto t0 = tm->getUsedCpuTime();
1097 
1098   auto runner = [](std::function<void()>&& fn) {
1099     return FunctionRunner::create(std::move(fn));
1100   };
1101 
1102   // Schedule 3 threads: 2 doing busy work, 1 idling
1103   folly::Latch latch(3);
1104   tm->add(HIGH, runner([&] {
1105             burn(500ms);
1106             latch.count_down();
1107           })); // + 500ms cpu time
1108   tm->add(NORMAL, runner([&] {
1109             burn(300ms);
1110             latch.count_down();
1111           })); // + 300ms cpu time
1112   tm->add(BEST_EFFORT, runner([&] {
1113             idle(500ms);
1114             latch.count_down();
1115           })); // + 0ms cpu time
1116   latch.wait();
1117   ASSERT_NEAR_NS(tm->getUsedCpuTime() - t0, 800ms, 20ms); // = 800ms ± 20ms
1118 
1119   // Removing a thread should preserve cpu time
1120   tm->removeWorker(NORMAL, 1);
1121   ASSERT_EQ(tm->workerCount(NORMAL), 0);
1122   ASSERT_NEAR_NS(tm->getUsedCpuTime() - t0, 800ms, 20ms); // = 800ms ± 20ms
1123 }
1124 
1125 #else //  __linux__
1126 
1127 //
1128 // On other platforms, just make sure getUsedCpuTime() does not crash and
1129 // returns 0.
1130 //
1131 
TEST_F(ThreadManagerTest,UsedCpuTime)1132 TEST_F(ThreadManagerTest, UsedCpuTime) {
1133   using namespace std::chrono;
1134 
1135   auto tm = ThreadManager::newSimpleThreadManager(3);
1136   tm->start();
1137 
1138   ASSERT_EQ(tm->getUsedCpuTime().count(), 0);
1139 
1140   auto burn = [](milliseconds ms) {
1141     auto expires = steady_clock::now() + ms;
1142     while (steady_clock::now() < expires) {
1143     }
1144   };
1145 
1146   folly::Latch latch(1);
1147   tm->add([&] {
1148     burn(500ms);
1149     latch.count_down();
1150   });
1151   latch.wait();
1152 
1153   ASSERT_EQ(tm->getUsedCpuTime().count(), 0);
1154 }
1155 
1156 #endif // __linux__
1157