1 #![warn(rust_2018_idioms)] 2 #![cfg(feature = "full")] 3 4 use tokio::io::{AsyncReadExt, AsyncWriteExt}; 5 use tokio::net::{TcpListener, TcpStream}; 6 use tokio::runtime::{self, Runtime}; 7 use tokio::sync::oneshot; 8 use tokio_test::{assert_err, assert_ok}; 9 10 use futures::future::poll_fn; 11 use std::future::Future; 12 use std::pin::Pin; 13 use std::sync::atomic::AtomicUsize; 14 use std::sync::atomic::Ordering::Relaxed; 15 use std::sync::{mpsc, Arc, Mutex}; 16 use std::task::{Context, Poll, Waker}; 17 18 #[test] 19 fn single_thread() { 20 // No panic when starting a runtime w/ a single thread 21 let _ = runtime::Builder::new_multi_thread() 22 .enable_all() 23 .worker_threads(1) 24 .build(); 25 } 26 27 #[test] 28 fn many_oneshot_futures() { 29 // used for notifying the main thread 30 const NUM: usize = 1_000; 31 32 for _ in 0..5 { 33 let (tx, rx) = mpsc::channel(); 34 35 let rt = rt(); 36 let cnt = Arc::new(AtomicUsize::new(0)); 37 38 for _ in 0..NUM { 39 let cnt = cnt.clone(); 40 let tx = tx.clone(); 41 42 rt.spawn(async move { 43 let num = cnt.fetch_add(1, Relaxed) + 1; 44 45 if num == NUM { 46 tx.send(()).unwrap(); 47 } 48 }); 49 } 50 51 rx.recv().unwrap(); 52 53 // Wait for the pool to shutdown 54 drop(rt); 55 } 56 } 57 58 #[test] 59 fn many_multishot_futures() { 60 const CHAIN: usize = 200; 61 const CYCLES: usize = 5; 62 const TRACKS: usize = 50; 63 64 for _ in 0..50 { 65 let rt = rt(); 66 let mut start_txs = Vec::with_capacity(TRACKS); 67 let mut final_rxs = Vec::with_capacity(TRACKS); 68 69 for _ in 0..TRACKS { 70 let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10); 71 72 for _ in 0..CHAIN { 73 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10); 74 75 // Forward all the messages 76 rt.spawn(async move { 77 while let Some(v) = chain_rx.recv().await { 78 next_tx.send(v).await.unwrap(); 79 } 80 }); 81 82 chain_rx = next_rx; 83 } 84 85 // This final task cycles if needed 86 let (final_tx, final_rx) = tokio::sync::mpsc::channel(10); 87 let cycle_tx = start_tx.clone(); 88 let mut rem = CYCLES; 89 90 rt.spawn(async move { 91 for _ in 0..CYCLES { 92 let msg = chain_rx.recv().await.unwrap(); 93 94 rem -= 1; 95 96 if rem == 0 { 97 final_tx.send(msg).await.unwrap(); 98 } else { 99 cycle_tx.send(msg).await.unwrap(); 100 } 101 } 102 }); 103 104 start_txs.push(start_tx); 105 final_rxs.push(final_rx); 106 } 107 108 { 109 rt.block_on(async move { 110 for start_tx in start_txs { 111 start_tx.send("ping").await.unwrap(); 112 } 113 114 for mut final_rx in final_rxs { 115 final_rx.recv().await.unwrap(); 116 } 117 }); 118 } 119 } 120 } 121 122 #[test] 123 fn spawn_shutdown() { 124 let rt = rt(); 125 let (tx, rx) = mpsc::channel(); 126 127 rt.block_on(async { 128 tokio::spawn(client_server(tx.clone())); 129 }); 130 131 // Use spawner 132 rt.spawn(client_server(tx)); 133 134 assert_ok!(rx.recv()); 135 assert_ok!(rx.recv()); 136 137 drop(rt); 138 assert_err!(rx.try_recv()); 139 } 140 141 async fn client_server(tx: mpsc::Sender<()>) { 142 let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); 143 144 // Get the assigned address 145 let addr = assert_ok!(server.local_addr()); 146 147 // Spawn the server 148 tokio::spawn(async move { 149 // Accept a socket 150 let (mut socket, _) = server.accept().await.unwrap(); 151 152 // Write some data 153 socket.write_all(b"hello").await.unwrap(); 154 }); 155 156 let mut client = TcpStream::connect(&addr).await.unwrap(); 157 158 let mut buf = vec![]; 159 client.read_to_end(&mut buf).await.unwrap(); 160 161 assert_eq!(buf, b"hello"); 162 tx.send(()).unwrap(); 163 } 164 165 #[test] 166 fn drop_threadpool_drops_futures() { 167 for _ in 0..1_000 { 168 let num_inc = Arc::new(AtomicUsize::new(0)); 169 let num_dec = Arc::new(AtomicUsize::new(0)); 170 let num_drop = Arc::new(AtomicUsize::new(0)); 171 172 struct Never(Arc<AtomicUsize>); 173 174 impl Future for Never { 175 type Output = (); 176 177 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { 178 Poll::Pending 179 } 180 } 181 182 impl Drop for Never { 183 fn drop(&mut self) { 184 self.0.fetch_add(1, Relaxed); 185 } 186 } 187 188 let a = num_inc.clone(); 189 let b = num_dec.clone(); 190 191 let rt = runtime::Builder::new_multi_thread() 192 .enable_all() 193 .on_thread_start(move || { 194 a.fetch_add(1, Relaxed); 195 }) 196 .on_thread_stop(move || { 197 b.fetch_add(1, Relaxed); 198 }) 199 .build() 200 .unwrap(); 201 202 rt.spawn(Never(num_drop.clone())); 203 204 // Wait for the pool to shutdown 205 drop(rt); 206 207 // Assert that only a single thread was spawned. 208 let a = num_inc.load(Relaxed); 209 assert!(a >= 1); 210 211 // Assert that all threads shutdown 212 let b = num_dec.load(Relaxed); 213 assert_eq!(a, b); 214 215 // Assert that the future was dropped 216 let c = num_drop.load(Relaxed); 217 assert_eq!(c, 1); 218 } 219 } 220 221 #[test] 222 fn start_stop_callbacks_called() { 223 use std::sync::atomic::{AtomicUsize, Ordering}; 224 225 let after_start = Arc::new(AtomicUsize::new(0)); 226 let before_stop = Arc::new(AtomicUsize::new(0)); 227 228 let after_inner = after_start.clone(); 229 let before_inner = before_stop.clone(); 230 let rt = tokio::runtime::Builder::new_multi_thread() 231 .enable_all() 232 .on_thread_start(move || { 233 after_inner.clone().fetch_add(1, Ordering::Relaxed); 234 }) 235 .on_thread_stop(move || { 236 before_inner.clone().fetch_add(1, Ordering::Relaxed); 237 }) 238 .build() 239 .unwrap(); 240 241 let (tx, rx) = oneshot::channel(); 242 243 rt.spawn(async move { 244 assert_ok!(tx.send(())); 245 }); 246 247 assert_ok!(rt.block_on(rx)); 248 249 drop(rt); 250 251 assert!(after_start.load(Ordering::Relaxed) > 0); 252 assert!(before_stop.load(Ordering::Relaxed) > 0); 253 } 254 255 #[test] 256 fn blocking() { 257 // used for notifying the main thread 258 const NUM: usize = 1_000; 259 260 for _ in 0..10 { 261 let (tx, rx) = mpsc::channel(); 262 263 let rt = rt(); 264 let cnt = Arc::new(AtomicUsize::new(0)); 265 266 // there are four workers in the pool 267 // so, if we run 4 blocking tasks, we know that handoff must have happened 268 let block = Arc::new(std::sync::Barrier::new(5)); 269 for _ in 0..4 { 270 let block = block.clone(); 271 rt.spawn(async move { 272 tokio::task::block_in_place(move || { 273 block.wait(); 274 block.wait(); 275 }) 276 }); 277 } 278 block.wait(); 279 280 for _ in 0..NUM { 281 let cnt = cnt.clone(); 282 let tx = tx.clone(); 283 284 rt.spawn(async move { 285 let num = cnt.fetch_add(1, Relaxed) + 1; 286 287 if num == NUM { 288 tx.send(()).unwrap(); 289 } 290 }); 291 } 292 293 rx.recv().unwrap(); 294 295 // Wait for the pool to shutdown 296 block.wait(); 297 } 298 } 299 300 #[test] 301 fn multi_threadpool() { 302 use tokio::sync::oneshot; 303 304 let rt1 = rt(); 305 let rt2 = rt(); 306 307 let (tx, rx) = oneshot::channel(); 308 let (done_tx, done_rx) = mpsc::channel(); 309 310 rt2.spawn(async move { 311 rx.await.unwrap(); 312 done_tx.send(()).unwrap(); 313 }); 314 315 rt1.spawn(async move { 316 tx.send(()).unwrap(); 317 }); 318 319 done_rx.recv().unwrap(); 320 } 321 322 // When `block_in_place` returns, it attempts to reclaim the yielded runtime 323 // worker. In this case, the remainder of the task is on the runtime worker and 324 // must take part in the cooperative task budgeting system. 325 // 326 // The test ensures that, when this happens, attempting to consume from a 327 // channel yields occasionally even if there are values ready to receive. 328 #[test] 329 fn coop_and_block_in_place() { 330 let rt = tokio::runtime::Builder::new_multi_thread() 331 // Setting max threads to 1 prevents another thread from claiming the 332 // runtime worker yielded as part of `block_in_place` and guarantees the 333 // same thread will reclaim the worker at the end of the 334 // `block_in_place` call. 335 .max_blocking_threads(1) 336 .build() 337 .unwrap(); 338 339 rt.block_on(async move { 340 let (tx, mut rx) = tokio::sync::mpsc::channel(1024); 341 342 // Fill the channel 343 for _ in 0..1024 { 344 tx.send(()).await.unwrap(); 345 } 346 347 drop(tx); 348 349 tokio::spawn(async move { 350 // Block in place without doing anything 351 tokio::task::block_in_place(|| {}); 352 353 // Receive all the values, this should trigger a `Pending` as the 354 // coop limit will be reached. 355 poll_fn(|cx| { 356 while let Poll::Ready(v) = { 357 tokio::pin! { 358 let fut = rx.recv(); 359 } 360 361 Pin::new(&mut fut).poll(cx) 362 } { 363 if v.is_none() { 364 panic!("did not yield"); 365 } 366 } 367 368 Poll::Ready(()) 369 }) 370 .await 371 }) 372 .await 373 .unwrap(); 374 }); 375 } 376 377 // Testing this does not panic 378 #[test] 379 fn max_blocking_threads() { 380 let _rt = tokio::runtime::Builder::new_multi_thread() 381 .max_blocking_threads(1) 382 .build() 383 .unwrap(); 384 } 385 386 #[test] 387 #[should_panic] 388 fn max_blocking_threads_set_to_zero() { 389 let _rt = tokio::runtime::Builder::new_multi_thread() 390 .max_blocking_threads(0) 391 .build() 392 .unwrap(); 393 } 394 395 #[tokio::test(flavor = "multi_thread", worker_threads = 2)] 396 async fn hang_on_shutdown() { 397 let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>(); 398 tokio::spawn(async move { 399 tokio::task::block_in_place(|| sync_rx.recv().ok()); 400 }); 401 402 tokio::spawn(async { 403 tokio::time::sleep(std::time::Duration::from_secs(2)).await; 404 drop(sync_tx); 405 }); 406 tokio::time::sleep(std::time::Duration::from_secs(1)).await; 407 } 408 409 /// Demonstrates tokio-rs/tokio#3869 410 #[test] 411 fn wake_during_shutdown() { 412 struct Shared { 413 waker: Option<Waker>, 414 } 415 416 struct MyFuture { 417 shared: Arc<Mutex<Shared>>, 418 put_waker: bool, 419 } 420 421 impl MyFuture { 422 fn new() -> (Self, Self) { 423 let shared = Arc::new(Mutex::new(Shared { waker: None })); 424 let f1 = MyFuture { 425 shared: shared.clone(), 426 put_waker: true, 427 }; 428 let f2 = MyFuture { 429 shared, 430 put_waker: false, 431 }; 432 (f1, f2) 433 } 434 } 435 436 impl Future for MyFuture { 437 type Output = (); 438 439 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { 440 let me = Pin::into_inner(self); 441 let mut lock = me.shared.lock().unwrap(); 442 println!("poll {}", me.put_waker); 443 if me.put_waker { 444 println!("putting"); 445 lock.waker = Some(cx.waker().clone()); 446 } 447 Poll::Pending 448 } 449 } 450 451 impl Drop for MyFuture { 452 fn drop(&mut self) { 453 println!("drop {} start", self.put_waker); 454 let mut lock = self.shared.lock().unwrap(); 455 if !self.put_waker { 456 lock.waker.take().unwrap().wake(); 457 } 458 drop(lock); 459 println!("drop {} stop", self.put_waker); 460 } 461 } 462 463 let rt = tokio::runtime::Builder::new_multi_thread() 464 .worker_threads(1) 465 .enable_all() 466 .build() 467 .unwrap(); 468 469 let (f1, f2) = MyFuture::new(); 470 471 rt.spawn(f1); 472 rt.spawn(f2); 473 474 rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await }); 475 } 476 477 #[should_panic] 478 #[tokio::test] 479 async fn test_block_in_place1() { 480 tokio::task::block_in_place(|| {}); 481 } 482 483 #[tokio::test(flavor = "multi_thread")] 484 async fn test_block_in_place2() { 485 tokio::task::block_in_place(|| {}); 486 } 487 488 #[should_panic] 489 #[tokio::main(flavor = "current_thread")] 490 #[test] 491 async fn test_block_in_place3() { 492 tokio::task::block_in_place(|| {}); 493 } 494 495 #[tokio::main] 496 #[test] 497 async fn test_block_in_place4() { 498 tokio::task::block_in_place(|| {}); 499 } 500 501 fn rt() -> Runtime { 502 Runtime::new().unwrap() 503 } 504