1 /*
2 Copyright (c) 2005-2021 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 // Do not include task.h directly. Use scheduler_common.h instead
18 #include "scheduler_common.h"
19 #include "governor.h"
20 #include "arena.h"
21 #include "thread_data.h"
22 #include "task_dispatcher.h"
23 #include "waiters.h"
24 #include "itt_notify.h"
25
26 #include "oneapi/tbb/detail/_task.h"
27 #include "oneapi/tbb/partitioner.h"
28 #include "oneapi/tbb/task.h"
29
30 #include <cstring>
31
32 namespace tbb {
33 namespace detail {
34 namespace r1 {
35
36 //------------------------------------------------------------------------
37 // resumable tasks
38 //------------------------------------------------------------------------
39 #if __TBB_RESUMABLE_TASKS
40
suspend(suspend_callback_type suspend_callback,void * user_callback)41 void suspend(suspend_callback_type suspend_callback, void* user_callback) {
42 thread_data& td = *governor::get_thread_data();
43 td.my_task_dispatcher->suspend(suspend_callback, user_callback);
44 // Do not access td after suspend.
45 }
46
resume(suspend_point_type * sp)47 void resume(suspend_point_type* sp) {
48 assert_pointers_valid(sp, sp->m_arena);
49 task_dispatcher& task_disp = sp->m_resume_task.m_target;
50 __TBB_ASSERT(task_disp.m_thread_data == nullptr, nullptr);
51
52 // TODO: remove this work-around
53 // Prolong the arena's lifetime while all coroutines are alive
54 // (otherwise the arena can be destroyed while some tasks are suspended).
55 arena& a = *sp->m_arena;
56 a.my_references += arena::ref_external;
57
58 if (task_disp.m_properties.critical_task_allowed) {
59 // The target is not in the process of executing critical task, so the resume task is not critical.
60 a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
61 } else {
62 #if __TBB_PREVIEW_CRITICAL_TASKS
63 // The target is in the process of executing critical task, so the resume task is critical.
64 a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
65 #endif
66 }
67
68 // Do not access target after that point.
69 a.advertise_new_work<arena::wakeup>();
70
71 // Release our reference to my_arena.
72 a.on_thread_leaving<arena::ref_external>();
73 }
74
current_suspend_point()75 suspend_point_type* current_suspend_point() {
76 thread_data& td = *governor::get_thread_data();
77 return td.my_task_dispatcher->get_suspend_point();
78 }
79
create_coroutine(thread_data & td)80 static task_dispatcher& create_coroutine(thread_data& td) {
81 // We may have some task dispatchers cached
82 task_dispatcher* task_disp = td.my_arena->my_co_cache.pop();
83 if (!task_disp) {
84 void* ptr = cache_aligned_allocate(sizeof(task_dispatcher));
85 task_disp = new(ptr) task_dispatcher(td.my_arena);
86 task_disp->init_suspend_point(td.my_arena, td.my_arena->my_market->worker_stack_size());
87 }
88 // Prolong the arena's lifetime until all coroutines is alive
89 // (otherwise the arena can be destroyed while some tasks are suspended).
90 // TODO: consider behavior if there are more than 4K external references.
91 td.my_arena->my_references += arena::ref_external;
92 return *task_disp;
93 }
94
suspend(suspend_callback_type suspend_callback,void * user_callback)95 void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
96 __TBB_ASSERT(suspend_callback != nullptr, nullptr);
97 __TBB_ASSERT(user_callback != nullptr, nullptr);
98 __TBB_ASSERT(m_thread_data != nullptr, nullptr);
99
100 arena_slot* slot = m_thread_data->my_arena_slot;
101 __TBB_ASSERT(slot != nullptr, nullptr);
102
103 task_dispatcher& default_task_disp = slot->default_task_dispatcher();
104 // TODO: simplify the next line, e.g. is_task_dispatcher_recalled( task_dispatcher& )
105 bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire);
106 task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data);
107
108 thread_data::suspend_callback_wrapper callback = { suspend_callback, user_callback, get_suspend_point() };
109 m_thread_data->set_post_resume_action(thread_data::post_resume_action::callback, &callback);
110 resume(target);
111
112 if (m_properties.outermost) {
113 recall_point();
114 }
115 }
116
resume(task_dispatcher & target)117 bool task_dispatcher::resume(task_dispatcher& target) {
118 // Do not create non-trivial objects on the stack of this function. They might never be destroyed
119 {
120 thread_data* td = m_thread_data;
121 __TBB_ASSERT(&target != this, "We cannot resume to ourself");
122 __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
123 __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
124 __TBB_ASSERT(td->my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
125 __TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
126
127 // Change the task dispatcher
128 td->detach_task_dispatcher();
129 td->attach_task_dispatcher(target);
130 }
131 __TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created");
132 __TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created");
133 // Swap to the target coroutine.
134 m_suspend_point->m_co_context.resume(target.m_suspend_point->m_co_context);
135 // Pay attention that m_thread_data can be changed after resume
136 if (m_thread_data) {
137 thread_data* td = m_thread_data;
138 __TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
139 __TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
140 td->do_post_resume_action();
141
142 // Remove the recall flag if the thread in its original task dispatcher
143 arena_slot* slot = td->my_arena_slot;
144 __TBB_ASSERT(slot != nullptr, nullptr);
145 if (this == slot->my_default_task_dispatcher) {
146 __TBB_ASSERT(m_suspend_point != nullptr, nullptr);
147 m_suspend_point->m_is_owner_recalled.store(false, std::memory_order_relaxed);
148 }
149 return true;
150 }
151 return false;
152 }
153
do_post_resume_action()154 void thread_data::do_post_resume_action() {
155 __TBB_ASSERT(my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
156 __TBB_ASSERT(my_post_resume_arg, "The post resume action must have an argument");
157
158 switch (my_post_resume_action) {
159 case post_resume_action::register_waiter:
160 {
161 static_cast<market_concurrent_monitor::resume_context*>(my_post_resume_arg)->notify();
162 break;
163 }
164 case post_resume_action::resume:
165 {
166 r1::resume(static_cast<suspend_point_type*>(my_post_resume_arg));
167 break;
168 }
169 case post_resume_action::callback:
170 {
171 suspend_callback_wrapper callback = *static_cast<suspend_callback_wrapper*>(my_post_resume_arg);
172 callback();
173 break;
174 }
175 case post_resume_action::cleanup:
176 {
177 task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(my_post_resume_arg);
178 // Release coroutine's reference to my_arena
179 my_arena->on_thread_leaving<arena::ref_external>();
180 // Cache the coroutine for possible later re-usage
181 my_arena->my_co_cache.push(to_cleanup);
182 break;
183 }
184 case post_resume_action::notify:
185 {
186 suspend_point_type* sp = static_cast<suspend_point_type*>(my_post_resume_arg);
187 sp->m_is_owner_recalled.store(true, std::memory_order_release);
188 // Do not access sp because it can be destroyed after the store
189
190 auto is_our_suspend_point = [sp](market_context ctx) {
191 return std::uintptr_t(sp) == ctx.my_uniq_addr;
192 };
193 my_arena->my_market->get_wait_list().notify(is_our_suspend_point);
194 break;
195 }
196 default:
197 __TBB_ASSERT(false, "Unknown post resume action");
198 }
199
200 my_post_resume_action = post_resume_action::none;
201 my_post_resume_arg = nullptr;
202 }
203
204 #else
205
206 void suspend(suspend_callback_type, void*) {
207 __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
208 }
209
210 void resume(suspend_point_type*) {
211 __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
212 }
213
214 suspend_point_type* current_suspend_point() {
215 __TBB_ASSERT_RELEASE(false, "Resumable tasks are unsupported on this platform");
216 return nullptr;
217 }
218
219 #endif /* __TBB_RESUMABLE_TASKS */
220
notify_waiters(std::uintptr_t wait_ctx_addr)221 void notify_waiters(std::uintptr_t wait_ctx_addr) {
222 auto is_related_wait_ctx = [&] (market_context context) {
223 return wait_ctx_addr == context.my_uniq_addr;
224 };
225
226 r1::governor::get_thread_data()->my_arena->my_market->get_wait_list().notify(is_related_wait_ctx);
227 }
228
229 } // namespace r1
230 } // namespace detail
231 } // namespace tbb
232
233