1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "platform.h"
16 
17 #if NCNN_SIMPLEOMP
18 
19 #include "simpleomp.h"
20 #include "cpu.h" // ncnn::get_cpu_count()
21 
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <stdint.h>
26 #include <stdarg.h>
27 
28 extern "C" typedef void (*kmpc_micro)(int32_t* gtid, int32_t* tid, ...);
29 extern "C" typedef void (*kmpc_micro_0)(int32_t* gtid, int32_t* tid);
30 extern "C" typedef void (*kmpc_micro_1)(int32_t* gtid, int32_t* tid, void*);
31 extern "C" typedef void (*kmpc_micro_2)(int32_t* gtid, int32_t* tid, void*, void*);
32 extern "C" typedef void (*kmpc_micro_3)(int32_t* gtid, int32_t* tid, void*, void*, void*);
33 extern "C" typedef void (*kmpc_micro_4)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*);
34 extern "C" typedef void (*kmpc_micro_5)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*);
35 extern "C" typedef void (*kmpc_micro_6)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*);
36 extern "C" typedef void (*kmpc_micro_7)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*);
37 extern "C" typedef void (*kmpc_micro_8)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*);
38 extern "C" typedef void (*kmpc_micro_9)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*);
39 extern "C" typedef void (*kmpc_micro_10)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
40 extern "C" typedef void (*kmpc_micro_11)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
41 extern "C" typedef void (*kmpc_micro_12)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
42 extern "C" typedef void (*kmpc_micro_13)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
43 extern "C" typedef void (*kmpc_micro_14)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
44 extern "C" typedef void (*kmpc_micro_15)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
45 
46 #ifdef __cplusplus
47 extern "C" {
48 #endif
49 
50 static void init_g_kmp_global();
51 static void* kmp_threadfunc(void* args);
52 
53 #ifdef __cplusplus
54 } // extern "C"
55 #endif
56 
57 namespace ncnn {
58 
59 class KMPTask
60 {
61 public:
62     // per-team
63     kmpc_micro fn;
64     int argc;
65     void** argv;
66     int num_threads;
67 
68     // per-task
69     int thread_num;
70 
71     // finish status
72     int* num_threads_to_wait;
73     Mutex* finish_lock;
74     ConditionVariable* finish_condition;
75 };
76 
77 class KMPTaskQueue
78 {
79 public:
KMPTaskQueue(int _max_size)80     KMPTaskQueue(int _max_size)
81     {
82         max_size = _max_size;
83         tasks = new KMPTask*[max_size];
84         size = 0;
85         front = 0;
86         back = 0;
87     }
88 
~KMPTaskQueue()89     ~KMPTaskQueue()
90     {
91         delete[] tasks;
92     }
93 
dispatch(KMPTask * v,int n)94     void dispatch(KMPTask* v, int n)
95     {
96         lock.lock();
97 
98         if (size + n > max_size)
99         {
100             lock.unlock();
101 
102             for (int i = 0; i < n; i++)
103             {
104                 put(&v[i]);
105             }
106             return;
107         }
108 
109         for (int i = 0; i < n; i++)
110         {
111             tasks[back] = &v[i];
112             back++;
113             if (back == max_size)
114                 back = 0;
115         }
116 
117         size += n;
118 
119         lock.unlock();
120 
121         condition.signal();
122     }
123 
put(KMPTask * v)124     void put(KMPTask* v)
125     {
126         lock.lock();
127         while (size >= max_size)
128         {
129             condition.wait(lock);
130         }
131         tasks[back] = v;
132         back++;
133         if (back == max_size)
134             back = 0;
135         size++;
136         lock.unlock();
137 
138         condition.signal();
139     }
140 
get(KMPTask * & v)141     void get(KMPTask*& v)
142     {
143         lock.lock();
144         while (size == 0)
145         {
146             condition.wait(lock);
147         }
148         v = tasks[front];
149         front++;
150         if (front == max_size)
151             front = 0;
152         size--;
153         lock.unlock();
154 
155         condition.signal();
156     }
157 
158 private:
159     Mutex lock;
160     ConditionVariable condition;
161 
162     // ring buffer queue
163     int max_size;
164     KMPTask** tasks;
165     int size;
166     int front;
167     int back;
168 };
169 
170 class KMPGlobal
171 {
172 public:
KMPGlobal()173     KMPGlobal()
174     {
175         kmp_max_threads = 0;
176         kmp_threads = 0;
177         kmp_threads_tid = 0;
178         kmp_task_queue = 0;
179     }
180 
~KMPGlobal()181     ~KMPGlobal()
182     {
183         deinit();
184     }
185 
try_init()186     void try_init()
187     {
188         pthread_once(&is_initialized, init_g_kmp_global);
189     }
190 
191 public:
192     static pthread_once_t is_initialized;
193 
init()194     void init()
195     {
196         // NCNN_LOGE("KMPGlobal init");
197         kmp_max_threads = ncnn::get_cpu_count();
198 
199         kmp_task_queue = new ncnn::KMPTaskQueue(std::max(kmp_max_threads * 4, 16));
200 
201         if (kmp_max_threads > 1)
202         {
203             kmp_threads = new ncnn::Thread*[kmp_max_threads - 1];
204             kmp_threads_tid = new int[kmp_max_threads - 1];
205             for (int i = 0; i < kmp_max_threads - 1; i++)
206             {
207                 kmp_threads_tid[i] = i + 1;
208                 kmp_threads[i] = new ncnn::Thread(kmp_threadfunc, (void*)&kmp_threads_tid[i]);
209             }
210         }
211     }
212 
deinit()213     void deinit()
214     {
215         // NCNN_LOGE("KMPGlobal deinit");
216         if (kmp_max_threads > 1)
217         {
218             // TODO portable stack allocation
219             ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((kmp_max_threads - 1) * sizeof(ncnn::KMPTask));
220             for (int i = 0; i < kmp_max_threads - 1; i++)
221             {
222                 tasks[i].fn = 0;
223                 tasks[i].argc = 0;
224                 tasks[i].argv = (void**)0;
225                 tasks[i].num_threads = kmp_max_threads;
226                 tasks[i].thread_num = i + 1;
227                 tasks[i].num_threads_to_wait = 0;
228                 tasks[i].finish_lock = 0;
229                 tasks[i].finish_condition = 0;
230             }
231 
232             // dispatch 1 ~ kmp_max_threads
233             kmp_task_queue->dispatch(tasks, kmp_max_threads - 1);
234 
235             for (int i = 0; i < kmp_max_threads - 1; i++)
236             {
237 #ifndef __EMSCRIPTEN__
238                 // FIXME emscripten complains
239                 // pthread_join attempted on thread 12345678,
240                 // which does not point to a valid thread, or does not exist anymore!
241                 kmp_threads[i]->join();
242 #endif
243                 delete kmp_threads[i];
244             }
245             delete[] kmp_threads;
246             delete[] kmp_threads_tid;
247         }
248 
249         delete kmp_task_queue;
250     }
251 
252 public:
253     int kmp_max_threads;
254     ncnn::Thread** kmp_threads;
255     int* kmp_threads_tid;
256     ncnn::KMPTaskQueue* kmp_task_queue;
257 };
258 
259 } // namespace ncnn
260 
261 pthread_once_t ncnn::KMPGlobal::is_initialized = PTHREAD_ONCE_INIT;
262 
263 static ncnn::KMPGlobal g_kmp_global;
264 
265 static ncnn::ThreadLocalStorage tls_num_threads;
266 static ncnn::ThreadLocalStorage tls_thread_num;
267 
init_g_kmp_global()268 static void init_g_kmp_global()
269 {
270     g_kmp_global.init();
271 }
272 
273 #ifdef __cplusplus
274 extern "C" {
275 #endif
276 
omp_get_max_threads()277 int omp_get_max_threads()
278 {
279     return ncnn::get_cpu_count();
280 }
281 
omp_get_dynamic()282 int omp_get_dynamic()
283 {
284     return 1;
285 }
286 
omp_set_dynamic(int)287 void omp_set_dynamic(int /*dynamic*/)
288 {
289     // always dynamic, ignore
290 }
291 
omp_set_num_threads(int num_threads)292 void omp_set_num_threads(int num_threads)
293 {
294     tls_num_threads.set(reinterpret_cast<void*>((size_t)std::max(num_threads, 1)));
295 }
296 
omp_get_num_threads()297 int omp_get_num_threads()
298 {
299     return std::max((int)reinterpret_cast<size_t>(tls_num_threads.get()), 1);
300 }
301 
omp_get_thread_num()302 int omp_get_thread_num()
303 {
304     return (int)reinterpret_cast<size_t>(tls_thread_num.get());
305 }
306 
kmp_get_blocktime()307 int kmp_get_blocktime()
308 {
309     return 0;
310 }
311 
kmp_set_blocktime(int)312 void kmp_set_blocktime(int /*blocktime*/)
313 {
314     // always passive, ignore
315 }
316 
kmp_invoke_microtask(kmpc_micro fn,int gtid,int tid,int argc,void ** argv)317 static int kmp_invoke_microtask(kmpc_micro fn, int gtid, int tid, int argc, void** argv)
318 {
319     // fprintf(stderr, "__kmp_invoke_microtask #%lu %d %d %d\n", gettid(), gtid, tid, argc);
320 
321     switch (argc)
322     {
323     case 0:
324         (*(kmpc_micro_0)fn)(&gtid, &tid);
325         break;
326     case 1:
327         (*(kmpc_micro_1)fn)(&gtid, &tid, argv[0]);
328         break;
329     case 2:
330         (*(kmpc_micro_2)fn)(&gtid, &tid, argv[0], argv[1]);
331         break;
332     case 3:
333         (*(kmpc_micro_3)fn)(&gtid, &tid, argv[0], argv[1], argv[2]);
334         break;
335     case 4:
336         (*(kmpc_micro_4)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3]);
337         break;
338     case 5:
339         (*(kmpc_micro_5)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4]);
340         break;
341     case 6:
342         (*(kmpc_micro_6)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]);
343         break;
344     case 7:
345         (*(kmpc_micro_7)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]);
346         break;
347     case 8:
348         (*(kmpc_micro_8)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7]);
349         break;
350     case 9:
351         (*(kmpc_micro_9)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8]);
352         break;
353     case 10:
354         (*(kmpc_micro_10)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9]);
355         break;
356     case 11:
357         (*(kmpc_micro_11)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10]);
358         break;
359     case 12:
360         (*(kmpc_micro_12)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11]);
361         break;
362     case 13:
363         (*(kmpc_micro_13)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12]);
364         break;
365     case 14:
366         (*(kmpc_micro_14)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12], argv[13]);
367         break;
368     case 15:
369         (*(kmpc_micro_15)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12], argv[13], argv[14]);
370         break;
371     default:
372         // assert never reach here
373         break;
374     }
375 
376     return 0;
377 }
378 
kmp_threadfunc(void * args)379 static void* kmp_threadfunc(void* args)
380 {
381     int tid = *(int*)args;
382 
383     for (;;)
384     {
385         ncnn::KMPTask* task;
386         g_kmp_global.kmp_task_queue->get(task);
387 
388         // fprintf(stderr, "get %d\n", tid);
389 
390         if (!task->fn)
391             break;
392 
393         tls_num_threads.set(reinterpret_cast<void*>((size_t)task->num_threads));
394         tls_thread_num.set(reinterpret_cast<void*>((size_t)task->thread_num));
395 
396         kmp_invoke_microtask(task->fn, task->thread_num, tid, task->argc, task->argv);
397 
398         // update finished
399         {
400             task->finish_lock->lock();
401             *task->num_threads_to_wait = *task->num_threads_to_wait - 1;
402             if (*task->num_threads_to_wait == 0)
403             {
404                 task->finish_condition->signal();
405             }
406             task->finish_lock->unlock();
407         }
408     }
409 
410     // fprintf(stderr, "exit\n");
411     return 0;
412 }
413 
__kmpc_global_thread_num(void *)414 int32_t __kmpc_global_thread_num(void* /*loc*/)
415 {
416     // NCNN_LOGE("__kmpc_global_thread_num");
417     return 0;
418 }
419 
__kmpc_push_num_threads(void *,int32_t,int32_t num_threads)420 void __kmpc_push_num_threads(void* /*loc*/, int32_t /*gtid*/, int32_t num_threads)
421 {
422     // NCNN_LOGE("__kmpc_push_num_threads %d", num_threads);
423     omp_set_num_threads(num_threads);
424 }
425 
__kmpc_fork_call(void *,int32_t argc,kmpc_micro fn,...)426 void __kmpc_fork_call(void* /*loc*/, int32_t argc, kmpc_micro fn, ...)
427 {
428     g_kmp_global.try_init();
429 
430     // NCNN_LOGE("__kmpc_fork_call %d", argc);
431     int num_threads = omp_get_num_threads();
432 
433     // build argv
434     void* argv[16];
435     {
436         va_list ap;
437         va_start(ap, fn);
438         for (int i = 0; i < argc; i++)
439             argv[i] = va_arg(ap, void*);
440         va_end(ap);
441     }
442 
443     if (g_kmp_global.kmp_max_threads == 1 || num_threads == 1)
444     {
445         for (int i = 0; i < num_threads; i++)
446         {
447             tls_thread_num.set(reinterpret_cast<void*>((size_t)i));
448 
449             kmp_invoke_microtask(fn, 0, 0, argc, argv);
450         }
451 
452         return;
453     }
454 
455     int num_threads_to_wait = num_threads - 1;
456     ncnn::Mutex finish_lock;
457     ncnn::ConditionVariable finish_condition;
458 
459     // TODO portable stack allocation
460     ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((num_threads - 1) * sizeof(ncnn::KMPTask));
461     for (int i = 0; i < num_threads - 1; i++)
462     {
463         tasks[i].fn = fn;
464         tasks[i].argc = argc;
465         tasks[i].argv = (void**)argv;
466         tasks[i].num_threads = num_threads;
467         tasks[i].thread_num = i + 1;
468         tasks[i].num_threads_to_wait = &num_threads_to_wait;
469         tasks[i].finish_lock = &finish_lock;
470         tasks[i].finish_condition = &finish_condition;
471     }
472 
473     // dispatch 1 ~ num_threads
474     g_kmp_global.kmp_task_queue->dispatch(tasks, num_threads - 1);
475 
476     // dispatch 0
477     {
478         tls_thread_num.set(reinterpret_cast<void*>((size_t)0));
479 
480         kmp_invoke_microtask(fn, 0, 0, argc, argv);
481     }
482 
483     // wait for finished
484     {
485         finish_lock.lock();
486         if (num_threads_to_wait != 0)
487         {
488             finish_condition.wait(finish_lock);
489         }
490         finish_lock.unlock();
491     }
492 }
493 
__kmpc_for_static_init_4(void *,int32_t gtid,int32_t,int32_t * last,int32_t * lower,int32_t * upper,int32_t *,int32_t,int32_t)494 void __kmpc_for_static_init_4(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, int32_t* lower, int32_t* upper, int32_t* /*stride*/, int32_t /*incr*/, int32_t /*chunk*/)
495 {
496     // NCNN_LOGE("__kmpc_for_static_init_4");
497     int num_threads = omp_get_num_threads();
498 
499     // TODO only support i++
500     int32_t count = *upper - *lower + 1;
501     int32_t threads = std::min(count, (int32_t)num_threads);
502     int32_t count_per_thread = count / threads;
503     int32_t remain = count % threads;
504 
505     *last = gtid == (int32_t)(threads - 1);
506     *lower = gtid * count_per_thread + std::min(remain, gtid);
507     *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, gtid + 1) - 1, *upper);
508 }
509 
__kmpc_for_static_init_4u(void *,int32_t gtid,int32_t,int32_t * last,uint32_t * lower,uint32_t * upper,int32_t *,int32_t,int32_t)510 void __kmpc_for_static_init_4u(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, uint32_t* lower, uint32_t* upper, int32_t* /*stride*/, int32_t /*incr*/, int32_t /*chunk*/)
511 {
512     // NCNN_LOGE("__kmpc_for_static_init_4u");
513     int num_threads = omp_get_num_threads();
514 
515     // TODO only support i++
516     uint32_t count = *upper - *lower + 1;
517     uint32_t threads = std::min(count, (uint32_t)num_threads);
518     uint32_t count_per_thread = count / threads;
519     uint32_t remain = count % threads;
520 
521     *last = gtid == (int32_t)(threads - 1);
522     *lower = gtid * count_per_thread + std::min(remain, (uint32_t)gtid);
523     *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (uint32_t)gtid + 1) - 1, *upper);
524 }
525 
__kmpc_for_static_init_8(void *,int32_t gtid,int32_t,int32_t * last,int64_t * lower,int64_t * upper,int64_t *,int64_t,int64_t)526 void __kmpc_for_static_init_8(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, int64_t* lower, int64_t* upper, int64_t* /*stride*/, int64_t /*incr*/, int64_t /*chunk*/)
527 {
528     // NCNN_LOGE("__kmpc_for_static_init_8");
529     int num_threads = omp_get_num_threads();
530 
531     // TODO only support i++
532     int64_t count = *upper - *lower + 1;
533     int64_t threads = std::min(count, (int64_t)num_threads);
534     int64_t count_per_thread = count / threads;
535     int64_t remain = count % threads;
536 
537     *last = gtid == (int64_t)(threads - 1);
538     *lower = gtid * count_per_thread + std::min(remain, (int64_t)gtid);
539     *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (int64_t)gtid + 1) - 1, *upper);
540 }
541 
__kmpc_for_static_init_8u(void *,int32_t gtid,int32_t,int32_t * last,uint64_t * lower,uint64_t * upper,int64_t *,int64_t,int64_t)542 void __kmpc_for_static_init_8u(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, uint64_t* lower, uint64_t* upper, int64_t* /*stride*/, int64_t /*incr*/, int64_t /*chunk*/)
543 {
544     // NCNN_LOGE("__kmpc_for_static_init_8u");
545     int num_threads = omp_get_num_threads();
546 
547     // TODO only support i++
548     uint64_t count = *upper - *lower + 1;
549     uint64_t threads = std::min(count, (uint64_t)num_threads);
550     uint64_t count_per_thread = count / threads;
551     uint64_t remain = count % threads;
552 
553     *last = gtid == (int64_t)(threads - 1);
554     *lower = gtid * count_per_thread + std::min(remain, (uint64_t)gtid);
555     *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (uint64_t)gtid + 1) - 1, *upper);
556 }
557 
__kmpc_for_static_fini(void *,int32_t gtid)558 void __kmpc_for_static_fini(void* /*loc*/, int32_t gtid)
559 {
560     // NCNN_LOGE("__kmpc_for_static_fini");
561     (void)gtid;
562 }
563 
564 #ifdef __cplusplus
565 } // extern "C"
566 #endif
567 
568 #endif // NCNN_SIMPLEOMP
569