1 // Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are met:
6 //
7 // * Redistributions of source code must retain the above copyright
8 // notice, this list of conditions and the following disclaimer.
9 //
10 // * Redistributions in binary form must reproduce the above copyright
11 // notice, this list of conditions and the following disclaimer in the
12 // documentation and/or other materials provided with the distribution.
13 //
14 // * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 // its contributors may be used to endorse or promote products derived
16 // from this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 // POSSIBILITY OF SUCH DAMAGE.
29 //
30 // Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31
32 #ifndef COLMAP_SRC_UTIL_THREADING_
33 #define COLMAP_SRC_UTIL_THREADING_
34
35 #include <atomic>
36 #include <climits>
37 #include <functional>
38 #include <future>
39 #include <list>
40 #include <queue>
41 #include <unordered_map>
42
43 #include "util/timer.h"
44
45 namespace colmap {
46
47 #ifdef __clang__
48 #pragma clang diagnostic push
49 #pragma clang diagnostic ignored "-Wkeyword-macro"
50 #endif
51
52 // Define `thread_local` cross-platform.
53 #ifndef thread_local
54 #if __STDC_VERSION__ >= 201112 && !defined __STDC_NO_THREADS__
55 #define thread_local _Thread_local
56 #elif defined _WIN32 && (defined _MSC_VER || defined __ICL || \
57 defined __DMC__ || defined __BORLANDC__)
58 #define thread_local __declspec(thread)
59 #elif defined __GNUC__ || defined __SUNPRO_C || defined __xlC__
60 #define thread_local __thread
61 #else
62 #error "Cannot define thread_local"
63 #endif
64 #endif
65
66 #ifdef __clang__
67 #pragma clang diagnostic pop // -Wkeyword-macro
68 #endif
69
70 // Helper class to create single threads with simple controls and timing, e.g.:
71 //
72 // class MyThread : public Thread {
73 // enum {
74 // PROCESSED_CALLBACK,
75 // };
76 //
77 // MyThread() { RegisterCallback(PROCESSED_CALLBACK); }
78 // void Run() {
79 // // Some setup routine... note that this optional.
80 // if (setup_valid) {
81 // SignalValidSetup();
82 // } else {
83 // SignalInvalidSetup();
84 // }
85 //
86 // // Some pre-processing...
87 // for (const auto& item : items) {
88 // BlockIfPaused();
89 // if (IsStopped()) {
90 // // Tear down...
91 // break;
92 // }
93 // // Process item...
94 // Callback(PROCESSED_CALLBACK);
95 // }
96 // }
97 // };
98 //
99 // MyThread thread;
100 // thread.AddCallback(MyThread::PROCESSED_CALLBACK, []() {
101 // std::cout << "Processed item"; })
102 // thread.AddCallback(MyThread::STARTED_CALLBACK, []() {
103 // std::cout << "Start"; })
104 // thread.AddCallback(MyThread::FINISHED_CALLBACK, []() {
105 // std::cout << "Finished"; })
106 // thread.Start();
107 // // thread.CheckValidSetup();
108 // // Pause, resume, stop, ...
109 // thread.Wait();
110 // thread.Timer().PrintElapsedSeconds();
111 //
112 class Thread {
113 public:
114 enum {
115 STARTED_CALLBACK = INT_MIN,
116 FINISHED_CALLBACK,
117 };
118
119 Thread();
120 virtual ~Thread() = default;
121
122 // Control the state of the thread.
123 virtual void Start();
124 virtual void Stop();
125 virtual void Pause();
126 virtual void Resume();
127 virtual void Wait();
128
129 // Check the state of the thread.
130 bool IsStarted();
131 bool IsStopped();
132 bool IsPaused();
133 bool IsRunning();
134 bool IsFinished();
135
136 // To be called from inside the main run function. This blocks the main
137 // caller, if the thread is paused, until the thread is resumed.
138 void BlockIfPaused();
139
140 // To be called from outside. This blocks the caller until the thread is
141 // setup, i.e. it signaled that its setup was valid or not. If it never gives
142 // this signal, this call will block the caller infinitely. Check whether
143 // setup is valid. Note that the result is only meaningful if the thread gives
144 // a setup signal.
145 bool CheckValidSetup();
146
147 // Set callbacks that can be triggered within the main run function.
148 void AddCallback(const int id, const std::function<void()>& func);
149
150 // Get timing information of the thread, properly accounting for pause times.
151 const Timer& GetTimer() const;
152
153 protected:
154 // This is the main run function to be implemented by the child class. If you
155 // are looping over data and want to support the pause operation, call
156 // `BlockIfPaused` at appropriate places in the loop. To support the stop
157 // operation, check the `IsStopped` state and early return from this method.
158 virtual void Run() = 0;
159
160 // Register a new callback. Note that only registered callbacks can be
161 // set/reset and called from within the thread. Hence, this method should be
162 // called from the derived thread constructor.
163 void RegisterCallback(const int id);
164
165 // Call back to the function with the specified name, if it exists.
166 void Callback(const int id) const;
167
168 // Get the unique identifier of the current thread.
169 std::thread::id GetThreadId() const;
170
171 // Signal that the thread is setup. Only call this function once.
172 void SignalValidSetup();
173 void SignalInvalidSetup();
174
175 private:
176 // Wrapper around the main run function to set the finished flag.
177 void RunFunc();
178
179 std::thread thread_;
180 std::mutex mutex_;
181 std::condition_variable pause_condition_;
182 std::condition_variable setup_condition_;
183
184 Timer timer_;
185
186 bool started_;
187 bool stopped_;
188 bool paused_;
189 bool pausing_;
190 bool finished_;
191 bool setup_;
192 bool setup_valid_;
193
194 std::unordered_map<int, std::list<std::function<void()>>> callbacks_;
195 };
196
197 // A thread pool class to submit generic tasks (functors) to a pool of workers:
198 //
199 // ThreadPool thread_pool;
200 // thread_pool.AddTask([]() { /* Do some work */ });
201 // auto future = thread_pool.AddTask([]() { /* Do some work */ return 1; });
202 // const auto result = future.get();
203 // for (int i = 0; i < 10; ++i) {
204 // thread_pool.AddTask([](const int i) { /* Do some work */ });
205 // }
206 // thread_pool.Wait();
207 //
208 class ThreadPool {
209 public:
210 static const int kMaxNumThreads = -1;
211
212 explicit ThreadPool(const int num_threads = kMaxNumThreads);
213 ~ThreadPool();
214
215 inline size_t NumThreads() const;
216
217 // Add new task to the thread pool.
218 template <class func_t, class... args_t>
219 auto AddTask(func_t&& f, args_t&&... args)
220 -> std::future<typename std::result_of<func_t(args_t...)>::type>;
221
222 // Stop the execution of all workers.
223 void Stop();
224
225 // Wait until tasks are finished.
226 void Wait();
227
228 // Get the unique identifier of the current thread.
229 std::thread::id GetThreadId() const;
230
231 // Get the index of the current thread. In a thread pool of size N,
232 // the thread index defines the 0-based index of the thread in the pool.
233 // In other words, there are the thread indices 0, ..., N-1.
234 int GetThreadIndex();
235
236 private:
237 void WorkerFunc(const int index);
238
239 std::vector<std::thread> workers_;
240 std::queue<std::function<void()>> tasks_;
241
242 std::mutex mutex_;
243 std::condition_variable task_condition_;
244 std::condition_variable finished_condition_;
245
246 bool stopped_;
247 int num_active_workers_;
248
249 std::unordered_map<std::thread::id, int> thread_id_to_index_;
250 };
251
252 // A job queue class for the producer-consumer paradigm.
253 //
254 // JobQueue<int> job_queue;
255 //
256 // std::thread producer_thread([&job_queue]() {
257 // for (int i = 0; i < 10; ++i) {
258 // job_queue.Push(i);
259 // }
260 // });
261 //
262 // std::thread consumer_thread([&job_queue]() {
263 // for (int i = 0; i < 10; ++i) {
264 // const auto job = job_queue.Pop();
265 // if (job.IsValid()) { /* Do some work */ }
266 // else { break; }
267 // }
268 // });
269 //
270 // producer_thread.join();
271 // consumer_thread.join();
272 //
273 template <typename T>
274 class JobQueue {
275 public:
276 class Job {
277 public:
Job()278 Job() : valid_(false) {}
Job(const T & data)279 explicit Job(const T& data) : data_(data), valid_(true) {}
280
281 // Check whether the data is valid.
IsValid()282 bool IsValid() const { return valid_; }
283
284 // Get reference to the data.
Data()285 T& Data() { return data_; }
Data()286 const T& Data() const { return data_; }
287
288 private:
289 T data_;
290 bool valid_;
291 };
292
293 JobQueue();
294 explicit JobQueue(const size_t max_num_jobs);
295 ~JobQueue();
296
297 // The number of pushed and not popped jobs in the queue.
298 size_t Size();
299
300 // Push a new job to the queue. Waits if the number of jobs is exceeded.
301 bool Push(const T& data);
302
303 // Pop a job from the queue. Waits if there is no job in the queue.
304 Job Pop();
305
306 // Wait for all jobs to be popped and then stop the queue.
307 void Wait();
308
309 // Stop the queue and return from all push/pop calls with false.
310 void Stop();
311
312 // Clear all pushed and not popped jobs from the queue.
313 void Clear();
314
315 private:
316 size_t max_num_jobs_;
317 std::atomic<bool> stop_;
318 std::queue<T> jobs_;
319 std::mutex mutex_;
320 std::condition_variable push_condition_;
321 std::condition_variable pop_condition_;
322 std::condition_variable empty_condition_;
323 };
324
325 // Return the number of logical CPU cores if num_threads <= 0,
326 // otherwise return the input value of num_threads.
327 int GetEffectiveNumThreads(const int num_threads);
328
329 ////////////////////////////////////////////////////////////////////////////////
330 // Implementation
331 ////////////////////////////////////////////////////////////////////////////////
332
NumThreads()333 size_t ThreadPool::NumThreads() const { return workers_.size(); }
334
335 template <class func_t, class... args_t>
336 auto ThreadPool::AddTask(func_t&& f, args_t&&... args)
337 -> std::future<typename std::result_of<func_t(args_t...)>::type> {
338 typedef typename std::result_of<func_t(args_t...)>::type return_t;
339
340 auto task = std::make_shared<std::packaged_task<return_t()>>(
341 std::bind(std::forward<func_t>(f), std::forward<args_t>(args)...));
342
343 std::future<return_t> result = task->get_future();
344
345 {
346 std::unique_lock<std::mutex> lock(mutex_);
347 if (stopped_) {
348 throw std::runtime_error("Cannot add task to stopped thread pool.");
349 }
350 tasks_.emplace([task]() { (*task)(); });
351 }
352
353 task_condition_.notify_one();
354
355 return result;
356 }
357
358 template <typename T>
JobQueue()359 JobQueue<T>::JobQueue() : JobQueue(std::numeric_limits<size_t>::max()) {}
360
361 template <typename T>
JobQueue(const size_t max_num_jobs)362 JobQueue<T>::JobQueue(const size_t max_num_jobs)
363 : max_num_jobs_(max_num_jobs), stop_(false) {}
364
365 template <typename T>
~JobQueue()366 JobQueue<T>::~JobQueue() {
367 Stop();
368 }
369
370 template <typename T>
Size()371 size_t JobQueue<T>::Size() {
372 std::unique_lock<std::mutex> lock(mutex_);
373 return jobs_.size();
374 }
375
376 template <typename T>
Push(const T & data)377 bool JobQueue<T>::Push(const T& data) {
378 std::unique_lock<std::mutex> lock(mutex_);
379 while (jobs_.size() >= max_num_jobs_ && !stop_) {
380 pop_condition_.wait(lock);
381 }
382 if (stop_) {
383 return false;
384 } else {
385 jobs_.push(data);
386 push_condition_.notify_one();
387 return true;
388 }
389 }
390
391 template <typename T>
Pop()392 typename JobQueue<T>::Job JobQueue<T>::Pop() {
393 std::unique_lock<std::mutex> lock(mutex_);
394 while (jobs_.empty() && !stop_) {
395 push_condition_.wait(lock);
396 }
397 if (stop_) {
398 return Job();
399 } else {
400 const T data = jobs_.front();
401 jobs_.pop();
402 pop_condition_.notify_one();
403 if (jobs_.empty()) {
404 empty_condition_.notify_all();
405 }
406 return Job(data);
407 }
408 }
409
410 template <typename T>
Wait()411 void JobQueue<T>::Wait() {
412 std::unique_lock<std::mutex> lock(mutex_);
413 while (!jobs_.empty()) {
414 empty_condition_.wait(lock);
415 }
416 }
417
418 template <typename T>
Stop()419 void JobQueue<T>::Stop() {
420 stop_ = true;
421 push_condition_.notify_all();
422 pop_condition_.notify_all();
423 }
424
425 template <typename T>
Clear()426 void JobQueue<T>::Clear() {
427 std::unique_lock<std::mutex> lock(mutex_);
428 std::queue<T> empty_jobs;
429 std::swap(jobs_, empty_jobs);
430 }
431
432 } // namespace colmap
433
434 #endif // COLMAP_SRC_UTIL_THREADING_
435