1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file thread_pool.cc
22 * \brief Threadpool for multi-threading runtime.
23 */
24 #include <tvm/runtime/c_runtime_api.h>
25 #include <tvm/runtime/c_backend_api.h>
26 #include <tvm/runtime/registry.h>
27 #include <tvm/runtime/packed_func.h>
28 #include <tvm/runtime/threading_backend.h>
29 #include <dmlc/thread_local.h>
30 #include <dmlc/logging.h>
31 #if TVM_THREADPOOL_USE_OPENMP
32 #include <omp.h>
33 #endif
34 #include <thread>
35 #include <condition_variable>
36 #include <mutex>
37 #include <atomic>
38 #include <algorithm>
39 #include <vector>
40 #include <string>
41 #include <cstring>
42 #include <memory>
43 #include <sstream>
44
45 const constexpr int kL1CacheBytes = 64;
46
47 namespace tvm {
48 namespace runtime {
49 namespace {
50
51 constexpr uint32_t kDefaultSpinCount = 300000;
52
GetSpinCount()53 uint32_t GetSpinCount() {
54 const char* val = getenv("TVM_THREAD_POOL_SPIN_COUNT");
55 if (!val) {
56 return kDefaultSpinCount;
57 }
58 return atoi(val);
59 }
60
61 } // namespace
62
63 // stride in the page, fit to cache line.
64 constexpr int kSyncStride = 64 / sizeof(std::atomic<int>);
65
66 /*!
67 * \brief Thread local master environment.
68 */
69 class ParallelLauncher {
70 public:
71 // Reset the the task request.
Init(FTVMParallelLambda flambda,void * cdata,int num_task,bool need_sync)72 void Init(FTVMParallelLambda flambda,
73 void* cdata,
74 int num_task,
75 bool need_sync) {
76 num_pending_.store(num_task);
77 this->cdata = cdata;
78 this->flambda = flambda;
79 this->env.num_task = num_task;
80 has_error_.store(false);
81 // reshape
82 if (static_cast<size_t>(num_task) > par_errors_.size()) {
83 par_errors_.resize(num_task + 1);
84 if (need_sync) {
85 delete[] sync_counter_;
86 sync_counter_ = new std::atomic<int>[num_task * kSyncStride];
87 }
88 }
89 if (need_sync) {
90 for (int i = 0; i < num_task; ++i) {
91 sync_counter_[i * kSyncStride].store(
92 0, std::memory_order_relaxed);
93 }
94 this->env.sync_handle = sync_counter_;
95 } else {
96 this->env.sync_handle = nullptr;
97 }
98 }
~ParallelLauncher()99 ~ParallelLauncher() {
100 delete[] sync_counter_;
101 }
102 // Wait n jobs to finish
WaitForJobs()103 int WaitForJobs() {
104 while (num_pending_.load() != 0) {
105 tvm::runtime::threading::Yield();
106 }
107 if (!has_error_.load()) return 0;
108 // the following is intended to use string due to
109 // security issue raised in SGX backend
110 std::string err("");
111 for (size_t i = 0; i < par_errors_.size(); ++i) {
112 if (par_errors_[i].length() != 0) {
113 err += "Task " + std::to_string(i) + " error: " + par_errors_[i] + '\n';
114 par_errors_[i].clear();
115 }
116 }
117 TVMAPISetLastError(err.c_str());
118 return -1;
119 }
120 // Signal that one job has finished.
SignalJobError(int task_id)121 void SignalJobError(int task_id) {
122 num_pending_.fetch_sub(1);
123 par_errors_[task_id] = TVMGetLastError();
124 has_error_.store(true);
125 }
126 // Signal that one job has finished.
SignalJobFinish()127 void SignalJobFinish() {
128 num_pending_.fetch_sub(1);
129 }
130 // Get thread local version of the store.
ThreadLocal()131 static ParallelLauncher* ThreadLocal() {
132 return dmlc::ThreadLocalStore<ParallelLauncher>::Get();
133 }
134 // The parallel lambda
135 FTVMParallelLambda flambda;
136 // The closure data
137 void* cdata;
138 // Local env
139 TVMParallelGroupEnv env;
140 // Whether this thread is worker of the pool.
141 // used to prevent recursive launch.
142 bool is_worker{false};
143
144 private:
145 // The pending jobs.
146 std::atomic<int32_t> num_pending_;
147 // Whether error has been countered.
148 std::atomic<bool> has_error_;
149 // The counter page.
150 std::atomic<int32_t>* sync_counter_{nullptr};
151 // The error message
152 std::vector<std::string> par_errors_;
153 };
154
155 /*! \brief Lock-free single-producer-single-consumer queue for each thread */
156 class SpscTaskQueue {
157 public:
158 /*! \brief The task entry */
159 struct Task {
160 ParallelLauncher* launcher;
161 int32_t task_id;
162 };
163
SpscTaskQueue()164 SpscTaskQueue() :
165 buffer_(new Task[kRingSize]),
166 head_(0),
167 tail_(0) {
168 }
169
~SpscTaskQueue()170 ~SpscTaskQueue() {
171 delete[] buffer_;
172 }
173
174 /*!
175 * \brief Push a task into the queue and notify the comsumer if it is on wait.
176 * \param input The task to be dequeued.
177 */
Push(const Task & input)178 void Push(const Task& input) {
179 while (!Enqueue(input)) {
180 tvm::runtime::threading::Yield();
181 }
182 if (pending_.fetch_add(1) == -1) {
183 std::unique_lock<std::mutex> lock(mutex_);
184 cv_.notify_one();
185 }
186 }
187
188 /*!
189 * \brief Pop a task out of the queue and condition wait if no tasks.
190 * \param output The pointer to the task to be dequeued.
191 * \param spin_count The number of iterations to spin before sleep.
192 * \return Whether pop is successful (true) or we need to exit now (false).
193 */
Pop(Task * output,uint32_t spin_count)194 bool Pop(Task* output, uint32_t spin_count) {
195 // Busy wait a bit when the queue is empty.
196 // If a new task comes to the queue quickly, this wait avoid the worker from sleeping.
197 // The default spin count is set by following the typical omp convention
198 for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {
199 tvm::runtime::threading::Yield();
200 }
201 if (pending_.fetch_sub(1) == 0) {
202 std::unique_lock<std::mutex> lock(mutex_);
203 cv_.wait(lock, [this] {
204 return pending_.load() >= 0 || exit_now_.load();
205 });
206 }
207 if (exit_now_.load(std::memory_order_relaxed)) {
208 return false;
209 }
210 const uint32_t head = head_.load(std::memory_order_relaxed);
211 // sanity check if the queue is empty
212 CHECK(tail_.load(std::memory_order_acquire) != head);
213 *output = buffer_[head];
214 head_.store((head + 1) % kRingSize, std::memory_order_release);
215 return true;
216 }
217
218 /*!
219 * \brief Signal to terminate the worker.
220 */
SignalForKill()221 void SignalForKill() {
222 std::lock_guard<std::mutex> lock(mutex_);
223 exit_now_.store(true);
224 cv_.notify_all();
225 }
226
227 protected:
228 /*!
229 * \brief Lock-free enqueue.
230 * \param input The task to be enqueued.
231 * \return Whether the task is enqueued.
232 */
Enqueue(const Task & input)233 bool Enqueue(const Task& input) {
234 if (exit_now_.load(std::memory_order_relaxed)) return false;
235
236 const uint32_t tail = tail_.load(std::memory_order_relaxed);
237
238 if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {
239 buffer_[tail] = input;
240 tail_.store((tail + 1) % kRingSize, std::memory_order_release);
241 return true;
242 }
243 return false;
244 }
245
246 // the cache line paddings are used for avoid false sharing between atomic variables
247 typedef char cache_line_pad_t[kL1CacheBytes];
248 cache_line_pad_t pad0_;
249 // size of the queue, the queue can host size_ - 1 items at most
250 // define it as a constant for better compiler optimization
251 static constexpr const int kRingSize = 2;
252 // pointer to access the item
253 Task* const buffer_;
254
255 cache_line_pad_t pad1_;
256 // queue head, where one gets a task from the queue
257 std::atomic<uint32_t> head_;
258
259 cache_line_pad_t pad2_;
260 // queue tail, when one puts a task to the queue
261 std::atomic<uint32_t> tail_;
262
263 cache_line_pad_t pad3_;
264 // pending tasks in the queue
265 std::atomic<int8_t> pending_{0};
266
267 cache_line_pad_t pad4_;
268 // signal for exit now
269 std::atomic<bool> exit_now_{false};
270
271 // internal mutex
272 std::mutex mutex_;
273 // cv for consumer
274 std::condition_variable cv_;
275 };
276
277 // The thread pool
278 class ThreadPool {
279 public:
ThreadPool()280 ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) {
281 for (int i = 0; i < num_workers_; ++i) {
282 // The SpscTaskQueue only hosts ONE item at a time
283 queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));
284 }
285 const char* exclude_worker0 = getenv("TVM_EXCLUDE_WORKER0");
286 if (exclude_worker0 && atoi(exclude_worker0) == 0) {
287 exclude_worker0_ = false;
288 }
289 threads_ = std::unique_ptr<tvm::runtime::threading::ThreadGroup>(
290 new tvm::runtime::threading::ThreadGroup(
291 num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },
292 exclude_worker0_ /* include_main_thread */));
293 num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);
294 }
~ThreadPool()295 ~ThreadPool() {
296 for (std::unique_ptr<SpscTaskQueue>& q : queues_) {
297 q->SignalForKill();
298 }
299 threads_.reset();
300 }
Launch(FTVMParallelLambda flambda,void * cdata,int num_task,int need_sync)301 int Launch(FTVMParallelLambda flambda,
302 void* cdata,
303 int num_task,
304 int need_sync) {
305 ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();
306 CHECK(!launcher->is_worker)
307 << "Cannot launch parallel job inside worker, consider fuse then parallel";
308 if (num_task == 0) {
309 num_task = num_workers_used_;
310 }
311 if (need_sync != 0) {
312 CHECK_LE(num_task, num_workers_used_)
313 << "Request parallel sync task larger than number of threads used "
314 << " workers=" << num_workers_used_ << " request=" << num_task;
315 }
316 launcher->Init(flambda, cdata, num_task, need_sync != 0);
317 SpscTaskQueue::Task tsk;
318 tsk.launcher = launcher;
319 // if worker0 is taken by the master, queues_[0] is abandoned
320 for (int i = exclude_worker0_; i < num_task; ++i) {
321 tsk.task_id = i;
322 queues_[i]->Push(tsk);
323 }
324 // use the master thread to run task 0
325 if (exclude_worker0_) {
326 TVMParallelGroupEnv* penv = &(tsk.launcher->env);
327 if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {
328 tsk.launcher->SignalJobFinish();
329 } else {
330 tsk.launcher->SignalJobError(tsk.task_id);
331 }
332 }
333 int res = launcher->WaitForJobs();
334 return res;
335 }
336
ThreadLocal()337 static ThreadPool* ThreadLocal() {
338 return dmlc::ThreadLocalStore<ThreadPool>::Get();
339 }
340
UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode,int nthreads)341 void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) {
342 // this will also reset the affinity of the ThreadGroup
343 // may use less than the MaxConcurrency number of workers
344 num_workers_used_ = threads_->Configure(mode, nthreads,
345 exclude_worker0_);
346 // if MaxConcurrency restricted the number of workers (e.g., due to
347 // hyperthreading), respect the restriction
348 num_workers_used_ = std::min(num_workers_, num_workers_used_);
349 }
350
351 private:
352 // Internal worker function.
RunWorker(int worker_id)353 void RunWorker(int worker_id) {
354 SpscTaskQueue* queue = queues_[worker_id].get();
355 SpscTaskQueue::Task task;
356 ParallelLauncher::ThreadLocal()->is_worker = true;
357 // Initialize the spin count (from envvar TVM_THREAD_POOL_SPIN_COUNT) on
358 // the global first use of the ThreadPool.
359 // TODO(tulloch): should we make this configurable via standard APIs?
360 static size_t spin_count = GetSpinCount();
361 while (queue->Pop(&task, spin_count)) {
362 CHECK(task.launcher != nullptr);
363 TVMParallelGroupEnv* penv = &(task.launcher->env);
364 void* cdata = task.launcher->cdata;
365 if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
366 task.launcher->SignalJobFinish();
367 } else {
368 task.launcher->SignalJobError(task.task_id);
369 }
370 }
371 }
372 int num_workers_;
373 // number of workers used (can be restricted with affinity pref)
374 int num_workers_used_;
375 // if or not to exclude worker 0 and use master to run task 0
376 #ifndef _LIBCPP_SGX_CONFIG
377 bool exclude_worker0_{true};
378 #else
379 bool exclude_worker0_{false};
380 #endif
381 std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
382 std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
383 };
384
385 TVM_REGISTER_GLOBAL("runtime.config_threadpool")
__anon9064db540402(TVMArgs args, TVMRetValue* rv) 386 .set_body([](TVMArgs args, TVMRetValue* rv) {
387 threading::ThreadGroup::AffinityMode mode =\
388 static_cast<threading::ThreadGroup::AffinityMode>(\
389 static_cast<int>(args[0]));
390 int nthreads = args[1];
391 ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);
392 });
393
394
395 } // namespace runtime
396 } // namespace tvm
397
398
TVMBackendParallelLaunch(FTVMParallelLambda flambda,void * cdata,int num_task)399 int TVMBackendParallelLaunch(
400 FTVMParallelLambda flambda,
401 void* cdata,
402 int num_task) {
403 #if !TVM_THREADPOOL_USE_OPENMP
404 int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(
405 flambda, cdata, num_task, 1);
406 return res;
407 #else
408 int num_workers = tvm::runtime::threading::MaxConcurrency();
409 if (num_task == 0) num_task = num_workers;
410 omp_set_num_threads(num_workers);
411 #pragma omp parallel num_threads(num_workers)
412 {
413 TVMParallelGroupEnv env;
414 env.num_task = num_task;
415 std::atomic<int32_t>* sync_counter = new std::atomic<int>[num_task * tvm::runtime::kSyncStride];
416 for (int i = 0; i < num_task; ++i) {
417 sync_counter[i * tvm::runtime::kSyncStride].store(
418 0, std::memory_order_relaxed);
419 }
420 env.sync_handle = sync_counter;
421 (*flambda)(omp_get_thread_num(), &env, cdata);
422 }
423 return 0;
424 #endif
425 }
426
TVMBackendParallelBarrier(int task_id,TVMParallelGroupEnv * penv)427 int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
428 #if TVM_THREADPOOL_USE_OPENMP
429 #pragma omp barrier
430 #else
431 using tvm::runtime::kSyncStride;
432 int num_task = penv->num_task;
433 std::atomic<int>* sync_counter =
434 reinterpret_cast<std::atomic<int>*>(penv->sync_handle);
435 int old_counter = sync_counter[task_id * kSyncStride].fetch_add(
436 1, std::memory_order_release);
437 for (int i = 0; i < num_task; ++i) {
438 if (i != task_id) {
439 while (sync_counter[i * kSyncStride].load(
440 std::memory_order_relaxed) <= old_counter) {
441 tvm::runtime::threading::Yield();
442 }
443 }
444 }
445 std::atomic_thread_fence(std::memory_order_acquire);
446 #endif
447 return 0;
448 }
449