1 use std::alloc::Layout;
2 use std::future::Future;
3 use std::panic::AssertUnwindSafe;
4 use std::pin::Pin;
5 use std::ptr::{self, NonNull};
6 use std::task::{Context, Poll};
7 use std::{fmt, panic};
8 
9 /// A reusable `Pin<Box<dyn Future<Output = T> + Send>>`.
10 ///
11 /// This type lets you replace the future stored in the box without
12 /// reallocating when the size and alignment permits this.
13 pub struct ReusableBoxFuture<T> {
14     boxed: NonNull<dyn Future<Output = T> + Send>,
15 }
16 
17 impl<T> ReusableBoxFuture<T> {
18     /// Create a new `ReusableBoxFuture<T>` containing the provided future.
new<F>(future: F) -> Self where F: Future<Output = T> + Send + 'static,19     pub fn new<F>(future: F) -> Self
20     where
21         F: Future<Output = T> + Send + 'static,
22     {
23         let boxed: Box<dyn Future<Output = T> + Send> = Box::new(future);
24 
25         let boxed = Box::into_raw(boxed);
26 
27         // SAFETY: Box::into_raw does not return null pointers.
28         let boxed = unsafe { NonNull::new_unchecked(boxed) };
29 
30         Self { boxed }
31     }
32 
33     /// Replace the future currently stored in this box.
34     ///
35     /// This reallocates if and only if the layout of the provided future is
36     /// different from the layout of the currently stored future.
set<F>(&mut self, future: F) where F: Future<Output = T> + Send + 'static,37     pub fn set<F>(&mut self, future: F)
38     where
39         F: Future<Output = T> + Send + 'static,
40     {
41         if let Err(future) = self.try_set(future) {
42             *self = Self::new(future);
43         }
44     }
45 
46     /// Replace the future currently stored in this box.
47     ///
48     /// This function never reallocates, but returns an error if the provided
49     /// future has a different size or alignment from the currently stored
50     /// future.
try_set<F>(&mut self, future: F) -> Result<(), F> where F: Future<Output = T> + Send + 'static,51     pub fn try_set<F>(&mut self, future: F) -> Result<(), F>
52     where
53         F: Future<Output = T> + Send + 'static,
54     {
55         // SAFETY: The pointer is not dangling.
56         let self_layout = {
57             let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() };
58             Layout::for_value(dyn_future)
59         };
60 
61         if Layout::new::<F>() == self_layout {
62             // SAFETY: We just checked that the layout of F is correct.
63             unsafe {
64                 self.set_same_layout(future);
65             }
66 
67             Ok(())
68         } else {
69             Err(future)
70         }
71     }
72 
73     /// Set the current future.
74     ///
75     /// # Safety
76     ///
77     /// This function requires that the layout of the provided future is the
78     /// same as `self.layout`.
set_same_layout<F>(&mut self, future: F) where F: Future<Output = T> + Send + 'static,79     unsafe fn set_same_layout<F>(&mut self, future: F)
80     where
81         F: Future<Output = T> + Send + 'static,
82     {
83         // Drop the existing future, catching any panics.
84         let result = panic::catch_unwind(AssertUnwindSafe(|| {
85             ptr::drop_in_place(self.boxed.as_ptr());
86         }));
87 
88         // Overwrite the future behind the pointer. This is safe because the
89         // allocation was allocated with the same size and alignment as the type F.
90         let self_ptr: *mut F = self.boxed.as_ptr() as *mut F;
91         ptr::write(self_ptr, future);
92 
93         // Update the vtable of self.boxed. The pointer is not null because we
94         // just got it from self.boxed, which is not null.
95         self.boxed = NonNull::new_unchecked(self_ptr);
96 
97         // If the old future's destructor panicked, resume unwinding.
98         match result {
99             Ok(()) => {}
100             Err(payload) => {
101                 panic::resume_unwind(payload);
102             }
103         }
104     }
105 
106     /// Get a pinned reference to the underlying future.
get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)>107     pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> {
108         // SAFETY: The user of this box cannot move the box, and we do not move it
109         // either.
110         unsafe { Pin::new_unchecked(self.boxed.as_mut()) }
111     }
112 
113     /// Poll the future stored inside this box.
poll(&mut self, cx: &mut Context<'_>) -> Poll<T>114     pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
115         self.get_pin().poll(cx)
116     }
117 }
118 
119 impl<T> Future for ReusableBoxFuture<T> {
120     type Output = T;
121 
122     /// Poll the future stored inside this box.
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T>123     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
124         Pin::into_inner(self).get_pin().poll(cx)
125     }
126 }
127 
128 // The future stored inside ReusableBoxFuture<T> must be Send.
129 unsafe impl<T> Send for ReusableBoxFuture<T> {}
130 
131 // The only method called on self.boxed is poll, which takes &mut self, so this
132 // struct being Sync does not permit any invalid access to the Future, even if
133 // the future is not Sync.
134 unsafe impl<T> Sync for ReusableBoxFuture<T> {}
135 
136 // Just like a Pin<Box<dyn Future>> is always Unpin, so is this type.
137 impl<T> Unpin for ReusableBoxFuture<T> {}
138 
139 impl<T> Drop for ReusableBoxFuture<T> {
drop(&mut self)140     fn drop(&mut self) {
141         unsafe {
142             drop(Box::from_raw(self.boxed.as_ptr()));
143         }
144     }
145 }
146 
147 impl<T> fmt::Debug for ReusableBoxFuture<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result148     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149         f.debug_struct("ReusableBoxFuture").finish()
150     }
151 }
152