1 use std::prelude::v1::*;
2 use std::any::Any;
3 use std::panic::{catch_unwind, UnwindSafe, AssertUnwindSafe};
4 use std::mem;
5
6 use super::super::{Poll, Async};
7 use super::Stream;
8
9 /// Stream for the `catch_unwind` combinator.
10 ///
11 /// This is created by the `Stream::catch_unwind` method.
12 #[derive(Debug)]
13 #[must_use = "streams do nothing unless polled"]
14 pub struct CatchUnwind<S> where S: Stream {
15 state: CatchUnwindState<S>,
16 }
17
new<S>(stream: S) -> CatchUnwind<S> where S: Stream + UnwindSafe,18 pub fn new<S>(stream: S) -> CatchUnwind<S>
19 where S: Stream + UnwindSafe,
20 {
21 CatchUnwind {
22 state: CatchUnwindState::Stream(stream),
23 }
24 }
25
26 #[derive(Debug)]
27 enum CatchUnwindState<S> {
28 Stream(S),
29 Eof,
30 Done,
31 }
32
33 impl<S> Stream for CatchUnwind<S>
34 where S: Stream + UnwindSafe,
35 {
36 type Item = Result<S::Item, S::Error>;
37 type Error = Box<Any + Send>;
38
poll(&mut self) -> Poll<Option<Self::Item>, Self::Error>39 fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
40 let mut stream = match mem::replace(&mut self.state, CatchUnwindState::Eof) {
41 CatchUnwindState::Done => panic!("cannot poll after eof"),
42 CatchUnwindState::Eof => {
43 self.state = CatchUnwindState::Done;
44 return Ok(Async::Ready(None));
45 }
46 CatchUnwindState::Stream(stream) => stream,
47 };
48 let res = catch_unwind(|| (stream.poll(), stream));
49 match res {
50 Err(e) => Err(e), // and state is already Eof
51 Ok((poll, stream)) => {
52 self.state = CatchUnwindState::Stream(stream);
53 match poll {
54 Err(e) => Ok(Async::Ready(Some(Err(e)))),
55 Ok(Async::NotReady) => Ok(Async::NotReady),
56 Ok(Async::Ready(Some(r))) => Ok(Async::Ready(Some(Ok(r)))),
57 Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
58 }
59 }
60 }
61 }
62 }
63
64 impl<S: Stream> Stream for AssertUnwindSafe<S> {
65 type Item = S::Item;
66 type Error = S::Error;
67
poll(&mut self) -> Poll<Option<S::Item>, S::Error>68 fn poll(&mut self) -> Poll<Option<S::Item>, S::Error> {
69 self.0.poll()
70 }
71 }
72