1 use std::prelude::v1::*;
2 use std::any::Any;
3 use std::panic::{catch_unwind, UnwindSafe, AssertUnwindSafe};
4 use std::mem;
5 
6 use super::super::{Poll, Async};
7 use super::Stream;
8 
9 /// Stream for the `catch_unwind` combinator.
10 ///
11 /// This is created by the `Stream::catch_unwind` method.
12 #[derive(Debug)]
13 #[must_use = "streams do nothing unless polled"]
14 pub struct CatchUnwind<S> where S: Stream {
15     state: CatchUnwindState<S>,
16 }
17 
new<S>(stream: S) -> CatchUnwind<S> where S: Stream + UnwindSafe,18 pub fn new<S>(stream: S) -> CatchUnwind<S>
19     where S: Stream + UnwindSafe,
20 {
21     CatchUnwind {
22         state: CatchUnwindState::Stream(stream),
23     }
24 }
25 
26 #[derive(Debug)]
27 enum CatchUnwindState<S> {
28     Stream(S),
29     Eof,
30     Done,
31 }
32 
33 impl<S> Stream for CatchUnwind<S>
34     where S: Stream + UnwindSafe,
35 {
36     type Item = Result<S::Item, S::Error>;
37     type Error = Box<Any + Send>;
38 
poll(&mut self) -> Poll<Option<Self::Item>, Self::Error>39     fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
40         let mut stream = match mem::replace(&mut self.state, CatchUnwindState::Eof) {
41             CatchUnwindState::Done => panic!("cannot poll after eof"),
42             CatchUnwindState::Eof => {
43                 self.state = CatchUnwindState::Done;
44                 return Ok(Async::Ready(None));
45             }
46             CatchUnwindState::Stream(stream) => stream,
47         };
48         let res = catch_unwind(|| (stream.poll(), stream));
49         match res {
50             Err(e) => Err(e), // and state is already Eof
51             Ok((poll, stream)) => {
52                 self.state = CatchUnwindState::Stream(stream);
53                 match poll {
54                     Err(e) => Ok(Async::Ready(Some(Err(e)))),
55                     Ok(Async::NotReady) => Ok(Async::NotReady),
56                     Ok(Async::Ready(Some(r))) => Ok(Async::Ready(Some(Ok(r)))),
57                     Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
58                 }
59             }
60         }
61     }
62 }
63 
64 impl<S: Stream> Stream for AssertUnwindSafe<S> {
65     type Item = S::Item;
66     type Error = S::Error;
67 
poll(&mut self) -> Poll<Option<S::Item>, S::Error>68     fn poll(&mut self) -> Poll<Option<S::Item>, S::Error> {
69         self.0.poll()
70     }
71 }
72