1 use std::error::Error as StdError;
2
3 use pin_project::pin_project;
4 use tokio::io::{AsyncRead, AsyncWrite};
5
6 use super::conn::{SpawnAll, UpgradeableConnection, Watcher};
7 use super::Accept;
8 use crate::body::{Body, HttpBody};
9 use crate::common::drain::{self, Draining, Signal, Watch, Watching};
10 use crate::common::exec::{H2Exec, NewSvcExec};
11 use crate::common::{task, Future, Pin, Poll, Unpin};
12 use crate::service::{HttpService, MakeServiceRef};
13
14 #[allow(missing_debug_implementations)]
15 #[pin_project]
16 pub struct Graceful<I, S, F, E> {
17 #[pin]
18 state: State<I, S, F, E>,
19 }
20
21 #[pin_project(project = StateProj)]
22 pub(super) enum State<I, S, F, E> {
23 Running {
24 drain: Option<(Signal, Watch)>,
25 #[pin]
26 spawn_all: SpawnAll<I, S, E>,
27 #[pin]
28 signal: F,
29 },
30 Draining(Draining),
31 }
32
33 impl<I, S, F, E> Graceful<I, S, F, E> {
new(spawn_all: SpawnAll<I, S, E>, signal: F) -> Self34 pub(super) fn new(spawn_all: SpawnAll<I, S, E>, signal: F) -> Self {
35 let drain = Some(drain::channel());
36 Graceful {
37 state: State::Running {
38 drain,
39 spawn_all,
40 signal,
41 },
42 }
43 }
44 }
45
46 impl<I, IO, IE, S, B, F, E> Future for Graceful<I, S, F, E>
47 where
48 I: Accept<Conn = IO, Error = IE>,
49 IE: Into<Box<dyn StdError + Send + Sync>>,
50 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
51 S: MakeServiceRef<IO, Body, ResBody = B>,
52 S::Error: Into<Box<dyn StdError + Send + Sync>>,
53 B: HttpBody + Send + Sync + 'static,
54 B::Error: Into<Box<dyn StdError + Send + Sync>>,
55 F: Future<Output = ()>,
56 E: H2Exec<<S::Service as HttpService<Body>>::Future, B>,
57 E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>,
58 {
59 type Output = crate::Result<()>;
60
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>61 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
62 let mut me = self.project();
63 loop {
64 let next = {
65 match me.state.as_mut().project() {
66 StateProj::Running {
67 drain,
68 spawn_all,
69 signal,
70 } => match signal.poll(cx) {
71 Poll::Ready(()) => {
72 debug!("signal received, starting graceful shutdown");
73 let sig = drain.take().expect("drain channel").0;
74 State::Draining(sig.drain())
75 }
76 Poll::Pending => {
77 let watch = drain.as_ref().expect("drain channel").1.clone();
78 return spawn_all.poll_watch(cx, &GracefulWatcher(watch));
79 }
80 },
81 StateProj::Draining(ref mut draining) => {
82 return Pin::new(draining).poll(cx).map(Ok);
83 }
84 }
85 };
86 me.state.set(next);
87 }
88 }
89 }
90
91 #[allow(missing_debug_implementations)]
92 #[derive(Clone)]
93 pub struct GracefulWatcher(Watch);
94
95 impl<I, S, E> Watcher<I, S, E> for GracefulWatcher
96 where
97 I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
98 S: HttpService<Body>,
99 E: H2Exec<S::Future, S::ResBody>,
100 S::ResBody: Send + Sync + 'static,
101 <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>,
102 {
103 type Future =
104 Watching<UpgradeableConnection<I, S, E>, fn(Pin<&mut UpgradeableConnection<I, S, E>>)>;
105
watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future106 fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future {
107 self.0.clone().watch(conn, on_drain)
108 }
109 }
110
on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>) where S: HttpService<Body>, S::Error: Into<Box<dyn StdError + Send + Sync>>, I: AsyncRead + AsyncWrite + Unpin, S::ResBody: HttpBody + Send + 'static, <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>, E: H2Exec<S::Future, S::ResBody>,111 fn on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>)
112 where
113 S: HttpService<Body>,
114 S::Error: Into<Box<dyn StdError + Send + Sync>>,
115 I: AsyncRead + AsyncWrite + Unpin,
116 S::ResBody: HttpBody + Send + 'static,
117 <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>,
118 E: H2Exec<S::Future, S::ResBody>,
119 {
120 conn.graceful_shutdown()
121 }
122