1 #pragma once
2 
3 /// \file Pool.h
4 /// \brief Simple thread pool with fixed number of threads
5 /// \author Pavel Sevecek (sevecek at sirrah.troja.mff.cuni.cz)
6 /// \date 2016-2021
7 
8 #include "objects/containers/Array.h"
9 #include "objects/wrappers/Function.h"
10 #include "objects/wrappers/Optional.h"
11 #include "thread/Scheduler.h"
12 #include <atomic>
13 #include <condition_variable>
14 #include <queue>
15 #include <thread>
16 
17 NAMESPACE_SPH_BEGIN
18 
19 /// \brief Task to be executed by one of available threads.
20 class Task : public ITask, public Shareable<Task> {
21 private:
22     std::condition_variable waitVar;
23     std::mutex waitMutex;
24 
25     /// Number of child tasks + 1 for itself
26     std::atomic<int> tasksLeft{ 1 };
27 
28     Function<void()> callable = nullptr;
29 
30     SharedPtr<Task> parent = nullptr;
31 
32     std::exception_ptr caughtException = nullptr;
33 
34 public:
35     explicit Task(const Function<void()>& callable);
36 
37     ~Task();
38 
39     virtual void wait() override;
40 
41     virtual bool completed() const override;
42 
43     /// \brief Assigns a task that spawned this task.
44     ///
45     /// Can be nullptr if this is the root task.
46     void setParent(SharedPtr<Task> parent);
47 
48     /// \brief Saves exception into the task.
49     ///
50     /// The exception propagates into the top-most task.
51     void setException(std::exception_ptr exception);
52 
53     /// \brief Returns true if this is the top-most task.
54     bool isRoot() const;
55 
56     SharedPtr<Task> getParent() const;
57 
58     /// \brief Returns the currently execute task, or nullptr if no task is currently executed on this thread.
59     static SharedPtr<Task> getCurrent();
60 
61     void runAndNotify();
62 
63 private:
64     void addReference();
65 
66     void removeReference();
67 };
68 
69 /// \brief Thread pool capable of executing tasks concurrently.
70 class ThreadPool : public IScheduler {
71     friend class Task;
72 
73 private:
74     /// Threads managed by this pool
75     Array<AutoPtr<std::thread>> threads;
76 
77     /// Selected granularity of the parallel processing.
78     Size granularity;
79 
80     /// Queue of waiting tasks.
81     std::queue<SharedPtr<Task>> tasks;
82 
83     /// Used for synchronization of the task queue
84     std::condition_variable taskVar;
85     std::mutex taskMutex;
86 
87     /// Used for synchronization of task scheduling
88     std::condition_variable waitVar;
89     std::mutex waitMutex;
90 
91     /// Set to true if all tasks should be stopped ASAP
92     std::atomic<bool> stop;
93 
94     /// Number of unprocessed tasks (either currently processing or waiting).
95     std::atomic<int> tasksLeft;
96 
97     /// Global instance of the ThreadPool.
98     /// \note This is not a singleton, another instances can be created if needed.
99     static SharedPtr<ThreadPool> globalInstance;
100 
101 public:
102     /// \brief Initialize thread pool given the number of threads to use.
103     ///
104     /// By default, all available threads are used.
105     ThreadPool(const Size numThreads = 0, const Size granularity = 1000);
106 
107     ~ThreadPool();
108 
109     /// \brief Submits a task into the thread pool.
110     ///
111     /// The task will be executed asynchronously once tasks submitted before it are completed.
112     virtual SharedPtr<ITask> submit(const Function<void()>& task) override;
113 
114     /// \brief Returns the index of this thread, or NOTHING if this thread was not invoked by the thread pool.
115     ///
116     /// The index is within [0, numThreads-1].
117     virtual Optional<Size> getThreadIdx() const override;
118 
119     /// \brief Returns the number of threads used by this thread pool.
120     ///
121     /// Note that this number is constant during the lifetime of thread pool.
122     virtual Size getThreadCnt() const override;
123 
124     virtual Size getRecommendedGranularity() const override;
125 
126     virtual void parallelFor(const Size from,
127         const Size to,
128         const Size granularity,
129         const RangeFunctor& functor) override;
130 
131     virtual void parallelInvoke(const Functor& task1, const Functor& task2) override;
132     /// \brief Blocks until all submitted tasks has been finished.
133     void waitForAll();
134 
135     /// \brief Returns the number of unfinished tasks.
136     ///
137     /// This includes both tasks currently running and tasks waiting in processing queue.
remainingTaskCnt()138     Size remainingTaskCnt() {
139         return tasksLeft;
140     }
141 
142     /// \brief Modifies the default granularity of the thread pool.
setGranularity(const Size newGranularity)143     void setGranularity(const Size newGranularity) {
144         granularity = newGranularity;
145     }
146 
147     /// \brief Returns the global instance of the thread pool.
148     ///
149     /// Other instances can be constructed if needed.
150     static SharedPtr<ThreadPool> getGlobalInstance();
151 
152 private:
153     SharedPtr<Task> getNextTask(const bool wait);
154 
155     void processTask(const bool wait);
156 };
157 
158 NAMESPACE_SPH_END
159