1 #include "thread/Pool.h"
2 #include "catch.hpp"
3 #include "system/Timer.h"
4 #include "utils/Utils.h"
5 
6 using namespace Sph;
7 
8 TEST_CASE("Pool submit task", "[thread]") {
9     ThreadPool pool;
10     REQUIRE_THREAD_SAFE(pool.getThreadCnt() == std::thread::hardware_concurrency());
11     std::atomic<uint64_t> sum;
12     sum = 0;
13     for (Size i = 0; i <= 100; ++i) {
__anoncfe36ad00102null14         pool.submit([&sum, i] { sum += i; });
15     }
16     pool.waitForAll();
17     REQUIRE_THREAD_SAFE(sum == 5050);
18     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
19 }
20 
21 TEST_CASE("Pool thread count", "[thread]") {
22     ThreadPool pool1(0);
23     REQUIRE_THREAD_SAFE(pool1.getThreadCnt() == std::thread::hardware_concurrency());
24     ThreadPool pool2(5);
25     REQUIRE_THREAD_SAFE(pool2.getThreadCnt() == 5);
26 }
27 
28 TEST_CASE("Pool submit task from different thread", "[thread]") {
29     ThreadPool pool;
30     std::atomic<uint64_t> sum1, sum2;
31     sum1 = sum2 = 0;
__anoncfe36ad00202null32     std::thread thread([&sum2, &pool] {
33         for (Size i = 0; i <= 100; i += 2) { // even numbers
34             pool.submit([&sum2, i] { sum2 += i; });
35         }
36     });
37     for (Size i = 1; i <= 100; i += 2) {
__anoncfe36ad00402null38         pool.submit([&sum1, i] { sum1 += i; });
39     }
40     thread.join();
41     pool.waitForAll();
42     REQUIRE_THREAD_SAFE(sum1 + sum2 == 5050);
43     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
44 }
45 
46 TEST_CASE("Pool submit single", "[thread]") {
47     ThreadPool pool;
48     bool executed = false;
__anoncfe36ad00502null49     pool.submit([&executed] { executed = true; });
50     pool.waitForAll();
51     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
52     REQUIRE_THREAD_SAFE(executed);
53 }
54 
55 TEST_CASE("Pool one thread", "[thread]") {
56     ThreadPool pool(1);
57     int executed = 0;
__anoncfe36ad00602null58     auto task = [&executed] {
59         std::this_thread::sleep_for(std::chrono::milliseconds(15));
60         ++executed;
61     };
62     for (Size i = 0; i < 4; ++i) {
63         pool.submit(task);
64     }
65     pool.waitForAll();
66     REQUIRE(executed == 4);
67 }
68 
69 TEST_CASE("Pool submit nested", "[thread]") {
70     ThreadPool pool;
71     std::atomic_bool innerRun{ 0 };
__anoncfe36ad00702null72     auto rootTask = pool.submit([&pool, &innerRun] {
73         REQUIRE_THREAD_SAFE(Task::getCurrent());
74         REQUIRE_THREAD_SAFE(Task::getCurrent()->isRoot());
75 
76         WeakPtr<Task> parent = Task::getCurrent();
77         pool.submit([parent, &innerRun] {
78             std::this_thread::sleep_for(std::chrono::milliseconds(50));
79 
80             auto task = Task::getCurrent();
81             REQUIRE_THREAD_SAFE(task);
82             REQUIRE_THREAD_SAFE(!task->isRoot());
83             REQUIRE_THREAD_SAFE(task->getParent() == parent.lock());
84             innerRun = true;
85         });
86     });
87     REQUIRE_THREAD_SAFE(!rootTask->completed());
88     REQUIRE_THREAD_SAFE(!innerRun);
89     rootTask->wait();
90     REQUIRE_THREAD_SAFE(innerRun);
91     REQUIRE_THREAD_SAFE(rootTask->completed());
92     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
93 
94     // second doesn't do anything
95     rootTask->wait();
96 
97     // pool.waitForAll();
98 }
99 
100 TEST_CASE("Pool submit parallel", "[thread]") {
101     // checks that we can wait for two tasks to finish independently
102     ThreadPool pool;
__anoncfe36ad00902null103     auto task1 = pool.submit([] {
104         REQUIRE_THREAD_SAFE(Task::getCurrent());
105         REQUIRE_THREAD_SAFE(Task::getCurrent()->isRoot());
106         std::this_thread::sleep_for(std::chrono::milliseconds(20));
107     });
108 
__anoncfe36ad00a02null109     auto task2 = pool.submit([] {
110         REQUIRE_THREAD_SAFE(Task::getCurrent());
111         REQUIRE_THREAD_SAFE(Task::getCurrent()->isRoot());
112         std::this_thread::sleep_for(std::chrono::milliseconds(60));
113     });
114 
115     REQUIRE_THREAD_SAFE(!task1->completed());
116     REQUIRE_THREAD_SAFE(!task2->completed());
117     task1->wait();
118     REQUIRE_THREAD_SAFE(task1->completed());
119     REQUIRE_THREAD_SAFE(!task2->completed());
120     task2->wait();
121     REQUIRE_THREAD_SAFE(task2->completed());
122     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
123 
124     // now the same thing, but wait for the second (longer) one
__anoncfe36ad00b02null125     task1 = pool.submit([] { std::this_thread::sleep_for(std::chrono::milliseconds(20)); });
__anoncfe36ad00c02null126     task2 = pool.submit([] { std::this_thread::sleep_for(std::chrono::milliseconds(60)); });
127     REQUIRE_THREAD_SAFE(!task1->completed());
128     REQUIRE_THREAD_SAFE(!task2->completed());
129     task2->wait();
130     REQUIRE_THREAD_SAFE(task1->completed());
131     REQUIRE_THREAD_SAFE(task2->completed());
132 
133     std::this_thread::sleep_for(std::chrono::milliseconds(5));
134     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
135 
136     // pool.waitForAll();
137 }
138 
139 TEST_CASE("Pool wait for child", "[thread]") {
140     ThreadPool pool;
141     SharedPtr<ITask> taskRoot, taskChild;
142     volatile bool childFinished = false;
__anoncfe36ad00d02null143     taskRoot = pool.submit([&pool, &taskChild, &childFinished] {
144         taskChild = pool.submit([&childFinished] {
145             std::this_thread::sleep_for(std::chrono::milliseconds(20));
146             childFinished = true;
147         });
148         taskChild->wait();
149     });
150     taskRoot->wait();
151 
152     REQUIRE_THREAD_SAFE(taskRoot->completed());
153     REQUIRE_THREAD_SAFE(taskChild->completed());
154     REQUIRE_THREAD_SAFE(childFinished);
155 }
156 
157 class TestException : public std::exception {
what() const158     virtual const char* what() const noexcept override {
159         return "exception";
160     }
161 };
162 
163 TEST_CASE("Pool task throw", "[thread]") {
164     ThreadPool pool;
165 
__anoncfe36ad00f02null166     auto task = pool.submit([] {
167         std::this_thread::sleep_for(std::chrono::milliseconds(10));
168         throw TestException();
169     });
170     REQUIRE_THROWS_AS(task->wait(), TestException);
171 }
172 
173 TEST_CASE("Pool task throw nested", "[thread]") {
174     ThreadPool pool;
175 
__anoncfe36ad01002null176     auto task = pool.submit([&pool] {
177         std::this_thread::sleep_for(std::chrono::milliseconds(10));
178         pool.submit([] {
179             std::this_thread::sleep_for(std::chrono::milliseconds(10));
180             throw TestException();
181         });
182     });
183     REQUIRE_THROWS_AS(task->wait(), TestException);
184 }
185 
186 TEST_CASE("Pool ParallelFor", "[thread]") {
187     ThreadPool pool;
188     std::atomic<uint64_t> sum;
189     sum = 0;
__anoncfe36ad01202(Size i) 190     parallelFor(pool, 1, 100000, [&sum](Size i) { sum += i; });
191     REQUIRE_THREAD_SAFE(sum == 4999950000);
192     REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0);
193 }
194 
195 TEST_CASE("Pool GetThreadIdx", "[thread]") {
196     ThreadPool pool(2);
197     REQUIRE_THREAD_SAFE(pool.getThreadCnt() == 2);
198     REQUIRE_FALSE(pool.getThreadIdx()); // main thread, not within the pool
199 
__anoncfe36ad01302null200     std::thread thread([&pool] {
201         std::this_thread::sleep_for(std::chrono::milliseconds(50));
202         REQUIRE_THREAD_SAFE(!pool.getThreadIdx()); // also not within the pool
203     });
204     thread.join();
205 
__anoncfe36ad01402null206     pool.submit([&pool] {
207         const Optional<Size> idx = pool.getThreadIdx();
208         REQUIRE_THREAD_SAFE(idx);
209         REQUIRE_THREAD_SAFE((idx.value() == 0 || idx.value() == 1));
210     });
211 }
212 
213 TEST_CASE("Pool WaitForAll", "[thread]") {
214     ThreadPool pool;
215     pool.waitForAll(); // waitForAll with no running tasks
216 
217     Timer timer;
218     const Size cnt = pool.getThreadCnt();
219     std::atomic_int taskIdx;
220     taskIdx = 0;
221     // run tasks with different duration
222     for (Size i = 0; i < cnt; ++i) {
__anoncfe36ad01502null223         pool.submit([&taskIdx] { std::this_thread::sleep_for(std::chrono::milliseconds(50 * ++taskIdx)); });
224     }
225     pool.waitForAll();
226     REQUIRE_THREAD_SAFE(timer.elapsed(TimerUnit::MILLISECOND) >= 50 * cnt);
227     REQUIRE_NOTHROW(pool.waitForAll()); // second does nothing
228 }
229 
230 #ifdef SPH_DEBUG
231 TEST_CASE("Pool ParallelFor assert", "[thread]") {
232     ThreadPool pool(2);
233     // throw from worker thread
__anoncfe36ad01602(Size) 234     auto lambda = [](Size) { SPH_ASSERT(false); };
235 
236     REQUIRE_SPH_ASSERT(parallelFor(pool, 1, 2, lambda));
237 }
238 #endif
239