1 /** 2 * @file thread_pool.h 3 * 4 * @section LICENSE 5 * 6 * The MIT License 7 * 8 * @copyright Copyright (c) 2018-2021 TileDB, Inc. 9 * 10 * Permission is hereby granted, free of charge, to any person obtaining a copy 11 * of this software and associated documentation files (the "Software"), to deal 12 * in the Software without restriction, including without limitation the rights 13 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 * copies of the Software, and to permit persons to whom the Software is 15 * furnished to do so, subject to the following conditions: 16 * 17 * The above copyright notice and this permission notice shall be included in 18 * all copies or substantial portions of the Software. 19 * 20 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 26 * THE SOFTWARE. 27 * 28 * @section DESCRIPTION 29 * 30 * This file declares the ThreadPool class. 31 */ 32 33 #ifndef TILEDB_THREAD_POOL_H 34 #define TILEDB_THREAD_POOL_H 35 36 #include <condition_variable> 37 #include <functional> 38 #include <mutex> 39 #include <stack> 40 #include <thread> 41 #include <unordered_map> 42 #include <unordered_set> 43 #include <vector> 44 45 #include "tiledb/common/macros.h" 46 #include "tiledb/common/status.h" 47 48 namespace tiledb { 49 namespace common { 50 51 /** 52 * A recusive-safe thread pool. 53 */ 54 class ThreadPool { 55 private: 56 /* ********************************* */ 57 /* PRIVATE DATATYPES */ 58 /* ********************************* */ 59 60 // Forward-declaration. 61 struct TaskState; 62 63 // Forward-declaration. 64 class PackagedTask; 65 66 public: 67 /* ********************************* */ 68 /* PUBLIC DATATYPES */ 69 /* ********************************* */ 70 71 class Task { 72 public: 73 /** Constructor. */ Task()74 Task() 75 : task_state_(nullptr) { 76 } 77 78 /** Move constructor. */ Task(Task && rhs)79 Task(Task&& rhs) { 80 task_state_ = std::move(rhs.task_state_); 81 } 82 83 /** Move-assign operator. */ 84 Task& operator=(Task&& rhs) { 85 task_state_ = std::move(rhs.task_state_); 86 return *this; 87 } 88 89 /** Returns true if this instance is associated with a valid task. */ valid()90 bool valid() { 91 return task_state_ != nullptr; 92 } 93 94 private: 95 /** Value constructor. */ Task(const tdb_shared_ptr<TaskState> & task_state)96 Task(const tdb_shared_ptr<TaskState>& task_state) 97 : task_state_(std::move(task_state)) { 98 } 99 100 DISABLE_COPY_AND_COPY_ASSIGN(Task); 101 102 /** Blocks until the task has completed or there are other tasks to service. 103 */ wait()104 void wait() { 105 std::unique_lock<std::mutex> ul(task_state_->return_st_mutex_); 106 if (!task_state_->return_st_set_ && !task_state_->check_task_stack_) 107 task_state_->cv_.wait(ul); 108 } 109 110 /** Returns true if the associated task has completed. */ done()111 bool done() { 112 std::lock_guard<std::mutex> lg(task_state_->return_st_mutex_); 113 return task_state_->return_st_set_; 114 } 115 116 /** 117 * Returns the result value from the task. If the task 118 * has not completed, it will wait. 119 */ get()120 Status get() { 121 wait(); 122 std::lock_guard<std::mutex> lg(task_state_->return_st_mutex_); 123 return task_state_->return_st_; 124 } 125 126 /** The shared task state between futures and their associated task. */ 127 tdb_shared_ptr<TaskState> task_state_; 128 129 friend ThreadPool; 130 friend PackagedTask; 131 }; 132 133 /* ********************************* */ 134 /* CONSTRUCTORS & DESTRUCTORS */ 135 /* ********************************* */ 136 137 /** Constructor. */ 138 ThreadPool(); 139 140 /** Destructor. */ 141 ~ThreadPool(); 142 143 /* ********************************* */ 144 /* API */ 145 /* ********************************* */ 146 147 /** 148 * Initialize the thread pool. 149 * 150 * @param concurrency_level Maximum level of concurrency. 151 * @return Status 152 */ 153 Status init(uint64_t concurrency_level = 1); 154 155 /** 156 * Schedule a new task to be executed. If the returned `Task` object 157 * is valid, `function` is guaranteed to execute. The 'function' may 158 * execute immediately on the calling thread. To avoid deadlock, `function` 159 * should not aquire non-recursive locks held by the calling thread. 160 * 161 * @param function Task function to execute. 162 * @return Task for the return status of the task. 163 */ 164 Task execute(std::function<Status()>&& function); 165 166 /** Return the maximum level of concurrency. */ 167 uint64_t concurrency_level() const; 168 169 /** 170 * Wait on all the given tasks to complete. This is safe to call recusively 171 * and may execute pending tasks on the calling thread while waiting. 172 * 173 * @param tasks Task list to wait on. 174 * @return Status::Ok if all tasks returned Status::Ok, otherwise the first 175 * error status is returned 176 */ 177 Status wait_all(std::vector<Task>& tasks); 178 179 /** 180 * Wait on all the given tasks to complete, return a vector of their return 181 * Status. This is safe to call recusively and may execute pending tasks 182 * on the calling thread while waiting. 183 * 184 * @param tasks Task list to wait on 185 * @return Vector of each task's Status. 186 */ 187 std::vector<Status> wait_all_status(std::vector<Task>& tasks); 188 189 private: 190 /* ********************************* */ 191 /* PRIVATE DATATYPES */ 192 /* ********************************* */ 193 194 struct TaskState { 195 /** Constructor. */ TaskStateTaskState196 TaskState() 197 : return_st_() 198 , check_task_stack_(false) 199 , return_st_set_(false) { 200 } 201 202 DISABLE_COPY_AND_COPY_ASSIGN(TaskState); 203 DISABLE_MOVE_AND_MOVE_ASSIGN(TaskState); 204 205 /** The return status from an executed task. */ 206 Status return_st_; 207 208 bool check_task_stack_; 209 210 /** True if the `return_st_` has been set. */ 211 bool return_st_set_; 212 213 /** Waits for a task to complete. */ 214 std::condition_variable cv_; 215 216 /** Protects `return_st_`, `return_st_set_`, and `cv_`. */ 217 std::mutex return_st_mutex_; 218 }; 219 220 class PackagedTask { 221 public: 222 /** Constructor. */ PackagedTask()223 PackagedTask() 224 : fn_(nullptr) 225 , task_state_(nullptr) 226 , parent_(nullptr) { 227 } 228 229 /** Value constructor. */ 230 template <class Fn_T> PackagedTask(Fn_T && fn,tdb_shared_ptr<PackagedTask> && parent)231 explicit PackagedTask(Fn_T&& fn, tdb_shared_ptr<PackagedTask>&& parent) { 232 fn_ = std::move(fn); 233 task_state_ = tdb_make_shared(TaskState); 234 parent_ = std::move(parent); 235 } 236 237 /** Function-call operator. */ operator()238 void operator()() { 239 const Status r = fn_(); 240 { 241 std::lock_guard<std::mutex> lg(task_state_->return_st_mutex_); 242 task_state_->return_st_set_ = true; 243 task_state_->return_st_ = r; 244 } 245 task_state_->cv_.notify_all(); 246 247 fn_ = std::function<Status()>(); 248 task_state_ = nullptr; 249 } 250 251 /** Returns the future associated with this task. */ get_future()252 ThreadPool::Task get_future() const { 253 return Task(task_state_); 254 } 255 get_parent()256 PackagedTask* get_parent() const { 257 return parent_.get(); 258 } 259 260 private: 261 DISABLE_COPY_AND_COPY_ASSIGN(PackagedTask); 262 DISABLE_MOVE_AND_MOVE_ASSIGN(PackagedTask); 263 264 /** The packaged function. */ 265 std::function<Status()> fn_; 266 267 /** The task state to share with futures. */ 268 tdb_shared_ptr<TaskState> task_state_; 269 270 /** The parent task that executed this task. */ 271 tdb_shared_ptr<PackagedTask> parent_; 272 }; 273 274 /* ********************************* */ 275 /* PRIVATE ATTRIBUTES */ 276 /* ********************************* */ 277 278 /** 279 * The maximum level of concurrency among a single waiter and all 280 * of the the `threads_`. 281 */ 282 uint64_t concurrency_level_; 283 284 /** Protects `task_stack_`, `idle_threads_`, and `task_stack_clock_`. */ 285 std::mutex task_stack_mutex_; 286 287 /** Notifies work threads to check `task_stack_` for work. */ 288 std::condition_variable task_stack_cv_; 289 290 /** Pending tasks in LIFO ordering. */ 291 std::vector<tdb_shared_ptr<PackagedTask>> task_stack_; 292 293 /* 294 * A logical, monotonically increasing clock that is incremented 295 * when a task is either added or removed from `task_stack_`. This 296 * is used by threads to determine if `task_stack_` has been modified 297 * between two points in time. 298 */ 299 uint64_t task_stack_clock_; 300 301 /** 302 * The number of threads waiting for the `task_stack_` to 303 * become non-empty. 304 */ 305 uint64_t idle_threads_; 306 307 /** The worker threads. */ 308 std::vector<std::thread> threads_; 309 310 /** When true, all pending tasks will remain unscheduled. */ 311 bool should_terminate_; 312 313 /** All tasks that threads in this instance are waiting on. */ 314 struct BlockedTasksHasher { operatorBlockedTasksHasher315 size_t operator()(const tdb_shared_ptr<TaskState>& task) const { 316 return reinterpret_cast<size_t>(task.get()); 317 } 318 }; 319 std::unordered_set<tdb_shared_ptr<TaskState>, BlockedTasksHasher> 320 blocked_tasks_; 321 322 /** Protects `blocked_tasks_`. */ 323 std::mutex blocked_tasks_mutex_; 324 325 /** Indexes thread ids to the ThreadPool instance they belong to. */ 326 static std::unordered_map<std::thread::id, ThreadPool*> tp_index_; 327 328 /** Protects 'tp_index_'. */ 329 static std::mutex tp_index_lock_; 330 331 /** Indexes thread ids to the task it is currently executing. */ 332 static std::unordered_map<std::thread::id, tdb_shared_ptr<PackagedTask>> 333 task_index_; 334 335 /** Protects 'task_index_'. */ 336 static std::mutex task_index_lock_; 337 338 /* ********************************* */ 339 /* PRIVATE METHODS */ 340 /* ********************************* */ 341 342 /** 343 * Waits for `task`, but will execute other tasks from `task_stack_` 344 * while waiting. While this may be an performance optimization 345 * to perform work on this thread rather than waiting, the primary 346 * motiviation is to prevent deadlock when tasks are enqueued recursively. 347 */ 348 Status wait_or_work(Task&& task); 349 350 /** Terminate the threads in the thread pool. */ 351 void terminate(); 352 353 /** The worker thread routine. */ 354 static void worker(ThreadPool& pool); 355 356 // Add indexes from each thread to this instance. 357 void add_tp_index(); 358 359 // Remove indexes from each thread to this instance. 360 void remove_tp_index(); 361 362 // Lookup the thread pool instance that contains `tid`. 363 static ThreadPool* lookup_tp(std::thread::id tid); 364 365 // Add indexes for each thread on the `task_index_`. 366 void add_task_index(); 367 368 // Remove indexes for each thread on the `task_index_`. 369 void remove_task_index(); 370 371 // Lookup the task executing on `tid`. 372 static tdb_shared_ptr<PackagedTask> lookup_task(std::thread::id tid); 373 374 // Wrapper to update `task_index_` and execute `task`. 375 static void exec_packaged_task(tdb_shared_ptr<PackagedTask> task); 376 }; 377 378 } // namespace common 379 } // namespace tiledb 380 381 #endif // TILEDB_THREAD_POOL_H 382