1 /**
2  * @file   thread_pool.h
3  *
4  * @section LICENSE
5  *
6  * The MIT License
7  *
8  * @copyright Copyright (c) 2018-2021 TileDB, Inc.
9  *
10  * Permission is hereby granted, free of charge, to any person obtaining a copy
11  * of this software and associated documentation files (the "Software"), to deal
12  * in the Software without restriction, including without limitation the rights
13  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14  * copies of the Software, and to permit persons to whom the Software is
15  * furnished to do so, subject to the following conditions:
16  *
17  * The above copyright notice and this permission notice shall be included in
18  * all copies or substantial portions of the Software.
19  *
20  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26  * THE SOFTWARE.
27  *
28  * @section DESCRIPTION
29  *
30  * This file declares the ThreadPool class.
31  */
32 
33 #ifndef TILEDB_THREAD_POOL_H
34 #define TILEDB_THREAD_POOL_H
35 
36 #include <condition_variable>
37 #include <functional>
38 #include <mutex>
39 #include <stack>
40 #include <thread>
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <vector>
44 
45 #include "tiledb/common/macros.h"
46 #include "tiledb/common/status.h"
47 
48 namespace tiledb {
49 namespace common {
50 
51 /**
52  * A recusive-safe thread pool.
53  */
54 class ThreadPool {
55  private:
56   /* ********************************* */
57   /*          PRIVATE DATATYPES        */
58   /* ********************************* */
59 
60   // Forward-declaration.
61   struct TaskState;
62 
63   // Forward-declaration.
64   class PackagedTask;
65 
66  public:
67   /* ********************************* */
68   /*          PUBLIC DATATYPES         */
69   /* ********************************* */
70 
71   class Task {
72    public:
73     /** Constructor. */
Task()74     Task()
75         : task_state_(nullptr) {
76     }
77 
78     /** Move constructor. */
Task(Task && rhs)79     Task(Task&& rhs) {
80       task_state_ = std::move(rhs.task_state_);
81     }
82 
83     /** Move-assign operator. */
84     Task& operator=(Task&& rhs) {
85       task_state_ = std::move(rhs.task_state_);
86       return *this;
87     }
88 
89     /** Returns true if this instance is associated with a valid task. */
valid()90     bool valid() {
91       return task_state_ != nullptr;
92     }
93 
94    private:
95     /** Value constructor. */
Task(const tdb_shared_ptr<TaskState> & task_state)96     Task(const tdb_shared_ptr<TaskState>& task_state)
97         : task_state_(std::move(task_state)) {
98     }
99 
100     DISABLE_COPY_AND_COPY_ASSIGN(Task);
101 
102     /** Blocks until the task has completed or there are other tasks to service.
103      */
wait()104     void wait() {
105       std::unique_lock<std::mutex> ul(task_state_->return_st_mutex_);
106       if (!task_state_->return_st_set_ && !task_state_->check_task_stack_)
107         task_state_->cv_.wait(ul);
108     }
109 
110     /** Returns true if the associated task has completed. */
done()111     bool done() {
112       std::lock_guard<std::mutex> lg(task_state_->return_st_mutex_);
113       return task_state_->return_st_set_;
114     }
115 
116     /**
117      * Returns the result value from the task. If the task
118      * has not completed, it will wait.
119      */
get()120     Status get() {
121       wait();
122       std::lock_guard<std::mutex> lg(task_state_->return_st_mutex_);
123       return task_state_->return_st_;
124     }
125 
126     /** The shared task state between futures and their associated task. */
127     tdb_shared_ptr<TaskState> task_state_;
128 
129     friend ThreadPool;
130     friend PackagedTask;
131   };
132 
133   /* ********************************* */
134   /*     CONSTRUCTORS & DESTRUCTORS    */
135   /* ********************************* */
136 
137   /** Constructor. */
138   ThreadPool();
139 
140   /** Destructor. */
141   ~ThreadPool();
142 
143   /* ********************************* */
144   /*                API                */
145   /* ********************************* */
146 
147   /**
148    * Initialize the thread pool.
149    *
150    * @param concurrency_level Maximum level of concurrency.
151    * @return Status
152    */
153   Status init(uint64_t concurrency_level = 1);
154 
155   /**
156    * Schedule a new task to be executed. If the returned `Task` object
157    * is valid, `function` is guaranteed to execute. The 'function' may
158    * execute immediately on the calling thread. To avoid deadlock, `function`
159    * should not aquire non-recursive locks held by the calling thread.
160    *
161    * @param function Task function to execute.
162    * @return Task for the return status of the task.
163    */
164   Task execute(std::function<Status()>&& function);
165 
166   /** Return the maximum level of concurrency. */
167   uint64_t concurrency_level() const;
168 
169   /**
170    * Wait on all the given tasks to complete. This is safe to call recusively
171    * and may execute pending tasks on the calling thread while waiting.
172    *
173    * @param tasks Task list to wait on.
174    * @return Status::Ok if all tasks returned Status::Ok, otherwise the first
175    * error status is returned
176    */
177   Status wait_all(std::vector<Task>& tasks);
178 
179   /**
180    * Wait on all the given tasks to complete, return a vector of their return
181    * Status. This is safe to call recusively and may execute pending tasks
182    * on the calling thread while waiting.
183    *
184    * @param tasks Task list to wait on
185    * @return Vector of each task's Status.
186    */
187   std::vector<Status> wait_all_status(std::vector<Task>& tasks);
188 
189  private:
190   /* ********************************* */
191   /*          PRIVATE DATATYPES        */
192   /* ********************************* */
193 
194   struct TaskState {
195     /** Constructor. */
TaskStateTaskState196     TaskState()
197         : return_st_()
198         , check_task_stack_(false)
199         , return_st_set_(false) {
200     }
201 
202     DISABLE_COPY_AND_COPY_ASSIGN(TaskState);
203     DISABLE_MOVE_AND_MOVE_ASSIGN(TaskState);
204 
205     /** The return status from an executed task. */
206     Status return_st_;
207 
208     bool check_task_stack_;
209 
210     /** True if the `return_st_` has been set. */
211     bool return_st_set_;
212 
213     /** Waits for a task to complete. */
214     std::condition_variable cv_;
215 
216     /** Protects `return_st_`, `return_st_set_`, and `cv_`. */
217     std::mutex return_st_mutex_;
218   };
219 
220   class PackagedTask {
221    public:
222     /** Constructor. */
PackagedTask()223     PackagedTask()
224         : fn_(nullptr)
225         , task_state_(nullptr)
226         , parent_(nullptr) {
227     }
228 
229     /** Value constructor. */
230     template <class Fn_T>
PackagedTask(Fn_T && fn,tdb_shared_ptr<PackagedTask> && parent)231     explicit PackagedTask(Fn_T&& fn, tdb_shared_ptr<PackagedTask>&& parent) {
232       fn_ = std::move(fn);
233       task_state_ = tdb_make_shared(TaskState);
234       parent_ = std::move(parent);
235     }
236 
237     /** Function-call operator. */
operator()238     void operator()() {
239       const Status r = fn_();
240       {
241         std::lock_guard<std::mutex> lg(task_state_->return_st_mutex_);
242         task_state_->return_st_set_ = true;
243         task_state_->return_st_ = r;
244       }
245       task_state_->cv_.notify_all();
246 
247       fn_ = std::function<Status()>();
248       task_state_ = nullptr;
249     }
250 
251     /** Returns the future associated with this task. */
get_future()252     ThreadPool::Task get_future() const {
253       return Task(task_state_);
254     }
255 
get_parent()256     PackagedTask* get_parent() const {
257       return parent_.get();
258     }
259 
260    private:
261     DISABLE_COPY_AND_COPY_ASSIGN(PackagedTask);
262     DISABLE_MOVE_AND_MOVE_ASSIGN(PackagedTask);
263 
264     /** The packaged function. */
265     std::function<Status()> fn_;
266 
267     /** The task state to share with futures. */
268     tdb_shared_ptr<TaskState> task_state_;
269 
270     /** The parent task that executed this task. */
271     tdb_shared_ptr<PackagedTask> parent_;
272   };
273 
274   /* ********************************* */
275   /*         PRIVATE ATTRIBUTES        */
276   /* ********************************* */
277 
278   /**
279    * The maximum level of concurrency among a single waiter and all
280    * of the the `threads_`.
281    */
282   uint64_t concurrency_level_;
283 
284   /** Protects `task_stack_`, `idle_threads_`, and `task_stack_clock_`. */
285   std::mutex task_stack_mutex_;
286 
287   /** Notifies work threads to check `task_stack_` for work. */
288   std::condition_variable task_stack_cv_;
289 
290   /** Pending tasks in LIFO ordering. */
291   std::vector<tdb_shared_ptr<PackagedTask>> task_stack_;
292 
293   /*
294    * A logical, monotonically increasing clock that is incremented
295    * when a task is either added or removed from `task_stack_`. This
296    * is used by threads to determine if `task_stack_` has been modified
297    * between two points in time.
298    */
299   uint64_t task_stack_clock_;
300 
301   /**
302    * The number of threads waiting for the `task_stack_` to
303    * become non-empty.
304    */
305   uint64_t idle_threads_;
306 
307   /** The worker threads. */
308   std::vector<std::thread> threads_;
309 
310   /** When true, all pending tasks will remain unscheduled. */
311   bool should_terminate_;
312 
313   /** All tasks that threads in this instance are waiting on. */
314   struct BlockedTasksHasher {
operatorBlockedTasksHasher315     size_t operator()(const tdb_shared_ptr<TaskState>& task) const {
316       return reinterpret_cast<size_t>(task.get());
317     }
318   };
319   std::unordered_set<tdb_shared_ptr<TaskState>, BlockedTasksHasher>
320       blocked_tasks_;
321 
322   /** Protects `blocked_tasks_`. */
323   std::mutex blocked_tasks_mutex_;
324 
325   /** Indexes thread ids to the ThreadPool instance they belong to. */
326   static std::unordered_map<std::thread::id, ThreadPool*> tp_index_;
327 
328   /** Protects 'tp_index_'. */
329   static std::mutex tp_index_lock_;
330 
331   /** Indexes thread ids to the task it is currently executing. */
332   static std::unordered_map<std::thread::id, tdb_shared_ptr<PackagedTask>>
333       task_index_;
334 
335   /** Protects 'task_index_'. */
336   static std::mutex task_index_lock_;
337 
338   /* ********************************* */
339   /*          PRIVATE METHODS          */
340   /* ********************************* */
341 
342   /**
343    * Waits for `task`, but will execute other tasks from `task_stack_`
344    * while waiting. While this may be an performance optimization
345    * to perform work on this thread rather than waiting, the primary
346    * motiviation is to prevent deadlock when tasks are enqueued recursively.
347    */
348   Status wait_or_work(Task&& task);
349 
350   /** Terminate the threads in the thread pool. */
351   void terminate();
352 
353   /** The worker thread routine. */
354   static void worker(ThreadPool& pool);
355 
356   // Add indexes from each thread to this instance.
357   void add_tp_index();
358 
359   // Remove indexes from each thread to this instance.
360   void remove_tp_index();
361 
362   // Lookup the thread pool instance that contains `tid`.
363   static ThreadPool* lookup_tp(std::thread::id tid);
364 
365   // Add indexes for each thread on the `task_index_`.
366   void add_task_index();
367 
368   // Remove indexes for each thread on the `task_index_`.
369   void remove_task_index();
370 
371   // Lookup the task executing on `tid`.
372   static tdb_shared_ptr<PackagedTask> lookup_task(std::thread::id tid);
373 
374   // Wrapper to update `task_index_` and execute `task`.
375   static void exec_packaged_task(tdb_shared_ptr<PackagedTask> task);
376 };
377 
378 }  // namespace common
379 }  // namespace tiledb
380 
381 #endif  // TILEDB_THREAD_POOL_H
382