1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::io::{AsyncReadExt, AsyncWriteExt};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::runtime::{self, Runtime};
7 use tokio::sync::oneshot;
8 use tokio_test::{assert_err, assert_ok};
9 
10 use futures::future::poll_fn;
11 use std::future::Future;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicUsize;
14 use std::sync::atomic::Ordering::Relaxed;
15 use std::sync::{mpsc, Arc, Mutex};
16 use std::task::{Context, Poll, Waker};
17 
18 #[test]
19 fn single_thread() {
20     // No panic when starting a runtime w/ a single thread
21     let _ = runtime::Builder::new_multi_thread()
22         .enable_all()
23         .worker_threads(1)
24         .build();
25 }
26 
27 #[test]
28 fn many_oneshot_futures() {
29     // used for notifying the main thread
30     const NUM: usize = 1_000;
31 
32     for _ in 0..5 {
33         let (tx, rx) = mpsc::channel();
34 
35         let rt = rt();
36         let cnt = Arc::new(AtomicUsize::new(0));
37 
38         for _ in 0..NUM {
39             let cnt = cnt.clone();
40             let tx = tx.clone();
41 
42             rt.spawn(async move {
43                 let num = cnt.fetch_add(1, Relaxed) + 1;
44 
45                 if num == NUM {
46                     tx.send(()).unwrap();
47                 }
48             });
49         }
50 
51         rx.recv().unwrap();
52 
53         // Wait for the pool to shutdown
54         drop(rt);
55     }
56 }
57 
58 #[test]
59 fn many_multishot_futures() {
60     const CHAIN: usize = 200;
61     const CYCLES: usize = 5;
62     const TRACKS: usize = 50;
63 
64     for _ in 0..50 {
65         let rt = rt();
66         let mut start_txs = Vec::with_capacity(TRACKS);
67         let mut final_rxs = Vec::with_capacity(TRACKS);
68 
69         for _ in 0..TRACKS {
70             let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
71 
72             for _ in 0..CHAIN {
73                 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
74 
75                 // Forward all the messages
76                 rt.spawn(async move {
77                     while let Some(v) = chain_rx.recv().await {
78                         next_tx.send(v).await.unwrap();
79                     }
80                 });
81 
82                 chain_rx = next_rx;
83             }
84 
85             // This final task cycles if needed
86             let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
87             let cycle_tx = start_tx.clone();
88             let mut rem = CYCLES;
89 
90             rt.spawn(async move {
91                 for _ in 0..CYCLES {
92                     let msg = chain_rx.recv().await.unwrap();
93 
94                     rem -= 1;
95 
96                     if rem == 0 {
97                         final_tx.send(msg).await.unwrap();
98                     } else {
99                         cycle_tx.send(msg).await.unwrap();
100                     }
101                 }
102             });
103 
104             start_txs.push(start_tx);
105             final_rxs.push(final_rx);
106         }
107 
108         {
109             rt.block_on(async move {
110                 for start_tx in start_txs {
111                     start_tx.send("ping").await.unwrap();
112                 }
113 
114                 for mut final_rx in final_rxs {
115                     final_rx.recv().await.unwrap();
116                 }
117             });
118         }
119     }
120 }
121 
122 #[test]
123 fn spawn_shutdown() {
124     let rt = rt();
125     let (tx, rx) = mpsc::channel();
126 
127     rt.block_on(async {
128         tokio::spawn(client_server(tx.clone()));
129     });
130 
131     // Use spawner
132     rt.spawn(client_server(tx));
133 
134     assert_ok!(rx.recv());
135     assert_ok!(rx.recv());
136 
137     drop(rt);
138     assert_err!(rx.try_recv());
139 }
140 
141 async fn client_server(tx: mpsc::Sender<()>) {
142     let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
143 
144     // Get the assigned address
145     let addr = assert_ok!(server.local_addr());
146 
147     // Spawn the server
148     tokio::spawn(async move {
149         // Accept a socket
150         let (mut socket, _) = server.accept().await.unwrap();
151 
152         // Write some data
153         socket.write_all(b"hello").await.unwrap();
154     });
155 
156     let mut client = TcpStream::connect(&addr).await.unwrap();
157 
158     let mut buf = vec![];
159     client.read_to_end(&mut buf).await.unwrap();
160 
161     assert_eq!(buf, b"hello");
162     tx.send(()).unwrap();
163 }
164 
165 #[test]
166 fn drop_threadpool_drops_futures() {
167     for _ in 0..1_000 {
168         let num_inc = Arc::new(AtomicUsize::new(0));
169         let num_dec = Arc::new(AtomicUsize::new(0));
170         let num_drop = Arc::new(AtomicUsize::new(0));
171 
172         struct Never(Arc<AtomicUsize>);
173 
174         impl Future for Never {
175             type Output = ();
176 
177             fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
178                 Poll::Pending
179             }
180         }
181 
182         impl Drop for Never {
183             fn drop(&mut self) {
184                 self.0.fetch_add(1, Relaxed);
185             }
186         }
187 
188         let a = num_inc.clone();
189         let b = num_dec.clone();
190 
191         let rt = runtime::Builder::new_multi_thread()
192             .enable_all()
193             .on_thread_start(move || {
194                 a.fetch_add(1, Relaxed);
195             })
196             .on_thread_stop(move || {
197                 b.fetch_add(1, Relaxed);
198             })
199             .build()
200             .unwrap();
201 
202         rt.spawn(Never(num_drop.clone()));
203 
204         // Wait for the pool to shutdown
205         drop(rt);
206 
207         // Assert that only a single thread was spawned.
208         let a = num_inc.load(Relaxed);
209         assert!(a >= 1);
210 
211         // Assert that all threads shutdown
212         let b = num_dec.load(Relaxed);
213         assert_eq!(a, b);
214 
215         // Assert that the future was dropped
216         let c = num_drop.load(Relaxed);
217         assert_eq!(c, 1);
218     }
219 }
220 
221 #[test]
222 fn start_stop_callbacks_called() {
223     use std::sync::atomic::{AtomicUsize, Ordering};
224 
225     let after_start = Arc::new(AtomicUsize::new(0));
226     let before_stop = Arc::new(AtomicUsize::new(0));
227 
228     let after_inner = after_start.clone();
229     let before_inner = before_stop.clone();
230     let rt = tokio::runtime::Builder::new_multi_thread()
231         .enable_all()
232         .on_thread_start(move || {
233             after_inner.clone().fetch_add(1, Ordering::Relaxed);
234         })
235         .on_thread_stop(move || {
236             before_inner.clone().fetch_add(1, Ordering::Relaxed);
237         })
238         .build()
239         .unwrap();
240 
241     let (tx, rx) = oneshot::channel();
242 
243     rt.spawn(async move {
244         assert_ok!(tx.send(()));
245     });
246 
247     assert_ok!(rt.block_on(rx));
248 
249     drop(rt);
250 
251     assert!(after_start.load(Ordering::Relaxed) > 0);
252     assert!(before_stop.load(Ordering::Relaxed) > 0);
253 }
254 
255 #[test]
256 fn blocking() {
257     // used for notifying the main thread
258     const NUM: usize = 1_000;
259 
260     for _ in 0..10 {
261         let (tx, rx) = mpsc::channel();
262 
263         let rt = rt();
264         let cnt = Arc::new(AtomicUsize::new(0));
265 
266         // there are four workers in the pool
267         // so, if we run 4 blocking tasks, we know that handoff must have happened
268         let block = Arc::new(std::sync::Barrier::new(5));
269         for _ in 0..4 {
270             let block = block.clone();
271             rt.spawn(async move {
272                 tokio::task::block_in_place(move || {
273                     block.wait();
274                     block.wait();
275                 })
276             });
277         }
278         block.wait();
279 
280         for _ in 0..NUM {
281             let cnt = cnt.clone();
282             let tx = tx.clone();
283 
284             rt.spawn(async move {
285                 let num = cnt.fetch_add(1, Relaxed) + 1;
286 
287                 if num == NUM {
288                     tx.send(()).unwrap();
289                 }
290             });
291         }
292 
293         rx.recv().unwrap();
294 
295         // Wait for the pool to shutdown
296         block.wait();
297     }
298 }
299 
300 #[test]
301 fn multi_threadpool() {
302     use tokio::sync::oneshot;
303 
304     let rt1 = rt();
305     let rt2 = rt();
306 
307     let (tx, rx) = oneshot::channel();
308     let (done_tx, done_rx) = mpsc::channel();
309 
310     rt2.spawn(async move {
311         rx.await.unwrap();
312         done_tx.send(()).unwrap();
313     });
314 
315     rt1.spawn(async move {
316         tx.send(()).unwrap();
317     });
318 
319     done_rx.recv().unwrap();
320 }
321 
322 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
323 // worker. In this case, the remainder of the task is on the runtime worker and
324 // must take part in the cooperative task budgeting system.
325 //
326 // The test ensures that, when this happens, attempting to consume from a
327 // channel yields occasionally even if there are values ready to receive.
328 #[test]
329 fn coop_and_block_in_place() {
330     let rt = tokio::runtime::Builder::new_multi_thread()
331         // Setting max threads to 1 prevents another thread from claiming the
332         // runtime worker yielded as part of `block_in_place` and guarantees the
333         // same thread will reclaim the worker at the end of the
334         // `block_in_place` call.
335         .max_blocking_threads(1)
336         .build()
337         .unwrap();
338 
339     rt.block_on(async move {
340         let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
341 
342         // Fill the channel
343         for _ in 0..1024 {
344             tx.send(()).await.unwrap();
345         }
346 
347         drop(tx);
348 
349         tokio::spawn(async move {
350             // Block in place without doing anything
351             tokio::task::block_in_place(|| {});
352 
353             // Receive all the values, this should trigger a `Pending` as the
354             // coop limit will be reached.
355             poll_fn(|cx| {
356                 while let Poll::Ready(v) = {
357                     tokio::pin! {
358                         let fut = rx.recv();
359                     }
360 
361                     Pin::new(&mut fut).poll(cx)
362                 } {
363                     if v.is_none() {
364                         panic!("did not yield");
365                     }
366                 }
367 
368                 Poll::Ready(())
369             })
370             .await
371         })
372         .await
373         .unwrap();
374     });
375 }
376 
377 // Testing this does not panic
378 #[test]
379 fn max_blocking_threads() {
380     let _rt = tokio::runtime::Builder::new_multi_thread()
381         .max_blocking_threads(1)
382         .build()
383         .unwrap();
384 }
385 
386 #[test]
387 #[should_panic]
388 fn max_blocking_threads_set_to_zero() {
389     let _rt = tokio::runtime::Builder::new_multi_thread()
390         .max_blocking_threads(0)
391         .build()
392         .unwrap();
393 }
394 
395 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
396 async fn hang_on_shutdown() {
397     let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
398     tokio::spawn(async move {
399         tokio::task::block_in_place(|| sync_rx.recv().ok());
400     });
401 
402     tokio::spawn(async {
403         tokio::time::sleep(std::time::Duration::from_secs(2)).await;
404         drop(sync_tx);
405     });
406     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
407 }
408 
409 /// Demonstrates tokio-rs/tokio#3869
410 #[test]
411 fn wake_during_shutdown() {
412     struct Shared {
413         waker: Option<Waker>,
414     }
415 
416     struct MyFuture {
417         shared: Arc<Mutex<Shared>>,
418         put_waker: bool,
419     }
420 
421     impl MyFuture {
422         fn new() -> (Self, Self) {
423             let shared = Arc::new(Mutex::new(Shared { waker: None }));
424             let f1 = MyFuture {
425                 shared: shared.clone(),
426                 put_waker: true,
427             };
428             let f2 = MyFuture {
429                 shared,
430                 put_waker: false,
431             };
432             (f1, f2)
433         }
434     }
435 
436     impl Future for MyFuture {
437         type Output = ();
438 
439         fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
440             let me = Pin::into_inner(self);
441             let mut lock = me.shared.lock().unwrap();
442             println!("poll {}", me.put_waker);
443             if me.put_waker {
444                 println!("putting");
445                 lock.waker = Some(cx.waker().clone());
446             }
447             Poll::Pending
448         }
449     }
450 
451     impl Drop for MyFuture {
452         fn drop(&mut self) {
453             println!("drop {} start", self.put_waker);
454             let mut lock = self.shared.lock().unwrap();
455             if !self.put_waker {
456                 lock.waker.take().unwrap().wake();
457             }
458             drop(lock);
459             println!("drop {} stop", self.put_waker);
460         }
461     }
462 
463     let rt = tokio::runtime::Builder::new_multi_thread()
464         .worker_threads(1)
465         .enable_all()
466         .build()
467         .unwrap();
468 
469     let (f1, f2) = MyFuture::new();
470 
471     rt.spawn(f1);
472     rt.spawn(f2);
473 
474     rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
475 }
476 
477 #[should_panic]
478 #[tokio::test]
479 async fn test_block_in_place1() {
480     tokio::task::block_in_place(|| {});
481 }
482 
483 #[tokio::test(flavor = "multi_thread")]
484 async fn test_block_in_place2() {
485     tokio::task::block_in_place(|| {});
486 }
487 
488 #[should_panic]
489 #[tokio::main(flavor = "current_thread")]
490 #[test]
491 async fn test_block_in_place3() {
492     tokio::task::block_in_place(|| {});
493 }
494 
495 #[tokio::main]
496 #[test]
497 async fn test_block_in_place4() {
498     tokio::task::block_in_place(|| {});
499 }
500 
501 fn rt() -> Runtime {
502     Runtime::new().unwrap()
503 }
504