1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "platform.h"
16
17 #if NCNN_SIMPLEOMP
18
19 #include "simpleomp.h"
20 #include "cpu.h" // ncnn::get_cpu_count()
21
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <stdint.h>
26 #include <stdarg.h>
27
28 extern "C" typedef void (*kmpc_micro)(int32_t* gtid, int32_t* tid, ...);
29 extern "C" typedef void (*kmpc_micro_0)(int32_t* gtid, int32_t* tid);
30 extern "C" typedef void (*kmpc_micro_1)(int32_t* gtid, int32_t* tid, void*);
31 extern "C" typedef void (*kmpc_micro_2)(int32_t* gtid, int32_t* tid, void*, void*);
32 extern "C" typedef void (*kmpc_micro_3)(int32_t* gtid, int32_t* tid, void*, void*, void*);
33 extern "C" typedef void (*kmpc_micro_4)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*);
34 extern "C" typedef void (*kmpc_micro_5)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*);
35 extern "C" typedef void (*kmpc_micro_6)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*);
36 extern "C" typedef void (*kmpc_micro_7)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*);
37 extern "C" typedef void (*kmpc_micro_8)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*);
38 extern "C" typedef void (*kmpc_micro_9)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*);
39 extern "C" typedef void (*kmpc_micro_10)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
40 extern "C" typedef void (*kmpc_micro_11)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
41 extern "C" typedef void (*kmpc_micro_12)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
42 extern "C" typedef void (*kmpc_micro_13)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
43 extern "C" typedef void (*kmpc_micro_14)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
44 extern "C" typedef void (*kmpc_micro_15)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
45
46 #ifdef __cplusplus
47 extern "C" {
48 #endif
49
50 static void init_g_kmp_global();
51 static void* kmp_threadfunc(void* args);
52
53 #ifdef __cplusplus
54 } // extern "C"
55 #endif
56
57 namespace ncnn {
58
59 class KMPTask
60 {
61 public:
62 // per-team
63 kmpc_micro fn;
64 int argc;
65 void** argv;
66 int num_threads;
67
68 // per-task
69 int thread_num;
70
71 // finish status
72 int* num_threads_to_wait;
73 Mutex* finish_lock;
74 ConditionVariable* finish_condition;
75 };
76
77 class KMPTaskQueue
78 {
79 public:
KMPTaskQueue(int _max_size)80 KMPTaskQueue(int _max_size)
81 {
82 max_size = _max_size;
83 tasks = new KMPTask*[max_size];
84 size = 0;
85 front = 0;
86 back = 0;
87 }
88
~KMPTaskQueue()89 ~KMPTaskQueue()
90 {
91 delete[] tasks;
92 }
93
dispatch(KMPTask * v,int n)94 void dispatch(KMPTask* v, int n)
95 {
96 lock.lock();
97
98 if (size + n > max_size)
99 {
100 lock.unlock();
101
102 for (int i = 0; i < n; i++)
103 {
104 put(&v[i]);
105 }
106 return;
107 }
108
109 for (int i = 0; i < n; i++)
110 {
111 tasks[back] = &v[i];
112 back++;
113 if (back == max_size)
114 back = 0;
115 }
116
117 size += n;
118
119 lock.unlock();
120
121 condition.signal();
122 }
123
put(KMPTask * v)124 void put(KMPTask* v)
125 {
126 lock.lock();
127 while (size >= max_size)
128 {
129 condition.wait(lock);
130 }
131 tasks[back] = v;
132 back++;
133 if (back == max_size)
134 back = 0;
135 size++;
136 lock.unlock();
137
138 condition.signal();
139 }
140
get(KMPTask * & v)141 void get(KMPTask*& v)
142 {
143 lock.lock();
144 while (size == 0)
145 {
146 condition.wait(lock);
147 }
148 v = tasks[front];
149 front++;
150 if (front == max_size)
151 front = 0;
152 size--;
153 lock.unlock();
154
155 condition.signal();
156 }
157
158 private:
159 Mutex lock;
160 ConditionVariable condition;
161
162 // ring buffer queue
163 int max_size;
164 KMPTask** tasks;
165 int size;
166 int front;
167 int back;
168 };
169
170 class KMPGlobal
171 {
172 public:
KMPGlobal()173 KMPGlobal()
174 {
175 kmp_max_threads = 0;
176 kmp_threads = 0;
177 kmp_threads_tid = 0;
178 kmp_task_queue = 0;
179 }
180
~KMPGlobal()181 ~KMPGlobal()
182 {
183 deinit();
184 }
185
try_init()186 void try_init()
187 {
188 pthread_once(&is_initialized, init_g_kmp_global);
189 }
190
191 public:
192 static pthread_once_t is_initialized;
193
init()194 void init()
195 {
196 // NCNN_LOGE("KMPGlobal init");
197 kmp_max_threads = ncnn::get_cpu_count();
198
199 kmp_task_queue = new ncnn::KMPTaskQueue(std::max(kmp_max_threads * 4, 16));
200
201 if (kmp_max_threads > 1)
202 {
203 kmp_threads = new ncnn::Thread*[kmp_max_threads - 1];
204 kmp_threads_tid = new int[kmp_max_threads - 1];
205 for (int i = 0; i < kmp_max_threads - 1; i++)
206 {
207 kmp_threads_tid[i] = i + 1;
208 kmp_threads[i] = new ncnn::Thread(kmp_threadfunc, (void*)&kmp_threads_tid[i]);
209 }
210 }
211 }
212
deinit()213 void deinit()
214 {
215 // NCNN_LOGE("KMPGlobal deinit");
216 if (kmp_max_threads > 1)
217 {
218 // TODO portable stack allocation
219 ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((kmp_max_threads - 1) * sizeof(ncnn::KMPTask));
220 for (int i = 0; i < kmp_max_threads - 1; i++)
221 {
222 tasks[i].fn = 0;
223 tasks[i].argc = 0;
224 tasks[i].argv = (void**)0;
225 tasks[i].num_threads = kmp_max_threads;
226 tasks[i].thread_num = i + 1;
227 tasks[i].num_threads_to_wait = 0;
228 tasks[i].finish_lock = 0;
229 tasks[i].finish_condition = 0;
230 }
231
232 // dispatch 1 ~ kmp_max_threads
233 kmp_task_queue->dispatch(tasks, kmp_max_threads - 1);
234
235 for (int i = 0; i < kmp_max_threads - 1; i++)
236 {
237 #ifndef __EMSCRIPTEN__
238 // FIXME emscripten complains
239 // pthread_join attempted on thread 12345678,
240 // which does not point to a valid thread, or does not exist anymore!
241 kmp_threads[i]->join();
242 #endif
243 delete kmp_threads[i];
244 }
245 delete[] kmp_threads;
246 delete[] kmp_threads_tid;
247 }
248
249 delete kmp_task_queue;
250 }
251
252 public:
253 int kmp_max_threads;
254 ncnn::Thread** kmp_threads;
255 int* kmp_threads_tid;
256 ncnn::KMPTaskQueue* kmp_task_queue;
257 };
258
259 } // namespace ncnn
260
261 pthread_once_t ncnn::KMPGlobal::is_initialized = PTHREAD_ONCE_INIT;
262
263 static ncnn::KMPGlobal g_kmp_global;
264
265 static ncnn::ThreadLocalStorage tls_num_threads;
266 static ncnn::ThreadLocalStorage tls_thread_num;
267
init_g_kmp_global()268 static void init_g_kmp_global()
269 {
270 g_kmp_global.init();
271 }
272
273 #ifdef __cplusplus
274 extern "C" {
275 #endif
276
omp_get_max_threads()277 int omp_get_max_threads()
278 {
279 return ncnn::get_cpu_count();
280 }
281
omp_get_dynamic()282 int omp_get_dynamic()
283 {
284 return 1;
285 }
286
omp_set_dynamic(int)287 void omp_set_dynamic(int /*dynamic*/)
288 {
289 // always dynamic, ignore
290 }
291
omp_set_num_threads(int num_threads)292 void omp_set_num_threads(int num_threads)
293 {
294 tls_num_threads.set(reinterpret_cast<void*>((size_t)std::max(num_threads, 1)));
295 }
296
omp_get_num_threads()297 int omp_get_num_threads()
298 {
299 return std::max((int)reinterpret_cast<size_t>(tls_num_threads.get()), 1);
300 }
301
omp_get_thread_num()302 int omp_get_thread_num()
303 {
304 return (int)reinterpret_cast<size_t>(tls_thread_num.get());
305 }
306
kmp_get_blocktime()307 int kmp_get_blocktime()
308 {
309 return 0;
310 }
311
kmp_set_blocktime(int)312 void kmp_set_blocktime(int /*blocktime*/)
313 {
314 // always passive, ignore
315 }
316
kmp_invoke_microtask(kmpc_micro fn,int gtid,int tid,int argc,void ** argv)317 static int kmp_invoke_microtask(kmpc_micro fn, int gtid, int tid, int argc, void** argv)
318 {
319 // fprintf(stderr, "__kmp_invoke_microtask #%lu %d %d %d\n", gettid(), gtid, tid, argc);
320
321 switch (argc)
322 {
323 case 0:
324 (*(kmpc_micro_0)fn)(>id, &tid);
325 break;
326 case 1:
327 (*(kmpc_micro_1)fn)(>id, &tid, argv[0]);
328 break;
329 case 2:
330 (*(kmpc_micro_2)fn)(>id, &tid, argv[0], argv[1]);
331 break;
332 case 3:
333 (*(kmpc_micro_3)fn)(>id, &tid, argv[0], argv[1], argv[2]);
334 break;
335 case 4:
336 (*(kmpc_micro_4)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3]);
337 break;
338 case 5:
339 (*(kmpc_micro_5)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4]);
340 break;
341 case 6:
342 (*(kmpc_micro_6)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]);
343 break;
344 case 7:
345 (*(kmpc_micro_7)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]);
346 break;
347 case 8:
348 (*(kmpc_micro_8)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7]);
349 break;
350 case 9:
351 (*(kmpc_micro_9)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8]);
352 break;
353 case 10:
354 (*(kmpc_micro_10)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9]);
355 break;
356 case 11:
357 (*(kmpc_micro_11)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10]);
358 break;
359 case 12:
360 (*(kmpc_micro_12)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11]);
361 break;
362 case 13:
363 (*(kmpc_micro_13)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12]);
364 break;
365 case 14:
366 (*(kmpc_micro_14)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12], argv[13]);
367 break;
368 case 15:
369 (*(kmpc_micro_15)fn)(>id, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12], argv[13], argv[14]);
370 break;
371 default:
372 // assert never reach here
373 break;
374 }
375
376 return 0;
377 }
378
kmp_threadfunc(void * args)379 static void* kmp_threadfunc(void* args)
380 {
381 int tid = *(int*)args;
382
383 for (;;)
384 {
385 ncnn::KMPTask* task;
386 g_kmp_global.kmp_task_queue->get(task);
387
388 // fprintf(stderr, "get %d\n", tid);
389
390 if (!task->fn)
391 break;
392
393 tls_num_threads.set(reinterpret_cast<void*>((size_t)task->num_threads));
394 tls_thread_num.set(reinterpret_cast<void*>((size_t)task->thread_num));
395
396 kmp_invoke_microtask(task->fn, task->thread_num, tid, task->argc, task->argv);
397
398 // update finished
399 {
400 task->finish_lock->lock();
401 *task->num_threads_to_wait = *task->num_threads_to_wait - 1;
402 if (*task->num_threads_to_wait == 0)
403 {
404 task->finish_condition->signal();
405 }
406 task->finish_lock->unlock();
407 }
408 }
409
410 // fprintf(stderr, "exit\n");
411 return 0;
412 }
413
__kmpc_global_thread_num(void *)414 int32_t __kmpc_global_thread_num(void* /*loc*/)
415 {
416 // NCNN_LOGE("__kmpc_global_thread_num");
417 return 0;
418 }
419
__kmpc_push_num_threads(void *,int32_t,int32_t num_threads)420 void __kmpc_push_num_threads(void* /*loc*/, int32_t /*gtid*/, int32_t num_threads)
421 {
422 // NCNN_LOGE("__kmpc_push_num_threads %d", num_threads);
423 omp_set_num_threads(num_threads);
424 }
425
__kmpc_fork_call(void *,int32_t argc,kmpc_micro fn,...)426 void __kmpc_fork_call(void* /*loc*/, int32_t argc, kmpc_micro fn, ...)
427 {
428 g_kmp_global.try_init();
429
430 // NCNN_LOGE("__kmpc_fork_call %d", argc);
431 int num_threads = omp_get_num_threads();
432
433 // build argv
434 void* argv[16];
435 {
436 va_list ap;
437 va_start(ap, fn);
438 for (int i = 0; i < argc; i++)
439 argv[i] = va_arg(ap, void*);
440 va_end(ap);
441 }
442
443 if (g_kmp_global.kmp_max_threads == 1 || num_threads == 1)
444 {
445 for (int i = 0; i < num_threads; i++)
446 {
447 tls_thread_num.set(reinterpret_cast<void*>((size_t)i));
448
449 kmp_invoke_microtask(fn, 0, 0, argc, argv);
450 }
451
452 return;
453 }
454
455 int num_threads_to_wait = num_threads - 1;
456 ncnn::Mutex finish_lock;
457 ncnn::ConditionVariable finish_condition;
458
459 // TODO portable stack allocation
460 ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((num_threads - 1) * sizeof(ncnn::KMPTask));
461 for (int i = 0; i < num_threads - 1; i++)
462 {
463 tasks[i].fn = fn;
464 tasks[i].argc = argc;
465 tasks[i].argv = (void**)argv;
466 tasks[i].num_threads = num_threads;
467 tasks[i].thread_num = i + 1;
468 tasks[i].num_threads_to_wait = &num_threads_to_wait;
469 tasks[i].finish_lock = &finish_lock;
470 tasks[i].finish_condition = &finish_condition;
471 }
472
473 // dispatch 1 ~ num_threads
474 g_kmp_global.kmp_task_queue->dispatch(tasks, num_threads - 1);
475
476 // dispatch 0
477 {
478 tls_thread_num.set(reinterpret_cast<void*>((size_t)0));
479
480 kmp_invoke_microtask(fn, 0, 0, argc, argv);
481 }
482
483 // wait for finished
484 {
485 finish_lock.lock();
486 if (num_threads_to_wait != 0)
487 {
488 finish_condition.wait(finish_lock);
489 }
490 finish_lock.unlock();
491 }
492 }
493
__kmpc_for_static_init_4(void *,int32_t gtid,int32_t,int32_t * last,int32_t * lower,int32_t * upper,int32_t *,int32_t,int32_t)494 void __kmpc_for_static_init_4(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, int32_t* lower, int32_t* upper, int32_t* /*stride*/, int32_t /*incr*/, int32_t /*chunk*/)
495 {
496 // NCNN_LOGE("__kmpc_for_static_init_4");
497 int num_threads = omp_get_num_threads();
498
499 // TODO only support i++
500 int32_t count = *upper - *lower + 1;
501 int32_t threads = std::min(count, (int32_t)num_threads);
502 int32_t count_per_thread = count / threads;
503 int32_t remain = count % threads;
504
505 *last = gtid == (int32_t)(threads - 1);
506 *lower = gtid * count_per_thread + std::min(remain, gtid);
507 *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, gtid + 1) - 1, *upper);
508 }
509
__kmpc_for_static_init_4u(void *,int32_t gtid,int32_t,int32_t * last,uint32_t * lower,uint32_t * upper,int32_t *,int32_t,int32_t)510 void __kmpc_for_static_init_4u(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, uint32_t* lower, uint32_t* upper, int32_t* /*stride*/, int32_t /*incr*/, int32_t /*chunk*/)
511 {
512 // NCNN_LOGE("__kmpc_for_static_init_4u");
513 int num_threads = omp_get_num_threads();
514
515 // TODO only support i++
516 uint32_t count = *upper - *lower + 1;
517 uint32_t threads = std::min(count, (uint32_t)num_threads);
518 uint32_t count_per_thread = count / threads;
519 uint32_t remain = count % threads;
520
521 *last = gtid == (int32_t)(threads - 1);
522 *lower = gtid * count_per_thread + std::min(remain, (uint32_t)gtid);
523 *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (uint32_t)gtid + 1) - 1, *upper);
524 }
525
__kmpc_for_static_init_8(void *,int32_t gtid,int32_t,int32_t * last,int64_t * lower,int64_t * upper,int64_t *,int64_t,int64_t)526 void __kmpc_for_static_init_8(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, int64_t* lower, int64_t* upper, int64_t* /*stride*/, int64_t /*incr*/, int64_t /*chunk*/)
527 {
528 // NCNN_LOGE("__kmpc_for_static_init_8");
529 int num_threads = omp_get_num_threads();
530
531 // TODO only support i++
532 int64_t count = *upper - *lower + 1;
533 int64_t threads = std::min(count, (int64_t)num_threads);
534 int64_t count_per_thread = count / threads;
535 int64_t remain = count % threads;
536
537 *last = gtid == (int64_t)(threads - 1);
538 *lower = gtid * count_per_thread + std::min(remain, (int64_t)gtid);
539 *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (int64_t)gtid + 1) - 1, *upper);
540 }
541
__kmpc_for_static_init_8u(void *,int32_t gtid,int32_t,int32_t * last,uint64_t * lower,uint64_t * upper,int64_t *,int64_t,int64_t)542 void __kmpc_for_static_init_8u(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, uint64_t* lower, uint64_t* upper, int64_t* /*stride*/, int64_t /*incr*/, int64_t /*chunk*/)
543 {
544 // NCNN_LOGE("__kmpc_for_static_init_8u");
545 int num_threads = omp_get_num_threads();
546
547 // TODO only support i++
548 uint64_t count = *upper - *lower + 1;
549 uint64_t threads = std::min(count, (uint64_t)num_threads);
550 uint64_t count_per_thread = count / threads;
551 uint64_t remain = count % threads;
552
553 *last = gtid == (int64_t)(threads - 1);
554 *lower = gtid * count_per_thread + std::min(remain, (uint64_t)gtid);
555 *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (uint64_t)gtid + 1) - 1, *upper);
556 }
557
__kmpc_for_static_fini(void *,int32_t gtid)558 void __kmpc_for_static_fini(void* /*loc*/, int32_t gtid)
559 {
560 // NCNN_LOGE("__kmpc_for_static_fini");
561 (void)gtid;
562 }
563
564 #ifdef __cplusplus
565 } // extern "C"
566 #endif
567
568 #endif // NCNN_SIMPLEOMP
569