1 // Copyright (C) 2018-2020 Sebastian Dröge <sebastian@centricular.com>
2 // Copyright (C) 2019-2020 François Laignel <fengalin@free.fr>
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Library General Public
6 // License as published by the Free Software Foundation; either
7 // version 2 of the License, or (at your option) any later version.
8 //
9 // This library is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 // Library General Public License for more details.
13 //
14 // You should have received a copy of the GNU Library General Public
15 // License along with this library; if not, write to the
16 // Free Software Foundation, Inc., 51 Franklin Street, Suite 500,
17 // Boston, MA 02110-1335, USA.
18
19 //! The `Executor` for the `threadshare` GStreamer plugins framework.
20 //!
21 //! The [`threadshare`]'s `Executor` consists in a set of [`Context`]s. Each [`Context`] is
22 //! identified by a `name` and runs a loop in a dedicated `thread`. Users can use the [`Context`]
23 //! to spawn `Future`s. `Future`s are asynchronous processings which allow waiting for resources
24 //! in a non-blocking way. Examples of non-blocking operations are:
25 //!
26 //! * Waiting for an incoming packet on a Socket.
27 //! * Waiting for an asynchronous `Mutex` `lock` to succeed.
28 //! * Waiting for a time related `Future`.
29 //!
30 //! `Element` implementations should use [`PadSrc`] & [`PadSink`] which provides high-level features.
31 //!
32 //! [`threadshare`]: ../../index.html
33 //! [`Context`]: struct.Context.html
34 //! [`PadSrc`]: ../pad/struct.PadSrc.html
35 //! [`PadSink`]: ../pad/struct.PadSink.html
36
37 use futures::channel::oneshot;
38 use futures::future::BoxFuture;
39 use futures::prelude::*;
40
41 use gst::{gst_debug, gst_log, gst_trace, gst_warning};
42
43 use once_cell::sync::Lazy;
44
45 use std::cell::RefCell;
46 use std::collections::{HashMap, VecDeque};
47 use std::fmt;
48 use std::io;
49 use std::mem;
50 use std::pin::Pin;
51 use std::sync::mpsc as sync_mpsc;
52 use std::sync::{Arc, Mutex, Weak};
53 use std::task::Poll;
54 use std::thread;
55 use std::time::Duration;
56
57 use super::RUNTIME_CAT;
58
59 // We are bound to using `sync` for the `runtime` `Mutex`es. Attempts to use `async` `Mutex`es
60 // lead to the following issues:
61 //
62 // * `CONTEXTS`: can't `spawn` a `Future` when called from a `Context` thread via `ffi`.
63 // * `timers`: can't automatically `remove` the timer from `BinaryHeap` because `async drop`
64 // is not available.
65 // * `task_queues`: can't `add` a pending task when called from a `Context` thread via `ffi`.
66 //
67 // Also, we want to be able to `acquire` a `Context` outside of an `async` context.
68 // These `Mutex`es must be `lock`ed for a short period.
69 static CONTEXTS: Lazy<Mutex<HashMap<String, Weak<ContextInner>>>> =
70 Lazy::new(|| Mutex::new(HashMap::new()));
71
72 thread_local!(static CURRENT_THREAD_CONTEXT: RefCell<Option<ContextWeak>> = RefCell::new(None));
73
74 tokio::task_local! {
75 static CURRENT_TASK_ID: TaskId;
76 }
77
78 /// Blocks on `future` in one way or another if possible.
79 ///
80 /// IO & time related `Future`s must be handled within their own [`Context`].
81 /// Wait for the result using a [`JoinHandle`] or a `channel`.
82 ///
83 /// If there's currently an active `Context` with a task, then the future is only queued up as a
84 /// pending sub task for that task.
85 ///
86 /// Otherwise the current thread is blocking and the passed in future is executed.
87 ///
88 /// Note that you must not pass any futures here that wait for the currently active task in one way
89 /// or another as this would deadlock!
block_on_or_add_sub_task<Fut: Future + Send + 'static>(future: Fut) -> Option<Fut::Output>90 pub fn block_on_or_add_sub_task<Fut: Future + Send + 'static>(future: Fut) -> Option<Fut::Output> {
91 if let Some((cur_context, cur_task_id)) = Context::current_task() {
92 gst_debug!(
93 RUNTIME_CAT,
94 "Adding subtask to task {:?} on context {}",
95 cur_task_id,
96 cur_context.name()
97 );
98 let _ = Context::add_sub_task(async move {
99 future.await;
100 Ok(())
101 });
102 return None;
103 }
104
105 // Not running in a Context thread so we can block
106 Some(block_on(future))
107 }
108
109 /// Blocks on `future`.
110 ///
111 /// IO & time related `Future`s must be handled within their own [`Context`].
112 /// Wait for the result using a [`JoinHandle`] or a `channel`.
113 ///
114 /// The current thread is blocking and the passed in future is executed.
115 ///
116 /// # Panics
117 ///
118 /// This function panics if called within a [`Context`] thread.
block_on<Fut: Future>(future: Fut) -> Fut::Output119 pub fn block_on<Fut: Future>(future: Fut) -> Fut::Output {
120 assert!(!Context::is_context_thread());
121
122 // Not running in a Context thread so we can block
123 gst_debug!(RUNTIME_CAT, "Blocking on new dummy context");
124
125 let context = Context(Arc::new(ContextInner {
126 real: None,
127 task_queues: Mutex::new((0, HashMap::new())),
128 }));
129
130 CURRENT_THREAD_CONTEXT.with(move |cur_ctx| {
131 *cur_ctx.borrow_mut() = Some(context.downgrade());
132
133 let res = futures::executor::block_on(async move {
134 CURRENT_TASK_ID
135 .scope(TaskId(0), async move {
136 let task_id = CURRENT_TASK_ID.try_with(|task_id| *task_id).ok();
137 assert_eq!(task_id, Some(TaskId(0)));
138
139 let res = future.await;
140
141 while Context::current_has_sub_tasks() {
142 if Context::drain_sub_tasks().await.is_err() {
143 break;
144 }
145 }
146
147 res
148 })
149 .await
150 });
151
152 *cur_ctx.borrow_mut() = None;
153
154 res
155 })
156 }
157
158 /// Yields execution back to the runtime
159 #[inline]
yield_now()160 pub async fn yield_now() {
161 tokio::task::yield_now().await;
162 }
163
164 struct ContextThread {
165 name: String,
166 }
167
168 impl ContextThread {
start(name: &str, wait: Duration) -> Context169 fn start(name: &str, wait: Duration) -> Context {
170 let context_thread = ContextThread { name: name.into() };
171 let (context_sender, context_receiver) = sync_mpsc::channel();
172 let join = thread::spawn(move || {
173 context_thread.spawn(wait, context_sender);
174 });
175
176 let context = context_receiver.recv().expect("Context thread init failed");
177 *context
178 .0
179 .real
180 .as_ref()
181 .unwrap()
182 .shutdown
183 .join
184 .lock()
185 .unwrap() = Some(join);
186
187 context
188 }
189
spawn(&self, wait: Duration, context_sender: sync_mpsc::Sender<Context>)190 fn spawn(&self, wait: Duration, context_sender: sync_mpsc::Sender<Context>) {
191 gst_debug!(RUNTIME_CAT, "Started context thread '{}'", self.name);
192
193 let mut runtime = tokio::runtime::Builder::new()
194 .basic_scheduler()
195 .thread_name(self.name.clone())
196 .enable_all()
197 .max_throttling(wait)
198 .build()
199 .expect("Couldn't build the runtime");
200
201 let (shutdown_sender, shutdown_receiver) = oneshot::channel();
202
203 let shutdown = ContextShutdown {
204 name: self.name.clone(),
205 shutdown: Some(shutdown_sender),
206 join: Mutex::new(None),
207 };
208
209 let context = Context(Arc::new(ContextInner {
210 real: Some(ContextRealInner {
211 name: self.name.clone(),
212 handle: Mutex::new(runtime.handle().clone()),
213 shutdown,
214 }),
215 task_queues: Mutex::new((0, HashMap::new())),
216 }));
217
218 CURRENT_THREAD_CONTEXT.with(|cur_ctx| {
219 *cur_ctx.borrow_mut() = Some(context.downgrade());
220 });
221
222 context_sender.send(context).unwrap();
223
224 let _ = runtime.block_on(shutdown_receiver);
225 }
226 }
227
228 impl Drop for ContextThread {
drop(&mut self)229 fn drop(&mut self) {
230 gst_debug!(RUNTIME_CAT, "Terminated: context thread '{}'", self.name);
231 }
232 }
233
234 #[derive(Debug)]
235 struct ContextShutdown {
236 name: String,
237 shutdown: Option<oneshot::Sender<()>>,
238 join: Mutex<Option<thread::JoinHandle<()>>>,
239 }
240
241 impl Drop for ContextShutdown {
drop(&mut self)242 fn drop(&mut self) {
243 gst_debug!(
244 RUNTIME_CAT,
245 "Shutting down context thread thread '{}'",
246 self.name
247 );
248 self.shutdown.take().unwrap();
249
250 gst_trace!(
251 RUNTIME_CAT,
252 "Waiting for context thread '{}' to shutdown",
253 self.name
254 );
255 let join_handle = self.join.lock().unwrap().take().unwrap();
256 let _ = join_handle.join();
257 }
258 }
259
260 #[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
261 pub struct TaskId(u64);
262
263 pub type SubTaskOutput = Result<(), gst::FlowError>;
264 pub struct SubTaskQueue(VecDeque<BoxFuture<'static, SubTaskOutput>>);
265
266 impl fmt::Debug for SubTaskQueue {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result267 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
268 fmt.debug_tuple("SubTaskQueue").finish()
269 }
270 }
271
272 pub struct JoinError(tokio::task::JoinError);
273
274 impl fmt::Display for JoinError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result275 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
276 fmt::Display::fmt(&self.0, fmt)
277 }
278 }
279
280 impl fmt::Debug for JoinError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result281 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
282 fmt::Debug::fmt(&self.0, fmt)
283 }
284 }
285
286 impl std::error::Error for JoinError {}
287
288 impl From<tokio::task::JoinError> for JoinError {
from(src: tokio::task::JoinError) -> Self289 fn from(src: tokio::task::JoinError) -> Self {
290 JoinError(src)
291 }
292 }
293
294 /// Wrapper for the underlying runtime JoinHandle implementation.
295 pub struct JoinHandle<T> {
296 join_handle: tokio::task::JoinHandle<T>,
297 context: ContextWeak,
298 task_id: TaskId,
299 }
300
301 unsafe impl<T: Send> Send for JoinHandle<T> {}
302 unsafe impl<T: Send> Sync for JoinHandle<T> {}
303
304 impl<T> JoinHandle<T> {
is_current(&self) -> bool305 pub fn is_current(&self) -> bool {
306 if let Some((context, task_id)) = Context::current_task() {
307 let self_context = self.context.upgrade();
308 self_context.map(|c| c == context).unwrap_or(false) && task_id == self.task_id
309 } else {
310 false
311 }
312 }
313
context(&self) -> Option<Context>314 pub fn context(&self) -> Option<Context> {
315 self.context.upgrade()
316 }
317
task_id(&self) -> TaskId318 pub fn task_id(&self) -> TaskId {
319 self.task_id
320 }
321 }
322
323 impl<T> Unpin for JoinHandle<T> {}
324
325 impl<T> Future for JoinHandle<T> {
326 type Output = Result<T, JoinError>;
327
poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output>328 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
329 if self.as_ref().is_current() {
330 panic!("Trying to join task {:?} from itself", self.as_ref());
331 }
332
333 self.as_mut()
334 .join_handle
335 .poll_unpin(cx)
336 .map_err(JoinError::from)
337 }
338 }
339
340 impl<T> fmt::Debug for JoinHandle<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result341 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
342 let context_name = self.context.upgrade().map(|c| String::from(c.name()));
343
344 fmt.debug_struct("JoinHandle")
345 .field("context", &context_name)
346 .field("task_id", &self.task_id)
347 .finish()
348 }
349 }
350
351 #[derive(Debug)]
352 struct ContextRealInner {
353 name: String,
354 handle: Mutex<tokio::runtime::Handle>,
355 // Only used for dropping
356 shutdown: ContextShutdown,
357 }
358
359 #[derive(Debug)]
360 struct ContextInner {
361 // Otherwise a dummy context
362 real: Option<ContextRealInner>,
363 task_queues: Mutex<(u64, HashMap<u64, SubTaskQueue>)>,
364 }
365
366 impl Drop for ContextInner {
drop(&mut self)367 fn drop(&mut self) {
368 if let Some(ref real) = self.real {
369 let mut contexts = CONTEXTS.lock().unwrap();
370 gst_debug!(RUNTIME_CAT, "Finalizing context '{}'", real.name);
371 contexts.remove(&real.name);
372 }
373 }
374 }
375
376 #[derive(Clone, Debug)]
377 pub struct ContextWeak(Weak<ContextInner>);
378
379 impl ContextWeak {
upgrade(&self) -> Option<Context>380 pub fn upgrade(&self) -> Option<Context> {
381 self.0.upgrade().map(Context)
382 }
383 }
384
385 /// A `threadshare` `runtime` `Context`.
386 ///
387 /// The `Context` provides low-level asynchronous processing features to
388 /// multiplex task execution on a single thread.
389 ///
390 /// `Element` implementations should use [`PadSrc`] and [`PadSink`] which
391 /// provide high-level features.
392 ///
393 /// See the [module-level documentation](index.html) for more.
394 ///
395 /// [`PadSrc`]: ../pad/struct.PadSrc.html
396 /// [`PadSink`]: ../pad/struct.PadSink.html
397 #[derive(Clone, Debug)]
398 pub struct Context(Arc<ContextInner>);
399
400 impl PartialEq for Context {
eq(&self, other: &Self) -> bool401 fn eq(&self, other: &Self) -> bool {
402 Arc::ptr_eq(&self.0, &other.0)
403 }
404 }
405
406 impl Eq for Context {}
407
408 impl Context {
acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error>409 pub fn acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error> {
410 assert_ne!(context_name, "DUMMY");
411
412 let mut contexts = CONTEXTS.lock().unwrap();
413
414 if let Some(inner_weak) = contexts.get(context_name) {
415 if let Some(inner_strong) = inner_weak.upgrade() {
416 gst_debug!(
417 RUNTIME_CAT,
418 "Joining Context '{}'",
419 inner_strong.real.as_ref().unwrap().name
420 );
421 return Ok(Context(inner_strong));
422 }
423 }
424
425 let context = ContextThread::start(context_name, wait);
426 contexts.insert(context_name.into(), Arc::downgrade(&context.0));
427
428 gst_debug!(
429 RUNTIME_CAT,
430 "New Context '{}'",
431 context.0.real.as_ref().unwrap().name
432 );
433 Ok(context)
434 }
435
downgrade(&self) -> ContextWeak436 pub fn downgrade(&self) -> ContextWeak {
437 ContextWeak(Arc::downgrade(&self.0))
438 }
439
name(&self) -> &str440 pub fn name(&self) -> &str {
441 match self.0.real {
442 Some(ref real) => real.name.as_str(),
443 None => "DUMMY",
444 }
445 }
446
447 /// Returns `true` if a `Context` is running on current thread.
is_context_thread() -> bool448 pub fn is_context_thread() -> bool {
449 CURRENT_THREAD_CONTEXT.with(|cur_ctx| cur_ctx.borrow().is_some())
450 }
451
452 /// Returns the `Context` running on current thread, if any.
current() -> Option<Context>453 pub fn current() -> Option<Context> {
454 CURRENT_THREAD_CONTEXT.with(|cur_ctx| {
455 cur_ctx
456 .borrow()
457 .as_ref()
458 .and_then(|ctx_weak| ctx_weak.upgrade())
459 })
460 }
461
462 /// Returns the `TaskId` running on current thread, if any.
current_task() -> Option<(Context, TaskId)>463 pub fn current_task() -> Option<(Context, TaskId)> {
464 CURRENT_THREAD_CONTEXT.with(|cur_ctx| {
465 cur_ctx
466 .borrow()
467 .as_ref()
468 .and_then(|ctx_weak| ctx_weak.upgrade())
469 .and_then(|ctx| {
470 let task_id = CURRENT_TASK_ID.try_with(|task_id| *task_id).ok();
471
472 task_id.map(move |task_id| (ctx, task_id))
473 })
474 })
475 }
476
enter<F, R>(&self, f: F) -> R where F: FnOnce() -> R,477 pub fn enter<F, R>(&self, f: F) -> R
478 where
479 F: FnOnce() -> R,
480 {
481 let real = match self.0.real {
482 Some(ref real) => real,
483 None => panic!("Can't enter on dummy context"),
484 };
485
486 real.handle.lock().unwrap().enter(f)
487 }
488
spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output> where Fut: Future + Send + 'static, Fut::Output: Send + 'static,489 pub fn spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
490 where
491 Fut: Future + Send + 'static,
492 Fut::Output: Send + 'static,
493 {
494 self.spawn_internal(future, false)
495 }
496
awake_and_spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output> where Fut: Future + Send + 'static, Fut::Output: Send + 'static,497 pub fn awake_and_spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
498 where
499 Fut: Future + Send + 'static,
500 Fut::Output: Send + 'static,
501 {
502 self.spawn_internal(future, true)
503 }
504
505 #[inline]
spawn_internal<Fut>(&self, future: Fut, must_awake: bool) -> JoinHandle<Fut::Output> where Fut: Future + Send + 'static, Fut::Output: Send + 'static,506 fn spawn_internal<Fut>(&self, future: Fut, must_awake: bool) -> JoinHandle<Fut::Output>
507 where
508 Fut: Future + Send + 'static,
509 Fut::Output: Send + 'static,
510 {
511 let real = match self.0.real {
512 Some(ref real) => real,
513 None => panic!("Can't spawn new tasks on dummy context"),
514 };
515
516 let mut task_queues = self.0.task_queues.lock().unwrap();
517 let id = task_queues.0;
518 task_queues.0 += 1;
519 task_queues.1.insert(id, SubTaskQueue(VecDeque::new()));
520
521 let id = TaskId(id);
522 gst_trace!(
523 RUNTIME_CAT,
524 "Spawning new task {:?} on context {}",
525 id,
526 real.name
527 );
528
529 let spawn_fut = async move {
530 let ctx = Context::current().unwrap();
531 let real = ctx.0.real.as_ref().unwrap();
532
533 gst_trace!(
534 RUNTIME_CAT,
535 "Running task {:?} on context {}",
536 id,
537 real.name
538 );
539 let res = CURRENT_TASK_ID.scope(id, future).await;
540
541 // Remove task from the list
542 {
543 let mut task_queues = ctx.0.task_queues.lock().unwrap();
544 if let Some(task_queue) = task_queues.1.remove(&id.0) {
545 let l = task_queue.0.len();
546 if l > 0 {
547 gst_warning!(
548 RUNTIME_CAT,
549 "Task {:?} on context {} has {} pending sub tasks",
550 id,
551 real.name,
552 l
553 );
554 }
555 }
556 }
557
558 gst_trace!(RUNTIME_CAT, "Task {:?} on context {} done", id, real.name);
559
560 res
561 };
562
563 let join_handle = {
564 if must_awake {
565 real.handle.lock().unwrap().awake_and_spawn(spawn_fut)
566 } else {
567 real.handle.lock().unwrap().spawn(spawn_fut)
568 }
569 };
570
571 JoinHandle {
572 join_handle,
573 context: self.downgrade(),
574 task_id: id,
575 }
576 }
577
current_has_sub_tasks() -> bool578 pub fn current_has_sub_tasks() -> bool {
579 let (ctx, task_id) = match Context::current_task() {
580 Some(task) => task,
581 None => {
582 gst_trace!(RUNTIME_CAT, "No current task");
583 return false;
584 }
585 };
586
587 let task_queues = ctx.0.task_queues.lock().unwrap();
588 task_queues
589 .1
590 .get(&task_id.0)
591 .map(|t| !t.0.is_empty())
592 .unwrap_or(false)
593 }
594
add_sub_task<T>(sub_task: T) -> Result<(), T> where T: Future<Output = SubTaskOutput> + Send + 'static,595 pub fn add_sub_task<T>(sub_task: T) -> Result<(), T>
596 where
597 T: Future<Output = SubTaskOutput> + Send + 'static,
598 {
599 let (ctx, task_id) = match Context::current_task() {
600 Some(task) => task,
601 None => {
602 gst_trace!(RUNTIME_CAT, "No current task");
603 return Err(sub_task);
604 }
605 };
606
607 let mut task_queues = ctx.0.task_queues.lock().unwrap();
608 match task_queues.1.get_mut(&task_id.0) {
609 Some(task_queue) => {
610 if let Some(ref real) = ctx.0.real {
611 gst_trace!(
612 RUNTIME_CAT,
613 "Adding subtask to {:?} on context {}",
614 task_id,
615 real.name
616 );
617 } else {
618 gst_trace!(
619 RUNTIME_CAT,
620 "Adding subtask to {:?} on dummy context",
621 task_id,
622 );
623 }
624 task_queue.0.push_back(sub_task.boxed());
625 Ok(())
626 }
627 None => {
628 gst_trace!(RUNTIME_CAT, "Task was removed in the meantime");
629 Err(sub_task)
630 }
631 }
632 }
633
drain_sub_tasks() -> SubTaskOutput634 pub async fn drain_sub_tasks() -> SubTaskOutput {
635 let (ctx, task_id) = match Context::current_task() {
636 Some(task) => task,
637 None => return Ok(()),
638 };
639
640 ctx.drain_sub_tasks_internal(task_id).await
641 }
642
drain_sub_tasks_internal( &self, id: TaskId, ) -> impl Future<Output = SubTaskOutput> + Send + 'static643 fn drain_sub_tasks_internal(
644 &self,
645 id: TaskId,
646 ) -> impl Future<Output = SubTaskOutput> + Send + 'static {
647 let mut task_queue = {
648 let mut task_queues = self.0.task_queues.lock().unwrap();
649 if let Some(task_queue) = task_queues.1.get_mut(&id.0) {
650 mem::replace(task_queue, SubTaskQueue(VecDeque::new()))
651 } else {
652 SubTaskQueue(VecDeque::new())
653 }
654 };
655
656 let name = self
657 .0
658 .real
659 .as_ref()
660 .map(|r| r.name.clone())
661 .unwrap_or_else(|| String::from("DUMMY"));
662 async move {
663 if !task_queue.0.is_empty() {
664 gst_log!(
665 RUNTIME_CAT,
666 "Scheduling draining {} sub tasks from {:?} on '{}'",
667 task_queue.0.len(),
668 id,
669 &name,
670 );
671
672 for task in task_queue.0.drain(..) {
673 task.await?;
674 }
675 }
676
677 Ok(())
678 }
679 }
680 }
681
682 #[cfg(test)]
683 mod tests {
684 use futures::channel::mpsc;
685 use futures::lock::Mutex;
686 use futures::prelude::*;
687
688 use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
689 use std::sync::Arc;
690 use std::time::{Duration, Instant};
691
692 use super::Context;
693
694 type Item = i32;
695
696 const SLEEP_DURATION_MS: u64 = 2;
697 const SLEEP_DURATION: Duration = Duration::from_millis(SLEEP_DURATION_MS);
698 const DELAY: Duration = Duration::from_millis(SLEEP_DURATION_MS * 10);
699
700 #[tokio::test]
drain_sub_tasks()701 async fn drain_sub_tasks() {
702 // Setup
703 gst::init().unwrap();
704
705 let context = Context::acquire("drain_sub_tasks", SLEEP_DURATION).unwrap();
706
707 let join_handle = context.spawn(async move {
708 let (sender, mut receiver) = mpsc::channel(1);
709 let sender: Arc<Mutex<mpsc::Sender<Item>>> = Arc::new(Mutex::new(sender));
710
711 let add_sub_task = move |item| {
712 let sender = sender.clone();
713 Context::add_sub_task(async move {
714 sender
715 .lock()
716 .await
717 .send(item)
718 .await
719 .map_err(|_| gst::FlowError::Error)
720 })
721 };
722
723 // Tests
724
725 // Drain empty queue
726 let drain_fut = Context::drain_sub_tasks();
727 drain_fut.await.unwrap();
728
729 // Add a subtask
730 add_sub_task(0).map_err(drop).unwrap();
731
732 // Check that it was not executed yet
733 receiver.try_next().unwrap_err();
734
735 // Drain it now and check that it was executed
736 let drain_fut = Context::drain_sub_tasks();
737 drain_fut.await.unwrap();
738 assert_eq!(receiver.try_next().unwrap(), Some(0));
739
740 // Add another task and check that it's not executed yet
741 add_sub_task(1).map_err(drop).unwrap();
742 receiver.try_next().unwrap_err();
743
744 // Return the receiver
745 receiver
746 });
747
748 let mut receiver = join_handle.await.unwrap();
749
750 // The last sub task should be simply dropped at this point
751 assert_eq!(receiver.try_next().unwrap(), None);
752 }
753
754 #[tokio::test]
block_on_within_tokio()755 async fn block_on_within_tokio() {
756 gst::init().unwrap();
757
758 let context = Context::acquire("block_on_within_tokio", SLEEP_DURATION).unwrap();
759
760 let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
761 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
762 let socket = UdpSocket::bind(saddr).unwrap();
763 let mut socket = tokio::net::UdpSocket::from_std(socket).unwrap();
764 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
765 socket.send_to(&[0; 10], saddr).await.unwrap()
766 }))
767 .unwrap();
768 assert_eq!(bytes_sent, 10);
769
770 let elapsed = crate::runtime::executor::block_on(context.spawn(async {
771 let now = Instant::now();
772 crate::runtime::time::delay_for(DELAY).await;
773 now.elapsed()
774 }))
775 .unwrap();
776 // Due to throttling, `Delay` may be fired earlier
777 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
778 }
779
780 #[test]
block_on_from_sync()781 fn block_on_from_sync() {
782 gst::init().unwrap();
783
784 let context = Context::acquire("block_on_from_sync", SLEEP_DURATION).unwrap();
785
786 let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
787 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5001);
788 let socket = UdpSocket::bind(saddr).unwrap();
789 let mut socket = tokio::net::UdpSocket::from_std(socket).unwrap();
790 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
791 socket.send_to(&[0; 10], saddr).await.unwrap()
792 }))
793 .unwrap();
794 assert_eq!(bytes_sent, 10);
795
796 let elapsed = crate::runtime::executor::block_on(context.spawn(async {
797 let now = Instant::now();
798 crate::runtime::time::delay_for(DELAY).await;
799 now.elapsed()
800 }))
801 .unwrap();
802 // Due to throttling, `Delay` may be fired earlier
803 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
804 }
805
806 #[test]
block_on_from_context()807 fn block_on_from_context() {
808 gst::init().unwrap();
809
810 let context = Context::acquire("block_on_from_context", SLEEP_DURATION).unwrap();
811 let join_handle = context.spawn(async {
812 crate::runtime::executor::block_on(async {
813 crate::runtime::time::delay_for(DELAY).await;
814 });
815 });
816 // Panic: attempt to `runtime::executor::block_on` within a `Context` thread
817 futures::executor::block_on(join_handle).unwrap_err();
818 }
819
820 #[tokio::test]
enter_context_from_tokio()821 async fn enter_context_from_tokio() {
822 gst::init().unwrap();
823
824 let context = Context::acquire("enter_context_from_tokio", SLEEP_DURATION).unwrap();
825 let mut socket = context
826 .enter(|| {
827 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002);
828 let socket = UdpSocket::bind(saddr).unwrap();
829 tokio::net::UdpSocket::from_std(socket)
830 })
831 .unwrap();
832
833 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
834 let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap();
835 assert_eq!(bytes_sent, 10);
836
837 let elapsed = context.enter(|| {
838 futures::executor::block_on(async {
839 let now = Instant::now();
840 crate::runtime::time::delay_for(DELAY).await;
841 now.elapsed()
842 })
843 });
844 // Due to throttling, `Delay` may be fired earlier
845 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
846 }
847
848 #[test]
enter_context_from_sync()849 fn enter_context_from_sync() {
850 gst::init().unwrap();
851
852 let context = Context::acquire("enter_context_from_sync", SLEEP_DURATION).unwrap();
853 let mut socket = context
854 .enter(|| {
855 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5003);
856 let socket = UdpSocket::bind(saddr).unwrap();
857 tokio::net::UdpSocket::from_std(socket)
858 })
859 .unwrap();
860
861 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
862 let bytes_sent = futures::executor::block_on(socket.send_to(&[0; 10], saddr)).unwrap();
863 assert_eq!(bytes_sent, 10);
864
865 let elapsed = context.enter(|| {
866 futures::executor::block_on(async {
867 let now = Instant::now();
868 crate::runtime::time::delay_for(DELAY).await;
869 now.elapsed()
870 })
871 });
872 // Due to throttling, `Delay` may be fired earlier
873 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
874 }
875 }
876