1 use std::error::Error as StdError;
2 
3 use pin_project::pin_project;
4 use tokio::io::{AsyncRead, AsyncWrite};
5 
6 use super::conn::{SpawnAll, UpgradeableConnection, Watcher};
7 use super::Accept;
8 use crate::body::{Body, HttpBody};
9 use crate::common::drain::{self, Draining, Signal, Watch, Watching};
10 use crate::common::exec::{H2Exec, NewSvcExec};
11 use crate::common::{task, Future, Pin, Poll, Unpin};
12 use crate::service::{HttpService, MakeServiceRef};
13 
14 #[allow(missing_debug_implementations)]
15 #[pin_project]
16 pub struct Graceful<I, S, F, E> {
17     #[pin]
18     state: State<I, S, F, E>,
19 }
20 
21 #[pin_project(project = StateProj)]
22 pub(super) enum State<I, S, F, E> {
23     Running {
24         drain: Option<(Signal, Watch)>,
25         #[pin]
26         spawn_all: SpawnAll<I, S, E>,
27         #[pin]
28         signal: F,
29     },
30     Draining(Draining),
31 }
32 
33 impl<I, S, F, E> Graceful<I, S, F, E> {
new(spawn_all: SpawnAll<I, S, E>, signal: F) -> Self34     pub(super) fn new(spawn_all: SpawnAll<I, S, E>, signal: F) -> Self {
35         let drain = Some(drain::channel());
36         Graceful {
37             state: State::Running {
38                 drain,
39                 spawn_all,
40                 signal,
41             },
42         }
43     }
44 }
45 
46 impl<I, IO, IE, S, B, F, E> Future for Graceful<I, S, F, E>
47 where
48     I: Accept<Conn = IO, Error = IE>,
49     IE: Into<Box<dyn StdError + Send + Sync>>,
50     IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
51     S: MakeServiceRef<IO, Body, ResBody = B>,
52     S::Error: Into<Box<dyn StdError + Send + Sync>>,
53     B: HttpBody + Send + Sync + 'static,
54     B::Error: Into<Box<dyn StdError + Send + Sync>>,
55     F: Future<Output = ()>,
56     E: H2Exec<<S::Service as HttpService<Body>>::Future, B>,
57     E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>,
58 {
59     type Output = crate::Result<()>;
60 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>61     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
62         let mut me = self.project();
63         loop {
64             let next = {
65                 match me.state.as_mut().project() {
66                     StateProj::Running {
67                         drain,
68                         spawn_all,
69                         signal,
70                     } => match signal.poll(cx) {
71                         Poll::Ready(()) => {
72                             debug!("signal received, starting graceful shutdown");
73                             let sig = drain.take().expect("drain channel").0;
74                             State::Draining(sig.drain())
75                         }
76                         Poll::Pending => {
77                             let watch = drain.as_ref().expect("drain channel").1.clone();
78                             return spawn_all.poll_watch(cx, &GracefulWatcher(watch));
79                         }
80                     },
81                     StateProj::Draining(ref mut draining) => {
82                         return Pin::new(draining).poll(cx).map(Ok);
83                     }
84                 }
85             };
86             me.state.set(next);
87         }
88     }
89 }
90 
91 #[allow(missing_debug_implementations)]
92 #[derive(Clone)]
93 pub struct GracefulWatcher(Watch);
94 
95 impl<I, S, E> Watcher<I, S, E> for GracefulWatcher
96 where
97     I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
98     S: HttpService<Body>,
99     E: H2Exec<S::Future, S::ResBody>,
100     S::ResBody: Send + Sync + 'static,
101     <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>,
102 {
103     type Future =
104         Watching<UpgradeableConnection<I, S, E>, fn(Pin<&mut UpgradeableConnection<I, S, E>>)>;
105 
watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future106     fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future {
107         self.0.clone().watch(conn, on_drain)
108     }
109 }
110 
on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>) where S: HttpService<Body>, S::Error: Into<Box<dyn StdError + Send + Sync>>, I: AsyncRead + AsyncWrite + Unpin, S::ResBody: HttpBody + Send + 'static, <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>, E: H2Exec<S::Future, S::ResBody>,111 fn on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>)
112 where
113     S: HttpService<Body>,
114     S::Error: Into<Box<dyn StdError + Send + Sync>>,
115     I: AsyncRead + AsyncWrite + Unpin,
116     S::ResBody: HttpBody + Send + 'static,
117     <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>,
118     E: H2Exec<S::Future, S::ResBody>,
119 {
120     conn.graceful_shutdown()
121 }
122