1 // Copyright (C) 2018-2020 Sebastian Dröge <sebastian@centricular.com>
2 // Copyright (C) 2019-2020 François Laignel <fengalin@free.fr>
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Library General Public
6 // License as published by the Free Software Foundation; either
7 // version 2 of the License, or (at your option) any later version.
8 //
9 // This library is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 // Library General Public License for more details.
13 //
14 // You should have received a copy of the GNU Library General Public
15 // License along with this library; if not, write to the
16 // Free Software Foundation, Inc., 51 Franklin Street, Suite 500,
17 // Boston, MA 02110-1335, USA.
18 
19 //! The `Executor` for the `threadshare` GStreamer plugins framework.
20 //!
21 //! The [`threadshare`]'s `Executor` consists in a set of [`Context`]s. Each [`Context`] is
22 //! identified by a `name` and runs a loop in a dedicated `thread`. Users can use the [`Context`]
23 //! to spawn `Future`s. `Future`s are asynchronous processings which allow waiting for resources
24 //! in a non-blocking way. Examples of non-blocking operations are:
25 //!
26 //! * Waiting for an incoming packet on a Socket.
27 //! * Waiting for an asynchronous `Mutex` `lock` to succeed.
28 //! * Waiting for a time related `Future`.
29 //!
30 //! `Element` implementations should use [`PadSrc`] & [`PadSink`] which provides high-level features.
31 //!
32 //! [`threadshare`]: ../../index.html
33 //! [`Context`]: struct.Context.html
34 //! [`PadSrc`]: ../pad/struct.PadSrc.html
35 //! [`PadSink`]: ../pad/struct.PadSink.html
36 
37 use futures::channel::oneshot;
38 use futures::future::BoxFuture;
39 use futures::prelude::*;
40 
41 use gst::{gst_debug, gst_log, gst_trace, gst_warning};
42 
43 use once_cell::sync::Lazy;
44 
45 use std::cell::RefCell;
46 use std::collections::{HashMap, VecDeque};
47 use std::fmt;
48 use std::io;
49 use std::mem;
50 use std::pin::Pin;
51 use std::sync::mpsc as sync_mpsc;
52 use std::sync::{Arc, Mutex, Weak};
53 use std::task::Poll;
54 use std::thread;
55 use std::time::Duration;
56 
57 use super::RUNTIME_CAT;
58 
59 // We are bound to using `sync` for the `runtime` `Mutex`es. Attempts to use `async` `Mutex`es
60 // lead to the following issues:
61 //
62 // * `CONTEXTS`: can't `spawn` a `Future` when called from a `Context` thread via `ffi`.
63 // * `timers`: can't automatically `remove` the timer from `BinaryHeap` because `async drop`
64 //    is not available.
65 // * `task_queues`: can't `add` a pending task when called from a `Context` thread via `ffi`.
66 //
67 // Also, we want to be able to `acquire` a `Context` outside of an `async` context.
68 // These `Mutex`es must be `lock`ed for a short period.
69 static CONTEXTS: Lazy<Mutex<HashMap<String, Weak<ContextInner>>>> =
70     Lazy::new(|| Mutex::new(HashMap::new()));
71 
72 thread_local!(static CURRENT_THREAD_CONTEXT: RefCell<Option<ContextWeak>> = RefCell::new(None));
73 
74 tokio::task_local! {
75     static CURRENT_TASK_ID: TaskId;
76 }
77 
78 /// Blocks on `future` in one way or another if possible.
79 ///
80 /// IO & time related `Future`s must be handled within their own [`Context`].
81 /// Wait for the result using a [`JoinHandle`] or a `channel`.
82 ///
83 /// If there's currently an active `Context` with a task, then the future is only queued up as a
84 /// pending sub task for that task.
85 ///
86 /// Otherwise the current thread is blocking and the passed in future is executed.
87 ///
88 /// Note that you must not pass any futures here that wait for the currently active task in one way
89 /// or another as this would deadlock!
block_on_or_add_sub_task<Fut: Future + Send + 'static>(future: Fut) -> Option<Fut::Output>90 pub fn block_on_or_add_sub_task<Fut: Future + Send + 'static>(future: Fut) -> Option<Fut::Output> {
91     if let Some((cur_context, cur_task_id)) = Context::current_task() {
92         gst_debug!(
93             RUNTIME_CAT,
94             "Adding subtask to task {:?} on context {}",
95             cur_task_id,
96             cur_context.name()
97         );
98         let _ = Context::add_sub_task(async move {
99             future.await;
100             Ok(())
101         });
102         return None;
103     }
104 
105     // Not running in a Context thread so we can block
106     Some(block_on(future))
107 }
108 
109 /// Blocks on `future`.
110 ///
111 /// IO & time related `Future`s must be handled within their own [`Context`].
112 /// Wait for the result using a [`JoinHandle`] or a `channel`.
113 ///
114 /// The current thread is blocking and the passed in future is executed.
115 ///
116 /// # Panics
117 ///
118 /// This function panics if called within a [`Context`] thread.
block_on<Fut: Future>(future: Fut) -> Fut::Output119 pub fn block_on<Fut: Future>(future: Fut) -> Fut::Output {
120     assert!(!Context::is_context_thread());
121 
122     // Not running in a Context thread so we can block
123     gst_debug!(RUNTIME_CAT, "Blocking on new dummy context");
124 
125     let context = Context(Arc::new(ContextInner {
126         real: None,
127         task_queues: Mutex::new((0, HashMap::new())),
128     }));
129 
130     CURRENT_THREAD_CONTEXT.with(move |cur_ctx| {
131         *cur_ctx.borrow_mut() = Some(context.downgrade());
132 
133         let res = futures::executor::block_on(async move {
134             CURRENT_TASK_ID
135                 .scope(TaskId(0), async move {
136                     let task_id = CURRENT_TASK_ID.try_with(|task_id| *task_id).ok();
137                     assert_eq!(task_id, Some(TaskId(0)));
138 
139                     let res = future.await;
140 
141                     while Context::current_has_sub_tasks() {
142                         if Context::drain_sub_tasks().await.is_err() {
143                             break;
144                         }
145                     }
146 
147                     res
148                 })
149                 .await
150         });
151 
152         *cur_ctx.borrow_mut() = None;
153 
154         res
155     })
156 }
157 
158 /// Yields execution back to the runtime
159 #[inline]
yield_now()160 pub async fn yield_now() {
161     tokio::task::yield_now().await;
162 }
163 
164 struct ContextThread {
165     name: String,
166 }
167 
168 impl ContextThread {
start(name: &str, wait: Duration) -> Context169     fn start(name: &str, wait: Duration) -> Context {
170         let context_thread = ContextThread { name: name.into() };
171         let (context_sender, context_receiver) = sync_mpsc::channel();
172         let join = thread::spawn(move || {
173             context_thread.spawn(wait, context_sender);
174         });
175 
176         let context = context_receiver.recv().expect("Context thread init failed");
177         *context
178             .0
179             .real
180             .as_ref()
181             .unwrap()
182             .shutdown
183             .join
184             .lock()
185             .unwrap() = Some(join);
186 
187         context
188     }
189 
spawn(&self, wait: Duration, context_sender: sync_mpsc::Sender<Context>)190     fn spawn(&self, wait: Duration, context_sender: sync_mpsc::Sender<Context>) {
191         gst_debug!(RUNTIME_CAT, "Started context thread '{}'", self.name);
192 
193         let mut runtime = tokio::runtime::Builder::new()
194             .basic_scheduler()
195             .thread_name(self.name.clone())
196             .enable_all()
197             .max_throttling(wait)
198             .build()
199             .expect("Couldn't build the runtime");
200 
201         let (shutdown_sender, shutdown_receiver) = oneshot::channel();
202 
203         let shutdown = ContextShutdown {
204             name: self.name.clone(),
205             shutdown: Some(shutdown_sender),
206             join: Mutex::new(None),
207         };
208 
209         let context = Context(Arc::new(ContextInner {
210             real: Some(ContextRealInner {
211                 name: self.name.clone(),
212                 handle: Mutex::new(runtime.handle().clone()),
213                 shutdown,
214             }),
215             task_queues: Mutex::new((0, HashMap::new())),
216         }));
217 
218         CURRENT_THREAD_CONTEXT.with(|cur_ctx| {
219             *cur_ctx.borrow_mut() = Some(context.downgrade());
220         });
221 
222         context_sender.send(context).unwrap();
223 
224         let _ = runtime.block_on(shutdown_receiver);
225     }
226 }
227 
228 impl Drop for ContextThread {
drop(&mut self)229     fn drop(&mut self) {
230         gst_debug!(RUNTIME_CAT, "Terminated: context thread '{}'", self.name);
231     }
232 }
233 
234 #[derive(Debug)]
235 struct ContextShutdown {
236     name: String,
237     shutdown: Option<oneshot::Sender<()>>,
238     join: Mutex<Option<thread::JoinHandle<()>>>,
239 }
240 
241 impl Drop for ContextShutdown {
drop(&mut self)242     fn drop(&mut self) {
243         gst_debug!(
244             RUNTIME_CAT,
245             "Shutting down context thread thread '{}'",
246             self.name
247         );
248         self.shutdown.take().unwrap();
249 
250         gst_trace!(
251             RUNTIME_CAT,
252             "Waiting for context thread '{}' to shutdown",
253             self.name
254         );
255         let join_handle = self.join.lock().unwrap().take().unwrap();
256         let _ = join_handle.join();
257     }
258 }
259 
260 #[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
261 pub struct TaskId(u64);
262 
263 pub type SubTaskOutput = Result<(), gst::FlowError>;
264 pub struct SubTaskQueue(VecDeque<BoxFuture<'static, SubTaskOutput>>);
265 
266 impl fmt::Debug for SubTaskQueue {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result267     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
268         fmt.debug_tuple("SubTaskQueue").finish()
269     }
270 }
271 
272 pub struct JoinError(tokio::task::JoinError);
273 
274 impl fmt::Display for JoinError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result275     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
276         fmt::Display::fmt(&self.0, fmt)
277     }
278 }
279 
280 impl fmt::Debug for JoinError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result281     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
282         fmt::Debug::fmt(&self.0, fmt)
283     }
284 }
285 
286 impl std::error::Error for JoinError {}
287 
288 impl From<tokio::task::JoinError> for JoinError {
from(src: tokio::task::JoinError) -> Self289     fn from(src: tokio::task::JoinError) -> Self {
290         JoinError(src)
291     }
292 }
293 
294 /// Wrapper for the underlying runtime JoinHandle implementation.
295 pub struct JoinHandle<T> {
296     join_handle: tokio::task::JoinHandle<T>,
297     context: ContextWeak,
298     task_id: TaskId,
299 }
300 
301 unsafe impl<T: Send> Send for JoinHandle<T> {}
302 unsafe impl<T: Send> Sync for JoinHandle<T> {}
303 
304 impl<T> JoinHandle<T> {
is_current(&self) -> bool305     pub fn is_current(&self) -> bool {
306         if let Some((context, task_id)) = Context::current_task() {
307             let self_context = self.context.upgrade();
308             self_context.map(|c| c == context).unwrap_or(false) && task_id == self.task_id
309         } else {
310             false
311         }
312     }
313 
context(&self) -> Option<Context>314     pub fn context(&self) -> Option<Context> {
315         self.context.upgrade()
316     }
317 
task_id(&self) -> TaskId318     pub fn task_id(&self) -> TaskId {
319         self.task_id
320     }
321 }
322 
323 impl<T> Unpin for JoinHandle<T> {}
324 
325 impl<T> Future for JoinHandle<T> {
326     type Output = Result<T, JoinError>;
327 
poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output>328     fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
329         if self.as_ref().is_current() {
330             panic!("Trying to join task {:?} from itself", self.as_ref());
331         }
332 
333         self.as_mut()
334             .join_handle
335             .poll_unpin(cx)
336             .map_err(JoinError::from)
337     }
338 }
339 
340 impl<T> fmt::Debug for JoinHandle<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result341     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
342         let context_name = self.context.upgrade().map(|c| String::from(c.name()));
343 
344         fmt.debug_struct("JoinHandle")
345             .field("context", &context_name)
346             .field("task_id", &self.task_id)
347             .finish()
348     }
349 }
350 
351 #[derive(Debug)]
352 struct ContextRealInner {
353     name: String,
354     handle: Mutex<tokio::runtime::Handle>,
355     // Only used for dropping
356     shutdown: ContextShutdown,
357 }
358 
359 #[derive(Debug)]
360 struct ContextInner {
361     // Otherwise a dummy context
362     real: Option<ContextRealInner>,
363     task_queues: Mutex<(u64, HashMap<u64, SubTaskQueue>)>,
364 }
365 
366 impl Drop for ContextInner {
drop(&mut self)367     fn drop(&mut self) {
368         if let Some(ref real) = self.real {
369             let mut contexts = CONTEXTS.lock().unwrap();
370             gst_debug!(RUNTIME_CAT, "Finalizing context '{}'", real.name);
371             contexts.remove(&real.name);
372         }
373     }
374 }
375 
376 #[derive(Clone, Debug)]
377 pub struct ContextWeak(Weak<ContextInner>);
378 
379 impl ContextWeak {
upgrade(&self) -> Option<Context>380     pub fn upgrade(&self) -> Option<Context> {
381         self.0.upgrade().map(Context)
382     }
383 }
384 
385 /// A `threadshare` `runtime` `Context`.
386 ///
387 /// The `Context` provides low-level asynchronous processing features to
388 /// multiplex task execution on a single thread.
389 ///
390 /// `Element` implementations should use [`PadSrc`] and [`PadSink`] which
391 ///  provide high-level features.
392 ///
393 /// See the [module-level documentation](index.html) for more.
394 ///
395 /// [`PadSrc`]: ../pad/struct.PadSrc.html
396 /// [`PadSink`]: ../pad/struct.PadSink.html
397 #[derive(Clone, Debug)]
398 pub struct Context(Arc<ContextInner>);
399 
400 impl PartialEq for Context {
eq(&self, other: &Self) -> bool401     fn eq(&self, other: &Self) -> bool {
402         Arc::ptr_eq(&self.0, &other.0)
403     }
404 }
405 
406 impl Eq for Context {}
407 
408 impl Context {
acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error>409     pub fn acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error> {
410         assert_ne!(context_name, "DUMMY");
411 
412         let mut contexts = CONTEXTS.lock().unwrap();
413 
414         if let Some(inner_weak) = contexts.get(context_name) {
415             if let Some(inner_strong) = inner_weak.upgrade() {
416                 gst_debug!(
417                     RUNTIME_CAT,
418                     "Joining Context '{}'",
419                     inner_strong.real.as_ref().unwrap().name
420                 );
421                 return Ok(Context(inner_strong));
422             }
423         }
424 
425         let context = ContextThread::start(context_name, wait);
426         contexts.insert(context_name.into(), Arc::downgrade(&context.0));
427 
428         gst_debug!(
429             RUNTIME_CAT,
430             "New Context '{}'",
431             context.0.real.as_ref().unwrap().name
432         );
433         Ok(context)
434     }
435 
downgrade(&self) -> ContextWeak436     pub fn downgrade(&self) -> ContextWeak {
437         ContextWeak(Arc::downgrade(&self.0))
438     }
439 
name(&self) -> &str440     pub fn name(&self) -> &str {
441         match self.0.real {
442             Some(ref real) => real.name.as_str(),
443             None => "DUMMY",
444         }
445     }
446 
447     /// Returns `true` if a `Context` is running on current thread.
is_context_thread() -> bool448     pub fn is_context_thread() -> bool {
449         CURRENT_THREAD_CONTEXT.with(|cur_ctx| cur_ctx.borrow().is_some())
450     }
451 
452     /// Returns the `Context` running on current thread, if any.
current() -> Option<Context>453     pub fn current() -> Option<Context> {
454         CURRENT_THREAD_CONTEXT.with(|cur_ctx| {
455             cur_ctx
456                 .borrow()
457                 .as_ref()
458                 .and_then(|ctx_weak| ctx_weak.upgrade())
459         })
460     }
461 
462     /// Returns the `TaskId` running on current thread, if any.
current_task() -> Option<(Context, TaskId)>463     pub fn current_task() -> Option<(Context, TaskId)> {
464         CURRENT_THREAD_CONTEXT.with(|cur_ctx| {
465             cur_ctx
466                 .borrow()
467                 .as_ref()
468                 .and_then(|ctx_weak| ctx_weak.upgrade())
469                 .and_then(|ctx| {
470                     let task_id = CURRENT_TASK_ID.try_with(|task_id| *task_id).ok();
471 
472                     task_id.map(move |task_id| (ctx, task_id))
473                 })
474         })
475     }
476 
enter<F, R>(&self, f: F) -> R where F: FnOnce() -> R,477     pub fn enter<F, R>(&self, f: F) -> R
478     where
479         F: FnOnce() -> R,
480     {
481         let real = match self.0.real {
482             Some(ref real) => real,
483             None => panic!("Can't enter on dummy context"),
484         };
485 
486         real.handle.lock().unwrap().enter(f)
487     }
488 
spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output> where Fut: Future + Send + 'static, Fut::Output: Send + 'static,489     pub fn spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
490     where
491         Fut: Future + Send + 'static,
492         Fut::Output: Send + 'static,
493     {
494         self.spawn_internal(future, false)
495     }
496 
awake_and_spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output> where Fut: Future + Send + 'static, Fut::Output: Send + 'static,497     pub fn awake_and_spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
498     where
499         Fut: Future + Send + 'static,
500         Fut::Output: Send + 'static,
501     {
502         self.spawn_internal(future, true)
503     }
504 
505     #[inline]
spawn_internal<Fut>(&self, future: Fut, must_awake: bool) -> JoinHandle<Fut::Output> where Fut: Future + Send + 'static, Fut::Output: Send + 'static,506     fn spawn_internal<Fut>(&self, future: Fut, must_awake: bool) -> JoinHandle<Fut::Output>
507     where
508         Fut: Future + Send + 'static,
509         Fut::Output: Send + 'static,
510     {
511         let real = match self.0.real {
512             Some(ref real) => real,
513             None => panic!("Can't spawn new tasks on dummy context"),
514         };
515 
516         let mut task_queues = self.0.task_queues.lock().unwrap();
517         let id = task_queues.0;
518         task_queues.0 += 1;
519         task_queues.1.insert(id, SubTaskQueue(VecDeque::new()));
520 
521         let id = TaskId(id);
522         gst_trace!(
523             RUNTIME_CAT,
524             "Spawning new task {:?} on context {}",
525             id,
526             real.name
527         );
528 
529         let spawn_fut = async move {
530             let ctx = Context::current().unwrap();
531             let real = ctx.0.real.as_ref().unwrap();
532 
533             gst_trace!(
534                 RUNTIME_CAT,
535                 "Running task {:?} on context {}",
536                 id,
537                 real.name
538             );
539             let res = CURRENT_TASK_ID.scope(id, future).await;
540 
541             // Remove task from the list
542             {
543                 let mut task_queues = ctx.0.task_queues.lock().unwrap();
544                 if let Some(task_queue) = task_queues.1.remove(&id.0) {
545                     let l = task_queue.0.len();
546                     if l > 0 {
547                         gst_warning!(
548                             RUNTIME_CAT,
549                             "Task {:?} on context {} has {} pending sub tasks",
550                             id,
551                             real.name,
552                             l
553                         );
554                     }
555                 }
556             }
557 
558             gst_trace!(RUNTIME_CAT, "Task {:?} on context {} done", id, real.name);
559 
560             res
561         };
562 
563         let join_handle = {
564             if must_awake {
565                 real.handle.lock().unwrap().awake_and_spawn(spawn_fut)
566             } else {
567                 real.handle.lock().unwrap().spawn(spawn_fut)
568             }
569         };
570 
571         JoinHandle {
572             join_handle,
573             context: self.downgrade(),
574             task_id: id,
575         }
576     }
577 
current_has_sub_tasks() -> bool578     pub fn current_has_sub_tasks() -> bool {
579         let (ctx, task_id) = match Context::current_task() {
580             Some(task) => task,
581             None => {
582                 gst_trace!(RUNTIME_CAT, "No current task");
583                 return false;
584             }
585         };
586 
587         let task_queues = ctx.0.task_queues.lock().unwrap();
588         task_queues
589             .1
590             .get(&task_id.0)
591             .map(|t| !t.0.is_empty())
592             .unwrap_or(false)
593     }
594 
add_sub_task<T>(sub_task: T) -> Result<(), T> where T: Future<Output = SubTaskOutput> + Send + 'static,595     pub fn add_sub_task<T>(sub_task: T) -> Result<(), T>
596     where
597         T: Future<Output = SubTaskOutput> + Send + 'static,
598     {
599         let (ctx, task_id) = match Context::current_task() {
600             Some(task) => task,
601             None => {
602                 gst_trace!(RUNTIME_CAT, "No current task");
603                 return Err(sub_task);
604             }
605         };
606 
607         let mut task_queues = ctx.0.task_queues.lock().unwrap();
608         match task_queues.1.get_mut(&task_id.0) {
609             Some(task_queue) => {
610                 if let Some(ref real) = ctx.0.real {
611                     gst_trace!(
612                         RUNTIME_CAT,
613                         "Adding subtask to {:?} on context {}",
614                         task_id,
615                         real.name
616                     );
617                 } else {
618                     gst_trace!(
619                         RUNTIME_CAT,
620                         "Adding subtask to {:?} on dummy context",
621                         task_id,
622                     );
623                 }
624                 task_queue.0.push_back(sub_task.boxed());
625                 Ok(())
626             }
627             None => {
628                 gst_trace!(RUNTIME_CAT, "Task was removed in the meantime");
629                 Err(sub_task)
630             }
631         }
632     }
633 
drain_sub_tasks() -> SubTaskOutput634     pub async fn drain_sub_tasks() -> SubTaskOutput {
635         let (ctx, task_id) = match Context::current_task() {
636             Some(task) => task,
637             None => return Ok(()),
638         };
639 
640         ctx.drain_sub_tasks_internal(task_id).await
641     }
642 
drain_sub_tasks_internal( &self, id: TaskId, ) -> impl Future<Output = SubTaskOutput> + Send + 'static643     fn drain_sub_tasks_internal(
644         &self,
645         id: TaskId,
646     ) -> impl Future<Output = SubTaskOutput> + Send + 'static {
647         let mut task_queue = {
648             let mut task_queues = self.0.task_queues.lock().unwrap();
649             if let Some(task_queue) = task_queues.1.get_mut(&id.0) {
650                 mem::replace(task_queue, SubTaskQueue(VecDeque::new()))
651             } else {
652                 SubTaskQueue(VecDeque::new())
653             }
654         };
655 
656         let name = self
657             .0
658             .real
659             .as_ref()
660             .map(|r| r.name.clone())
661             .unwrap_or_else(|| String::from("DUMMY"));
662         async move {
663             if !task_queue.0.is_empty() {
664                 gst_log!(
665                     RUNTIME_CAT,
666                     "Scheduling draining {} sub tasks from {:?} on '{}'",
667                     task_queue.0.len(),
668                     id,
669                     &name,
670                 );
671 
672                 for task in task_queue.0.drain(..) {
673                     task.await?;
674                 }
675             }
676 
677             Ok(())
678         }
679     }
680 }
681 
682 #[cfg(test)]
683 mod tests {
684     use futures::channel::mpsc;
685     use futures::lock::Mutex;
686     use futures::prelude::*;
687 
688     use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
689     use std::sync::Arc;
690     use std::time::{Duration, Instant};
691 
692     use super::Context;
693 
694     type Item = i32;
695 
696     const SLEEP_DURATION_MS: u64 = 2;
697     const SLEEP_DURATION: Duration = Duration::from_millis(SLEEP_DURATION_MS);
698     const DELAY: Duration = Duration::from_millis(SLEEP_DURATION_MS * 10);
699 
700     #[tokio::test]
drain_sub_tasks()701     async fn drain_sub_tasks() {
702         // Setup
703         gst::init().unwrap();
704 
705         let context = Context::acquire("drain_sub_tasks", SLEEP_DURATION).unwrap();
706 
707         let join_handle = context.spawn(async move {
708             let (sender, mut receiver) = mpsc::channel(1);
709             let sender: Arc<Mutex<mpsc::Sender<Item>>> = Arc::new(Mutex::new(sender));
710 
711             let add_sub_task = move |item| {
712                 let sender = sender.clone();
713                 Context::add_sub_task(async move {
714                     sender
715                         .lock()
716                         .await
717                         .send(item)
718                         .await
719                         .map_err(|_| gst::FlowError::Error)
720                 })
721             };
722 
723             // Tests
724 
725             // Drain empty queue
726             let drain_fut = Context::drain_sub_tasks();
727             drain_fut.await.unwrap();
728 
729             // Add a subtask
730             add_sub_task(0).map_err(drop).unwrap();
731 
732             // Check that it was not executed yet
733             receiver.try_next().unwrap_err();
734 
735             // Drain it now and check that it was executed
736             let drain_fut = Context::drain_sub_tasks();
737             drain_fut.await.unwrap();
738             assert_eq!(receiver.try_next().unwrap(), Some(0));
739 
740             // Add another task and check that it's not executed yet
741             add_sub_task(1).map_err(drop).unwrap();
742             receiver.try_next().unwrap_err();
743 
744             // Return the receiver
745             receiver
746         });
747 
748         let mut receiver = join_handle.await.unwrap();
749 
750         // The last sub task should be simply dropped at this point
751         assert_eq!(receiver.try_next().unwrap(), None);
752     }
753 
754     #[tokio::test]
block_on_within_tokio()755     async fn block_on_within_tokio() {
756         gst::init().unwrap();
757 
758         let context = Context::acquire("block_on_within_tokio", SLEEP_DURATION).unwrap();
759 
760         let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
761             let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
762             let socket = UdpSocket::bind(saddr).unwrap();
763             let mut socket = tokio::net::UdpSocket::from_std(socket).unwrap();
764             let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
765             socket.send_to(&[0; 10], saddr).await.unwrap()
766         }))
767         .unwrap();
768         assert_eq!(bytes_sent, 10);
769 
770         let elapsed = crate::runtime::executor::block_on(context.spawn(async {
771             let now = Instant::now();
772             crate::runtime::time::delay_for(DELAY).await;
773             now.elapsed()
774         }))
775         .unwrap();
776         // Due to throttling, `Delay` may be fired earlier
777         assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
778     }
779 
780     #[test]
block_on_from_sync()781     fn block_on_from_sync() {
782         gst::init().unwrap();
783 
784         let context = Context::acquire("block_on_from_sync", SLEEP_DURATION).unwrap();
785 
786         let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
787             let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5001);
788             let socket = UdpSocket::bind(saddr).unwrap();
789             let mut socket = tokio::net::UdpSocket::from_std(socket).unwrap();
790             let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
791             socket.send_to(&[0; 10], saddr).await.unwrap()
792         }))
793         .unwrap();
794         assert_eq!(bytes_sent, 10);
795 
796         let elapsed = crate::runtime::executor::block_on(context.spawn(async {
797             let now = Instant::now();
798             crate::runtime::time::delay_for(DELAY).await;
799             now.elapsed()
800         }))
801         .unwrap();
802         // Due to throttling, `Delay` may be fired earlier
803         assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
804     }
805 
806     #[test]
block_on_from_context()807     fn block_on_from_context() {
808         gst::init().unwrap();
809 
810         let context = Context::acquire("block_on_from_context", SLEEP_DURATION).unwrap();
811         let join_handle = context.spawn(async {
812             crate::runtime::executor::block_on(async {
813                 crate::runtime::time::delay_for(DELAY).await;
814             });
815         });
816         // Panic: attempt to `runtime::executor::block_on` within a `Context` thread
817         futures::executor::block_on(join_handle).unwrap_err();
818     }
819 
820     #[tokio::test]
enter_context_from_tokio()821     async fn enter_context_from_tokio() {
822         gst::init().unwrap();
823 
824         let context = Context::acquire("enter_context_from_tokio", SLEEP_DURATION).unwrap();
825         let mut socket = context
826             .enter(|| {
827                 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002);
828                 let socket = UdpSocket::bind(saddr).unwrap();
829                 tokio::net::UdpSocket::from_std(socket)
830             })
831             .unwrap();
832 
833         let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
834         let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap();
835         assert_eq!(bytes_sent, 10);
836 
837         let elapsed = context.enter(|| {
838             futures::executor::block_on(async {
839                 let now = Instant::now();
840                 crate::runtime::time::delay_for(DELAY).await;
841                 now.elapsed()
842             })
843         });
844         // Due to throttling, `Delay` may be fired earlier
845         assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
846     }
847 
848     #[test]
enter_context_from_sync()849     fn enter_context_from_sync() {
850         gst::init().unwrap();
851 
852         let context = Context::acquire("enter_context_from_sync", SLEEP_DURATION).unwrap();
853         let mut socket = context
854             .enter(|| {
855                 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5003);
856                 let socket = UdpSocket::bind(saddr).unwrap();
857                 tokio::net::UdpSocket::from_std(socket)
858             })
859             .unwrap();
860 
861         let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000);
862         let bytes_sent = futures::executor::block_on(socket.send_to(&[0; 10], saddr)).unwrap();
863         assert_eq!(bytes_sent, 10);
864 
865         let elapsed = context.enter(|| {
866             futures::executor::block_on(async {
867                 let now = Instant::now();
868                 crate::runtime::time::delay_for(DELAY).await;
869                 now.elapsed()
870             })
871         });
872         // Due to throttling, `Delay` may be fired earlier
873         assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
874     }
875 }
876