1 /*
2 Copyright (c) 2021 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #define TBB_PREVIEW_MUTEXES 1
18 #define TBB_PREVIEW_WAITING_FOR_WORKERS 1
19 #define TBB_PREVIEW_TASK_GROUP_EXTENSIONS 1
20
21 #include "common/config.h"
22
23 #include <oneapi/tbb/task_arena.h>
24 #include <oneapi/tbb/concurrent_vector.h>
25 #include <oneapi/tbb/rw_mutex.h>
26 #include <oneapi/tbb/task_group.h>
27 #include <oneapi/tbb/parallel_for.h>
28
29 #include <oneapi/tbb/global_control.h>
30
31 #include "common/test.h"
32 #include "common/utils.h"
33 #include "common/utils_concurrency_limit.h"
34 #include "common/spin_barrier.h"
35
36 #include <stdlib.h> // C11/POSIX aligned_alloc
37 #include <random>
38
39
40 //! \file test_scheduler_mix.cpp
41 //! \brief Test for [scheduler.task_arena scheduler.task_scheduler_observer] specification
42
43 const std::uint64_t maxNumActions = 1 * 100 * 1000;
44 static std::atomic<std::uint64_t> globalNumActions{};
45
46 //using Random = utils::FastRandom<>;
47 class Random {
48 struct State {
49 std::random_device rd;
50 std::mt19937 gen;
51 std::uniform_int_distribution<> dist;
52
StateRandom::State53 State() : gen(rd()), dist(0, std::numeric_limits<unsigned short>::max()) {}
54
getRandom::State55 int get() {
56 return dist(gen);
57 }
58 };
59 static thread_local State* mState;
60 tbb::concurrent_vector<State*> mStateList;
61 public:
~Random()62 ~Random() {
63 for (auto s : mStateList) {
64 delete s;
65 }
66 }
67
get()68 int get() {
69 auto& s = mState;
70 if (!s) {
71 s = new State;
72 mStateList.push_back(s);
73 }
74 return s->get();
75 }
76 };
77
78 thread_local Random::State* Random::mState = nullptr;
79
80
aligned_malloc(std::size_t alignment,std::size_t size)81 void* aligned_malloc(std::size_t alignment, std::size_t size) {
82 #if _WIN32
83 return _aligned_malloc(size, alignment);
84 #elif __unix__ || __APPLE__
85 void* ptr{};
86 int res = posix_memalign(&ptr, alignment, size);
87 CHECK(res == 0);
88 return ptr;
89 #else
90 return aligned_alloc(alignment, size);
91 #endif
92 }
93
aligned_free(void * ptr)94 void aligned_free(void* ptr) {
95 #if _WIN32
96 _aligned_free(ptr);
97 #else
98 free(ptr);
99 #endif
100 }
101
102 template <typename T, std::size_t Alignment>
103 class PtrRWMutex {
104 static const std::size_t maxThreads = (Alignment >> 1) - 1;
105 static const std::uintptr_t READER_MASK = maxThreads; // 7F..
106 static const std::uintptr_t LOCKED = Alignment - 1; // FF..
107 static const std::uintptr_t LOCKED_MASK = LOCKED; // FF..
108 static const std::uintptr_t LOCK_PENDING = READER_MASK + 1; // 80..
109
110 std::atomic<std::uintptr_t> mState;
111
pointer()112 T* pointer() {
113 return reinterpret_cast<T*>(state() & ~LOCKED_MASK);
114 }
115
state()116 std::uintptr_t state() {
117 return mState.load(std::memory_order_relaxed);
118 }
119
120 public:
121 class ScopedLock {
122 public:
ScopedLock()123 constexpr ScopedLock() : mMutex(nullptr), mIsWriter(false) {}
124 //! Acquire lock on given mutex.
ScopedLock(PtrRWMutex & m,bool write=true)125 ScopedLock(PtrRWMutex& m, bool write = true) : mMutex(nullptr) {
126 CHECK_FAST(write == true);
127 acquire(m);
128 }
129 //! Release lock (if lock is held).
~ScopedLock()130 ~ScopedLock() {
131 if (mMutex) {
132 release();
133 }
134 }
135 //! No Copy
136 ScopedLock(const ScopedLock&) = delete;
137 ScopedLock& operator=(const ScopedLock&) = delete;
138
139 //! Acquire lock on given mutex.
acquire(PtrRWMutex & m)140 void acquire(PtrRWMutex& m) {
141 CHECK_FAST(mMutex == nullptr);
142 mIsWriter = true;
143 mMutex = &m;
144 mMutex->lock();
145 }
146
147 //! Try acquire lock on given mutex.
tryAcquire(PtrRWMutex & m,bool write=true)148 bool tryAcquire(PtrRWMutex& m, bool write = true) {
149 bool succeed = write ? m.tryLock() : m.tryLockShared();
150 if (succeed) {
151 mMutex = &m;
152 mIsWriter = write;
153 }
154 return succeed;
155 }
156
clear()157 void clear() {
158 CHECK_FAST(mMutex != nullptr);
159 CHECK_FAST(mIsWriter);
160 PtrRWMutex* m = mMutex;
161 mMutex = nullptr;
162 m->clear();
163 }
164
165 //! Release lock.
release()166 void release() {
167 CHECK_FAST(mMutex != nullptr);
168 PtrRWMutex* m = mMutex;
169 mMutex = nullptr;
170
171 if (mIsWriter) {
172 m->unlock();
173 }
174 else {
175 m->unlockShared();
176 }
177 }
178 protected:
179 PtrRWMutex* mMutex{};
180 bool mIsWriter{};
181 };
182
trySet(T * ptr)183 bool trySet(T* ptr) {
184 auto p = reinterpret_cast<std::uintptr_t>(ptr);
185 CHECK_FAST((p & (Alignment - 1)) == 0);
186 if (!state()) {
187 std::uintptr_t expected = 0;
188 if (mState.compare_exchange_strong(expected, p)) {
189 return true;
190 }
191 }
192 return false;
193 }
194
clear()195 void clear() {
196 CHECK_FAST((state() & LOCKED_MASK) == LOCKED);
197 mState.store(0, std::memory_order_relaxed);
198 }
199
tryLock()200 bool tryLock() {
201 auto v = state();
202 if (v == 0) {
203 return false;
204 }
205 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
206 if ((v & READER_MASK) == 0) {
207 if (mState.compare_exchange_strong(v, v | LOCKED)) {
208 return true;
209 }
210 }
211 return false;
212 }
213
tryLockShared()214 bool tryLockShared() {
215 auto v = state();
216 if (v == 0) {
217 return false;
218 }
219 CHECK_FAST((v & LOCKED_MASK) == LOCKED || (v & READER_MASK) < maxThreads);
220 if ((v & LOCKED_MASK) != LOCKED && (v & LOCK_PENDING) == 0) {
221 if (mState.compare_exchange_strong(v, v + 1)) {
222 return true;
223 }
224 }
225 return false;
226 }
227
lock()228 void lock() {
229 auto v = state();
230 mState.compare_exchange_strong(v, v | LOCK_PENDING);
231 while (!tryLock()) {
232 utils::yield();
233 }
234 }
235
unlock()236 void unlock() {
237 auto v = state();
238 CHECK_FAST((v & LOCKED_MASK) == LOCKED);
239 mState.store(v & ~LOCKED, std::memory_order_release);
240 }
241
unlockShared()242 void unlockShared() {
243 auto v = state();
244 CHECK_FAST((v & LOCKED_MASK) != LOCKED);
245 CHECK_FAST((v & READER_MASK) > 0);
246 mState -= 1;
247 }
248
operator bool() const249 operator bool() const {
250 return pointer() != 0;
251 }
252
get()253 T* get() {
254 return pointer();
255 }
256 };
257
258 class Statistics {
259 public:
260 enum ACTION {
261 ArenaCreate,
262 ArenaDestroy,
263 ArenaAcquire,
264 skippedArenaCreate,
265 skippedArenaDestroy,
266 skippedArenaAcquire,
267 ParallelAlgorithm,
268 ArenaEnqueue,
269 ArenaExecute,
270 numActions
271 };
272
273 static const char* const mStatNames[numActions];
274 private:
275 struct StatType {
StatTypeStatistics::StatType276 StatType() : mCounters() {}
277 std::array<std::uint64_t, numActions> mCounters;
278 };
279
280 tbb::concurrent_vector<StatType*> mStatsList;
281 static thread_local StatType* mStats;
282
get()283 StatType& get() {
284 auto& s = mStats;
285 if (!s) {
286 s = new StatType;
287 mStatsList.push_back(s);
288 }
289 return *s;
290 }
291 public:
~Statistics()292 ~Statistics() {
293 for (auto s : mStatsList) {
294 delete s;
295 }
296 }
297
notify(ACTION a)298 void notify(ACTION a) {
299 ++get().mCounters[a];
300 }
301
report()302 void report() {
303 StatType summary;
304 for (auto& s : mStatsList) {
305 for (int i = 0; i < numActions; ++i) {
306 summary.mCounters[i] += s->mCounters[i];
307 }
308 }
309 std::cout << std::endl << "Statistics:" << std::endl;
310 std::cout << "Total actions: " << globalNumActions << std::endl;
311 for (int i = 0; i < numActions; ++i) {
312 std::cout << mStatNames[i] << ": " << summary.mCounters[i] << std::endl;
313 }
314 }
315 };
316
317
318 const char* const Statistics::mStatNames[Statistics::numActions] = {
319 "Arena create", "Arena destroy", "Arena acquire",
320 "Skipped arena create", "Skipped arena destroy", "Skipped arena acquire",
321 "Parallel algorithm", "Arena enqueue", "Arena execute"
322 };
323 thread_local Statistics::StatType* Statistics::mStats;
324
325 static Statistics gStats;
326
327 class LifetimeTracker {
328 public:
329 LifetimeTracker() = default;
330
331 class Guard {
332 public:
Guard(LifetimeTracker * obj)333 Guard(LifetimeTracker* obj) {
334 if (!(obj->mReferences.load(std::memory_order_relaxed) & SHUTDOWN_FLAG)) {
335 if (obj->mReferences.fetch_add(REFERENCE_FLAG) & SHUTDOWN_FLAG) {
336 obj->mReferences.fetch_sub(REFERENCE_FLAG);
337 } else {
338 mObj = obj;
339 }
340 }
341 }
342
Guard(Guard && ing)343 Guard(Guard&& ing) : mObj(ing.mObj) {
344 ing.mObj = nullptr;
345 }
346
~Guard()347 ~Guard() {
348 if (mObj != nullptr) {
349 mObj->mReferences.fetch_sub(REFERENCE_FLAG);
350 }
351 }
352
continue_execution()353 bool continue_execution() {
354 return mObj != nullptr;
355 }
356
357 private:
358 LifetimeTracker* mObj{nullptr};
359 };
360
makeGuard()361 Guard makeGuard() {
362 return Guard(this);
363 }
364
signalShutdown()365 void signalShutdown() {
366 mReferences.fetch_add(SHUTDOWN_FLAG);
367 }
368
waitCompletion()369 void waitCompletion() {
370 utils::SpinWaitUntilEq(mReferences, SHUTDOWN_FLAG);
371 }
372
373 private:
374 friend class Guard;
375 static constexpr std::uintptr_t SHUTDOWN_FLAG = 1;
376 static constexpr std::uintptr_t REFERENCE_FLAG = 1 << 1;
377 std::atomic<std::uintptr_t> mReferences{};
378 };
379
380 class ArenaTable {
381 static const std::size_t maxArenas = 64;
382 static const std::size_t maxThreads = 1 << 9;
383 static const std::size_t arenaAligment = maxThreads << 1;
384
385 using ArenaPtrRWMutex = PtrRWMutex<tbb::task_arena, arenaAligment>;
386 std::array<ArenaPtrRWMutex, maxArenas> mArenaTable;
387
388 struct ThreadState {
389 bool lockedArenas[maxArenas]{};
390 int arenaIdxStack[maxArenas];
391 int level{};
392 };
393
394 LifetimeTracker mLifetimeTracker{};
395 static thread_local ThreadState mThreadState;
396
397 template <typename F>
find_arena(std::size_t start,F f)398 auto find_arena(std::size_t start, F f) -> decltype(f(std::declval<ArenaPtrRWMutex&>(), std::size_t{})) {
399 for (std::size_t idx = start, i = 0; i < maxArenas; ++i, idx = (idx + 1) % maxArenas) {
400 auto res = f(mArenaTable[idx], idx);
401 if (res) {
402 return res;
403 }
404 }
405 return {};
406 }
407
408 public:
409 using ScopedLock = ArenaPtrRWMutex::ScopedLock;
410
create(Random & rnd)411 void create(Random& rnd) {
412 auto guard = mLifetimeTracker.makeGuard();
413 if (guard.continue_execution()) {
414 int num_threads = rnd.get() % utils::get_platform_max_threads() + 1;
415 unsigned int num_reserved = rnd.get() % num_threads;
416 tbb::task_arena::priority priorities[] = { tbb::task_arena::priority::low , tbb::task_arena::priority::normal, tbb::task_arena::priority::high };
417 tbb::task_arena::priority priority = priorities[rnd.get() % 3];
418
419 tbb::task_arena* a = new (aligned_malloc(arenaAligment, arenaAligment)) tbb::task_arena{ num_threads , num_reserved , priority };
420
421 if (!find_arena(rnd.get() % maxArenas, [a](ArenaPtrRWMutex& arena, std::size_t) -> bool {
422 if (arena.trySet(a)) {
423 return true;
424 }
425 return false;
426 }))
427 {
428 gStats.notify(Statistics::skippedArenaCreate);
429 a->~task_arena();
430 aligned_free(a);
431 }
432 }
433 }
434
destroy(Random & rnd)435 void destroy(Random& rnd) {
436 auto guard = mLifetimeTracker.makeGuard();
437 if (guard.continue_execution()) {
438 auto& ts = mThreadState;
439 if (!find_arena(rnd.get() % maxArenas, [&ts](ArenaPtrRWMutex& arena, std::size_t idx) {
440 if (!ts.lockedArenas[idx]) {
441 ScopedLock lock;
442 if (lock.tryAcquire(arena, true)) {
443 auto a = arena.get();
444 lock.clear();
445 a->~task_arena();
446 aligned_free(a);
447 return true;
448 }
449 }
450 return false;
451 }))
452 {
453 gStats.notify(Statistics::skippedArenaDestroy);
454 }
455 }
456 }
457
shutdown()458 void shutdown() {
459 mLifetimeTracker.signalShutdown();
460 mLifetimeTracker.waitCompletion();
461 find_arena(0, [](ArenaPtrRWMutex& arena, std::size_t) {
462 if (arena.get()) {
463 ScopedLock lock{ arena, true };
464 auto a = arena.get();
465 lock.clear();
466 a->~task_arena();
467 aligned_free(a);
468 }
469 return false;
470 });
471 }
472
acquire(Random & rnd,ScopedLock & lock)473 std::pair<tbb::task_arena*, std::size_t> acquire(Random& rnd, ScopedLock& lock) {
474 auto guard = mLifetimeTracker.makeGuard();
475
476 tbb::task_arena* a{nullptr};
477 std::size_t resIdx{};
478 if (guard.continue_execution()) {
479 auto& ts = mThreadState;
480 a = find_arena(rnd.get() % maxArenas,
481 [&ts, &lock, &resIdx](ArenaPtrRWMutex& arena, std::size_t idx) -> tbb::task_arena* {
482 if (!ts.lockedArenas[idx]) {
483 if (lock.tryAcquire(arena, false)) {
484 ts.lockedArenas[idx] = true;
485 ts.arenaIdxStack[ts.level++] = int(idx);
486 resIdx = idx;
487 return arena.get();
488 }
489 }
490 return nullptr;
491 });
492 if (!a) {
493 gStats.notify(Statistics::skippedArenaAcquire);
494 }
495 }
496 return { a, resIdx };
497 }
498
release(ScopedLock & lock)499 void release(ScopedLock& lock) {
500 auto& ts = mThreadState;
501 CHECK_FAST(ts.level > 0);
502 auto idx = ts.arenaIdxStack[--ts.level];
503 CHECK_FAST(ts.lockedArenas[idx]);
504 ts.lockedArenas[idx] = false;
505 lock.release();
506 }
507 };
508
509 thread_local ArenaTable::ThreadState ArenaTable::mThreadState;
510
511 static ArenaTable arenaTable;
512 static Random threadRandom;
513
514 enum ACTIONS {
515 arena_create,
516 arena_destroy,
517 arena_action,
518 parallel_algorithm,
519 // TODO:
520 // observer_attach,
521 // observer_detach,
522 // flow_graph,
523 // task_group,
524 // resumable_tasks,
525
526 num_actions
527 };
528
529 void global_actor();
530
531 template <ACTIONS action>
532 struct actor;
533
534 template <>
535 struct actor<arena_create> {
do_itactor536 static void do_it(Random& r) {
537 arenaTable.create(r);
538 }
539 };
540
541 template <>
542 struct actor<arena_destroy> {
do_itactor543 static void do_it(Random& r) {
544 arenaTable.destroy(r);
545 }
546 };
547
548 template <>
549 struct actor<arena_action> {
do_itactor550 static void do_it(Random& r) {
551 static thread_local std::size_t arenaLevel = 0;
552 ArenaTable::ScopedLock lock;
553 auto entry = arenaTable.acquire(r, lock);
554 if (entry.first) {
555 enum arena_actions {
556 arena_execute,
557 arena_enqueue,
558 num_arena_actions
559 };
560 auto process = r.get() % 2;
561 auto body = [process] {
562 if (process) {
563 tbb::detail::d1::wait_context wctx{ 1 };
564 tbb::task_group_context ctx;
565 tbb::this_task_arena::enqueue([&wctx] { wctx.release(); });
566 tbb::detail::d1::wait(wctx, ctx);
567 } else {
568 global_actor();
569 }
570 };
571 switch (r.get() % (16*num_arena_actions)) {
572 case arena_execute:
573 if (entry.second > arenaLevel) {
574 gStats.notify(Statistics::ArenaExecute);
575 auto oldArenaLevel = arenaLevel;
576 arenaLevel = entry.second;
577 entry.first->execute(body);
578 arenaLevel = oldArenaLevel;
579 break;
580 }
581 utils_fallthrough;
582 case arena_enqueue:
583 utils_fallthrough;
584 default:
585 gStats.notify(Statistics::ArenaEnqueue);
586 entry.first->enqueue([] { global_actor(); });
587 break;
588 }
589 arenaTable.release(lock);
590 }
591 }
592 };
593
594 template <>
595 struct actor<parallel_algorithm> {
do_itactor596 static void do_it(Random& rnd) {
597 enum PARTITIONERS {
598 simpl_part,
599 auto_part,
600 aff_part,
601 static_part,
602 num_parts
603 };
604 int sz = rnd.get() % 10000;
605 auto doGlbAction = rnd.get() % 1000 == 42;
606 auto body = [doGlbAction, sz](int i) {
607 if (i == sz / 2 && doGlbAction) {
608 global_actor();
609 }
610 };
611
612 switch (rnd.get() % num_parts) {
613 case simpl_part:
614 tbb::parallel_for(0, sz, body, tbb::simple_partitioner{}); break;
615 case auto_part:
616 tbb::parallel_for(0, sz, body, tbb::auto_partitioner{}); break;
617 case aff_part:
618 {
619 tbb::affinity_partitioner aff;
620 tbb::parallel_for(0, sz, body, aff); break;
621 }
622 case static_part:
623 tbb::parallel_for(0, sz, body, tbb::static_partitioner{}); break;
624 }
625 }
626 };
627
global_actor()628 void global_actor() {
629 static thread_local std::uint64_t localNumActions{};
630
631 while (globalNumActions < maxNumActions) {
632 auto& rnd = threadRandom;
633 switch (rnd.get() % num_actions) {
634 case arena_create: gStats.notify(Statistics::ArenaCreate); actor<arena_create>::do_it(rnd); break;
635 case arena_destroy: gStats.notify(Statistics::ArenaDestroy); actor<arena_destroy>::do_it(rnd); break;
636 case arena_action: gStats.notify(Statistics::ArenaAcquire); actor<arena_action>::do_it(rnd); break;
637 case parallel_algorithm: gStats.notify(Statistics::ParallelAlgorithm); actor<parallel_algorithm>::do_it(rnd); break;
638 }
639
640 if (++localNumActions == 100) {
641 localNumActions = 0;
642 globalNumActions += 100;
643
644 // TODO: Enable statistics
645 // static std::mutex mutex;
646 // std::lock_guard<std::mutex> lock{ mutex };
647 // std::cout << globalNumActions << "\r" << std::flush;
648 }
649 }
650 }
651
652 #if TBB_USE_EXCEPTIONS
653 //! \brief \ref stress
654 TEST_CASE("Stress test with mixing functionality") {
655 // TODO add thread recreation
656 // TODO: Enable statistics
657 // tbb::task_scheduler_handle handle = tbb::task_scheduler_handle::get();
658
659 const std::size_t numExtraThreads = 16;
660 utils::SpinBarrier startBarrier{numExtraThreads};
__anon1f3e715f0902(std::size_t) 661 utils::NativeParallelFor(numExtraThreads, [&startBarrier](std::size_t) {
662 startBarrier.wait();
663 global_actor();
664 });
665
666 arenaTable.shutdown();
667
668 // tbb::finalize(handle, std::nothrow_t{});
669
670 // gStats.report();
671 }
672 #endif
673