1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3
4 use futures::{
5 future::{pending, ready},
6 FutureExt,
7 };
8
9 use tokio::runtime::{self, Runtime};
10 use tokio::sync::{mpsc, oneshot};
11 use tokio::task::{self, LocalSet};
12 use tokio::time;
13
14 use std::cell::Cell;
15 use std::sync::atomic::Ordering::{self, SeqCst};
16 use std::sync::atomic::{AtomicBool, AtomicUsize};
17 use std::time::Duration;
18
19 #[tokio::test(flavor = "current_thread")]
local_basic_scheduler()20 async fn local_basic_scheduler() {
21 LocalSet::new()
22 .run_until(async {
23 task::spawn_local(async {}).await.unwrap();
24 })
25 .await;
26 }
27
28 #[tokio::test(flavor = "multi_thread")]
local_threadpool()29 async fn local_threadpool() {
30 thread_local! {
31 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
32 }
33
34 ON_RT_THREAD.with(|cell| cell.set(true));
35
36 LocalSet::new()
37 .run_until(async {
38 assert!(ON_RT_THREAD.with(|cell| cell.get()));
39 task::spawn_local(async {
40 assert!(ON_RT_THREAD.with(|cell| cell.get()));
41 })
42 .await
43 .unwrap();
44 })
45 .await;
46 }
47
48 #[tokio::test(flavor = "multi_thread")]
localset_future_threadpool()49 async fn localset_future_threadpool() {
50 thread_local! {
51 static ON_LOCAL_THREAD: Cell<bool> = Cell::new(false);
52 }
53
54 ON_LOCAL_THREAD.with(|cell| cell.set(true));
55
56 let local = LocalSet::new();
57 local.spawn_local(async move {
58 assert!(ON_LOCAL_THREAD.with(|cell| cell.get()));
59 });
60 local.await;
61 }
62
63 #[tokio::test(flavor = "multi_thread")]
localset_future_timers()64 async fn localset_future_timers() {
65 static RAN1: AtomicBool = AtomicBool::new(false);
66 static RAN2: AtomicBool = AtomicBool::new(false);
67
68 let local = LocalSet::new();
69 local.spawn_local(async move {
70 time::sleep(Duration::from_millis(5)).await;
71 RAN1.store(true, Ordering::SeqCst);
72 });
73 local.spawn_local(async move {
74 time::sleep(Duration::from_millis(10)).await;
75 RAN2.store(true, Ordering::SeqCst);
76 });
77 local.await;
78 assert!(RAN1.load(Ordering::SeqCst));
79 assert!(RAN2.load(Ordering::SeqCst));
80 }
81
82 #[tokio::test]
localset_future_drives_all_local_futs()83 async fn localset_future_drives_all_local_futs() {
84 static RAN1: AtomicBool = AtomicBool::new(false);
85 static RAN2: AtomicBool = AtomicBool::new(false);
86 static RAN3: AtomicBool = AtomicBool::new(false);
87
88 let local = LocalSet::new();
89 local.spawn_local(async move {
90 task::spawn_local(async {
91 task::yield_now().await;
92 RAN3.store(true, Ordering::SeqCst);
93 });
94 task::yield_now().await;
95 RAN1.store(true, Ordering::SeqCst);
96 });
97 local.spawn_local(async move {
98 task::yield_now().await;
99 RAN2.store(true, Ordering::SeqCst);
100 });
101 local.await;
102 assert!(RAN1.load(Ordering::SeqCst));
103 assert!(RAN2.load(Ordering::SeqCst));
104 assert!(RAN3.load(Ordering::SeqCst));
105 }
106
107 #[tokio::test(flavor = "multi_thread")]
local_threadpool_timer()108 async fn local_threadpool_timer() {
109 // This test ensures that runtime services like the timer are properly
110 // set for the local task set.
111 thread_local! {
112 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
113 }
114
115 ON_RT_THREAD.with(|cell| cell.set(true));
116
117 LocalSet::new()
118 .run_until(async {
119 assert!(ON_RT_THREAD.with(|cell| cell.get()));
120 let join = task::spawn_local(async move {
121 assert!(ON_RT_THREAD.with(|cell| cell.get()));
122 time::sleep(Duration::from_millis(10)).await;
123 assert!(ON_RT_THREAD.with(|cell| cell.get()));
124 });
125 join.await.unwrap();
126 })
127 .await;
128 }
129
130 #[test]
131 // This will panic, since the thread that calls `block_on` cannot use
132 // in-place blocking inside of `block_on`.
133 #[should_panic]
local_threadpool_blocking_in_place()134 fn local_threadpool_blocking_in_place() {
135 thread_local! {
136 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
137 }
138
139 ON_RT_THREAD.with(|cell| cell.set(true));
140
141 let rt = runtime::Builder::new_current_thread()
142 .enable_all()
143 .build()
144 .unwrap();
145 LocalSet::new().block_on(&rt, async {
146 assert!(ON_RT_THREAD.with(|cell| cell.get()));
147 let join = task::spawn_local(async move {
148 assert!(ON_RT_THREAD.with(|cell| cell.get()));
149 task::block_in_place(|| {});
150 assert!(ON_RT_THREAD.with(|cell| cell.get()));
151 });
152 join.await.unwrap();
153 });
154 }
155
156 #[tokio::test(flavor = "multi_thread")]
local_threadpool_blocking_run()157 async fn local_threadpool_blocking_run() {
158 thread_local! {
159 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
160 }
161
162 ON_RT_THREAD.with(|cell| cell.set(true));
163
164 LocalSet::new()
165 .run_until(async {
166 assert!(ON_RT_THREAD.with(|cell| cell.get()));
167 let join = task::spawn_local(async move {
168 assert!(ON_RT_THREAD.with(|cell| cell.get()));
169 task::spawn_blocking(|| {
170 assert!(
171 !ON_RT_THREAD.with(|cell| cell.get()),
172 "blocking must not run on the local task set's thread"
173 );
174 })
175 .await
176 .unwrap();
177 assert!(ON_RT_THREAD.with(|cell| cell.get()));
178 });
179 join.await.unwrap();
180 })
181 .await;
182 }
183
184 #[tokio::test(flavor = "multi_thread")]
all_spawns_are_local()185 async fn all_spawns_are_local() {
186 use futures::future;
187 thread_local! {
188 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
189 }
190
191 ON_RT_THREAD.with(|cell| cell.set(true));
192
193 LocalSet::new()
194 .run_until(async {
195 assert!(ON_RT_THREAD.with(|cell| cell.get()));
196 let handles = (0..128)
197 .map(|_| {
198 task::spawn_local(async {
199 assert!(ON_RT_THREAD.with(|cell| cell.get()));
200 })
201 })
202 .collect::<Vec<_>>();
203 for joined in future::join_all(handles).await {
204 joined.unwrap();
205 }
206 })
207 .await;
208 }
209
210 #[tokio::test(flavor = "multi_thread")]
nested_spawn_is_local()211 async fn nested_spawn_is_local() {
212 thread_local! {
213 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
214 }
215
216 ON_RT_THREAD.with(|cell| cell.set(true));
217
218 LocalSet::new()
219 .run_until(async {
220 assert!(ON_RT_THREAD.with(|cell| cell.get()));
221 task::spawn_local(async {
222 assert!(ON_RT_THREAD.with(|cell| cell.get()));
223 task::spawn_local(async {
224 assert!(ON_RT_THREAD.with(|cell| cell.get()));
225 task::spawn_local(async {
226 assert!(ON_RT_THREAD.with(|cell| cell.get()));
227 task::spawn_local(async {
228 assert!(ON_RT_THREAD.with(|cell| cell.get()));
229 })
230 .await
231 .unwrap();
232 })
233 .await
234 .unwrap();
235 })
236 .await
237 .unwrap();
238 })
239 .await
240 .unwrap();
241 })
242 .await;
243 }
244
245 #[test]
join_local_future_elsewhere()246 fn join_local_future_elsewhere() {
247 thread_local! {
248 static ON_RT_THREAD: Cell<bool> = Cell::new(false);
249 }
250
251 ON_RT_THREAD.with(|cell| cell.set(true));
252
253 let rt = runtime::Runtime::new().unwrap();
254 let local = LocalSet::new();
255 local.block_on(&rt, async move {
256 let (tx, rx) = oneshot::channel();
257 let join = task::spawn_local(async move {
258 println!("hello world running...");
259 assert!(
260 ON_RT_THREAD.with(|cell| cell.get()),
261 "local task must run on local thread, no matter where it is awaited"
262 );
263 rx.await.unwrap();
264
265 println!("hello world task done");
266 "hello world"
267 });
268 let join2 = task::spawn(async move {
269 assert!(
270 !ON_RT_THREAD.with(|cell| cell.get()),
271 "spawned task should be on a worker"
272 );
273
274 tx.send(()).expect("task shouldn't have ended yet");
275 println!("waking up hello world...");
276
277 join.await.expect("task should complete successfully");
278
279 println!("hello world task joined");
280 });
281 join2.await.unwrap()
282 });
283 }
284
285 #[test]
drop_cancels_tasks()286 fn drop_cancels_tasks() {
287 use std::rc::Rc;
288
289 // This test reproduces issue #1842
290 let rt = rt();
291 let rc1 = Rc::new(());
292 let rc2 = rc1.clone();
293
294 let (started_tx, started_rx) = oneshot::channel();
295
296 let local = LocalSet::new();
297 local.spawn_local(async move {
298 // Move this in
299 let _rc2 = rc2;
300
301 started_tx.send(()).unwrap();
302 futures::future::pending::<()>().await;
303 });
304
305 local.block_on(&rt, async {
306 started_rx.await.unwrap();
307 });
308 drop(local);
309 drop(rt);
310
311 assert_eq!(1, Rc::strong_count(&rc1));
312 }
313
314 /// Runs a test function in a separate thread, and panics if the test does not
315 /// complete within the specified timeout, or if the test function panics.
316 ///
317 /// This is intended for running tests whose failure mode is a hang or infinite
318 /// loop that cannot be detected otherwise.
with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static)319 fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) {
320 use std::sync::mpsc::RecvTimeoutError;
321
322 let (done_tx, done_rx) = std::sync::mpsc::channel();
323 let thread = std::thread::spawn(move || {
324 f();
325
326 // Send a message on the channel so that the test thread can
327 // determine if we have entered an infinite loop:
328 done_tx.send(()).unwrap();
329 });
330
331 // Since the failure mode of this test is an infinite loop, rather than
332 // something we can easily make assertions about, we'll run it in a
333 // thread. When the test thread finishes, it will send a message on a
334 // channel to this thread. We'll wait for that message with a fairly
335 // generous timeout, and if we don't receive it, we assume the test
336 // thread has hung.
337 //
338 // Note that it should definitely complete in under a minute, but just
339 // in case CI is slow, we'll give it a long timeout.
340 match done_rx.recv_timeout(timeout) {
341 Err(RecvTimeoutError::Timeout) => panic!(
342 "test did not complete within {:?} seconds, \
343 we have (probably) entered an infinite loop!",
344 timeout,
345 ),
346 // Did the test thread panic? We'll find out for sure when we `join`
347 // with it.
348 Err(RecvTimeoutError::Disconnected) => {
349 println!("done_rx dropped, did the test thread panic?");
350 }
351 // Test completed successfully!
352 Ok(()) => {}
353 }
354
355 thread.join().expect("test thread should not panic!")
356 }
357
358 #[test]
drop_cancels_remote_tasks()359 fn drop_cancels_remote_tasks() {
360 // This test reproduces issue #1885.
361 with_timeout(Duration::from_secs(60), || {
362 let (tx, mut rx) = mpsc::channel::<()>(1024);
363
364 let rt = rt();
365
366 let local = LocalSet::new();
367 local.spawn_local(async move { while rx.recv().await.is_some() {} });
368 local.block_on(&rt, async {
369 time::sleep(Duration::from_millis(1)).await;
370 });
371
372 drop(tx);
373
374 // This enters an infinite loop if the remote notified tasks are not
375 // properly cancelled.
376 drop(local);
377 });
378 }
379
380 #[test]
local_tasks_wake_join_all()381 fn local_tasks_wake_join_all() {
382 // This test reproduces issue #2460.
383 with_timeout(Duration::from_secs(60), || {
384 use futures::future::join_all;
385 use tokio::task::LocalSet;
386
387 let rt = rt();
388 let set = LocalSet::new();
389 let mut handles = Vec::new();
390
391 for _ in 1..=128 {
392 handles.push(set.spawn_local(async move {
393 tokio::task::spawn_local(async move {}).await.unwrap();
394 }));
395 }
396
397 rt.block_on(set.run_until(join_all(handles)));
398 });
399 }
400
401 #[test]
local_tasks_are_polled_after_tick()402 fn local_tasks_are_polled_after_tick() {
403 // This test depends on timing, so we run it up to five times.
404 for _ in 0..4 {
405 let res = std::panic::catch_unwind(local_tasks_are_polled_after_tick_inner);
406 if res.is_ok() {
407 // success
408 return;
409 }
410 }
411
412 // Test failed 4 times. Try one more time without catching panics. If it
413 // fails again, the test fails.
414 local_tasks_are_polled_after_tick_inner();
415 }
416
417 #[tokio::main(flavor = "current_thread")]
local_tasks_are_polled_after_tick_inner()418 async fn local_tasks_are_polled_after_tick_inner() {
419 // Reproduces issues #1899 and #1900
420
421 static RX1: AtomicUsize = AtomicUsize::new(0);
422 static RX2: AtomicUsize = AtomicUsize::new(0);
423 const EXPECTED: usize = 500;
424
425 RX1.store(0, SeqCst);
426 RX2.store(0, SeqCst);
427
428 let (tx, mut rx) = mpsc::unbounded_channel();
429
430 let local = LocalSet::new();
431
432 local
433 .run_until(async {
434 let task2 = task::spawn(async move {
435 // Wait a bit
436 time::sleep(Duration::from_millis(10)).await;
437
438 let mut oneshots = Vec::with_capacity(EXPECTED);
439
440 // Send values
441 for _ in 0..EXPECTED {
442 let (oneshot_tx, oneshot_rx) = oneshot::channel();
443 oneshots.push(oneshot_tx);
444 tx.send(oneshot_rx).unwrap();
445 }
446
447 time::sleep(Duration::from_millis(10)).await;
448
449 for tx in oneshots.drain(..) {
450 tx.send(()).unwrap();
451 }
452
453 time::sleep(Duration::from_millis(20)).await;
454 let rx1 = RX1.load(SeqCst);
455 let rx2 = RX2.load(SeqCst);
456 println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2);
457 assert_eq!(EXPECTED, rx1);
458 assert_eq!(EXPECTED, rx2);
459 });
460
461 while let Some(oneshot) = rx.recv().await {
462 RX1.fetch_add(1, SeqCst);
463
464 task::spawn_local(async move {
465 oneshot.await.unwrap();
466 RX2.fetch_add(1, SeqCst);
467 });
468 }
469
470 task2.await.unwrap();
471 })
472 .await;
473 }
474
475 #[tokio::test]
acquire_mutex_in_drop()476 async fn acquire_mutex_in_drop() {
477 use futures::future::pending;
478
479 let (tx1, rx1) = oneshot::channel();
480 let (tx2, rx2) = oneshot::channel();
481 let local = LocalSet::new();
482
483 local.spawn_local(async move {
484 let _ = rx2.await;
485 unreachable!();
486 });
487
488 local.spawn_local(async move {
489 let _ = rx1.await;
490 tx2.send(()).unwrap();
491 unreachable!();
492 });
493
494 // Spawn a task that will never notify
495 local.spawn_local(async move {
496 pending::<()>().await;
497 tx1.send(()).unwrap();
498 });
499
500 // Tick the loop
501 local
502 .run_until(async {
503 task::yield_now().await;
504 })
505 .await;
506
507 // Drop the LocalSet
508 drop(local);
509 }
510
511 #[tokio::test]
spawn_wakes_localset()512 async fn spawn_wakes_localset() {
513 let local = LocalSet::new();
514 futures::select! {
515 _ = local.run_until(pending::<()>()).fuse() => unreachable!(),
516 ret = async { local.spawn_local(ready(())).await.unwrap()}.fuse() => ret
517 }
518 }
519
rt() -> Runtime520 fn rt() -> Runtime {
521 tokio::runtime::Builder::new_current_thread()
522 .enable_all()
523 .build()
524 .unwrap()
525 }
526