1 use std::future::Future;
2 use std::sync::{
3     atomic::{fence, AtomicUsize, Ordering},
4     Arc,
5 };
6 use tokio_02::sync::mpsc;
7 
8 #[derive(Debug)]
9 pub(super) struct Rx {
10     rx: mpsc::UnboundedReceiver<()>,
11     spawned: Arc<AtomicUsize>,
12 }
13 
14 /// Tracks the number of tasks spawned on a runtime.
15 ///
16 /// This is required to implement `shutdown_on_idle` and `tokio::run` APIs that
17 /// exist in `tokio` 0.1, as the `tokio` 0.2 threadpool does not expose a
18 /// `shutdown_on_idle` API.
19 #[derive(Clone, Debug)]
20 pub(super) struct Idle {
21     tx: mpsc::UnboundedSender<()>,
22     spawned: Arc<AtomicUsize>,
23 }
24 
25 /// Wraps a future to decrement the spawned count when it completes.
26 ///
27 /// This is obtained from `Idle::reserve`.
28 pub(super) struct Track(Idle);
29 
30 impl Idle {
new() -> (Self, Rx)31     pub(super) fn new() -> (Self, Rx) {
32         let (tx, rx) = mpsc::unbounded_channel();
33         let this = Self {
34             tx,
35             spawned: Arc::new(AtomicUsize::new(0)),
36         };
37         let rx = Rx {
38             rx,
39             spawned: this.spawned.clone(),
40         };
41         (this, rx)
42     }
43 
44     /// Prepare to spawn a task on the runtime, incrementing the spawned count.
reserve(&self) -> Track45     pub(super) fn reserve(&self) -> Track {
46         self.spawned.fetch_add(1, Ordering::Relaxed);
47         Track(self.clone())
48     }
49 }
50 
51 impl Rx {
idle(&mut self)52     pub(super) async fn idle(&mut self) {
53         while self.spawned.load(Ordering::Acquire) != 0 {
54             // Wait to be woken up again.
55             let _ = self.rx.recv().await;
56         }
57     }
58 }
59 
60 impl Track {
61     /// Run a task, decrementing the spawn count when it completes.
62     ///
63     /// If the spawned count is now 0, this sends a notification on the idle channel.
with<T>(self, f: impl Future<Output = T>) -> T64     pub(super) async fn with<T>(self, f: impl Future<Output = T>) -> T {
65         let result = f.await;
66         let spawned = self.0.spawned.fetch_sub(1, Ordering::Release);
67         if spawned == 1 {
68             fence(Ordering::Acquire);
69             let _ = self.0.tx.send(());
70         }
71         result
72     }
73 }
74