1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file thread_pool.cc
22 * \brief Threadpool for multi-threading runtime.
23 */
24 #include <dmlc/logging.h>
25 #include <dmlc/thread_local.h>
26 #include <tvm/runtime/c_backend_api.h>
27 #include <tvm/runtime/c_runtime_api.h>
28 #include <tvm/runtime/packed_func.h>
29 #include <tvm/runtime/registry.h>
30 #include <tvm/runtime/threading_backend.h>
31 #if TVM_THREADPOOL_USE_OPENMP
32 #include <omp.h>
33 #endif
34 #include <algorithm>
35 #include <atomic>
36 #include <condition_variable>
37 #include <cstring>
38 #include <memory>
39 #include <mutex>
40 #include <sstream>
41 #include <string>
42 #include <thread>
43 #include <vector>
44
45 const constexpr int kL1CacheBytes = 64;
46
47 namespace tvm {
48 namespace runtime {
49 namespace {
50
51 constexpr uint32_t kDefaultSpinCount = 300000;
52
GetSpinCount()53 uint32_t GetSpinCount() {
54 const char* val = getenv("TVM_THREAD_POOL_SPIN_COUNT");
55 if (!val) {
56 return kDefaultSpinCount;
57 }
58 return atoi(val);
59 }
60
61 } // namespace
62
63 // stride in the page, fit to cache line.
64 constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
65
66 /*!
67 * \brief Thread local master environment.
68 */
69 class ParallelLauncher {
70 public:
71 // Reset the the task request.
Init(FTVMParallelLambda flambda,void * cdata,int num_task,bool need_sync)72 void Init(FTVMParallelLambda flambda, void* cdata, int num_task, bool need_sync) {
73 num_pending_.store(num_task);
74 this->cdata = cdata;
75 this->flambda = flambda;
76 this->env.num_task = num_task;
77 has_error_.store(false);
78 // reshape
79 if (static_cast<size_t>(num_task) > par_errors_.size()) {
80 par_errors_.resize(num_task + 1);
81 if (need_sync) {
82 delete[] sync_counter_;
83 sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
84 }
85 }
86 if (need_sync) {
87 for (int i = 0; i < num_task; ++i) {
88 sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed);
89 }
90 this->env.sync_handle = sync_counter_;
91 } else {
92 this->env.sync_handle = nullptr;
93 }
94 }
~ParallelLauncher()95 ~ParallelLauncher() { delete[] sync_counter_; }
96 // Wait n jobs to finish
WaitForJobs()97 int WaitForJobs() {
98 while (num_pending_.load() != 0) {
99 tvm::runtime::threading::Yield();
100 }
101 if (!has_error_.load()) return 0;
102 std::ostringstream os;
103 for (size_t i = 0; i < par_errors_.size(); ++i) {
104 if (par_errors_[i].length() != 0) {
105 os << "Task " << i << " error: " << par_errors_[i] << '\n';
106 par_errors_[i].clear();
107 }
108 }
109 TVMAPISetLastError(os.str().c_str());
110 return -1;
111 }
112 // Signal that one job has finished.
SignalJobError(int task_id)113 void SignalJobError(int task_id) {
114 num_pending_.fetch_sub(1);
115 par_errors_[task_id] = TVMGetLastError();
116 has_error_.store(true);
117 }
118 // Signal that one job has finished.
SignalJobFinish()119 void SignalJobFinish() { num_pending_.fetch_sub(1); }
120 // Get thread local version of the store.
ThreadLocal()121 static ParallelLauncher* ThreadLocal() { return dmlc::ThreadLocalStore<ParallelLauncher>::Get(); }
122 // The parallel lambda
123 FTVMParallelLambda flambda;
124 // The closure data
125 void* cdata;
126 // Local env
127 TVMParallelGroupEnv env;
128 // Whether this thread is worker of the pool.
129 // used to prevent recursive launch.
130 bool is_worker{false};
131
132 private:
133 // The pending jobs.
134 std::atomic<int32_t> num_pending_;
135 // Whether error has been countered.
136 std::atomic<bool> has_error_;
137 // The counter page.
138 std::atomic<int32_t>* sync_counter_{nullptr};
139 // The error message
140 std::vector<std::string> par_errors_;
141 };
142
143 /*! \brief Lock-free single-producer-single-consumer queue for each thread */
144 class SpscTaskQueue {
145 public:
146 /*! \brief The task entry */
147 struct Task {
148 ParallelLauncher* launcher;
149 int32_t task_id;
150 };
151
SpscTaskQueue()152 SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {}
153
~SpscTaskQueue()154 ~SpscTaskQueue() { delete[] buffer_; }
155
156 /*!
157 * \brief Push a task into the queue and notify the comsumer if it is on wait.
158 * \param input The task to be dequeued.
159 */
Push(const Task & input)160 void Push(const Task& input) {
161 while (!Enqueue(input)) {
162 tvm::runtime::threading::Yield();
163 }
164 if (pending_.fetch_add(1) == -1) {
165 std::unique_lock<std::mutex> lock(mutex_);
166 cv_.notify_one();
167 }
168 }
169
170 /*!
171 * \brief Pop a task out of the queue and condition wait if no tasks.
172 * \param output The pointer to the task to be dequeued.
173 * \param spin_count The number of iterations to spin before sleep.
174 * \return Whether pop is successful (true) or we need to exit now (false).
175 */
Pop(Task * output,uint32_t spin_count)176 bool Pop(Task* output, uint32_t spin_count) {
177 // Busy wait a bit when the queue is empty.
178 // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
179 // The default spin count is set by following the typical omp convention
180 for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
181 tvm::runtime::threading::Yield();
182 }
183 if (pending_.fetch_sub(1) == 0) {
184 std::unique_lock<std::mutex> lock(mutex_);
185 cv_.wait(lock, [this] { return pending_.load() >= 0 || exit_now_.load(); });
186 }
187 if (exit_now_.load(std::memory_order_relaxed)) {
188 return false;
189 }
190 const uint32_t head = head_.load(std::memory_order_relaxed);
191 // sanity check if the queue is empty
192 CHECK(tail_.load(std::memory_order_acquire) != head);
193 *output = buffer_[head];
194 head_.store((head + 1) % kRingSize, std::memory_order_release);
195 return true;
196 }
197
198 /*!
199 * \brief Signal to terminate the worker.
200 */
SignalForKill()201 void SignalForKill() {
202 std::lock_guard<std::mutex> lock(mutex_);
203 exit_now_.store(true);
204 cv_.notify_all();
205 }
206
207 protected:
208 /*!
209 * \brief Lock-free enqueue.
210 * \param input The task to be enqueued.
211 * \return Whether the task is enqueued.
212 */
Enqueue(const Task & input)213 bool Enqueue(const Task& input) {
214 if (exit_now_.load(std::memory_order_relaxed)) return false;
215
216 const uint32_t tail = tail_.load(std::memory_order_relaxed);
217
218 if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
219 buffer_[tail] = input;
220 tail_.store((tail + 1) % kRingSize, std::memory_order_release);
221 return true;
222 }
223 return false;
224 }
225
226 // the cache line paddings are used for avoid false sharing between atomic variables
227 typedef char cache_line_pad_t[kL1CacheBytes];
228 cache_line_pad_t pad0_;
229 // size of the queue, the queue can host size_ - 1 items at most
230 // define it as a constant for better compiler optimization
231 static constexpr const int kRingSize = 2;
232 // pointer to access the item
233 Task* const buffer_;
234
235 cache_line_pad_t pad1_;
236 // queue head, where one gets a task from the queue
237 std::atomic<uint32_t> head_;
238
239 cache_line_pad_t pad2_;
240 // queue tail, when one puts a task to the queue
241 std::atomic<uint32_t> tail_;
242
243 cache_line_pad_t pad3_;
244 // pending tasks in the queue
245 std::atomic<int8_t> pending_{0};
246
247 cache_line_pad_t pad4_;
248 // signal for exit now
249 std::atomic<bool> exit_now_{false};
250
251 // internal mutex
252 std::mutex mutex_;
253 // cv for consumer
254 std::condition_variable cv_;
255 };
256
257 // The thread pool
258 class ThreadPool {
259 public:
ThreadPool()260 ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) {
261 for (int i = 0; i < num_workers_; ++i) {
262 // The SpscTaskQueue only hosts ONE item at a time
263 queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
264 }
265 const char* exclude_worker0 = getenv("TVM_EXCLUDE_WORKER0");
266 if (exclude_worker0 && atoi(exclude_worker0) == 0) {
267 exclude_worker0_ = false;
268 }
269 threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
270 new tvm::runtime::threading::ThreadGroup(
271 num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
272 exclude_worker0_ /* include_main_thread */));
273 num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
274 }
~ThreadPool()275 ~ThreadPool() {
276 for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
277 q->SignalForKill();
278 }
279 threads_.reset();
280 }
Launch(FTVMParallelLambda flambda,void * cdata,int num_task,int need_sync)281 int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) {
282 ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
283 CHECK(!launcher->is_worker)
284 << "Cannot launch parallel job inside worker, consider fuse then parallel";
285 if (num_task == 0) {
286 num_task = num_workers_used_;
287 }
288 if (need_sync != 0) {
289 CHECK_LE(num_task, num_workers_used_)
290 << "Request parallel sync task larger than number of threads used "
291 << " workers=" << num_workers_used_ << " request=" << num_task;
292 }
293 launcher->Init(flambda, cdata, num_task, need_sync != 0);
294 SpscTaskQueue::Task tsk;
295 tsk.launcher = launcher;
296 // if worker0 is taken by the master, queues_[0] is abandoned
297 for (int i = exclude_worker0_; i < num_task; ++i) {
298 tsk.task_id = i;
299 queues_[i]->Push(tsk);
300 }
301 // use the master thread to run task 0
302 if (exclude_worker0_) {
303 TVMParallelGroupEnv* penv = &(tsk.launcher->env);
304 if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
305 tsk.launcher->SignalJobFinish();
306 } else {
307 tsk.launcher->SignalJobError(tsk.task_id);
308 }
309 }
310 int res = launcher->WaitForJobs();
311 return res;
312 }
313
ThreadLocal()314 static ThreadPool* ThreadLocal() { return dmlc::ThreadLocalStore<ThreadPool>::Get(); }
315
UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode,int nthreads)316 void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
317 // this will also reset the affinity of the ThreadGroup
318 // may use less than the MaxConcurrency number of workers
319 num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_);
320 // if MaxConcurrency restricted the number of workers (e.g., due to
321 // hyperthreading), respect the restriction
322 num_workers_used_ = std::min(num_workers_, num_workers_used_);
323 }
324
325 private:
326 // Internal worker function.
RunWorker(int worker_id)327 void RunWorker(int worker_id) {
328 SpscTaskQueue* queue = queues_[worker_id].get();
329 SpscTaskQueue::Task task;
330 ParallelLauncher::ThreadLocal()->is_worker = true;
331 // Initialize the spin count (from envvar TVM_THREAD_POOL_SPIN_COUNT) on
332 // the global first use of the ThreadPool.
333 // TODO(tulloch): should we make this configurable via standard APIs?
334 static size_t spin_count = GetSpinCount();
335 while (queue->Pop(&task, spin_count)) {
336 CHECK(task.launcher != nullptr);
337 TVMParallelGroupEnv* penv = &(task.launcher->env);
338 void* cdata = task.launcher->cdata;
339 if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
340 task.launcher->SignalJobFinish();
341 } else {
342 task.launcher->SignalJobError(task.task_id);
343 }
344 }
345 }
346 int num_workers_;
347 // number of workers used (can be restricted with affinity pref)
348 int num_workers_used_;
349 // if or not to exclude worker 0 and use master to run task 0
350 bool exclude_worker0_{true};
351 std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
352 std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
353 };
354
__anon065c8dcf0402(TVMArgs args, TVMRetValue* rv) 355 TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRetValue* rv) {
356 threading::ThreadGroup::AffinityMode mode =
357 static_cast<threading::ThreadGroup::AffinityMode>(static_cast<int>(args[0]));
358 int nthreads = args[1];
359 ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
360 });
361
362 } // namespace runtime
363 } // namespace tvm
364
TVMBackendParallelLaunch(FTVMParallelLambda flambda,void * cdata,int num_task)365 int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) {
366 #if !TVM_THREADPOOL_USE_OPENMP
367 int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1);
368 return res;
369 #else
370 int num_workers = tvm::runtime::threading::MaxConcurrency();
371 if (num_task == 0) num_task = num_workers;
372 omp_set_num_threads(num_workers);
373 #pragma omp parallel num_threads(num_workers)
374 {
375 TVMParallelGroupEnv env;
376 env.num_task = num_task;
377 (*flambda)(omp_get_thread_num(), &env, cdata);
378 }
379 return 0;
380 #endif
381 }
382
TVMBackendParallelBarrier(int task_id,TVMParallelGroupEnv * penv)383 int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
384 #if TVM_THREADPOOL_USE_OPENMP
385 #pragma omp barrier
386 #else
387 using tvm::runtime::kSyncStride;
388 int num_task = penv->num_task;
389 std::atomic<int>* sync_counter = reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
390 int old_counter = sync_counter[task_id * kSyncStride].fetch_add(1, std::memory_order_release);
391 for (int i = 0; i < num_task; ++i) {
392 if (i != task_id) {
393 while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) {
394 tvm::runtime::threading::Yield();
395 }
396 }
397 }
398 std::atomic_thread_fence(std::memory_order_acquire);
399 #endif
400 return 0;
401 }
402