1 /*
2     Copyright (c) 2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #define TBB_PREVIEW_MUTEXES 1
18 #define TBB_PREVIEW_WAITING_FOR_WORKERS 1
19 #define TBB_PREVIEW_TASK_GROUP_EXTENSIONS 1
20 
21 #include "common/config.h"
22 
23 #include <oneapi/tbb/task_arena.h>
24 #include <oneapi/tbb/concurrent_vector.h>
25 #include <oneapi/tbb/rw_mutex.h>
26 #include <oneapi/tbb/task_group.h>
27 #include <oneapi/tbb/parallel_for.h>
28 
29 #include <oneapi/tbb/global_control.h>
30 
31 #include "common/test.h"
32 #include "common/utils.h"
33 #include "common/utils_concurrency_limit.h"
34 #include "common/spin_barrier.h"
35 
36 #include <stdlib.h> // C11/POSIX aligned_alloc
37 #include <random>
38 
39 
40 //! \file test_scheduler_mix.cpp
41 //! \brief Test for [scheduler.task_arena scheduler.task_scheduler_observer] specification
42 
43 const std::uint64_t maxNumActions = 1 * 100 * 1000;
44 static std::atomic<std::uint64_t> globalNumActions{};
45 
46 //using Random = utils::FastRandom<>;
47 class Random {
48     struct State {
49         std::random_device rd;
50         std::mt19937 gen;
51         std::uniform_int_distribution<> dist;
52 
StateRandom::State53         State() : gen(rd()), dist(0, std::numeric_limits<unsigned short>::max()) {}
54 
getRandom::State55         int get() {
56             return dist(gen);
57         }
58     };
59     static thread_local State* mState;
60     tbb::concurrent_vector<State*> mStateList;
61 public:
~Random()62     ~Random() {
63         for (auto s : mStateList) {
64             delete s;
65         }
66     }
67 
get()68     int get() {
69         auto& s = mState;
70         if (!s) {
71             s = new State;
72             mStateList.push_back(s);
73         }
74         return s->get();
75     }
76 };
77 
78 thread_local Random::State* Random::mState = nullptr;
79 
80 
aligned_malloc(std::size_t alignment,std::size_t size)81 void* aligned_malloc(std::size_t alignment, std::size_t size) {
82 #if _WIN32
83     return _aligned_malloc(size, alignment);
84 #elif __unix__ || __APPLE__
85     void* ptr{};
86     int res = posix_memalign(&ptr, alignment, size);
87     CHECK(res == 0);
88     return ptr;
89 #else
90     return aligned_alloc(alignment, size);
91 #endif
92 }
93 
aligned_free(void * ptr)94 void aligned_free(void* ptr) {
95 #if _WIN32
96     _aligned_free(ptr);
97 #else
98     free(ptr);
99 #endif
100 }
101 
102 template <typename T, std::size_t Alignment>
103 class PtrRWMutex {
104     static const std::size_t maxThreads = (Alignment >> 1) - 1;
105     static const std::uintptr_t READER_MASK = maxThreads;       // 7F..
106     static const std::uintptr_t LOCKED = Alignment - 1;         // FF..
107     static const std::uintptr_t LOCKED_MASK = LOCKED;           // FF..
108     static const std::uintptr_t LOCK_PENDING = READER_MASK + 1; // 80..
109 
110     std::atomic<std::uintptr_t> mState;
111 
pointer()112     T* pointer() {
113         return reinterpret_cast<T*>(state() & ~LOCKED_MASK);
114     }
115 
state()116     std::uintptr_t state() {
117         return mState.load(std::memory_order_relaxed);
118     }
119 
120 public:
121     class ScopedLock {
122     public:
ScopedLock()123         constexpr ScopedLock() : mMutex(nullptr), mIsWriter(false) {}
124         //! Acquire lock on given mutex.
ScopedLock(PtrRWMutex & m,bool write=true)125         ScopedLock(PtrRWMutex& m, bool write = true) : mMutex(nullptr) {
126             CHECK_FAST(write == true);
127             acquire(m);
128         }
129         //! Release lock (if lock is held).
~ScopedLock()130         ~ScopedLock() {
131             if (mMutex) {
132                 release();
133             }
134         }
135         //! No Copy
136         ScopedLock(const ScopedLock&) = delete;
137         ScopedLock& operator=(const ScopedLock&) = delete;
138 
139         //! Acquire lock on given mutex.
acquire(PtrRWMutex & m)140         void acquire(PtrRWMutex& m) {
141             CHECK_FAST(mMutex == nullptr);
142             mIsWriter = true;
143             mMutex = &m;
144             mMutex->lock();
145         }
146 
147         //! Try acquire lock on given mutex.
tryAcquire(PtrRWMutex & m,bool write=true)148         bool tryAcquire(PtrRWMutex& m, bool write = true) {
149             bool succeed = write ? m.tryLock() : m.tryLockShared();
150             if (succeed) {
151                 mMutex = &m;
152                 mIsWriter = write;
153             }
154             return succeed;
155         }
156 
clear()157         void clear() {
158             CHECK_FAST(mMutex != nullptr);
159             CHECK_FAST(mIsWriter);
160             PtrRWMutex* m = mMutex;
161             mMutex = nullptr;
162             m->clear();
163         }
164 
165         //! Release lock.
release()166         void release() {
167             CHECK_FAST(mMutex != nullptr);
168             PtrRWMutex* m = mMutex;
169             mMutex = nullptr;
170 
171             if (mIsWriter) {
172                 m->unlock();
173             }
174             else {
175                 m->unlockShared();
176             }
177         }
178     protected:
179         PtrRWMutex* mMutex{};
180         bool mIsWriter{};
181     };
182 
trySet(T * ptr)183     bool trySet(T* ptr) {
184         auto p = reinterpret_cast<std::uintptr_t>(ptr);
185         CHECK_FAST((p & (Alignment - 1)) == 0);
186         if (!state()) {
187             std::uintptr_t expected = 0;
188             if (mState.compare_exchange_strong(expected, p)) {
189                 return true;
190             }
191         }
192         return false;
193     }
194 
clear()195     void clear() {
196         CHECK_FAST((state() & LOCKED_MASK) == LOCKED);
197         mState.store(0, std::memory_order_relaxed);
198     }
199 
tryLock()200     bool tryLock() {
201         auto v = state();
202         if (v == 0) {
203             return false;
204         }
205         CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
206         if ((v & READER_MASK) == 0) {
207             if (mState.compare_exchange_strong(v, v | LOCKED)) {
208                 return true;
209             }
210         }
211         return false;
212     }
213 
tryLockShared()214     bool tryLockShared() {
215         auto v = state();
216         if (v == 0) {
217             return false;
218         }
219         CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
220         if ((v & LOCKED_MASK) != LOCKED && (v & LOCK_PENDING) == 0) {
221             if (mState.compare_exchange_strong(v, v + 1)) {
222                 return true;
223             }
224         }
225         return false;
226     }
227 
lock()228     void lock() {
229         auto v = state();
230         mState.compare_exchange_strong(v, v | LOCK_PENDING);
231         while (!tryLock()) {
232             utils::yield();
233         }
234     }
235 
unlock()236     void unlock() {
237         auto v = state();
238         CHECK_FAST((v & LOCKED_MASK) == LOCKED);
239         mState.store(v & ~LOCKED, std::memory_order_release);
240     }
241 
unlockShared()242     void unlockShared() {
243         auto v = state();
244         CHECK_FAST((v & LOCKED_MASK) != LOCKED);
245         CHECK_FAST((v & READER_MASK) > 0);
246         mState -= 1;
247     }
248 
operator bool() const249     operator bool() const {
250         return pointer() != 0;
251     }
252 
get()253     T* get() {
254         return pointer();
255     }
256 };
257 
258 class Statistics {
259 public:
260     enum ACTION {
261         ArenaCreate,
262         ArenaDestroy,
263         ArenaAcquire,
264         skippedArenaCreate,
265         skippedArenaDestroy,
266         skippedArenaAcquire,
267         ParallelAlgorithm,
268         ArenaEnqueue,
269         ArenaExecute,
270         numActions
271     };
272 
273     static const char* const mStatNames[numActions];
274 private:
275     struct StatType {
StatTypeStatistics::StatType276         StatType() : mCounters() {}
277         std::array<std::uint64_t, numActions> mCounters;
278     };
279 
280     tbb::concurrent_vector<StatType*> mStatsList;
281     static thread_local StatType* mStats;
282 
get()283     StatType& get() {
284         auto& s = mStats;
285         if (!s) {
286             s = new StatType;
287             mStatsList.push_back(s);
288         }
289         return *s;
290     }
291 public:
~Statistics()292     ~Statistics() {
293         for (auto s : mStatsList) {
294             delete s;
295         }
296     }
297 
notify(ACTION a)298     void notify(ACTION a) {
299         ++get().mCounters[a];
300     }
301 
report()302     void report() {
303         StatType summary;
304         for (auto& s : mStatsList) {
305             for (int i = 0; i < numActions; ++i) {
306                 summary.mCounters[i] += s->mCounters[i];
307             }
308         }
309         std::cout << std::endl << "Statistics:" << std::endl;
310         std::cout << "Total actions: " << globalNumActions << std::endl;
311         for (int i = 0; i < numActions; ++i) {
312             std::cout << mStatNames[i] << ": " << summary.mCounters[i] << std::endl;
313         }
314     }
315 };
316 
317 
318 const char* const Statistics::mStatNames[Statistics::numActions] = {
319     "Arena create", "Arena destroy", "Arena acquire",
320     "Skipped arena create", "Skipped arena destroy", "Skipped arena acquire",
321     "Parallel algorithm", "Arena enqueue", "Arena execute"
322 };
323 thread_local Statistics::StatType* Statistics::mStats;
324 
325 static Statistics gStats;
326 
327 class LifetimeTracker {
328 public:
329     LifetimeTracker() = default;
330 
331     class Guard {
332     public:
Guard(LifetimeTracker * obj)333         Guard(LifetimeTracker* obj) {
334             if (!(obj->mReferences.load(std::memory_order_relaxed) & SHUTDOWN_FLAG)) {
335                 if (obj->mReferences.fetch_add(REFERENCE_FLAG) & SHUTDOWN_FLAG) {
336                     obj->mReferences.fetch_sub(REFERENCE_FLAG);
337                 } else {
338                     mObj = obj;
339                 }
340             }
341         }
342 
Guard(Guard && ing)343         Guard(Guard&& ing) : mObj(ing.mObj) {
344             ing.mObj = nullptr;
345         }
346 
~Guard()347         ~Guard() {
348             if (mObj != nullptr) {
349                 mObj->mReferences.fetch_sub(REFERENCE_FLAG);
350             }
351         }
352 
continue_execution()353         bool continue_execution() {
354             return mObj != nullptr;
355         }
356 
357     private:
358         LifetimeTracker* mObj{nullptr};
359     };
360 
makeGuard()361     Guard makeGuard() {
362         return Guard(this);
363     }
364 
signalShutdown()365     void signalShutdown() {
366         mReferences.fetch_add(SHUTDOWN_FLAG);
367     }
368 
waitCompletion()369     void waitCompletion() {
370         utils::SpinWaitUntilEq(mReferences, SHUTDOWN_FLAG);
371     }
372 
373 private:
374     friend class Guard;
375     static constexpr std::uintptr_t SHUTDOWN_FLAG = 1;
376     static constexpr std::uintptr_t REFERENCE_FLAG = 1 << 1;
377     std::atomic<std::uintptr_t> mReferences{};
378 };
379 
380 class ArenaTable {
381     static const std::size_t maxArenas = 64;
382     static const std::size_t maxThreads = 1 << 9;
383     static const std::size_t arenaAligment = maxThreads << 1;
384 
385     using ArenaPtrRWMutex = PtrRWMutex<tbb::task_arena, arenaAligment>;
386     std::array<ArenaPtrRWMutex, maxArenas> mArenaTable;
387 
388     struct ThreadState {
389         bool lockedArenas[maxArenas]{};
390         int arenaIdxStack[maxArenas];
391         int level{};
392     };
393 
394     LifetimeTracker mLifetimeTracker{};
395     static thread_local ThreadState mThreadState;
396 
397     template <typename F>
find_arena(std::size_t start,F f)398     auto find_arena(std::size_t start, F f) -> decltype(f(std::declval<ArenaPtrRWMutex&>(), std::size_t{})) {
399         for (std::size_t idx = start, i = 0; i < maxArenas; ++i, idx = (idx + 1) % maxArenas) {
400             auto res = f(mArenaTable[idx], idx);
401             if (res) {
402                 return res;
403             }
404         }
405         return {};
406     }
407 
408 public:
409     using ScopedLock = ArenaPtrRWMutex::ScopedLock;
410 
create(Random & rnd)411     void create(Random& rnd) {
412         auto guard = mLifetimeTracker.makeGuard();
413         if (guard.continue_execution()) {
414             int num_threads = rnd.get() % utils::get_platform_max_threads() + 1;
415             unsigned int num_reserved = rnd.get() % num_threads;
416             tbb::task_arena::priority priorities[] = { tbb::task_arena::priority::low , tbb::task_arena::priority::normal, tbb::task_arena::priority::high };
417             tbb::task_arena::priority priority = priorities[rnd.get() % 3];
418 
419             tbb::task_arena* a = new (aligned_malloc(arenaAligment, arenaAligment)) tbb::task_arena{ num_threads , num_reserved , priority };
420 
421             if (!find_arena(rnd.get() % maxArenas, [a](ArenaPtrRWMutex& arena, std::size_t) -> bool {
422                     if (arena.trySet(a)) {
423                         return true;
424                     }
425                     return false;
426                 }))
427             {
428                 gStats.notify(Statistics::skippedArenaCreate);
429                 a->~task_arena();
430                 aligned_free(a);
431             }
432         }
433     }
434 
destroy(Random & rnd)435     void destroy(Random& rnd) {
436         auto guard = mLifetimeTracker.makeGuard();
437         if (guard.continue_execution()) {
438             auto& ts = mThreadState;
439             if (!find_arena(rnd.get() % maxArenas, [&ts](ArenaPtrRWMutex& arena, std::size_t idx) {
440                     if (!ts.lockedArenas[idx]) {
441                         ScopedLock lock;
442                         if (lock.tryAcquire(arena, true)) {
443                             auto a = arena.get();
444                             lock.clear();
445                             a->~task_arena();
446                             aligned_free(a);
447                             return true;
448                         }
449                     }
450                     return false;
451                 }))
452             {
453                 gStats.notify(Statistics::skippedArenaDestroy);
454             }
455         }
456     }
457 
shutdown()458     void shutdown() {
459         mLifetimeTracker.signalShutdown();
460         mLifetimeTracker.waitCompletion();
461         find_arena(0, [](ArenaPtrRWMutex& arena, std::size_t) {
462             if (arena.get()) {
463                 ScopedLock lock{ arena, true };
464                 auto a = arena.get();
465                 lock.clear();
466                 a->~task_arena();
467                 aligned_free(a);
468             }
469             return false;
470         });
471     }
472 
acquire(Random & rnd,ScopedLock & lock)473     std::pair<tbb::task_arena*, std::size_t> acquire(Random& rnd, ScopedLock& lock) {
474         auto guard = mLifetimeTracker.makeGuard();
475 
476         tbb::task_arena* a{nullptr};
477         std::size_t resIdx{};
478         if (guard.continue_execution()) {
479             auto& ts = mThreadState;
480             a = find_arena(rnd.get() % maxArenas,
481                 [&ts, &lock, &resIdx](ArenaPtrRWMutex& arena, std::size_t idx) -> tbb::task_arena* {
482                     if (!ts.lockedArenas[idx]) {
483                         if (lock.tryAcquire(arena, false)) {
484                             ts.lockedArenas[idx] = true;
485                             ts.arenaIdxStack[ts.level++] = int(idx);
486                             resIdx = idx;
487                             return arena.get();
488                         }
489                     }
490                     return nullptr;
491                 });
492             if (!a) {
493                 gStats.notify(Statistics::skippedArenaAcquire);
494             }
495         }
496         return { a, resIdx };
497     }
498 
release(ScopedLock & lock)499     void release(ScopedLock& lock) {
500         auto& ts = mThreadState;
501         CHECK_FAST(ts.level > 0);
502         auto idx = ts.arenaIdxStack[--ts.level];
503         CHECK_FAST(ts.lockedArenas[idx]);
504         ts.lockedArenas[idx] = false;
505         lock.release();
506     }
507 };
508 
509 thread_local ArenaTable::ThreadState ArenaTable::mThreadState;
510 
511 static ArenaTable arenaTable;
512 static Random threadRandom;
513 
514 enum ACTIONS {
515     arena_create,
516     arena_destroy,
517     arena_action,
518     parallel_algorithm,
519     // TODO:
520     // observer_attach,
521     // observer_detach,
522     // flow_graph,
523     // task_group,
524     // resumable_tasks,
525 
526     num_actions
527 };
528 
529 void global_actor();
530 
531 template <ACTIONS action>
532 struct actor;
533 
534 template <>
535 struct actor<arena_create> {
do_itactor536     static void do_it(Random& r) {
537         arenaTable.create(r);
538     }
539 };
540 
541 template <>
542 struct actor<arena_destroy> {
do_itactor543     static void do_it(Random& r) {
544         arenaTable.destroy(r);
545     }
546 };
547 
548 template <>
549 struct actor<arena_action> {
do_itactor550     static void do_it(Random& r) {
551         static thread_local std::size_t arenaLevel = 0;
552         ArenaTable::ScopedLock lock;
553         auto entry = arenaTable.acquire(r, lock);
554         if (entry.first) {
555             enum arena_actions {
556                 arena_execute,
557                 arena_enqueue,
558                 num_arena_actions
559             };
560             auto process = r.get() % 2;
561             auto body = [process] {
562                 if (process) {
563                     tbb::detail::d1::wait_context wctx{ 1 };
564                     tbb::task_group_context ctx;
565                     tbb::this_task_arena::enqueue([&wctx] { wctx.release(); });
566                     tbb::detail::d1::wait(wctx, ctx);
567                 } else {
568                     global_actor();
569                 }
570             };
571             switch (r.get() % (16*num_arena_actions)) {
572             case arena_execute:
573                 if (entry.second > arenaLevel) {
574                     gStats.notify(Statistics::ArenaExecute);
575                     auto oldArenaLevel = arenaLevel;
576                     arenaLevel = entry.second;
577                     entry.first->execute(body);
578                     arenaLevel = oldArenaLevel;
579                     break;
580                 }
581                 utils_fallthrough;
582             case arena_enqueue:
583                 utils_fallthrough;
584             default:
585                 gStats.notify(Statistics::ArenaEnqueue);
586                 entry.first->enqueue([] { global_actor(); });
587                 break;
588             }
589             arenaTable.release(lock);
590         }
591     }
592 };
593 
594 template <>
595 struct actor<parallel_algorithm> {
do_itactor596     static void do_it(Random& rnd) {
597         enum PARTITIONERS {
598             simpl_part,
599             auto_part,
600             aff_part,
601             static_part,
602             num_parts
603         };
604         int sz = rnd.get() % 10000;
605         auto doGlbAction = rnd.get() % 1000 == 42;
606         auto body = [doGlbAction, sz](int i) {
607             if (i == sz / 2 && doGlbAction) {
608                 global_actor();
609             }
610         };
611 
612         switch (rnd.get() % num_parts) {
613         case simpl_part:
614             tbb::parallel_for(0, sz, body, tbb::simple_partitioner{}); break;
615         case auto_part:
616             tbb::parallel_for(0, sz, body, tbb::auto_partitioner{}); break;
617         case aff_part:
618         {
619             tbb::affinity_partitioner aff;
620             tbb::parallel_for(0, sz, body, aff); break;
621         }
622         case static_part:
623             tbb::parallel_for(0, sz, body, tbb::static_partitioner{}); break;
624         }
625     }
626 };
627 
global_actor()628 void global_actor() {
629     static thread_local std::uint64_t localNumActions{};
630 
631     while (globalNumActions < maxNumActions) {
632         auto& rnd = threadRandom;
633         switch (rnd.get() % num_actions) {
634         case arena_create:  gStats.notify(Statistics::ArenaCreate); actor<arena_create>::do_it(rnd);  break;
635         case arena_destroy: gStats.notify(Statistics::ArenaDestroy); actor<arena_destroy>::do_it(rnd); break;
636         case arena_action:  gStats.notify(Statistics::ArenaAcquire); actor<arena_action>::do_it(rnd);  break;
637         case parallel_algorithm: gStats.notify(Statistics::ParallelAlgorithm); actor<parallel_algorithm>::do_it(rnd);  break;
638         }
639 
640         if (++localNumActions == 100) {
641             localNumActions = 0;
642             globalNumActions += 100;
643 
644             // TODO: Enable statistics
645             // static std::mutex mutex;
646             // std::lock_guard<std::mutex> lock{ mutex };
647             // std::cout << globalNumActions << "\r" << std::flush;
648         }
649     }
650 }
651 
652 #if TBB_USE_EXCEPTIONS
653 //! \brief \ref stress
654 TEST_CASE("Stress test with mixing functionality") {
655     // TODO add thread recreation
656     // TODO: Enable statistics
657     // tbb::task_scheduler_handle handle = tbb::task_scheduler_handle::get();
658 
659     const std::size_t numExtraThreads = 16;
660     utils::SpinBarrier startBarrier{numExtraThreads};
__anon1f3e715f0902(std::size_t) 661     utils::NativeParallelFor(numExtraThreads, [&startBarrier](std::size_t) {
662         startBarrier.wait();
663         global_actor();
664     });
665 
666     arenaTable.shutdown();
667 
668     // tbb::finalize(handle, std::nothrow_t{});
669 
670     // gStats.report();
671 }
672 #endif
673