1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use futures::{
5     future::{pending, ready},
6     FutureExt,
7 };
8 
9 use tokio::runtime::{self, Runtime};
10 use tokio::sync::{mpsc, oneshot};
11 use tokio::task::{self, LocalSet};
12 use tokio::time;
13 
14 use std::cell::Cell;
15 use std::sync::atomic::Ordering::{self, SeqCst};
16 use std::sync::atomic::{AtomicBool, AtomicUsize};
17 use std::time::Duration;
18 
19 #[tokio::test(flavor = "current_thread")]
local_basic_scheduler()20 async fn local_basic_scheduler() {
21     LocalSet::new()
22         .run_until(async {
23             task::spawn_local(async {}).await.unwrap();
24         })
25         .await;
26 }
27 
28 #[tokio::test(flavor = "multi_thread")]
local_threadpool()29 async fn local_threadpool() {
30     thread_local! {
31         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
32     }
33 
34     ON_RT_THREAD.with(|cell| cell.set(true));
35 
36     LocalSet::new()
37         .run_until(async {
38             assert!(ON_RT_THREAD.with(|cell| cell.get()));
39             task::spawn_local(async {
40                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
41             })
42             .await
43             .unwrap();
44         })
45         .await;
46 }
47 
48 #[tokio::test(flavor = "multi_thread")]
localset_future_threadpool()49 async fn localset_future_threadpool() {
50     thread_local! {
51         static ON_LOCAL_THREAD: Cell<bool> = Cell::new(false);
52     }
53 
54     ON_LOCAL_THREAD.with(|cell| cell.set(true));
55 
56     let local = LocalSet::new();
57     local.spawn_local(async move {
58         assert!(ON_LOCAL_THREAD.with(|cell| cell.get()));
59     });
60     local.await;
61 }
62 
63 #[tokio::test(flavor = "multi_thread")]
localset_future_timers()64 async fn localset_future_timers() {
65     static RAN1: AtomicBool = AtomicBool::new(false);
66     static RAN2: AtomicBool = AtomicBool::new(false);
67 
68     let local = LocalSet::new();
69     local.spawn_local(async move {
70         time::sleep(Duration::from_millis(5)).await;
71         RAN1.store(true, Ordering::SeqCst);
72     });
73     local.spawn_local(async move {
74         time::sleep(Duration::from_millis(10)).await;
75         RAN2.store(true, Ordering::SeqCst);
76     });
77     local.await;
78     assert!(RAN1.load(Ordering::SeqCst));
79     assert!(RAN2.load(Ordering::SeqCst));
80 }
81 
82 #[tokio::test]
localset_future_drives_all_local_futs()83 async fn localset_future_drives_all_local_futs() {
84     static RAN1: AtomicBool = AtomicBool::new(false);
85     static RAN2: AtomicBool = AtomicBool::new(false);
86     static RAN3: AtomicBool = AtomicBool::new(false);
87 
88     let local = LocalSet::new();
89     local.spawn_local(async move {
90         task::spawn_local(async {
91             task::yield_now().await;
92             RAN3.store(true, Ordering::SeqCst);
93         });
94         task::yield_now().await;
95         RAN1.store(true, Ordering::SeqCst);
96     });
97     local.spawn_local(async move {
98         task::yield_now().await;
99         RAN2.store(true, Ordering::SeqCst);
100     });
101     local.await;
102     assert!(RAN1.load(Ordering::SeqCst));
103     assert!(RAN2.load(Ordering::SeqCst));
104     assert!(RAN3.load(Ordering::SeqCst));
105 }
106 
107 #[tokio::test(flavor = "multi_thread")]
local_threadpool_timer()108 async fn local_threadpool_timer() {
109     // This test ensures that runtime services like the timer are properly
110     // set for the local task set.
111     thread_local! {
112         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
113     }
114 
115     ON_RT_THREAD.with(|cell| cell.set(true));
116 
117     LocalSet::new()
118         .run_until(async {
119             assert!(ON_RT_THREAD.with(|cell| cell.get()));
120             let join = task::spawn_local(async move {
121                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
122                 time::sleep(Duration::from_millis(10)).await;
123                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
124             });
125             join.await.unwrap();
126         })
127         .await;
128 }
129 
130 #[test]
131 // This will panic, since the thread that calls `block_on` cannot use
132 // in-place blocking inside of `block_on`.
133 #[should_panic]
local_threadpool_blocking_in_place()134 fn local_threadpool_blocking_in_place() {
135     thread_local! {
136         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
137     }
138 
139     ON_RT_THREAD.with(|cell| cell.set(true));
140 
141     let rt = runtime::Builder::new_current_thread()
142         .enable_all()
143         .build()
144         .unwrap();
145     LocalSet::new().block_on(&rt, async {
146         assert!(ON_RT_THREAD.with(|cell| cell.get()));
147         let join = task::spawn_local(async move {
148             assert!(ON_RT_THREAD.with(|cell| cell.get()));
149             task::block_in_place(|| {});
150             assert!(ON_RT_THREAD.with(|cell| cell.get()));
151         });
152         join.await.unwrap();
153     });
154 }
155 
156 #[tokio::test(flavor = "multi_thread")]
local_threadpool_blocking_run()157 async fn local_threadpool_blocking_run() {
158     thread_local! {
159         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
160     }
161 
162     ON_RT_THREAD.with(|cell| cell.set(true));
163 
164     LocalSet::new()
165         .run_until(async {
166             assert!(ON_RT_THREAD.with(|cell| cell.get()));
167             let join = task::spawn_local(async move {
168                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
169                 task::spawn_blocking(|| {
170                     assert!(
171                         !ON_RT_THREAD.with(|cell| cell.get()),
172                         "blocking must not run on the local task set's thread"
173                     );
174                 })
175                 .await
176                 .unwrap();
177                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
178             });
179             join.await.unwrap();
180         })
181         .await;
182 }
183 
184 #[tokio::test(flavor = "multi_thread")]
all_spawns_are_local()185 async fn all_spawns_are_local() {
186     use futures::future;
187     thread_local! {
188         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
189     }
190 
191     ON_RT_THREAD.with(|cell| cell.set(true));
192 
193     LocalSet::new()
194         .run_until(async {
195             assert!(ON_RT_THREAD.with(|cell| cell.get()));
196             let handles = (0..128)
197                 .map(|_| {
198                     task::spawn_local(async {
199                         assert!(ON_RT_THREAD.with(|cell| cell.get()));
200                     })
201                 })
202                 .collect::<Vec<_>>();
203             for joined in future::join_all(handles).await {
204                 joined.unwrap();
205             }
206         })
207         .await;
208 }
209 
210 #[tokio::test(flavor = "multi_thread")]
nested_spawn_is_local()211 async fn nested_spawn_is_local() {
212     thread_local! {
213         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
214     }
215 
216     ON_RT_THREAD.with(|cell| cell.set(true));
217 
218     LocalSet::new()
219         .run_until(async {
220             assert!(ON_RT_THREAD.with(|cell| cell.get()));
221             task::spawn_local(async {
222                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
223                 task::spawn_local(async {
224                     assert!(ON_RT_THREAD.with(|cell| cell.get()));
225                     task::spawn_local(async {
226                         assert!(ON_RT_THREAD.with(|cell| cell.get()));
227                         task::spawn_local(async {
228                             assert!(ON_RT_THREAD.with(|cell| cell.get()));
229                         })
230                         .await
231                         .unwrap();
232                     })
233                     .await
234                     .unwrap();
235                 })
236                 .await
237                 .unwrap();
238             })
239             .await
240             .unwrap();
241         })
242         .await;
243 }
244 
245 #[test]
join_local_future_elsewhere()246 fn join_local_future_elsewhere() {
247     thread_local! {
248         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
249     }
250 
251     ON_RT_THREAD.with(|cell| cell.set(true));
252 
253     let rt = runtime::Runtime::new().unwrap();
254     let local = LocalSet::new();
255     local.block_on(&rt, async move {
256         let (tx, rx) = oneshot::channel();
257         let join = task::spawn_local(async move {
258             println!("hello world running...");
259             assert!(
260                 ON_RT_THREAD.with(|cell| cell.get()),
261                 "local task must run on local thread, no matter where it is awaited"
262             );
263             rx.await.unwrap();
264 
265             println!("hello world task done");
266             "hello world"
267         });
268         let join2 = task::spawn(async move {
269             assert!(
270                 !ON_RT_THREAD.with(|cell| cell.get()),
271                 "spawned task should be on a worker"
272             );
273 
274             tx.send(()).expect("task shouldn't have ended yet");
275             println!("waking up hello world...");
276 
277             join.await.expect("task should complete successfully");
278 
279             println!("hello world task joined");
280         });
281         join2.await.unwrap()
282     });
283 }
284 
285 #[test]
drop_cancels_tasks()286 fn drop_cancels_tasks() {
287     use std::rc::Rc;
288 
289     // This test reproduces issue #1842
290     let rt = rt();
291     let rc1 = Rc::new(());
292     let rc2 = rc1.clone();
293 
294     let (started_tx, started_rx) = oneshot::channel();
295 
296     let local = LocalSet::new();
297     local.spawn_local(async move {
298         // Move this in
299         let _rc2 = rc2;
300 
301         started_tx.send(()).unwrap();
302         futures::future::pending::<()>().await;
303     });
304 
305     local.block_on(&rt, async {
306         started_rx.await.unwrap();
307     });
308     drop(local);
309     drop(rt);
310 
311     assert_eq!(1, Rc::strong_count(&rc1));
312 }
313 
314 /// Runs a test function in a separate thread, and panics if the test does not
315 /// complete within the specified timeout, or if the test function panics.
316 ///
317 /// This is intended for running tests whose failure mode is a hang or infinite
318 /// loop that cannot be detected otherwise.
with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static)319 fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) {
320     use std::sync::mpsc::RecvTimeoutError;
321 
322     let (done_tx, done_rx) = std::sync::mpsc::channel();
323     let thread = std::thread::spawn(move || {
324         f();
325 
326         // Send a message on the channel so that the test thread can
327         // determine if we have entered an infinite loop:
328         done_tx.send(()).unwrap();
329     });
330 
331     // Since the failure mode of this test is an infinite loop, rather than
332     // something we can easily make assertions about, we'll run it in a
333     // thread. When the test thread finishes, it will send a message on a
334     // channel to this thread. We'll wait for that message with a fairly
335     // generous timeout, and if we don't receive it, we assume the test
336     // thread has hung.
337     //
338     // Note that it should definitely complete in under a minute, but just
339     // in case CI is slow, we'll give it a long timeout.
340     match done_rx.recv_timeout(timeout) {
341         Err(RecvTimeoutError::Timeout) => panic!(
342             "test did not complete within {:?} seconds, \
343              we have (probably) entered an infinite loop!",
344             timeout,
345         ),
346         // Did the test thread panic? We'll find out for sure when we `join`
347         // with it.
348         Err(RecvTimeoutError::Disconnected) => {
349             println!("done_rx dropped, did the test thread panic?");
350         }
351         // Test completed successfully!
352         Ok(()) => {}
353     }
354 
355     thread.join().expect("test thread should not panic!")
356 }
357 
358 #[test]
drop_cancels_remote_tasks()359 fn drop_cancels_remote_tasks() {
360     // This test reproduces issue #1885.
361     with_timeout(Duration::from_secs(60), || {
362         let (tx, mut rx) = mpsc::channel::<()>(1024);
363 
364         let rt = rt();
365 
366         let local = LocalSet::new();
367         local.spawn_local(async move { while rx.recv().await.is_some() {} });
368         local.block_on(&rt, async {
369             time::sleep(Duration::from_millis(1)).await;
370         });
371 
372         drop(tx);
373 
374         // This enters an infinite loop if the remote notified tasks are not
375         // properly cancelled.
376         drop(local);
377     });
378 }
379 
380 #[test]
local_tasks_wake_join_all()381 fn local_tasks_wake_join_all() {
382     // This test reproduces issue #2460.
383     with_timeout(Duration::from_secs(60), || {
384         use futures::future::join_all;
385         use tokio::task::LocalSet;
386 
387         let rt = rt();
388         let set = LocalSet::new();
389         let mut handles = Vec::new();
390 
391         for _ in 1..=128 {
392             handles.push(set.spawn_local(async move {
393                 tokio::task::spawn_local(async move {}).await.unwrap();
394             }));
395         }
396 
397         rt.block_on(set.run_until(join_all(handles)));
398     });
399 }
400 
401 #[test]
local_tasks_are_polled_after_tick()402 fn local_tasks_are_polled_after_tick() {
403     // This test depends on timing, so we run it up to five times.
404     for _ in 0..4 {
405         let res = std::panic::catch_unwind(local_tasks_are_polled_after_tick_inner);
406         if res.is_ok() {
407             // success
408             return;
409         }
410     }
411 
412     // Test failed 4 times. Try one more time without catching panics. If it
413     // fails again, the test fails.
414     local_tasks_are_polled_after_tick_inner();
415 }
416 
417 #[tokio::main(flavor = "current_thread")]
local_tasks_are_polled_after_tick_inner()418 async fn local_tasks_are_polled_after_tick_inner() {
419     // Reproduces issues #1899 and #1900
420 
421     static RX1: AtomicUsize = AtomicUsize::new(0);
422     static RX2: AtomicUsize = AtomicUsize::new(0);
423     const EXPECTED: usize = 500;
424 
425     RX1.store(0, SeqCst);
426     RX2.store(0, SeqCst);
427 
428     let (tx, mut rx) = mpsc::unbounded_channel();
429 
430     let local = LocalSet::new();
431 
432     local
433         .run_until(async {
434             let task2 = task::spawn(async move {
435                 // Wait a bit
436                 time::sleep(Duration::from_millis(10)).await;
437 
438                 let mut oneshots = Vec::with_capacity(EXPECTED);
439 
440                 // Send values
441                 for _ in 0..EXPECTED {
442                     let (oneshot_tx, oneshot_rx) = oneshot::channel();
443                     oneshots.push(oneshot_tx);
444                     tx.send(oneshot_rx).unwrap();
445                 }
446 
447                 time::sleep(Duration::from_millis(10)).await;
448 
449                 for tx in oneshots.drain(..) {
450                     tx.send(()).unwrap();
451                 }
452 
453                 time::sleep(Duration::from_millis(20)).await;
454                 let rx1 = RX1.load(SeqCst);
455                 let rx2 = RX2.load(SeqCst);
456                 println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2);
457                 assert_eq!(EXPECTED, rx1);
458                 assert_eq!(EXPECTED, rx2);
459             });
460 
461             while let Some(oneshot) = rx.recv().await {
462                 RX1.fetch_add(1, SeqCst);
463 
464                 task::spawn_local(async move {
465                     oneshot.await.unwrap();
466                     RX2.fetch_add(1, SeqCst);
467                 });
468             }
469 
470             task2.await.unwrap();
471         })
472         .await;
473 }
474 
475 #[tokio::test]
acquire_mutex_in_drop()476 async fn acquire_mutex_in_drop() {
477     use futures::future::pending;
478 
479     let (tx1, rx1) = oneshot::channel();
480     let (tx2, rx2) = oneshot::channel();
481     let local = LocalSet::new();
482 
483     local.spawn_local(async move {
484         let _ = rx2.await;
485         unreachable!();
486     });
487 
488     local.spawn_local(async move {
489         let _ = rx1.await;
490         tx2.send(()).unwrap();
491         unreachable!();
492     });
493 
494     // Spawn a task that will never notify
495     local.spawn_local(async move {
496         pending::<()>().await;
497         tx1.send(()).unwrap();
498     });
499 
500     // Tick the loop
501     local
502         .run_until(async {
503             task::yield_now().await;
504         })
505         .await;
506 
507     // Drop the LocalSet
508     drop(local);
509 }
510 
511 #[tokio::test]
spawn_wakes_localset()512 async fn spawn_wakes_localset() {
513     let local = LocalSet::new();
514     futures::select! {
515         _ = local.run_until(pending::<()>()).fuse() => unreachable!(),
516         ret = async { local.spawn_local(ready(())).await.unwrap()}.fuse() => ret
517     }
518 }
519 
rt() -> Runtime520 fn rt() -> Runtime {
521     tokio::runtime::Builder::new_current_thread()
522         .enable_all()
523         .build()
524         .unwrap()
525 }
526