1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file thread_pool.cc
22  * \brief Threadpool for multi-threading runtime.
23  */
24 #include <tvm/runtime/c_runtime_api.h>
25 #include <tvm/runtime/c_backend_api.h>
26 #include <tvm/runtime/registry.h>
27 #include <tvm/runtime/packed_func.h>
28 #include <tvm/runtime/threading_backend.h>
29 #include <dmlc/thread_local.h>
30 #include <dmlc/logging.h>
31 #if TVM_THREADPOOL_USE_OPENMP
32 #include <omp.h>
33 #endif
34 #include <thread>
35 #include <condition_variable>
36 #include <mutex>
37 #include <atomic>
38 #include <algorithm>
39 #include <vector>
40 #include <string>
41 #include <cstring>
42 #include <memory>
43 #include <sstream>
44 
45 const constexpr int kL1CacheBytes = 64;
46 
47 namespace tvm {
48 namespace runtime {
49 namespace {
50 
51 constexpr uint32_t kDefaultSpinCount = 300000;
52 
GetSpinCount()53 uint32_t GetSpinCount() {
54   const char* val = getenv("TVM_THREAD_POOL_SPIN_COUNT");
55   if (!val) {
56     return kDefaultSpinCount;
57   }
58   return atoi(val);
59 }
60 
61 }  // namespace
62 
63 // stride in the page, fit to cache line.
64 constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
65 
66 /*!
67  * \brief Thread local master environment.
68  */
69 class ParallelLauncher {
70  public:
71   // Reset the the task request.
Init(FTVMParallelLambda flambda,void * cdata,int num_task,bool need_sync)72   void Init(FTVMParallelLambda flambda,
73             void* cdata,
74             int num_task,
75             bool need_sync) {
76     num_pending_.store(num_task);
77     this->cdata = cdata;
78     this->flambda = flambda;
79     this->env.num_task = num_task;
80     has_error_.store(false);
81     // reshape
82     if (static_cast<size_t>(num_task) > par_errors_.size()) {
83       par_errors_.resize(num_task + 1);
84       if (need_sync) {
85         delete[] sync_counter_;
86         sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
87       }
88     }
89     if (need_sync) {
90       for (int i = 0; i < num_task; ++i) {
91         sync_counter_[i * kSyncStride].store(
92             0, std::memory_order_relaxed);
93       }
94       this->env.sync_handle = sync_counter_;
95     } else {
96       this->env.sync_handle = nullptr;
97     }
98   }
~ParallelLauncher()99   ~ParallelLauncher() {
100     delete[] sync_counter_;
101   }
102   // Wait n jobs to finish
WaitForJobs()103   int WaitForJobs() {
104     while (num_pending_.load() != 0) {
105       tvm::runtime::threading::Yield();
106     }
107     if (!has_error_.load()) return 0;
108     // the following is intended to use string due to
109     // security issue raised in SGX backend
110     std::string err("");
111     for (size_t i = 0; i < par_errors_.size(); ++i) {
112       if (par_errors_[i].length() != 0) {
113         err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
114         par_errors_[i].clear();
115       }
116     }
117     TVMAPISetLastError(err.c_str());
118     return -1;
119   }
120   // Signal that one job has finished.
SignalJobError(int task_id)121   void SignalJobError(int task_id) {
122     num_pending_.fetch_sub(1);
123     par_errors_[task_id] = TVMGetLastError();
124     has_error_.store(true);
125   }
126   // Signal that one job has finished.
SignalJobFinish()127   void SignalJobFinish() {
128     num_pending_.fetch_sub(1);
129   }
130   // Get thread local version of the store.
ThreadLocal()131   static ParallelLauncher* ThreadLocal() {
132     return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
133   }
134   // The parallel lambda
135   FTVMParallelLambda flambda;
136   // The closure data
137   void* cdata;
138   // Local env
139   TVMParallelGroupEnv env;
140   // Whether this thread is worker of the pool.
141   // used to prevent recursive launch.
142   bool is_worker{false};
143 
144  private:
145   // The pending jobs.
146   std::atomic<int32_t> num_pending_;
147   // Whether error has been countered.
148   std::atomic<bool> has_error_;
149   // The counter page.
150   std::atomic<int32_t>* sync_counter_{nullptr};
151   // The error message
152   std::vector<std::string> par_errors_;
153 };
154 
155 /*! \brief Lock-free single-producer-single-consumer queue for each thread */
156 class SpscTaskQueue {
157  public:
158   /*! \brief The task entry */
159   struct Task {
160     ParallelLauncher* launcher;
161     int32_t task_id;
162   };
163 
SpscTaskQueue()164   SpscTaskQueue() :
165     buffer_(new Task[kRingSize]),
166     head_(0),
167     tail_(0) {
168   }
169 
~SpscTaskQueue()170   ~SpscTaskQueue() {
171     delete[] buffer_;
172   }
173 
174   /*!
175    * \brief Push a task into the queue and notify the comsumer if it is on wait.
176    * \param input The task to be dequeued.
177    */
Push(const Task & input)178   void Push(const Task& input) {
179     while (!Enqueue(input)) {
180       tvm::runtime::threading::Yield();
181     }
182     if (pending_.fetch_add(1) == -1) {
183       std::unique_lock<std::mutex> lock(mutex_);
184       cv_.notify_one();
185     }
186   }
187 
188   /*!
189    * \brief Pop a task out of the queue and condition wait if no tasks.
190    * \param output The pointer to the task to be dequeued.
191    * \param spin_count The number of iterations to spin before sleep.
192    * \return Whether pop is successful (true) or we need to exit now (false).
193    */
Pop(Task * output,uint32_t spin_count)194   bool Pop(Task* output, uint32_t spin_count) {
195     // Busy wait a bit when the queue is empty.
196     // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
197     // The default spin count is set by following the typical omp convention
198     for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
199       tvm::runtime::threading::Yield();
200     }
201     if (pending_.fetch_sub(1) == 0) {
202       std::unique_lock<std::mutex> lock(mutex_);
203       cv_.wait(lock, [this] {
204           return pending_.load() >= 0 || exit_now_.load();
205         });
206     }
207     if (exit_now_.load(std::memory_order_relaxed)) {
208       return false;
209     }
210     const uint32_t head = head_.load(std::memory_order_relaxed);
211     // sanity check if the queue is empty
212     CHECK(tail_.load(std::memory_order_acquire) != head);
213     *output = buffer_[head];
214     head_.store((head + 1) % kRingSize, std::memory_order_release);
215     return true;
216   }
217 
218   /*!
219    * \brief Signal to terminate the worker.
220    */
SignalForKill()221   void SignalForKill() {
222     std::lock_guard<std::mutex> lock(mutex_);
223     exit_now_.store(true);
224     cv_.notify_all();
225   }
226 
227  protected:
228   /*!
229    * \brief Lock-free enqueue.
230    * \param input The task to be enqueued.
231    * \return Whether the task is enqueued.
232    */
Enqueue(const Task & input)233   bool Enqueue(const Task& input) {
234     if (exit_now_.load(std::memory_order_relaxed)) return false;
235 
236     const uint32_t tail = tail_.load(std::memory_order_relaxed);
237 
238     if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
239       buffer_[tail] = input;
240       tail_.store((tail + 1) % kRingSize, std::memory_order_release);
241       return true;
242     }
243     return false;
244   }
245 
246   // the cache line paddings are used for avoid false sharing between atomic variables
247   typedef char cache_line_pad_t[kL1CacheBytes];
248   cache_line_pad_t pad0_;
249   // size of the queue, the queue can host size_ - 1 items at most
250   // define it as a constant for better compiler optimization
251   static constexpr const int kRingSize = 2;
252   // pointer to access the item
253   Task* const buffer_;
254 
255   cache_line_pad_t pad1_;
256   // queue head, where one gets a task from the queue
257   std::atomic<uint32_t> head_;
258 
259   cache_line_pad_t pad2_;
260   // queue tail, when one puts a task to the queue
261   std::atomic<uint32_t> tail_;
262 
263   cache_line_pad_t pad3_;
264   // pending tasks in the queue
265   std::atomic<int8_t> pending_{0};
266 
267   cache_line_pad_t pad4_;
268   // signal for exit now
269   std::atomic<bool> exit_now_{false};
270 
271   // internal mutex
272   std::mutex mutex_;
273   // cv for consumer
274   std::condition_variable cv_;
275 };
276 
277 // The thread pool
278 class ThreadPool {
279  public:
ThreadPool()280   ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
281     for (int i = 0; i < num_workers_; ++i) {
282       // The SpscTaskQueue only hosts ONE item at a time
283       queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
284     }
285     const char* exclude_worker0 = getenv("TVM_EXCLUDE_WORKER0");
286     if (exclude_worker0 && atoi(exclude_worker0) == 0) {
287       exclude_worker0_ = false;
288     }
289     threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
290         new tvm::runtime::threading::ThreadGroup(
291           num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
292           exclude_worker0_ /* include_main_thread */));
293     num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
294   }
~ThreadPool()295   ~ThreadPool() {
296     for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
297       q->SignalForKill();
298     }
299     threads_.reset();
300   }
Launch(FTVMParallelLambda flambda,void * cdata,int num_task,int need_sync)301   int Launch(FTVMParallelLambda flambda,
302              void* cdata,
303              int num_task,
304              int need_sync) {
305     ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
306     CHECK(!launcher->is_worker)
307         << "Cannot launch parallel job inside worker, consider fuse then parallel";
308     if (num_task == 0) {
309       num_task = num_workers_used_;
310     }
311     if (need_sync != 0) {
312       CHECK_LE(num_task, num_workers_used_)
313           << "Request parallel sync task larger than number of threads used "
314           << " workers=" << num_workers_used_ << " request=" << num_task;
315     }
316     launcher->Init(flambda, cdata, num_task, need_sync != 0);
317     SpscTaskQueue::Task tsk;
318     tsk.launcher = launcher;
319     // if worker0 is taken by the master, queues_[0] is abandoned
320     for (int i = exclude_worker0_; i < num_task; ++i) {
321       tsk.task_id = i;
322       queues_[i]->Push(tsk);
323     }
324     // use the master thread to run task 0
325     if (exclude_worker0_) {
326       TVMParallelGroupEnv* penv = &(tsk.launcher->env);
327       if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
328         tsk.launcher->SignalJobFinish();
329       } else {
330         tsk.launcher->SignalJobError(tsk.task_id);
331       }
332     }
333     int res = launcher->WaitForJobs();
334     return res;
335   }
336 
ThreadLocal()337   static ThreadPool* ThreadLocal() {
338     return dmlc::ThreadLocalStore<ThreadPool>::Get();
339   }
340 
UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode,int nthreads)341   void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
342     // this will also reset the affinity of the ThreadGroup
343     // may use less than the MaxConcurrency number of workers
344     num_workers_used_ = threads_->Configure(mode, nthreads,
345                                             exclude_worker0_);
346     // if MaxConcurrency restricted the number of workers (e.g., due to
347     // hyperthreading), respect the restriction
348     num_workers_used_ = std::min(num_workers_, num_workers_used_);
349   }
350 
351  private:
352   // Internal worker function.
RunWorker(int worker_id)353   void RunWorker(int worker_id) {
354     SpscTaskQueue* queue = queues_[worker_id].get();
355     SpscTaskQueue::Task task;
356     ParallelLauncher::ThreadLocal()->is_worker = true;
357     // Initialize the spin count (from envvar TVM_THREAD_POOL_SPIN_COUNT) on
358     // the global first use of the ThreadPool.
359     // TODO(tulloch): should we make this configurable via standard APIs?
360     static size_t spin_count = GetSpinCount();
361     while (queue->Pop(&task, spin_count)) {
362       CHECK(task.launcher != nullptr);
363       TVMParallelGroupEnv* penv = &(task.launcher->env);
364       void* cdata = task.launcher->cdata;
365       if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
366         task.launcher->SignalJobFinish();
367       } else {
368         task.launcher->SignalJobError(task.task_id);
369       }
370     }
371   }
372   int num_workers_;
373   // number of workers used (can be restricted with affinity pref)
374   int num_workers_used_;
375   // if or not to exclude worker 0 and use master to run task 0
376 #ifndef _LIBCPP_SGX_CONFIG
377   bool exclude_worker0_{true};
378 #else
379   bool exclude_worker0_{false};
380 #endif
381   std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
382   std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
383 };
384 
385 TVM_REGISTER_GLOBAL("runtime.config_threadpool")
__anon9064db540402(TVMArgs args, TVMRetValue* rv) 386 .set_body([](TVMArgs args, TVMRetValue* rv) {
387     threading::ThreadGroup::AffinityMode mode =\
388     static_cast<threading::ThreadGroup::AffinityMode>(\
389     static_cast<int>(args[0]));
390     int nthreads = args[1];
391     ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
392 });
393 
394 
395 }  // namespace runtime
396 }  // namespace tvm
397 
398 
TVMBackendParallelLaunch(FTVMParallelLambda flambda,void * cdata,int num_task)399 int TVMBackendParallelLaunch(
400     FTVMParallelLambda flambda,
401     void* cdata,
402     int num_task) {
403 #if !TVM_THREADPOOL_USE_OPENMP
404   int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(
405       flambda, cdata, num_task, 1);
406   return res;
407 #else
408   int num_workers = tvm::runtime::threading::MaxConcurrency();
409   if (num_task == 0) num_task = num_workers;
410   omp_set_num_threads(num_workers);
411   #pragma omp parallel num_threads(num_workers)
412   {
413     TVMParallelGroupEnv env;
414     env.num_task = num_task;
415     std::atomic<int32_t>* sync_counter = new std::atomic<int>[num_task * tvm::runtime::kSyncStride];
416     for (int i = 0; i < num_task; ++i) {
417       sync_counter[i * tvm::runtime::kSyncStride].store(
418           0, std::memory_order_relaxed);
419     }
420     env.sync_handle = sync_counter;
421     (*flambda)(omp_get_thread_num(), &env, cdata);
422   }
423   return 0;
424 #endif
425 }
426 
TVMBackendParallelBarrier(int task_id,TVMParallelGroupEnv * penv)427 int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
428 #if TVM_THREADPOOL_USE_OPENMP
429   #pragma omp barrier
430 #else
431   using tvm::runtime::kSyncStride;
432   int num_task = penv->num_task;
433   std::atomic<int>* sync_counter =
434       reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
435   int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
436       1, std::memory_order_release);
437   for (int i = 0; i < num_task; ++i) {
438     if (i != task_id) {
439       while (sync_counter[i * kSyncStride].load(
440                  std::memory_order_relaxed) <= old_counter) {
441         tvm::runtime::threading::Yield();
442       }
443     }
444   }
445   std::atomic_thread_fence(std::memory_order_acquire);
446 #endif
447   return 0;
448 }
449