1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::io::{AsyncReadExt, AsyncWriteExt};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::runtime::{self, Runtime};
7 use tokio::sync::oneshot;
8 use tokio_test::{assert_err, assert_ok};
9 
10 use futures::future::poll_fn;
11 use std::future::Future;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicUsize;
14 use std::sync::atomic::Ordering::Relaxed;
15 use std::sync::{mpsc, Arc, Mutex};
16 use std::task::{Context, Poll, Waker};
17 
18 #[test]
single_thread()19 fn single_thread() {
20     // No panic when starting a runtime w/ a single thread
21     let _ = runtime::Builder::new_multi_thread()
22         .enable_all()
23         .worker_threads(1)
24         .build();
25 }
26 
27 #[test]
many_oneshot_futures()28 fn many_oneshot_futures() {
29     // used for notifying the main thread
30     const NUM: usize = 1_000;
31 
32     for _ in 0..5 {
33         let (tx, rx) = mpsc::channel();
34 
35         let rt = rt();
36         let cnt = Arc::new(AtomicUsize::new(0));
37 
38         for _ in 0..NUM {
39             let cnt = cnt.clone();
40             let tx = tx.clone();
41 
42             rt.spawn(async move {
43                 let num = cnt.fetch_add(1, Relaxed) + 1;
44 
45                 if num == NUM {
46                     tx.send(()).unwrap();
47                 }
48             });
49         }
50 
51         rx.recv().unwrap();
52 
53         // Wait for the pool to shutdown
54         drop(rt);
55     }
56 }
57 #[test]
many_multishot_futures()58 fn many_multishot_futures() {
59     const CHAIN: usize = 200;
60     const CYCLES: usize = 5;
61     const TRACKS: usize = 50;
62 
63     for _ in 0..50 {
64         let rt = rt();
65         let mut start_txs = Vec::with_capacity(TRACKS);
66         let mut final_rxs = Vec::with_capacity(TRACKS);
67 
68         for _ in 0..TRACKS {
69             let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
70 
71             for _ in 0..CHAIN {
72                 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
73 
74                 // Forward all the messages
75                 rt.spawn(async move {
76                     while let Some(v) = chain_rx.recv().await {
77                         next_tx.send(v).await.unwrap();
78                     }
79                 });
80 
81                 chain_rx = next_rx;
82             }
83 
84             // This final task cycles if needed
85             let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
86             let cycle_tx = start_tx.clone();
87             let mut rem = CYCLES;
88 
89             rt.spawn(async move {
90                 for _ in 0..CYCLES {
91                     let msg = chain_rx.recv().await.unwrap();
92 
93                     rem -= 1;
94 
95                     if rem == 0 {
96                         final_tx.send(msg).await.unwrap();
97                     } else {
98                         cycle_tx.send(msg).await.unwrap();
99                     }
100                 }
101             });
102 
103             start_txs.push(start_tx);
104             final_rxs.push(final_rx);
105         }
106 
107         {
108             rt.block_on(async move {
109                 for start_tx in start_txs {
110                     start_tx.send("ping").await.unwrap();
111                 }
112 
113                 for mut final_rx in final_rxs {
114                     final_rx.recv().await.unwrap();
115                 }
116             });
117         }
118     }
119 }
120 
121 #[test]
spawn_shutdown()122 fn spawn_shutdown() {
123     let rt = rt();
124     let (tx, rx) = mpsc::channel();
125 
126     rt.block_on(async {
127         tokio::spawn(client_server(tx.clone()));
128     });
129 
130     // Use spawner
131     rt.spawn(client_server(tx));
132 
133     assert_ok!(rx.recv());
134     assert_ok!(rx.recv());
135 
136     drop(rt);
137     assert_err!(rx.try_recv());
138 }
139 
client_server(tx: mpsc::Sender<()>)140 async fn client_server(tx: mpsc::Sender<()>) {
141     let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
142 
143     // Get the assigned address
144     let addr = assert_ok!(server.local_addr());
145 
146     // Spawn the server
147     tokio::spawn(async move {
148         // Accept a socket
149         let (mut socket, _) = server.accept().await.unwrap();
150 
151         // Write some data
152         socket.write_all(b"hello").await.unwrap();
153     });
154 
155     let mut client = TcpStream::connect(&addr).await.unwrap();
156 
157     let mut buf = vec![];
158     client.read_to_end(&mut buf).await.unwrap();
159 
160     assert_eq!(buf, b"hello");
161     tx.send(()).unwrap();
162 }
163 
164 #[test]
drop_threadpool_drops_futures()165 fn drop_threadpool_drops_futures() {
166     for _ in 0..1_000 {
167         let num_inc = Arc::new(AtomicUsize::new(0));
168         let num_dec = Arc::new(AtomicUsize::new(0));
169         let num_drop = Arc::new(AtomicUsize::new(0));
170 
171         struct Never(Arc<AtomicUsize>);
172 
173         impl Future for Never {
174             type Output = ();
175 
176             fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
177                 Poll::Pending
178             }
179         }
180 
181         impl Drop for Never {
182             fn drop(&mut self) {
183                 self.0.fetch_add(1, Relaxed);
184             }
185         }
186 
187         let a = num_inc.clone();
188         let b = num_dec.clone();
189 
190         let rt = runtime::Builder::new_multi_thread()
191             .enable_all()
192             .on_thread_start(move || {
193                 a.fetch_add(1, Relaxed);
194             })
195             .on_thread_stop(move || {
196                 b.fetch_add(1, Relaxed);
197             })
198             .build()
199             .unwrap();
200 
201         rt.spawn(Never(num_drop.clone()));
202 
203         // Wait for the pool to shutdown
204         drop(rt);
205 
206         // Assert that only a single thread was spawned.
207         let a = num_inc.load(Relaxed);
208         assert!(a >= 1);
209 
210         // Assert that all threads shutdown
211         let b = num_dec.load(Relaxed);
212         assert_eq!(a, b);
213 
214         // Assert that the future was dropped
215         let c = num_drop.load(Relaxed);
216         assert_eq!(c, 1);
217     }
218 }
219 
220 #[test]
start_stop_callbacks_called()221 fn start_stop_callbacks_called() {
222     use std::sync::atomic::{AtomicUsize, Ordering};
223 
224     let after_start = Arc::new(AtomicUsize::new(0));
225     let before_stop = Arc::new(AtomicUsize::new(0));
226 
227     let after_inner = after_start.clone();
228     let before_inner = before_stop.clone();
229     let rt = tokio::runtime::Builder::new_multi_thread()
230         .enable_all()
231         .on_thread_start(move || {
232             after_inner.clone().fetch_add(1, Ordering::Relaxed);
233         })
234         .on_thread_stop(move || {
235             before_inner.clone().fetch_add(1, Ordering::Relaxed);
236         })
237         .build()
238         .unwrap();
239 
240     let (tx, rx) = oneshot::channel();
241 
242     rt.spawn(async move {
243         assert_ok!(tx.send(()));
244     });
245 
246     assert_ok!(rt.block_on(rx));
247 
248     drop(rt);
249 
250     assert!(after_start.load(Ordering::Relaxed) > 0);
251     assert!(before_stop.load(Ordering::Relaxed) > 0);
252 }
253 
254 #[test]
blocking()255 fn blocking() {
256     // used for notifying the main thread
257     const NUM: usize = 1_000;
258 
259     for _ in 0..10 {
260         let (tx, rx) = mpsc::channel();
261 
262         let rt = rt();
263         let cnt = Arc::new(AtomicUsize::new(0));
264 
265         // there are four workers in the pool
266         // so, if we run 4 blocking tasks, we know that handoff must have happened
267         let block = Arc::new(std::sync::Barrier::new(5));
268         for _ in 0..4 {
269             let block = block.clone();
270             rt.spawn(async move {
271                 tokio::task::block_in_place(move || {
272                     block.wait();
273                     block.wait();
274                 })
275             });
276         }
277         block.wait();
278 
279         for _ in 0..NUM {
280             let cnt = cnt.clone();
281             let tx = tx.clone();
282 
283             rt.spawn(async move {
284                 let num = cnt.fetch_add(1, Relaxed) + 1;
285 
286                 if num == NUM {
287                     tx.send(()).unwrap();
288                 }
289             });
290         }
291 
292         rx.recv().unwrap();
293 
294         // Wait for the pool to shutdown
295         block.wait();
296     }
297 }
298 
299 #[test]
multi_threadpool()300 fn multi_threadpool() {
301     use tokio::sync::oneshot;
302 
303     let rt1 = rt();
304     let rt2 = rt();
305 
306     let (tx, rx) = oneshot::channel();
307     let (done_tx, done_rx) = mpsc::channel();
308 
309     rt2.spawn(async move {
310         rx.await.unwrap();
311         done_tx.send(()).unwrap();
312     });
313 
314     rt1.spawn(async move {
315         tx.send(()).unwrap();
316     });
317 
318     done_rx.recv().unwrap();
319 }
320 
321 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
322 // worker. In this case, the remainder of the task is on the runtime worker and
323 // must take part in the cooperative task budgeting system.
324 //
325 // The test ensures that, when this happens, attempting to consume from a
326 // channel yields occasionally even if there are values ready to receive.
327 #[test]
coop_and_block_in_place()328 fn coop_and_block_in_place() {
329     let rt = tokio::runtime::Builder::new_multi_thread()
330         // Setting max threads to 1 prevents another thread from claiming the
331         // runtime worker yielded as part of `block_in_place` and guarantees the
332         // same thread will reclaim the worker at the end of the
333         // `block_in_place` call.
334         .max_blocking_threads(1)
335         .build()
336         .unwrap();
337 
338     rt.block_on(async move {
339         let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
340 
341         // Fill the channel
342         for _ in 0..1024 {
343             tx.send(()).await.unwrap();
344         }
345 
346         drop(tx);
347 
348         tokio::spawn(async move {
349             // Block in place without doing anything
350             tokio::task::block_in_place(|| {});
351 
352             // Receive all the values, this should trigger a `Pending` as the
353             // coop limit will be reached.
354             poll_fn(|cx| {
355                 while let Poll::Ready(v) = {
356                     tokio::pin! {
357                         let fut = rx.recv();
358                     }
359 
360                     Pin::new(&mut fut).poll(cx)
361                 } {
362                     if v.is_none() {
363                         panic!("did not yield");
364                     }
365                 }
366 
367                 Poll::Ready(())
368             })
369             .await
370         })
371         .await
372         .unwrap();
373     });
374 }
375 
376 // Testing this does not panic
377 #[test]
max_blocking_threads()378 fn max_blocking_threads() {
379     let _rt = tokio::runtime::Builder::new_multi_thread()
380         .max_blocking_threads(1)
381         .build()
382         .unwrap();
383 }
384 
385 #[test]
386 #[should_panic]
max_blocking_threads_set_to_zero()387 fn max_blocking_threads_set_to_zero() {
388     let _rt = tokio::runtime::Builder::new_multi_thread()
389         .max_blocking_threads(0)
390         .build()
391         .unwrap();
392 }
393 
394 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
hang_on_shutdown()395 async fn hang_on_shutdown() {
396     let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
397     tokio::spawn(async move {
398         tokio::task::block_in_place(|| sync_rx.recv().ok());
399     });
400 
401     tokio::spawn(async {
402         tokio::time::sleep(std::time::Duration::from_secs(2)).await;
403         drop(sync_tx);
404     });
405     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
406 }
407 
408 /// Demonstrates tokio-rs/tokio#3869
409 #[test]
wake_during_shutdown()410 fn wake_during_shutdown() {
411     struct Shared {
412         waker: Option<Waker>,
413     }
414 
415     struct MyFuture {
416         shared: Arc<Mutex<Shared>>,
417         put_waker: bool,
418     }
419 
420     impl MyFuture {
421         fn new() -> (Self, Self) {
422             let shared = Arc::new(Mutex::new(Shared { waker: None }));
423             let f1 = MyFuture {
424                 shared: shared.clone(),
425                 put_waker: true,
426             };
427             let f2 = MyFuture {
428                 shared,
429                 put_waker: false,
430             };
431             (f1, f2)
432         }
433     }
434 
435     impl Future for MyFuture {
436         type Output = ();
437 
438         fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
439             let me = Pin::into_inner(self);
440             let mut lock = me.shared.lock().unwrap();
441             println!("poll {}", me.put_waker);
442             if me.put_waker {
443                 println!("putting");
444                 lock.waker = Some(cx.waker().clone());
445             }
446             Poll::Pending
447         }
448     }
449 
450     impl Drop for MyFuture {
451         fn drop(&mut self) {
452             println!("drop {} start", self.put_waker);
453             let mut lock = self.shared.lock().unwrap();
454             if !self.put_waker {
455                 lock.waker.take().unwrap().wake();
456             }
457             drop(lock);
458             println!("drop {} stop", self.put_waker);
459         }
460     }
461 
462     let rt = tokio::runtime::Builder::new_multi_thread()
463         .worker_threads(1)
464         .enable_all()
465         .build()
466         .unwrap();
467 
468     let (f1, f2) = MyFuture::new();
469 
470     rt.spawn(f1);
471     rt.spawn(f2);
472 
473     rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
474 }
475 
rt() -> Runtime476 fn rt() -> Runtime {
477     Runtime::new().unwrap()
478 }
479