1 #pragma once 2 3 /// \file Pool.h 4 /// \brief Simple thread pool with fixed number of threads 5 /// \author Pavel Sevecek (sevecek at sirrah.troja.mff.cuni.cz) 6 /// \date 2016-2021 7 8 #include "objects/containers/Array.h" 9 #include "objects/wrappers/Function.h" 10 #include "objects/wrappers/Optional.h" 11 #include "thread/Scheduler.h" 12 #include <atomic> 13 #include <condition_variable> 14 #include <queue> 15 #include <thread> 16 17 NAMESPACE_SPH_BEGIN 18 19 /// \brief Task to be executed by one of available threads. 20 class Task : public ITask, public Shareable<Task> { 21 private: 22 std::condition_variable waitVar; 23 std::mutex waitMutex; 24 25 /// Number of child tasks + 1 for itself 26 std::atomic<int> tasksLeft{ 1 }; 27 28 Function<void()> callable = nullptr; 29 30 SharedPtr<Task> parent = nullptr; 31 32 std::exception_ptr caughtException = nullptr; 33 34 public: 35 explicit Task(const Function<void()>& callable); 36 37 ~Task(); 38 39 virtual void wait() override; 40 41 virtual bool completed() const override; 42 43 /// \brief Assigns a task that spawned this task. 44 /// 45 /// Can be nullptr if this is the root task. 46 void setParent(SharedPtr<Task> parent); 47 48 /// \brief Saves exception into the task. 49 /// 50 /// The exception propagates into the top-most task. 51 void setException(std::exception_ptr exception); 52 53 /// \brief Returns true if this is the top-most task. 54 bool isRoot() const; 55 56 SharedPtr<Task> getParent() const; 57 58 /// \brief Returns the currently execute task, or nullptr if no task is currently executed on this thread. 59 static SharedPtr<Task> getCurrent(); 60 61 void runAndNotify(); 62 63 private: 64 void addReference(); 65 66 void removeReference(); 67 }; 68 69 /// \brief Thread pool capable of executing tasks concurrently. 70 class ThreadPool : public IScheduler { 71 friend class Task; 72 73 private: 74 /// Threads managed by this pool 75 Array<AutoPtr<std::thread>> threads; 76 77 /// Selected granularity of the parallel processing. 78 Size granularity; 79 80 /// Queue of waiting tasks. 81 std::queue<SharedPtr<Task>> tasks; 82 83 /// Used for synchronization of the task queue 84 std::condition_variable taskVar; 85 std::mutex taskMutex; 86 87 /// Used for synchronization of task scheduling 88 std::condition_variable waitVar; 89 std::mutex waitMutex; 90 91 /// Set to true if all tasks should be stopped ASAP 92 std::atomic<bool> stop; 93 94 /// Number of unprocessed tasks (either currently processing or waiting). 95 std::atomic<int> tasksLeft; 96 97 /// Global instance of the ThreadPool. 98 /// \note This is not a singleton, another instances can be created if needed. 99 static SharedPtr<ThreadPool> globalInstance; 100 101 public: 102 /// \brief Initialize thread pool given the number of threads to use. 103 /// 104 /// By default, all available threads are used. 105 ThreadPool(const Size numThreads = 0, const Size granularity = 1000); 106 107 ~ThreadPool(); 108 109 /// \brief Submits a task into the thread pool. 110 /// 111 /// The task will be executed asynchronously once tasks submitted before it are completed. 112 virtual SharedPtr<ITask> submit(const Function<void()>& task) override; 113 114 /// \brief Returns the index of this thread, or NOTHING if this thread was not invoked by the thread pool. 115 /// 116 /// The index is within [0, numThreads-1]. 117 virtual Optional<Size> getThreadIdx() const override; 118 119 /// \brief Returns the number of threads used by this thread pool. 120 /// 121 /// Note that this number is constant during the lifetime of thread pool. 122 virtual Size getThreadCnt() const override; 123 124 virtual Size getRecommendedGranularity() const override; 125 126 virtual void parallelFor(const Size from, 127 const Size to, 128 const Size granularity, 129 const RangeFunctor& functor) override; 130 131 virtual void parallelInvoke(const Functor& task1, const Functor& task2) override; 132 /// \brief Blocks until all submitted tasks has been finished. 133 void waitForAll(); 134 135 /// \brief Returns the number of unfinished tasks. 136 /// 137 /// This includes both tasks currently running and tasks waiting in processing queue. remainingTaskCnt()138 Size remainingTaskCnt() { 139 return tasksLeft; 140 } 141 142 /// \brief Modifies the default granularity of the thread pool. setGranularity(const Size newGranularity)143 void setGranularity(const Size newGranularity) { 144 granularity = newGranularity; 145 } 146 147 /// \brief Returns the global instance of the thread pool. 148 /// 149 /// Other instances can be constructed if needed. 150 static SharedPtr<ThreadPool> getGlobalInstance(); 151 152 private: 153 SharedPtr<Task> getNextTask(const bool wait); 154 155 void processTask(const bool wait); 156 }; 157 158 NAMESPACE_SPH_END 159