1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file thread_pool.cc
22  * \brief Threadpool for multi-threading runtime.
23  */
24 #include <dmlc/logging.h>
25 #include <dmlc/thread_local.h>
26 #include <tvm/runtime/c_backend_api.h>
27 #include <tvm/runtime/c_runtime_api.h>
28 #include <tvm/runtime/packed_func.h>
29 #include <tvm/runtime/registry.h>
30 #include <tvm/runtime/threading_backend.h>
31 #if TVM_THREADPOOL_USE_OPENMP
32 #include <omp.h>
33 #endif
34 #include <algorithm>
35 #include <atomic>
36 #include <condition_variable>
37 #include <cstring>
38 #include <memory>
39 #include <mutex>
40 #include <sstream>
41 #include <string>
42 #include <thread>
43 #include <vector>
44 
45 const constexpr int kL1CacheBytes = 64;
46 
47 namespace tvm {
48 namespace runtime {
49 namespace {
50 
51 constexpr uint32_t kDefaultSpinCount = 300000;
52 
GetSpinCount()53 uint32_t GetSpinCount() {
54   const char* val = getenv("TVM_THREAD_POOL_SPIN_COUNT");
55   if (!val) {
56     return kDefaultSpinCount;
57   }
58   return atoi(val);
59 }
60 
61 }  // namespace
62 
63 // stride in the page, fit to cache line.
64 constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
65 
66 /*!
67  * \brief Thread local master environment.
68  */
69 class ParallelLauncher {
70  public:
71   // Reset the the task request.
Init(FTVMParallelLambda flambda,void * cdata,int num_task,bool need_sync)72   void Init(FTVMParallelLambda flambda, void* cdata, int num_task, bool need_sync) {
73     num_pending_.store(num_task);
74     this->cdata = cdata;
75     this->flambda = flambda;
76     this->env.num_task = num_task;
77     has_error_.store(false);
78     // reshape
79     if (static_cast<size_t>(num_task) > par_errors_.size()) {
80       par_errors_.resize(num_task + 1);
81       if (need_sync) {
82         delete[] sync_counter_;
83         sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
84       }
85     }
86     if (need_sync) {
87       for (int i = 0; i < num_task; ++i) {
88         sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed);
89       }
90       this->env.sync_handle = sync_counter_;
91     } else {
92       this->env.sync_handle = nullptr;
93     }
94   }
~ParallelLauncher()95   ~ParallelLauncher() { delete[] sync_counter_; }
96   // Wait n jobs to finish
WaitForJobs()97   int WaitForJobs() {
98     while (num_pending_.load() != 0) {
99       tvm::runtime::threading::Yield();
100     }
101     if (!has_error_.load()) return 0;
102     std::ostringstream os;
103     for (size_t i = 0; i < par_errors_.size(); ++i) {
104       if (par_errors_[i].length() != 0) {
105         os << "Task " << i << " error: " << par_errors_[i] << '\n';
106         par_errors_[i].clear();
107       }
108     }
109     TVMAPISetLastError(os.str().c_str());
110     return -1;
111   }
112   // Signal that one job has finished.
SignalJobError(int task_id)113   void SignalJobError(int task_id) {
114     num_pending_.fetch_sub(1);
115     par_errors_[task_id] = TVMGetLastError();
116     has_error_.store(true);
117   }
118   // Signal that one job has finished.
SignalJobFinish()119   void SignalJobFinish() { num_pending_.fetch_sub(1); }
120   // Get thread local version of the store.
ThreadLocal()121   static ParallelLauncher* ThreadLocal() { return dmlc::ThreadLocalStore<ParallelLauncher>::Get(); }
122   // The parallel lambda
123   FTVMParallelLambda flambda;
124   // The closure data
125   void* cdata;
126   // Local env
127   TVMParallelGroupEnv env;
128   // Whether this thread is worker of the pool.
129   // used to prevent recursive launch.
130   bool is_worker{false};
131 
132  private:
133   // The pending jobs.
134   std::atomic<int32_t> num_pending_;
135   // Whether error has been countered.
136   std::atomic<bool> has_error_;
137   // The counter page.
138   std::atomic<int32_t>* sync_counter_{nullptr};
139   // The error message
140   std::vector<std::string> par_errors_;
141 };
142 
143 /*! \brief Lock-free single-producer-single-consumer queue for each thread */
144 class SpscTaskQueue {
145  public:
146   /*! \brief The task entry */
147   struct Task {
148     ParallelLauncher* launcher;
149     int32_t task_id;
150   };
151 
SpscTaskQueue()152   SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {}
153 
~SpscTaskQueue()154   ~SpscTaskQueue() { delete[] buffer_; }
155 
156   /*!
157    * \brief Push a task into the queue and notify the comsumer if it is on wait.
158    * \param input The task to be dequeued.
159    */
Push(const Task & input)160   void Push(const Task& input) {
161     while (!Enqueue(input)) {
162       tvm::runtime::threading::Yield();
163     }
164     if (pending_.fetch_add(1) == -1) {
165       std::unique_lock<std::mutex> lock(mutex_);
166       cv_.notify_one();
167     }
168   }
169 
170   /*!
171    * \brief Pop a task out of the queue and condition wait if no tasks.
172    * \param output The pointer to the task to be dequeued.
173    * \param spin_count The number of iterations to spin before sleep.
174    * \return Whether pop is successful (true) or we need to exit now (false).
175    */
Pop(Task * output,uint32_t spin_count)176   bool Pop(Task* output, uint32_t spin_count) {
177     // Busy wait a bit when the queue is empty.
178     // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
179     // The default spin count is set by following the typical omp convention
180     for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
181       tvm::runtime::threading::Yield();
182     }
183     if (pending_.fetch_sub(1) == 0) {
184       std::unique_lock<std::mutex> lock(mutex_);
185       cv_.wait(lock, [this] { return pending_.load() >= 0 || exit_now_.load(); });
186     }
187     if (exit_now_.load(std::memory_order_relaxed)) {
188       return false;
189     }
190     const uint32_t head = head_.load(std::memory_order_relaxed);
191     // sanity check if the queue is empty
192     CHECK(tail_.load(std::memory_order_acquire) != head);
193     *output = buffer_[head];
194     head_.store((head + 1) % kRingSize, std::memory_order_release);
195     return true;
196   }
197 
198   /*!
199    * \brief Signal to terminate the worker.
200    */
SignalForKill()201   void SignalForKill() {
202     std::lock_guard<std::mutex> lock(mutex_);
203     exit_now_.store(true);
204     cv_.notify_all();
205   }
206 
207  protected:
208   /*!
209    * \brief Lock-free enqueue.
210    * \param input The task to be enqueued.
211    * \return Whether the task is enqueued.
212    */
Enqueue(const Task & input)213   bool Enqueue(const Task& input) {
214     if (exit_now_.load(std::memory_order_relaxed)) return false;
215 
216     const uint32_t tail = tail_.load(std::memory_order_relaxed);
217 
218     if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
219       buffer_[tail] = input;
220       tail_.store((tail + 1) % kRingSize, std::memory_order_release);
221       return true;
222     }
223     return false;
224   }
225 
226   // the cache line paddings are used for avoid false sharing between atomic variables
227   typedef char cache_line_pad_t[kL1CacheBytes];
228   cache_line_pad_t pad0_;
229   // size of the queue, the queue can host size_ - 1 items at most
230   // define it as a constant for better compiler optimization
231   static constexpr const int kRingSize = 2;
232   // pointer to access the item
233   Task* const buffer_;
234 
235   cache_line_pad_t pad1_;
236   // queue head, where one gets a task from the queue
237   std::atomic<uint32_t> head_;
238 
239   cache_line_pad_t pad2_;
240   // queue tail, when one puts a task to the queue
241   std::atomic<uint32_t> tail_;
242 
243   cache_line_pad_t pad3_;
244   // pending tasks in the queue
245   std::atomic<int8_t> pending_{0};
246 
247   cache_line_pad_t pad4_;
248   // signal for exit now
249   std::atomic<bool> exit_now_{false};
250 
251   // internal mutex
252   std::mutex mutex_;
253   // cv for consumer
254   std::condition_variable cv_;
255 };
256 
257 // The thread pool
258 class ThreadPool {
259  public:
ThreadPool()260   ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) {
261     for (int i = 0; i < num_workers_; ++i) {
262       // The SpscTaskQueue only hosts ONE item at a time
263       queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
264     }
265     const char* exclude_worker0 = getenv("TVM_EXCLUDE_WORKER0");
266     if (exclude_worker0 && atoi(exclude_worker0) == 0) {
267       exclude_worker0_ = false;
268     }
269     threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
270         new tvm::runtime::threading::ThreadGroup(
271             num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
272             exclude_worker0_ /* include_main_thread */));
273     num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
274   }
~ThreadPool()275   ~ThreadPool() {
276     for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
277       q->SignalForKill();
278     }
279     threads_.reset();
280   }
Launch(FTVMParallelLambda flambda,void * cdata,int num_task,int need_sync)281   int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) {
282     ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
283     CHECK(!launcher->is_worker)
284         << "Cannot launch parallel job inside worker, consider fuse then parallel";
285     if (num_task == 0) {
286       num_task = num_workers_used_;
287     }
288     if (need_sync != 0) {
289       CHECK_LE(num_task, num_workers_used_)
290           << "Request parallel sync task larger than number of threads used "
291           << " workers=" << num_workers_used_ << " request=" << num_task;
292     }
293     launcher->Init(flambda, cdata, num_task, need_sync != 0);
294     SpscTaskQueue::Task tsk;
295     tsk.launcher = launcher;
296     // if worker0 is taken by the master, queues_[0] is abandoned
297     for (int i = exclude_worker0_; i < num_task; ++i) {
298       tsk.task_id = i;
299       queues_[i]->Push(tsk);
300     }
301     // use the master thread to run task 0
302     if (exclude_worker0_) {
303       TVMParallelGroupEnv* penv = &(tsk.launcher->env);
304       if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
305         tsk.launcher->SignalJobFinish();
306       } else {
307         tsk.launcher->SignalJobError(tsk.task_id);
308       }
309     }
310     int res = launcher->WaitForJobs();
311     return res;
312   }
313 
ThreadLocal()314   static ThreadPool* ThreadLocal() { return dmlc::ThreadLocalStore<ThreadPool>::Get(); }
315 
UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode,int nthreads)316   void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
317     // this will also reset the affinity of the ThreadGroup
318     // may use less than the MaxConcurrency number of workers
319     num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_);
320     // if MaxConcurrency restricted the number of workers (e.g., due to
321     // hyperthreading), respect the restriction
322     num_workers_used_ = std::min(num_workers_, num_workers_used_);
323   }
324 
325  private:
326   // Internal worker function.
RunWorker(int worker_id)327   void RunWorker(int worker_id) {
328     SpscTaskQueue* queue = queues_[worker_id].get();
329     SpscTaskQueue::Task task;
330     ParallelLauncher::ThreadLocal()->is_worker = true;
331     // Initialize the spin count (from envvar TVM_THREAD_POOL_SPIN_COUNT) on
332     // the global first use of the ThreadPool.
333     // TODO(tulloch): should we make this configurable via standard APIs?
334     static size_t spin_count = GetSpinCount();
335     while (queue->Pop(&task, spin_count)) {
336       CHECK(task.launcher != nullptr);
337       TVMParallelGroupEnv* penv = &(task.launcher->env);
338       void* cdata = task.launcher->cdata;
339       if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
340         task.launcher->SignalJobFinish();
341       } else {
342         task.launcher->SignalJobError(task.task_id);
343       }
344     }
345   }
346   int num_workers_;
347   // number of workers used (can be restricted with affinity pref)
348   int num_workers_used_;
349   // if or not to exclude worker 0 and use master to run task 0
350   bool exclude_worker0_{true};
351   std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
352   std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
353 };
354 
__anon065c8dcf0402(TVMArgs args, TVMRetValue* rv) 355 TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRetValue* rv) {
356   threading::ThreadGroup::AffinityMode mode =
357       static_cast<threading::ThreadGroup::AffinityMode>(static_cast<int>(args[0]));
358   int nthreads = args[1];
359   ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
360 });
361 
362 }  // namespace runtime
363 }  // namespace tvm
364 
TVMBackendParallelLaunch(FTVMParallelLambda flambda,void * cdata,int num_task)365 int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
366 #if !TVM_THREADPOOL_USE_OPENMP
367   int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1);
368   return res;
369 #else
370   int num_workers = tvm::runtime::threading::MaxConcurrency();
371   if (num_task == 0) num_task = num_workers;
372   omp_set_num_threads(num_workers);
373 #pragma omp parallel num_threads(num_workers)
374   {
375     TVMParallelGroupEnv env;
376     env.num_task = num_task;
377     (*flambda)(omp_get_thread_num(), &env, cdata);
378   }
379   return 0;
380 #endif
381 }
382 
TVMBackendParallelBarrier(int task_id,TVMParallelGroupEnv * penv)383 int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
384 #if TVM_THREADPOOL_USE_OPENMP
385 #pragma omp barrier
386 #else
387   using tvm::runtime::kSyncStride;
388   int num_task = penv->num_task;
389   std::atomic<int>* sync_counter = reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
390   int old_counter = sync_counter[task_id * kSyncStride].fetch_add(1, std::memory_order_release);
391   for (int i = 0; i < num_task; ++i) {
392     if (i != task_id) {
393       while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) {
394         tvm::runtime::threading::Yield();
395       }
396     }
397   }
398   std::atomic_thread_fence(std::memory_order_acquire);
399 #endif
400   return 0;
401 }
402