1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3
4 use tokio::io::{AsyncReadExt, AsyncWriteExt};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::runtime::{self, Runtime};
7 use tokio::sync::oneshot;
8 use tokio_test::{assert_err, assert_ok};
9
10 use futures::future::poll_fn;
11 use std::future::Future;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicUsize;
14 use std::sync::atomic::Ordering::Relaxed;
15 use std::sync::{mpsc, Arc, Mutex};
16 use std::task::{Context, Poll, Waker};
17
18 #[test]
single_thread()19 fn single_thread() {
20 // No panic when starting a runtime w/ a single thread
21 let _ = runtime::Builder::new_multi_thread()
22 .enable_all()
23 .worker_threads(1)
24 .build();
25 }
26
27 #[test]
many_oneshot_futures()28 fn many_oneshot_futures() {
29 // used for notifying the main thread
30 const NUM: usize = 1_000;
31
32 for _ in 0..5 {
33 let (tx, rx) = mpsc::channel();
34
35 let rt = rt();
36 let cnt = Arc::new(AtomicUsize::new(0));
37
38 for _ in 0..NUM {
39 let cnt = cnt.clone();
40 let tx = tx.clone();
41
42 rt.spawn(async move {
43 let num = cnt.fetch_add(1, Relaxed) + 1;
44
45 if num == NUM {
46 tx.send(()).unwrap();
47 }
48 });
49 }
50
51 rx.recv().unwrap();
52
53 // Wait for the pool to shutdown
54 drop(rt);
55 }
56 }
57
58 #[test]
many_multishot_futures()59 fn many_multishot_futures() {
60 const CHAIN: usize = 200;
61 const CYCLES: usize = 5;
62 const TRACKS: usize = 50;
63
64 for _ in 0..50 {
65 let rt = rt();
66 let mut start_txs = Vec::with_capacity(TRACKS);
67 let mut final_rxs = Vec::with_capacity(TRACKS);
68
69 for _ in 0..TRACKS {
70 let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
71
72 for _ in 0..CHAIN {
73 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
74
75 // Forward all the messages
76 rt.spawn(async move {
77 while let Some(v) = chain_rx.recv().await {
78 next_tx.send(v).await.unwrap();
79 }
80 });
81
82 chain_rx = next_rx;
83 }
84
85 // This final task cycles if needed
86 let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
87 let cycle_tx = start_tx.clone();
88 let mut rem = CYCLES;
89
90 rt.spawn(async move {
91 for _ in 0..CYCLES {
92 let msg = chain_rx.recv().await.unwrap();
93
94 rem -= 1;
95
96 if rem == 0 {
97 final_tx.send(msg).await.unwrap();
98 } else {
99 cycle_tx.send(msg).await.unwrap();
100 }
101 }
102 });
103
104 start_txs.push(start_tx);
105 final_rxs.push(final_rx);
106 }
107
108 {
109 rt.block_on(async move {
110 for start_tx in start_txs {
111 start_tx.send("ping").await.unwrap();
112 }
113
114 for mut final_rx in final_rxs {
115 final_rx.recv().await.unwrap();
116 }
117 });
118 }
119 }
120 }
121
122 #[test]
spawn_shutdown()123 fn spawn_shutdown() {
124 let rt = rt();
125 let (tx, rx) = mpsc::channel();
126
127 rt.block_on(async {
128 tokio::spawn(client_server(tx.clone()));
129 });
130
131 // Use spawner
132 rt.spawn(client_server(tx));
133
134 assert_ok!(rx.recv());
135 assert_ok!(rx.recv());
136
137 drop(rt);
138 assert_err!(rx.try_recv());
139 }
140
client_server(tx: mpsc::Sender<()>)141 async fn client_server(tx: mpsc::Sender<()>) {
142 let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
143
144 // Get the assigned address
145 let addr = assert_ok!(server.local_addr());
146
147 // Spawn the server
148 tokio::spawn(async move {
149 // Accept a socket
150 let (mut socket, _) = server.accept().await.unwrap();
151
152 // Write some data
153 socket.write_all(b"hello").await.unwrap();
154 });
155
156 let mut client = TcpStream::connect(&addr).await.unwrap();
157
158 let mut buf = vec![];
159 client.read_to_end(&mut buf).await.unwrap();
160
161 assert_eq!(buf, b"hello");
162 tx.send(()).unwrap();
163 }
164
165 #[test]
drop_threadpool_drops_futures()166 fn drop_threadpool_drops_futures() {
167 for _ in 0..1_000 {
168 let num_inc = Arc::new(AtomicUsize::new(0));
169 let num_dec = Arc::new(AtomicUsize::new(0));
170 let num_drop = Arc::new(AtomicUsize::new(0));
171
172 struct Never(Arc<AtomicUsize>);
173
174 impl Future for Never {
175 type Output = ();
176
177 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
178 Poll::Pending
179 }
180 }
181
182 impl Drop for Never {
183 fn drop(&mut self) {
184 self.0.fetch_add(1, Relaxed);
185 }
186 }
187
188 let a = num_inc.clone();
189 let b = num_dec.clone();
190
191 let rt = runtime::Builder::new_multi_thread()
192 .enable_all()
193 .on_thread_start(move || {
194 a.fetch_add(1, Relaxed);
195 })
196 .on_thread_stop(move || {
197 b.fetch_add(1, Relaxed);
198 })
199 .build()
200 .unwrap();
201
202 rt.spawn(Never(num_drop.clone()));
203
204 // Wait for the pool to shutdown
205 drop(rt);
206
207 // Assert that only a single thread was spawned.
208 let a = num_inc.load(Relaxed);
209 assert!(a >= 1);
210
211 // Assert that all threads shutdown
212 let b = num_dec.load(Relaxed);
213 assert_eq!(a, b);
214
215 // Assert that the future was dropped
216 let c = num_drop.load(Relaxed);
217 assert_eq!(c, 1);
218 }
219 }
220
221 #[test]
start_stop_callbacks_called()222 fn start_stop_callbacks_called() {
223 use std::sync::atomic::{AtomicUsize, Ordering};
224
225 let after_start = Arc::new(AtomicUsize::new(0));
226 let before_stop = Arc::new(AtomicUsize::new(0));
227
228 let after_inner = after_start.clone();
229 let before_inner = before_stop.clone();
230 let rt = tokio::runtime::Builder::new_multi_thread()
231 .enable_all()
232 .on_thread_start(move || {
233 after_inner.clone().fetch_add(1, Ordering::Relaxed);
234 })
235 .on_thread_stop(move || {
236 before_inner.clone().fetch_add(1, Ordering::Relaxed);
237 })
238 .build()
239 .unwrap();
240
241 let (tx, rx) = oneshot::channel();
242
243 rt.spawn(async move {
244 assert_ok!(tx.send(()));
245 });
246
247 assert_ok!(rt.block_on(rx));
248
249 drop(rt);
250
251 assert!(after_start.load(Ordering::Relaxed) > 0);
252 assert!(before_stop.load(Ordering::Relaxed) > 0);
253 }
254
255 #[test]
blocking()256 fn blocking() {
257 // used for notifying the main thread
258 const NUM: usize = 1_000;
259
260 for _ in 0..10 {
261 let (tx, rx) = mpsc::channel();
262
263 let rt = rt();
264 let cnt = Arc::new(AtomicUsize::new(0));
265
266 // there are four workers in the pool
267 // so, if we run 4 blocking tasks, we know that handoff must have happened
268 let block = Arc::new(std::sync::Barrier::new(5));
269 for _ in 0..4 {
270 let block = block.clone();
271 rt.spawn(async move {
272 tokio::task::block_in_place(move || {
273 block.wait();
274 block.wait();
275 })
276 });
277 }
278 block.wait();
279
280 for _ in 0..NUM {
281 let cnt = cnt.clone();
282 let tx = tx.clone();
283
284 rt.spawn(async move {
285 let num = cnt.fetch_add(1, Relaxed) + 1;
286
287 if num == NUM {
288 tx.send(()).unwrap();
289 }
290 });
291 }
292
293 rx.recv().unwrap();
294
295 // Wait for the pool to shutdown
296 block.wait();
297 }
298 }
299
300 #[test]
multi_threadpool()301 fn multi_threadpool() {
302 use tokio::sync::oneshot;
303
304 let rt1 = rt();
305 let rt2 = rt();
306
307 let (tx, rx) = oneshot::channel();
308 let (done_tx, done_rx) = mpsc::channel();
309
310 rt2.spawn(async move {
311 rx.await.unwrap();
312 done_tx.send(()).unwrap();
313 });
314
315 rt1.spawn(async move {
316 tx.send(()).unwrap();
317 });
318
319 done_rx.recv().unwrap();
320 }
321
322 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
323 // worker. In this case, the remainder of the task is on the runtime worker and
324 // must take part in the cooperative task budgeting system.
325 //
326 // The test ensures that, when this happens, attempting to consume from a
327 // channel yields occasionally even if there are values ready to receive.
328 #[test]
coop_and_block_in_place()329 fn coop_and_block_in_place() {
330 let rt = tokio::runtime::Builder::new_multi_thread()
331 // Setting max threads to 1 prevents another thread from claiming the
332 // runtime worker yielded as part of `block_in_place` and guarantees the
333 // same thread will reclaim the worker at the end of the
334 // `block_in_place` call.
335 .max_blocking_threads(1)
336 .build()
337 .unwrap();
338
339 rt.block_on(async move {
340 let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
341
342 // Fill the channel
343 for _ in 0..1024 {
344 tx.send(()).await.unwrap();
345 }
346
347 drop(tx);
348
349 tokio::spawn(async move {
350 // Block in place without doing anything
351 tokio::task::block_in_place(|| {});
352
353 // Receive all the values, this should trigger a `Pending` as the
354 // coop limit will be reached.
355 poll_fn(|cx| {
356 while let Poll::Ready(v) = {
357 tokio::pin! {
358 let fut = rx.recv();
359 }
360
361 Pin::new(&mut fut).poll(cx)
362 } {
363 if v.is_none() {
364 panic!("did not yield");
365 }
366 }
367
368 Poll::Ready(())
369 })
370 .await
371 })
372 .await
373 .unwrap();
374 });
375 }
376
377 // Testing this does not panic
378 #[test]
max_blocking_threads()379 fn max_blocking_threads() {
380 let _rt = tokio::runtime::Builder::new_multi_thread()
381 .max_blocking_threads(1)
382 .build()
383 .unwrap();
384 }
385
386 #[test]
387 #[should_panic]
max_blocking_threads_set_to_zero()388 fn max_blocking_threads_set_to_zero() {
389 let _rt = tokio::runtime::Builder::new_multi_thread()
390 .max_blocking_threads(0)
391 .build()
392 .unwrap();
393 }
394
395 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
hang_on_shutdown()396 async fn hang_on_shutdown() {
397 let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
398 tokio::spawn(async move {
399 tokio::task::block_in_place(|| sync_rx.recv().ok());
400 });
401
402 tokio::spawn(async {
403 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
404 drop(sync_tx);
405 });
406 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
407 }
408
409 /// Demonstrates tokio-rs/tokio#3869
410 #[test]
wake_during_shutdown()411 fn wake_during_shutdown() {
412 struct Shared {
413 waker: Option<Waker>,
414 }
415
416 struct MyFuture {
417 shared: Arc<Mutex<Shared>>,
418 put_waker: bool,
419 }
420
421 impl MyFuture {
422 fn new() -> (Self, Self) {
423 let shared = Arc::new(Mutex::new(Shared { waker: None }));
424 let f1 = MyFuture {
425 shared: shared.clone(),
426 put_waker: true,
427 };
428 let f2 = MyFuture {
429 shared,
430 put_waker: false,
431 };
432 (f1, f2)
433 }
434 }
435
436 impl Future for MyFuture {
437 type Output = ();
438
439 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
440 let me = Pin::into_inner(self);
441 let mut lock = me.shared.lock().unwrap();
442 println!("poll {}", me.put_waker);
443 if me.put_waker {
444 println!("putting");
445 lock.waker = Some(cx.waker().clone());
446 }
447 Poll::Pending
448 }
449 }
450
451 impl Drop for MyFuture {
452 fn drop(&mut self) {
453 println!("drop {} start", self.put_waker);
454 let mut lock = self.shared.lock().unwrap();
455 if !self.put_waker {
456 lock.waker.take().unwrap().wake();
457 }
458 drop(lock);
459 println!("drop {} stop", self.put_waker);
460 }
461 }
462
463 let rt = tokio::runtime::Builder::new_multi_thread()
464 .worker_threads(1)
465 .enable_all()
466 .build()
467 .unwrap();
468
469 let (f1, f2) = MyFuture::new();
470
471 rt.spawn(f1);
472 rt.spawn(f2);
473
474 rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
475 }
476
477 #[should_panic]
478 #[tokio::test]
test_block_in_place1()479 async fn test_block_in_place1() {
480 tokio::task::block_in_place(|| {});
481 }
482
483 #[tokio::test(flavor = "multi_thread")]
test_block_in_place2()484 async fn test_block_in_place2() {
485 tokio::task::block_in_place(|| {});
486 }
487
488 #[should_panic]
489 #[tokio::main(flavor = "current_thread")]
490 #[test]
test_block_in_place3()491 async fn test_block_in_place3() {
492 tokio::task::block_in_place(|| {});
493 }
494
495 #[tokio::main]
496 #[test]
test_block_in_place4()497 async fn test_block_in_place4() {
498 tokio::task::block_in_place(|| {});
499 }
500
rt() -> Runtime501 fn rt() -> Runtime {
502 Runtime::new().unwrap()
503 }
504