1 #include "thread/Pool.h" 2 #include "catch.hpp" 3 #include "system/Timer.h" 4 #include "utils/Utils.h" 5 6 using namespace Sph; 7 8 TEST_CASE("Pool submit task", "[thread]") { 9 ThreadPool pool; 10 REQUIRE_THREAD_SAFE(pool.getThreadCnt() == std::thread::hardware_concurrency()); 11 std::atomic<uint64_t> sum; 12 sum = 0; 13 for (Size i = 0; i <= 100; ++i) { __anoncfe36ad00102null14 pool.submit([&sum, i] { sum += i; }); 15 } 16 pool.waitForAll(); 17 REQUIRE_THREAD_SAFE(sum == 5050); 18 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 19 } 20 21 TEST_CASE("Pool thread count", "[thread]") { 22 ThreadPool pool1(0); 23 REQUIRE_THREAD_SAFE(pool1.getThreadCnt() == std::thread::hardware_concurrency()); 24 ThreadPool pool2(5); 25 REQUIRE_THREAD_SAFE(pool2.getThreadCnt() == 5); 26 } 27 28 TEST_CASE("Pool submit task from different thread", "[thread]") { 29 ThreadPool pool; 30 std::atomic<uint64_t> sum1, sum2; 31 sum1 = sum2 = 0; __anoncfe36ad00202null32 std::thread thread([&sum2, &pool] { 33 for (Size i = 0; i <= 100; i += 2) { // even numbers 34 pool.submit([&sum2, i] { sum2 += i; }); 35 } 36 }); 37 for (Size i = 1; i <= 100; i += 2) { __anoncfe36ad00402null38 pool.submit([&sum1, i] { sum1 += i; }); 39 } 40 thread.join(); 41 pool.waitForAll(); 42 REQUIRE_THREAD_SAFE(sum1 + sum2 == 5050); 43 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 44 } 45 46 TEST_CASE("Pool submit single", "[thread]") { 47 ThreadPool pool; 48 bool executed = false; __anoncfe36ad00502null49 pool.submit([&executed] { executed = true; }); 50 pool.waitForAll(); 51 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 52 REQUIRE_THREAD_SAFE(executed); 53 } 54 55 TEST_CASE("Pool one thread", "[thread]") { 56 ThreadPool pool(1); 57 int executed = 0; __anoncfe36ad00602null58 auto task = [&executed] { 59 std::this_thread::sleep_for(std::chrono::milliseconds(15)); 60 ++executed; 61 }; 62 for (Size i = 0; i < 4; ++i) { 63 pool.submit(task); 64 } 65 pool.waitForAll(); 66 REQUIRE(executed == 4); 67 } 68 69 TEST_CASE("Pool submit nested", "[thread]") { 70 ThreadPool pool; 71 std::atomic_bool innerRun{ 0 }; __anoncfe36ad00702null72 auto rootTask = pool.submit([&pool, &innerRun] { 73 REQUIRE_THREAD_SAFE(Task::getCurrent()); 74 REQUIRE_THREAD_SAFE(Task::getCurrent()->isRoot()); 75 76 WeakPtr<Task> parent = Task::getCurrent(); 77 pool.submit([parent, &innerRun] { 78 std::this_thread::sleep_for(std::chrono::milliseconds(50)); 79 80 auto task = Task::getCurrent(); 81 REQUIRE_THREAD_SAFE(task); 82 REQUIRE_THREAD_SAFE(!task->isRoot()); 83 REQUIRE_THREAD_SAFE(task->getParent() == parent.lock()); 84 innerRun = true; 85 }); 86 }); 87 REQUIRE_THREAD_SAFE(!rootTask->completed()); 88 REQUIRE_THREAD_SAFE(!innerRun); 89 rootTask->wait(); 90 REQUIRE_THREAD_SAFE(innerRun); 91 REQUIRE_THREAD_SAFE(rootTask->completed()); 92 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 93 94 // second doesn't do anything 95 rootTask->wait(); 96 97 // pool.waitForAll(); 98 } 99 100 TEST_CASE("Pool submit parallel", "[thread]") { 101 // checks that we can wait for two tasks to finish independently 102 ThreadPool pool; __anoncfe36ad00902null103 auto task1 = pool.submit([] { 104 REQUIRE_THREAD_SAFE(Task::getCurrent()); 105 REQUIRE_THREAD_SAFE(Task::getCurrent()->isRoot()); 106 std::this_thread::sleep_for(std::chrono::milliseconds(20)); 107 }); 108 __anoncfe36ad00a02null109 auto task2 = pool.submit([] { 110 REQUIRE_THREAD_SAFE(Task::getCurrent()); 111 REQUIRE_THREAD_SAFE(Task::getCurrent()->isRoot()); 112 std::this_thread::sleep_for(std::chrono::milliseconds(60)); 113 }); 114 115 REQUIRE_THREAD_SAFE(!task1->completed()); 116 REQUIRE_THREAD_SAFE(!task2->completed()); 117 task1->wait(); 118 REQUIRE_THREAD_SAFE(task1->completed()); 119 REQUIRE_THREAD_SAFE(!task2->completed()); 120 task2->wait(); 121 REQUIRE_THREAD_SAFE(task2->completed()); 122 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 123 124 // now the same thing, but wait for the second (longer) one __anoncfe36ad00b02null125 task1 = pool.submit([] { std::this_thread::sleep_for(std::chrono::milliseconds(20)); }); __anoncfe36ad00c02null126 task2 = pool.submit([] { std::this_thread::sleep_for(std::chrono::milliseconds(60)); }); 127 REQUIRE_THREAD_SAFE(!task1->completed()); 128 REQUIRE_THREAD_SAFE(!task2->completed()); 129 task2->wait(); 130 REQUIRE_THREAD_SAFE(task1->completed()); 131 REQUIRE_THREAD_SAFE(task2->completed()); 132 133 std::this_thread::sleep_for(std::chrono::milliseconds(5)); 134 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 135 136 // pool.waitForAll(); 137 } 138 139 TEST_CASE("Pool wait for child", "[thread]") { 140 ThreadPool pool; 141 SharedPtr<ITask> taskRoot, taskChild; 142 volatile bool childFinished = false; __anoncfe36ad00d02null143 taskRoot = pool.submit([&pool, &taskChild, &childFinished] { 144 taskChild = pool.submit([&childFinished] { 145 std::this_thread::sleep_for(std::chrono::milliseconds(20)); 146 childFinished = true; 147 }); 148 taskChild->wait(); 149 }); 150 taskRoot->wait(); 151 152 REQUIRE_THREAD_SAFE(taskRoot->completed()); 153 REQUIRE_THREAD_SAFE(taskChild->completed()); 154 REQUIRE_THREAD_SAFE(childFinished); 155 } 156 157 class TestException : public std::exception { what() const158 virtual const char* what() const noexcept override { 159 return "exception"; 160 } 161 }; 162 163 TEST_CASE("Pool task throw", "[thread]") { 164 ThreadPool pool; 165 __anoncfe36ad00f02null166 auto task = pool.submit([] { 167 std::this_thread::sleep_for(std::chrono::milliseconds(10)); 168 throw TestException(); 169 }); 170 REQUIRE_THROWS_AS(task->wait(), TestException); 171 } 172 173 TEST_CASE("Pool task throw nested", "[thread]") { 174 ThreadPool pool; 175 __anoncfe36ad01002null176 auto task = pool.submit([&pool] { 177 std::this_thread::sleep_for(std::chrono::milliseconds(10)); 178 pool.submit([] { 179 std::this_thread::sleep_for(std::chrono::milliseconds(10)); 180 throw TestException(); 181 }); 182 }); 183 REQUIRE_THROWS_AS(task->wait(), TestException); 184 } 185 186 TEST_CASE("Pool ParallelFor", "[thread]") { 187 ThreadPool pool; 188 std::atomic<uint64_t> sum; 189 sum = 0; __anoncfe36ad01202(Size i) 190 parallelFor(pool, 1, 100000, [&sum](Size i) { sum += i; }); 191 REQUIRE_THREAD_SAFE(sum == 4999950000); 192 REQUIRE_THREAD_SAFE(pool.remainingTaskCnt() == 0); 193 } 194 195 TEST_CASE("Pool GetThreadIdx", "[thread]") { 196 ThreadPool pool(2); 197 REQUIRE_THREAD_SAFE(pool.getThreadCnt() == 2); 198 REQUIRE_FALSE(pool.getThreadIdx()); // main thread, not within the pool 199 __anoncfe36ad01302null200 std::thread thread([&pool] { 201 std::this_thread::sleep_for(std::chrono::milliseconds(50)); 202 REQUIRE_THREAD_SAFE(!pool.getThreadIdx()); // also not within the pool 203 }); 204 thread.join(); 205 __anoncfe36ad01402null206 pool.submit([&pool] { 207 const Optional<Size> idx = pool.getThreadIdx(); 208 REQUIRE_THREAD_SAFE(idx); 209 REQUIRE_THREAD_SAFE((idx.value() == 0 || idx.value() == 1)); 210 }); 211 } 212 213 TEST_CASE("Pool WaitForAll", "[thread]") { 214 ThreadPool pool; 215 pool.waitForAll(); // waitForAll with no running tasks 216 217 Timer timer; 218 const Size cnt = pool.getThreadCnt(); 219 std::atomic_int taskIdx; 220 taskIdx = 0; 221 // run tasks with different duration 222 for (Size i = 0; i < cnt; ++i) { __anoncfe36ad01502null223 pool.submit([&taskIdx] { std::this_thread::sleep_for(std::chrono::milliseconds(50 * ++taskIdx)); }); 224 } 225 pool.waitForAll(); 226 REQUIRE_THREAD_SAFE(timer.elapsed(TimerUnit::MILLISECOND) >= 50 * cnt); 227 REQUIRE_NOTHROW(pool.waitForAll()); // second does nothing 228 } 229 230 #ifdef SPH_DEBUG 231 TEST_CASE("Pool ParallelFor assert", "[thread]") { 232 ThreadPool pool(2); 233 // throw from worker thread __anoncfe36ad01602(Size) 234 auto lambda = [](Size) { SPH_ASSERT(false); }; 235 236 REQUIRE_SPH_ASSERT(parallelFor(pool, 1, 2, lambda)); 237 } 238 #endif 239