1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::io::{AsyncReadExt, AsyncWriteExt};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::runtime::{self, Runtime};
7 use tokio::sync::oneshot;
8 use tokio_test::{assert_err, assert_ok};
9 
10 use futures::future::poll_fn;
11 use std::future::Future;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicUsize;
14 use std::sync::atomic::Ordering::Relaxed;
15 use std::sync::{mpsc, Arc};
16 use std::task::{Context, Poll};
17 
18 #[test]
single_thread()19 fn single_thread() {
20     // No panic when starting a runtime w/ a single thread
21     let _ = runtime::Builder::new()
22         .threaded_scheduler()
23         .enable_all()
24         .core_threads(1)
25         .build();
26 }
27 
28 #[test]
many_oneshot_futures()29 fn many_oneshot_futures() {
30     // used for notifying the main thread
31     const NUM: usize = 1_000;
32 
33     for _ in 0..5 {
34         let (tx, rx) = mpsc::channel();
35 
36         let rt = rt();
37         let cnt = Arc::new(AtomicUsize::new(0));
38 
39         for _ in 0..NUM {
40             let cnt = cnt.clone();
41             let tx = tx.clone();
42 
43             rt.spawn(async move {
44                 let num = cnt.fetch_add(1, Relaxed) + 1;
45 
46                 if num == NUM {
47                     tx.send(()).unwrap();
48                 }
49             });
50         }
51 
52         rx.recv().unwrap();
53 
54         // Wait for the pool to shutdown
55         drop(rt);
56     }
57 }
58 #[test]
many_multishot_futures()59 fn many_multishot_futures() {
60     use tokio::sync::mpsc;
61 
62     const CHAIN: usize = 200;
63     const CYCLES: usize = 5;
64     const TRACKS: usize = 50;
65 
66     for _ in 0..50 {
67         let mut rt = rt();
68         let mut start_txs = Vec::with_capacity(TRACKS);
69         let mut final_rxs = Vec::with_capacity(TRACKS);
70 
71         for _ in 0..TRACKS {
72             let (start_tx, mut chain_rx) = mpsc::channel(10);
73 
74             for _ in 0..CHAIN {
75                 let (mut next_tx, next_rx) = mpsc::channel(10);
76 
77                 // Forward all the messages
78                 rt.spawn(async move {
79                     while let Some(v) = chain_rx.recv().await {
80                         next_tx.send(v).await.unwrap();
81                     }
82                 });
83 
84                 chain_rx = next_rx;
85             }
86 
87             // This final task cycles if needed
88             let (mut final_tx, final_rx) = mpsc::channel(10);
89             let mut cycle_tx = start_tx.clone();
90             let mut rem = CYCLES;
91 
92             rt.spawn(async move {
93                 for _ in 0..CYCLES {
94                     let msg = chain_rx.recv().await.unwrap();
95 
96                     rem -= 1;
97 
98                     if rem == 0 {
99                         final_tx.send(msg).await.unwrap();
100                     } else {
101                         cycle_tx.send(msg).await.unwrap();
102                     }
103                 }
104             });
105 
106             start_txs.push(start_tx);
107             final_rxs.push(final_rx);
108         }
109 
110         {
111             rt.block_on(async move {
112                 for mut start_tx in start_txs {
113                     start_tx.send("ping").await.unwrap();
114                 }
115 
116                 for mut final_rx in final_rxs {
117                     final_rx.recv().await.unwrap();
118                 }
119             });
120         }
121     }
122 }
123 
124 #[test]
spawn_shutdown()125 fn spawn_shutdown() {
126     let mut rt = rt();
127     let (tx, rx) = mpsc::channel();
128 
129     rt.block_on(async {
130         tokio::spawn(client_server(tx.clone()));
131     });
132 
133     // Use spawner
134     rt.spawn(client_server(tx));
135 
136     assert_ok!(rx.recv());
137     assert_ok!(rx.recv());
138 
139     drop(rt);
140     assert_err!(rx.try_recv());
141 }
142 
client_server(tx: mpsc::Sender<()>)143 async fn client_server(tx: mpsc::Sender<()>) {
144     let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
145 
146     // Get the assigned address
147     let addr = assert_ok!(server.local_addr());
148 
149     // Spawn the server
150     tokio::spawn(async move {
151         // Accept a socket
152         let (mut socket, _) = server.accept().await.unwrap();
153 
154         // Write some data
155         socket.write_all(b"hello").await.unwrap();
156     });
157 
158     let mut client = TcpStream::connect(&addr).await.unwrap();
159 
160     let mut buf = vec![];
161     client.read_to_end(&mut buf).await.unwrap();
162 
163     assert_eq!(buf, b"hello");
164     tx.send(()).unwrap();
165 }
166 
167 #[test]
drop_threadpool_drops_futures()168 fn drop_threadpool_drops_futures() {
169     for _ in 0..1_000 {
170         let num_inc = Arc::new(AtomicUsize::new(0));
171         let num_dec = Arc::new(AtomicUsize::new(0));
172         let num_drop = Arc::new(AtomicUsize::new(0));
173 
174         struct Never(Arc<AtomicUsize>);
175 
176         impl Future for Never {
177             type Output = ();
178 
179             fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
180                 Poll::Pending
181             }
182         }
183 
184         impl Drop for Never {
185             fn drop(&mut self) {
186                 self.0.fetch_add(1, Relaxed);
187             }
188         }
189 
190         let a = num_inc.clone();
191         let b = num_dec.clone();
192 
193         let rt = runtime::Builder::new()
194             .threaded_scheduler()
195             .enable_all()
196             .on_thread_start(move || {
197                 a.fetch_add(1, Relaxed);
198             })
199             .on_thread_stop(move || {
200                 b.fetch_add(1, Relaxed);
201             })
202             .build()
203             .unwrap();
204 
205         rt.spawn(Never(num_drop.clone()));
206 
207         // Wait for the pool to shutdown
208         drop(rt);
209 
210         // Assert that only a single thread was spawned.
211         let a = num_inc.load(Relaxed);
212         assert!(a >= 1);
213 
214         // Assert that all threads shutdown
215         let b = num_dec.load(Relaxed);
216         assert_eq!(a, b);
217 
218         // Assert that the future was dropped
219         let c = num_drop.load(Relaxed);
220         assert_eq!(c, 1);
221     }
222 }
223 
224 #[test]
start_stop_callbacks_called()225 fn start_stop_callbacks_called() {
226     use std::sync::atomic::{AtomicUsize, Ordering};
227 
228     let after_start = Arc::new(AtomicUsize::new(0));
229     let before_stop = Arc::new(AtomicUsize::new(0));
230 
231     let after_inner = after_start.clone();
232     let before_inner = before_stop.clone();
233     let mut rt = tokio::runtime::Builder::new()
234         .threaded_scheduler()
235         .enable_all()
236         .on_thread_start(move || {
237             after_inner.clone().fetch_add(1, Ordering::Relaxed);
238         })
239         .on_thread_stop(move || {
240             before_inner.clone().fetch_add(1, Ordering::Relaxed);
241         })
242         .build()
243         .unwrap();
244 
245     let (tx, rx) = oneshot::channel();
246 
247     rt.spawn(async move {
248         assert_ok!(tx.send(()));
249     });
250 
251     assert_ok!(rt.block_on(rx));
252 
253     drop(rt);
254 
255     assert!(after_start.load(Ordering::Relaxed) > 0);
256     assert!(before_stop.load(Ordering::Relaxed) > 0);
257 }
258 
259 #[test]
blocking()260 fn blocking() {
261     // used for notifying the main thread
262     const NUM: usize = 1_000;
263 
264     for _ in 0..10 {
265         let (tx, rx) = mpsc::channel();
266 
267         let rt = rt();
268         let cnt = Arc::new(AtomicUsize::new(0));
269 
270         // there are four workers in the pool
271         // so, if we run 4 blocking tasks, we know that handoff must have happened
272         let block = Arc::new(std::sync::Barrier::new(5));
273         for _ in 0..4 {
274             let block = block.clone();
275             rt.spawn(async move {
276                 tokio::task::block_in_place(move || {
277                     block.wait();
278                     block.wait();
279                 })
280             });
281         }
282         block.wait();
283 
284         for _ in 0..NUM {
285             let cnt = cnt.clone();
286             let tx = tx.clone();
287 
288             rt.spawn(async move {
289                 let num = cnt.fetch_add(1, Relaxed) + 1;
290 
291                 if num == NUM {
292                     tx.send(()).unwrap();
293                 }
294             });
295         }
296 
297         rx.recv().unwrap();
298 
299         // Wait for the pool to shutdown
300         block.wait();
301     }
302 }
303 
304 #[test]
multi_threadpool()305 fn multi_threadpool() {
306     use tokio::sync::oneshot;
307 
308     let rt1 = rt();
309     let rt2 = rt();
310 
311     let (tx, rx) = oneshot::channel();
312     let (done_tx, done_rx) = mpsc::channel();
313 
314     rt2.spawn(async move {
315         rx.await.unwrap();
316         done_tx.send(()).unwrap();
317     });
318 
319     rt1.spawn(async move {
320         tx.send(()).unwrap();
321     });
322 
323     done_rx.recv().unwrap();
324 }
325 
326 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
327 // worker. In this case, the remainder of the task is on the runtime worker and
328 // must take part in the cooperative task budgeting system.
329 //
330 // The test ensures that, when this happens, attempting to consume from a
331 // channel yields occasionally even if there are values ready to receive.
332 #[test]
coop_and_block_in_place()333 fn coop_and_block_in_place() {
334     use tokio::sync::mpsc;
335 
336     let mut rt = tokio::runtime::Builder::new()
337         .threaded_scheduler()
338         // Setting max threads to 1 prevents another thread from claiming the
339         // runtime worker yielded as part of `block_in_place` and guarantees the
340         // same thread will reclaim the worker at the end of the
341         // `block_in_place` call.
342         .max_threads(1)
343         .build()
344         .unwrap();
345 
346     rt.block_on(async move {
347         let (mut tx, mut rx) = mpsc::channel(1024);
348 
349         // Fill the channel
350         for _ in 0..1024 {
351             tx.send(()).await.unwrap();
352         }
353 
354         drop(tx);
355 
356         tokio::spawn(async move {
357             // Block in place without doing anything
358             tokio::task::block_in_place(|| {});
359 
360             // Receive all the values, this should trigger a `Pending` as the
361             // coop limit will be reached.
362             poll_fn(|cx| {
363                 while let Poll::Ready(v) = {
364                     tokio::pin! {
365                         let fut = rx.recv();
366                     }
367 
368                     Pin::new(&mut fut).poll(cx)
369                 } {
370                     if v.is_none() {
371                         panic!("did not yield");
372                     }
373                 }
374 
375                 Poll::Ready(())
376             })
377             .await
378         })
379         .await
380         .unwrap();
381     });
382 }
383 
384 // Testing this does not panic
385 #[test]
max_threads()386 fn max_threads() {
387     let _rt = tokio::runtime::Builder::new()
388         .threaded_scheduler()
389         .max_threads(1)
390         .build()
391         .unwrap();
392 }
393 
rt() -> Runtime394 fn rt() -> Runtime {
395     Runtime::new().unwrap()
396 }
397