1 use futures_core::stream::{Stream, FusedStream}; 2 use futures_core::task::{Context, Poll}; 3 use pin_utils::{unsafe_pinned, unsafe_unpinned}; 4 use std::any::Any; 5 use std::pin::Pin; 6 use std::panic::{catch_unwind, UnwindSafe, AssertUnwindSafe}; 7 8 /// Stream for the [`catch_unwind`](super::StreamExt::catch_unwind) method. 9 #[derive(Debug)] 10 #[must_use = "streams do nothing unless polled"] 11 pub struct CatchUnwind<St: Stream> { 12 stream: St, 13 caught_unwind: bool, 14 } 15 16 impl<St: Stream + UnwindSafe> CatchUnwind<St> { 17 unsafe_pinned!(stream: St); 18 unsafe_unpinned!(caught_unwind: bool); 19 new(stream: St) -> CatchUnwind<St>20 pub(super) fn new(stream: St) -> CatchUnwind<St> { 21 CatchUnwind { stream, caught_unwind: false } 22 } 23 } 24 25 impl<St: Stream + UnwindSafe> Stream for CatchUnwind<St> { 26 type Item = Result<St::Item, Box<dyn Any + Send>>; 27 poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Self::Item>>28 fn poll_next( 29 mut self: Pin<&mut Self>, 30 cx: &mut Context<'_>, 31 ) -> Poll<Option<Self::Item>> { 32 if self.caught_unwind { 33 Poll::Ready(None) 34 } else { 35 let res = catch_unwind(AssertUnwindSafe(|| { 36 self.as_mut().stream().poll_next(cx) 37 })); 38 39 match res { 40 Ok(poll) => poll.map(|opt| opt.map(Ok)), 41 Err(e) => { 42 *self.as_mut().caught_unwind() = true; 43 Poll::Ready(Some(Err(e))) 44 }, 45 } 46 } 47 } 48 size_hint(&self) -> (usize, Option<usize>)49 fn size_hint(&self) -> (usize, Option<usize>) { 50 if self.caught_unwind { 51 (0, Some(0)) 52 } else { 53 self.stream.size_hint() 54 } 55 } 56 } 57 58 impl<St: FusedStream + UnwindSafe> FusedStream for CatchUnwind<St> { is_terminated(&self) -> bool59 fn is_terminated(&self) -> bool { 60 self.caught_unwind || self.stream.is_terminated() 61 } 62 } 63