1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3
4 use tokio::io::{AsyncReadExt, AsyncWriteExt};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::runtime::{self, Runtime};
7 use tokio::sync::oneshot;
8 use tokio_test::{assert_err, assert_ok};
9
10 use futures::future::poll_fn;
11 use std::future::Future;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicUsize;
14 use std::sync::atomic::Ordering::Relaxed;
15 use std::sync::{mpsc, Arc, Mutex};
16 use std::task::{Context, Poll, Waker};
17
18 #[test]
single_thread()19 fn single_thread() {
20 // No panic when starting a runtime w/ a single thread
21 let _ = runtime::Builder::new_multi_thread()
22 .enable_all()
23 .worker_threads(1)
24 .build();
25 }
26
27 #[test]
many_oneshot_futures()28 fn many_oneshot_futures() {
29 // used for notifying the main thread
30 const NUM: usize = 1_000;
31
32 for _ in 0..5 {
33 let (tx, rx) = mpsc::channel();
34
35 let rt = rt();
36 let cnt = Arc::new(AtomicUsize::new(0));
37
38 for _ in 0..NUM {
39 let cnt = cnt.clone();
40 let tx = tx.clone();
41
42 rt.spawn(async move {
43 let num = cnt.fetch_add(1, Relaxed) + 1;
44
45 if num == NUM {
46 tx.send(()).unwrap();
47 }
48 });
49 }
50
51 rx.recv().unwrap();
52
53 // Wait for the pool to shutdown
54 drop(rt);
55 }
56 }
57 #[test]
many_multishot_futures()58 fn many_multishot_futures() {
59 const CHAIN: usize = 200;
60 const CYCLES: usize = 5;
61 const TRACKS: usize = 50;
62
63 for _ in 0..50 {
64 let rt = rt();
65 let mut start_txs = Vec::with_capacity(TRACKS);
66 let mut final_rxs = Vec::with_capacity(TRACKS);
67
68 for _ in 0..TRACKS {
69 let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
70
71 for _ in 0..CHAIN {
72 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
73
74 // Forward all the messages
75 rt.spawn(async move {
76 while let Some(v) = chain_rx.recv().await {
77 next_tx.send(v).await.unwrap();
78 }
79 });
80
81 chain_rx = next_rx;
82 }
83
84 // This final task cycles if needed
85 let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
86 let cycle_tx = start_tx.clone();
87 let mut rem = CYCLES;
88
89 rt.spawn(async move {
90 for _ in 0..CYCLES {
91 let msg = chain_rx.recv().await.unwrap();
92
93 rem -= 1;
94
95 if rem == 0 {
96 final_tx.send(msg).await.unwrap();
97 } else {
98 cycle_tx.send(msg).await.unwrap();
99 }
100 }
101 });
102
103 start_txs.push(start_tx);
104 final_rxs.push(final_rx);
105 }
106
107 {
108 rt.block_on(async move {
109 for start_tx in start_txs {
110 start_tx.send("ping").await.unwrap();
111 }
112
113 for mut final_rx in final_rxs {
114 final_rx.recv().await.unwrap();
115 }
116 });
117 }
118 }
119 }
120
121 #[test]
spawn_shutdown()122 fn spawn_shutdown() {
123 let rt = rt();
124 let (tx, rx) = mpsc::channel();
125
126 rt.block_on(async {
127 tokio::spawn(client_server(tx.clone()));
128 });
129
130 // Use spawner
131 rt.spawn(client_server(tx));
132
133 assert_ok!(rx.recv());
134 assert_ok!(rx.recv());
135
136 drop(rt);
137 assert_err!(rx.try_recv());
138 }
139
client_server(tx: mpsc::Sender<()>)140 async fn client_server(tx: mpsc::Sender<()>) {
141 let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
142
143 // Get the assigned address
144 let addr = assert_ok!(server.local_addr());
145
146 // Spawn the server
147 tokio::spawn(async move {
148 // Accept a socket
149 let (mut socket, _) = server.accept().await.unwrap();
150
151 // Write some data
152 socket.write_all(b"hello").await.unwrap();
153 });
154
155 let mut client = TcpStream::connect(&addr).await.unwrap();
156
157 let mut buf = vec![];
158 client.read_to_end(&mut buf).await.unwrap();
159
160 assert_eq!(buf, b"hello");
161 tx.send(()).unwrap();
162 }
163
164 #[test]
drop_threadpool_drops_futures()165 fn drop_threadpool_drops_futures() {
166 for _ in 0..1_000 {
167 let num_inc = Arc::new(AtomicUsize::new(0));
168 let num_dec = Arc::new(AtomicUsize::new(0));
169 let num_drop = Arc::new(AtomicUsize::new(0));
170
171 struct Never(Arc<AtomicUsize>);
172
173 impl Future for Never {
174 type Output = ();
175
176 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
177 Poll::Pending
178 }
179 }
180
181 impl Drop for Never {
182 fn drop(&mut self) {
183 self.0.fetch_add(1, Relaxed);
184 }
185 }
186
187 let a = num_inc.clone();
188 let b = num_dec.clone();
189
190 let rt = runtime::Builder::new_multi_thread()
191 .enable_all()
192 .on_thread_start(move || {
193 a.fetch_add(1, Relaxed);
194 })
195 .on_thread_stop(move || {
196 b.fetch_add(1, Relaxed);
197 })
198 .build()
199 .unwrap();
200
201 rt.spawn(Never(num_drop.clone()));
202
203 // Wait for the pool to shutdown
204 drop(rt);
205
206 // Assert that only a single thread was spawned.
207 let a = num_inc.load(Relaxed);
208 assert!(a >= 1);
209
210 // Assert that all threads shutdown
211 let b = num_dec.load(Relaxed);
212 assert_eq!(a, b);
213
214 // Assert that the future was dropped
215 let c = num_drop.load(Relaxed);
216 assert_eq!(c, 1);
217 }
218 }
219
220 #[test]
start_stop_callbacks_called()221 fn start_stop_callbacks_called() {
222 use std::sync::atomic::{AtomicUsize, Ordering};
223
224 let after_start = Arc::new(AtomicUsize::new(0));
225 let before_stop = Arc::new(AtomicUsize::new(0));
226
227 let after_inner = after_start.clone();
228 let before_inner = before_stop.clone();
229 let rt = tokio::runtime::Builder::new_multi_thread()
230 .enable_all()
231 .on_thread_start(move || {
232 after_inner.clone().fetch_add(1, Ordering::Relaxed);
233 })
234 .on_thread_stop(move || {
235 before_inner.clone().fetch_add(1, Ordering::Relaxed);
236 })
237 .build()
238 .unwrap();
239
240 let (tx, rx) = oneshot::channel();
241
242 rt.spawn(async move {
243 assert_ok!(tx.send(()));
244 });
245
246 assert_ok!(rt.block_on(rx));
247
248 drop(rt);
249
250 assert!(after_start.load(Ordering::Relaxed) > 0);
251 assert!(before_stop.load(Ordering::Relaxed) > 0);
252 }
253
254 #[test]
blocking()255 fn blocking() {
256 // used for notifying the main thread
257 const NUM: usize = 1_000;
258
259 for _ in 0..10 {
260 let (tx, rx) = mpsc::channel();
261
262 let rt = rt();
263 let cnt = Arc::new(AtomicUsize::new(0));
264
265 // there are four workers in the pool
266 // so, if we run 4 blocking tasks, we know that handoff must have happened
267 let block = Arc::new(std::sync::Barrier::new(5));
268 for _ in 0..4 {
269 let block = block.clone();
270 rt.spawn(async move {
271 tokio::task::block_in_place(move || {
272 block.wait();
273 block.wait();
274 })
275 });
276 }
277 block.wait();
278
279 for _ in 0..NUM {
280 let cnt = cnt.clone();
281 let tx = tx.clone();
282
283 rt.spawn(async move {
284 let num = cnt.fetch_add(1, Relaxed) + 1;
285
286 if num == NUM {
287 tx.send(()).unwrap();
288 }
289 });
290 }
291
292 rx.recv().unwrap();
293
294 // Wait for the pool to shutdown
295 block.wait();
296 }
297 }
298
299 #[test]
multi_threadpool()300 fn multi_threadpool() {
301 use tokio::sync::oneshot;
302
303 let rt1 = rt();
304 let rt2 = rt();
305
306 let (tx, rx) = oneshot::channel();
307 let (done_tx, done_rx) = mpsc::channel();
308
309 rt2.spawn(async move {
310 rx.await.unwrap();
311 done_tx.send(()).unwrap();
312 });
313
314 rt1.spawn(async move {
315 tx.send(()).unwrap();
316 });
317
318 done_rx.recv().unwrap();
319 }
320
321 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
322 // worker. In this case, the remainder of the task is on the runtime worker and
323 // must take part in the cooperative task budgeting system.
324 //
325 // The test ensures that, when this happens, attempting to consume from a
326 // channel yields occasionally even if there are values ready to receive.
327 #[test]
coop_and_block_in_place()328 fn coop_and_block_in_place() {
329 let rt = tokio::runtime::Builder::new_multi_thread()
330 // Setting max threads to 1 prevents another thread from claiming the
331 // runtime worker yielded as part of `block_in_place` and guarantees the
332 // same thread will reclaim the worker at the end of the
333 // `block_in_place` call.
334 .max_blocking_threads(1)
335 .build()
336 .unwrap();
337
338 rt.block_on(async move {
339 let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
340
341 // Fill the channel
342 for _ in 0..1024 {
343 tx.send(()).await.unwrap();
344 }
345
346 drop(tx);
347
348 tokio::spawn(async move {
349 // Block in place without doing anything
350 tokio::task::block_in_place(|| {});
351
352 // Receive all the values, this should trigger a `Pending` as the
353 // coop limit will be reached.
354 poll_fn(|cx| {
355 while let Poll::Ready(v) = {
356 tokio::pin! {
357 let fut = rx.recv();
358 }
359
360 Pin::new(&mut fut).poll(cx)
361 } {
362 if v.is_none() {
363 panic!("did not yield");
364 }
365 }
366
367 Poll::Ready(())
368 })
369 .await
370 })
371 .await
372 .unwrap();
373 });
374 }
375
376 // Testing this does not panic
377 #[test]
max_blocking_threads()378 fn max_blocking_threads() {
379 let _rt = tokio::runtime::Builder::new_multi_thread()
380 .max_blocking_threads(1)
381 .build()
382 .unwrap();
383 }
384
385 #[test]
386 #[should_panic]
max_blocking_threads_set_to_zero()387 fn max_blocking_threads_set_to_zero() {
388 let _rt = tokio::runtime::Builder::new_multi_thread()
389 .max_blocking_threads(0)
390 .build()
391 .unwrap();
392 }
393
394 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
hang_on_shutdown()395 async fn hang_on_shutdown() {
396 let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
397 tokio::spawn(async move {
398 tokio::task::block_in_place(|| sync_rx.recv().ok());
399 });
400
401 tokio::spawn(async {
402 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
403 drop(sync_tx);
404 });
405 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
406 }
407
408 /// Demonstrates tokio-rs/tokio#3869
409 #[test]
wake_during_shutdown()410 fn wake_during_shutdown() {
411 struct Shared {
412 waker: Option<Waker>,
413 }
414
415 struct MyFuture {
416 shared: Arc<Mutex<Shared>>,
417 put_waker: bool,
418 }
419
420 impl MyFuture {
421 fn new() -> (Self, Self) {
422 let shared = Arc::new(Mutex::new(Shared { waker: None }));
423 let f1 = MyFuture {
424 shared: shared.clone(),
425 put_waker: true,
426 };
427 let f2 = MyFuture {
428 shared,
429 put_waker: false,
430 };
431 (f1, f2)
432 }
433 }
434
435 impl Future for MyFuture {
436 type Output = ();
437
438 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
439 let me = Pin::into_inner(self);
440 let mut lock = me.shared.lock().unwrap();
441 println!("poll {}", me.put_waker);
442 if me.put_waker {
443 println!("putting");
444 lock.waker = Some(cx.waker().clone());
445 }
446 Poll::Pending
447 }
448 }
449
450 impl Drop for MyFuture {
451 fn drop(&mut self) {
452 println!("drop {} start", self.put_waker);
453 let mut lock = self.shared.lock().unwrap();
454 if !self.put_waker {
455 lock.waker.take().unwrap().wake();
456 }
457 drop(lock);
458 println!("drop {} stop", self.put_waker);
459 }
460 }
461
462 let rt = tokio::runtime::Builder::new_multi_thread()
463 .worker_threads(1)
464 .enable_all()
465 .build()
466 .unwrap();
467
468 let (f1, f2) = MyFuture::new();
469
470 rt.spawn(f1);
471 rt.spawn(f2);
472
473 rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
474 }
475
rt() -> Runtime476 fn rt() -> Runtime {
477 Runtime::new().unwrap()
478 }
479