1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2018 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 //   this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 //   this list of conditions and the following disclaimer in the documentation
12 //   and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 //   used to endorse or promote products derived from this software without
15 //   specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
28 //
29 // Author: vitus@google.com (Michael Vitus)
30 
31 // This include must come before any #ifndef check on Ceres compile options.
32 #include "ceres/internal/port.h"
33 
34 #ifdef CERES_USE_CXX_THREADS
35 
36 #include <cmath>
37 #include <condition_variable>
38 #include <memory>
39 #include <mutex>
40 
41 #include "ceres/concurrent_queue.h"
42 #include "ceres/parallel_for.h"
43 #include "ceres/scoped_thread_token.h"
44 #include "ceres/thread_token_provider.h"
45 #include "glog/logging.h"
46 
47 namespace ceres {
48 namespace internal {
49 namespace {
50 // This class creates a thread safe barrier which will block until a
51 // pre-specified number of threads call Finished.  This allows us to block the
52 // main thread until all the parallel threads are finished processing all the
53 // work.
54 class BlockUntilFinished {
55  public:
BlockUntilFinished(int num_total)56   explicit BlockUntilFinished(int num_total)
57       : num_finished_(0), num_total_(num_total) {}
58 
59   // Increment the number of jobs that have finished and signal the blocking
60   // thread if all jobs have finished.
Finished()61   void Finished() {
62     std::lock_guard<std::mutex> lock(mutex_);
63     ++num_finished_;
64     CHECK_LE(num_finished_, num_total_);
65     if (num_finished_ == num_total_) {
66       condition_.notify_one();
67     }
68   }
69 
70   // Block until all threads have signaled they are finished.
Block()71   void Block() {
72     std::unique_lock<std::mutex> lock(mutex_);
73     condition_.wait(lock, [&]() { return num_finished_ == num_total_; });
74   }
75 
76  private:
77   std::mutex mutex_;
78   std::condition_variable condition_;
79   // The current number of jobs finished.
80   int num_finished_;
81   // The total number of jobs.
82   int num_total_;
83 };
84 
85 // Shared state between the parallel tasks. Each thread will use this
86 // information to get the next block of work to be performed.
87 struct SharedState {
SharedStateceres::internal::__anon4b90abf60111::SharedState88   SharedState(int start, int end, int num_work_items)
89       : start(start),
90         end(end),
91         num_work_items(num_work_items),
92         i(0),
93         thread_token_provider(num_work_items),
94         block_until_finished(num_work_items) {}
95 
96   // The start and end index of the for loop.
97   const int start;
98   const int end;
99   // The number of blocks that need to be processed.
100   const int num_work_items;
101 
102   // The next block of work to be assigned to a worker.  The parallel for loop
103   // range is split into num_work_items blocks of work, i.e. a single block of
104   // work is:
105   //  for (int j = start + i; j < end; j += num_work_items) { ... }.
106   int i;
107   std::mutex mutex_i;
108 
109   // Provides a unique thread ID among all active threads working on the same
110   // group of tasks.  Thread-safe.
111   ThreadTokenProvider thread_token_provider;
112 
113   // Used to signal when all the work has been completed.  Thread safe.
114   BlockUntilFinished block_until_finished;
115 };
116 
117 }  // namespace
118 
MaxNumThreadsAvailable()119 int MaxNumThreadsAvailable() { return ThreadPool::MaxNumThreadsAvailable(); }
120 
121 // See ParallelFor (below) for more details.
ParallelFor(ContextImpl * context,int start,int end,int num_threads,const std::function<void (int)> & function)122 void ParallelFor(ContextImpl* context,
123                  int start,
124                  int end,
125                  int num_threads,
126                  const std::function<void(int)>& function) {
127   CHECK_GT(num_threads, 0);
128   CHECK(context != NULL);
129   if (end <= start) {
130     return;
131   }
132 
133   // Fast path for when it is single threaded.
134   if (num_threads == 1) {
135     for (int i = start; i < end; ++i) {
136       function(i);
137     }
138     return;
139   }
140 
141   ParallelFor(
142       context, start, end, num_threads, [&function](int /*thread_id*/, int i) {
143         function(i);
144       });
145 }
146 
147 // This implementation uses a fixed size max worker pool with a shared task
148 // queue. The problem of executing the function for the interval of [start, end)
149 // is broken up into at most num_threads blocks and added to the thread pool. To
150 // avoid deadlocks, the calling thread is allowed to steal work from the worker
151 // pool. This is implemented via a shared state between the tasks. In order for
152 // the calling thread or thread pool to get a block of work, it will query the
153 // shared state for the next block of work to be done. If there is nothing left,
154 // it will return. We will exit the ParallelFor call when all of the work has
155 // been done, not when all of the tasks have been popped off the task queue.
156 //
157 // A unique thread ID among all active tasks will be acquired once for each
158 // block of work.  This avoids the significant performance penalty for acquiring
159 // it on every iteration of the for loop. The thread ID is guaranteed to be in
160 // [0, num_threads).
161 //
162 // A performance analysis has shown this implementation is onpar with OpenMP and
163 // TBB.
ParallelFor(ContextImpl * context,int start,int end,int num_threads,const std::function<void (int thread_id,int i)> & function)164 void ParallelFor(ContextImpl* context,
165                  int start,
166                  int end,
167                  int num_threads,
168                  const std::function<void(int thread_id, int i)>& function) {
169   CHECK_GT(num_threads, 0);
170   CHECK(context != NULL);
171   if (end <= start) {
172     return;
173   }
174 
175   // Fast path for when it is single threaded.
176   if (num_threads == 1) {
177     // Even though we only have one thread, use the thread token provider to
178     // guarantee the exact same behavior when running with multiple threads.
179     ThreadTokenProvider thread_token_provider(num_threads);
180     const ScopedThreadToken scoped_thread_token(&thread_token_provider);
181     const int thread_id = scoped_thread_token.token();
182     for (int i = start; i < end; ++i) {
183       function(thread_id, i);
184     }
185     return;
186   }
187 
188   // We use a std::shared_ptr because the main thread can finish all
189   // the work before the tasks have been popped off the queue.  So the
190   // shared state needs to exist for the duration of all the tasks.
191   const int num_work_items = std::min((end - start), num_threads);
192   std::shared_ptr<SharedState> shared_state(
193       new SharedState(start, end, num_work_items));
194 
195   // A function which tries to perform a chunk of work. This returns false if
196   // there is no work to be done.
197   auto task_function = [shared_state, &function]() {
198     int i = 0;
199     {
200       // Get the next available chunk of work to be performed. If there is no
201       // work, return false.
202       std::lock_guard<std::mutex> lock(shared_state->mutex_i);
203       if (shared_state->i >= shared_state->num_work_items) {
204         return false;
205       }
206       i = shared_state->i;
207       ++shared_state->i;
208     }
209 
210     const ScopedThreadToken scoped_thread_token(
211         &shared_state->thread_token_provider);
212     const int thread_id = scoped_thread_token.token();
213 
214     // Perform each task.
215     for (int j = shared_state->start + i; j < shared_state->end;
216          j += shared_state->num_work_items) {
217       function(thread_id, j);
218     }
219     shared_state->block_until_finished.Finished();
220     return true;
221   };
222 
223   // Add all the tasks to the thread pool.
224   for (int i = 0; i < num_work_items; ++i) {
225     // Note we are taking the task_function as value so the shared_state
226     // shared pointer is copied and the ref count is increased. This is to
227     // prevent it from being deleted when the main thread finishes all the
228     // work and exits before the threads finish.
229     context->thread_pool.AddTask([task_function]() { task_function(); });
230   }
231 
232   // Try to do any available work on the main thread. This may steal work from
233   // the thread pool, but when there is no work left the thread pool tasks
234   // will be no-ops.
235   while (task_function()) {
236   }
237 
238   // Wait until all tasks have finished.
239   shared_state->block_until_finished.Block();
240 }
241 
242 }  // namespace internal
243 }  // namespace ceres
244 
245 #endif  // CERES_USE_CXX_THREADS
246