1 //! Abstract implementation of a channel manager
2 
3 use crate::{Error, Result};
4 
5 use async_trait::async_trait;
6 use futures::channel::oneshot;
7 use futures::future::{FutureExt, Shared};
8 use std::hash::Hash;
9 use std::sync::Arc;
10 
11 mod map;
12 
13 /// Trait to describe as much of a
14 /// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr`
15 /// needs to use.
16 pub(crate) trait AbstractChannel {
17     /// Identity type for the other side of the channel.
18     type Ident: Hash + Eq + Clone;
19     /// Return this channel's identity.
ident(&self) -> &Self::Ident20     fn ident(&self) -> &Self::Ident;
21     /// Return true if this channel is usable.
22     ///
23     /// A channel might be unusable because it is closed, because it has
24     /// hit a bug, or for some other reason.  We don't return unusable
25     /// channels back to the user.
is_usable(&self) -> bool26     fn is_usable(&self) -> bool;
27 }
28 
29 /// Trait to describe how channels are created.
30 #[async_trait]
31 pub(crate) trait ChannelFactory {
32     /// The type of channel that this factory can build.
33     type Channel: AbstractChannel;
34     /// Type that explains how to build a channel.
35     type BuildSpec;
36 
37     /// Construct a new channel to the destination described at `target`.
38     ///
39     /// This function must take care of all timeouts, error detection,
40     /// and so on.
41     ///
42     /// It should not retry; that is handled at a higher level.
build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<Self::Channel>>43     async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<Self::Channel>>;
44 }
45 
46 /// A type- and network-agnostic implementation for
47 /// [`ChanMgr`](crate::ChanMgr).
48 ///
49 /// This type does the work of keeping track of open channels and
50 /// pending channel requests, launching requests as needed, waiting
51 /// for pending requests, and so forth.
52 ///
53 /// The actual job of launching connections is deferred to a ChannelFactory
54 /// type.
55 pub(crate) struct AbstractChanMgr<CF: ChannelFactory> {
56     /// A 'connector' object that we use to create channels.
57     connector: CF,
58 
59     /// A map from ed25519 identity to channel, or to pending channel status.
60     channels: map::ChannelMap<CF::Channel>,
61 }
62 
63 /// Type alias for a future that we wait on to see when a pending
64 /// channel is done or failed.
65 type Pending<C> = Shared<oneshot::Receiver<Result<Arc<C>>>>;
66 
67 /// Type alias for the sender we notify when we complete a channel (or
68 /// fail to complete it).
69 type Sending<C> = oneshot::Sender<Result<Arc<C>>>;
70 
71 impl<CF: ChannelFactory> AbstractChanMgr<CF> {
72     /// Make a new empty channel manager.
new(connector: CF) -> Self73     pub(crate) fn new(connector: CF) -> Self {
74         AbstractChanMgr {
75             connector,
76             channels: map::ChannelMap::new(),
77         }
78     }
79 
80     /// Remove every unusable entry from this channel manager.
81     #[cfg(test)]
remove_unusable_entries(&self) -> Result<()>82     pub(crate) fn remove_unusable_entries(&self) -> Result<()> {
83         self.channels.remove_unusable()
84     }
85 
86     /// Helper: return the objects used to inform pending tasks
87     /// about a newly open or failed channel.
setup_launch<C>(&self) -> (map::ChannelState<C>, Sending<C>)88     fn setup_launch<C>(&self) -> (map::ChannelState<C>, Sending<C>) {
89         let (snd, rcv) = oneshot::channel();
90         let shared = rcv.shared();
91         (map::ChannelState::Building(shared), snd)
92     }
93 
94     /// Get a channel whose identity is `ident`.
95     ///
96     /// If a usable channel exists with that identity, return it.
97     ///
98     /// If no such channel exists already, and none is in progress,
99     /// launch a new request using `target`, which must match `ident`.
100     ///
101     /// If no such channel exists already, but we have one that's in
102     /// progress, wait for it to succeed or fail.
get_or_launch( &self, ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident, target: CF::BuildSpec, ) -> Result<Arc<CF::Channel>>103     pub(crate) async fn get_or_launch(
104         &self,
105         ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
106         target: CF::BuildSpec,
107     ) -> Result<Arc<CF::Channel>> {
108         use map::ChannelState::*;
109 
110         /// Possible actions that we'll decide to take based on the
111         /// channel's initial state.
112         enum Action<C> {
113             /// We found no channel.  We're going to launch a new one,
114             /// then tell everybody about it.
115             Launch(Sending<C>),
116             /// We found an in-progress attempt at making a channel.
117             /// We're going to wait for it to finish.
118             Wait(Pending<C>),
119             /// We found a usable channel.  We're going to return it.
120             Return(Result<Arc<C>>),
121         }
122         /// How many times do we try?
123         const N_ATTEMPTS: usize = 2;
124 
125         // XXXX It would be neat to use tor_retry instead, but it's
126         // too tied to anyhow right now.
127         let mut last_err = Err(Error::Internal("Error was never set!?"));
128 
129         for _ in 0..N_ATTEMPTS {
130             // First, see what state we're in, and what we should do
131             // about it.
132             let action = self
133                 .channels
134                 .change_state(&ident, |oldstate| match oldstate {
135                     Some(Open(ref ch)) => {
136                         if ch.is_usable() {
137                             // Good channel. Return it.
138                             let action = Action::Return(Ok(Arc::clone(ch)));
139                             (oldstate, action)
140                         } else {
141                             // Unusable channel.  Move to the Building
142                             // state and launch a new channel.
143                             let (newstate, send) = self.setup_launch();
144                             let action = Action::Launch(send);
145                             (Some(newstate), action)
146                         }
147                     }
148                     Some(Building(ref pending)) => {
149                         let action = Action::Wait(pending.clone());
150                         (oldstate, action)
151                     }
152                     Some(Poisoned(_)) => {
153                         // We should never be able to see this state; this
154                         // is a bug.
155                         (
156                             None,
157                             Action::Return(Err(Error::Internal("Found a poisoned entry"))),
158                         )
159                     }
160                     None => {
161                         // No channel.  Move to the Building
162                         // state and launch a new channel.
163                         let (newstate, send) = self.setup_launch();
164                         let action = Action::Launch(send);
165                         (Some(newstate), action)
166                     }
167                 })?;
168 
169             // Now we act based on the channel.
170             match action {
171                 // Easy case: we have an error or a channel to return.
172                 Action::Return(v) => {
173                     return v;
174                 }
175                 // There's an in-progress channel.  Wait for it.
176                 Action::Wait(pend) => match pend.await {
177                     Ok(Ok(chan)) => return Ok(chan),
178                     Ok(Err(e)) => {
179                         last_err = Err(e);
180                     }
181                     Err(_) => {
182                         last_err = Err(Error::Internal("channel build task disappeared"));
183                     }
184                 },
185                 // We need to launch a channel.
186                 Action::Launch(send) => match self.connector.build_channel(&target).await {
187                     Ok(chan) => {
188                         // The channel got built: remember it, tell the
189                         // others, and return it.
190                         self.channels
191                             .replace(ident.clone(), Open(Arc::clone(&chan)))?;
192                         // It's okay if all the receivers went away:
193                         // that means that nobody was waiting for this channel.
194                         let _ignore_err = send.send(Ok(Arc::clone(&chan)));
195                         return Ok(chan);
196                     }
197                     Err(e) => {
198                         // The channel failed. Make it non-pending, tell the
199                         // others, and set the error.
200                         self.channels.remove(&ident)?;
201                         // (As above)
202                         let _ignore_err = send.send(Err(e.clone()));
203                         last_err = Err(e);
204                     }
205                 },
206             }
207         }
208 
209         last_err
210     }
211 
212     /// Test only: return the current open usable channel with a given
213     /// `ident`, if any.
214     #[cfg(test)]
get_nowait( &self, ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident, ) -> Option<Arc<CF::Channel>>215     pub(crate) fn get_nowait(
216         &self,
217         ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
218     ) -> Option<Arc<CF::Channel>> {
219         use map::ChannelState::*;
220         match self.channels.get(ident) {
221             Ok(Some(Open(ref ch))) if ch.is_usable() => Some(Arc::clone(ch)),
222             _ => None,
223         }
224     }
225 }
226 
227 #[cfg(test)]
228 mod test {
229     #![allow(clippy::unwrap_used)]
230     use super::*;
231     use crate::Error;
232 
233     use futures::join;
234     use std::sync::atomic::{AtomicBool, Ordering};
235     use std::time::Duration;
236 
237     use tor_rtcompat::{task::yield_now, test_with_one_runtime, Runtime};
238 
239     struct FakeChannelFactory<RT> {
240         runtime: RT,
241     }
242 
243     #[derive(Debug)]
244     struct FakeChannel {
245         ident: u32,
246         mood: char,
247         closing: AtomicBool,
248     }
249 
250     impl AbstractChannel for FakeChannel {
251         type Ident = u32;
ident(&self) -> &u32252         fn ident(&self) -> &u32 {
253             &self.ident
254         }
is_usable(&self) -> bool255         fn is_usable(&self) -> bool {
256             !self.closing.load(Ordering::SeqCst)
257         }
258     }
259 
260     impl FakeChannel {
start_closing(&self)261         fn start_closing(&self) {
262             self.closing.store(true, Ordering::SeqCst);
263         }
264     }
265 
266     impl<RT: Runtime> FakeChannelFactory<RT> {
new(runtime: RT) -> Self267         fn new(runtime: RT) -> Self {
268             FakeChannelFactory { runtime }
269         }
270     }
271 
272     #[async_trait]
273     impl<RT: Runtime> ChannelFactory for FakeChannelFactory<RT> {
274         type Channel = FakeChannel;
275         type BuildSpec = (u32, char);
276 
build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<FakeChannel>>277         async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<FakeChannel>> {
278             yield_now().await;
279             let (ident, mood) = *target;
280             match mood {
281                 // "X" means never connect.
282                 '❌' | '��' => return Err(Error::UnusableTarget("emoji".into())),
283                 // "zzz" means wait for 15 seconds then succeed.
284                 '��' => {
285                     self.runtime.sleep(Duration::new(15, 0)).await;
286                 }
287                 _ => {}
288             }
289             Ok(Arc::new(FakeChannel {
290                 ident,
291                 mood,
292                 closing: AtomicBool::new(false),
293             }))
294         }
295     }
296 
297     #[test]
connect_one_ok()298     fn connect_one_ok() {
299         test_with_one_runtime!(|runtime| async {
300             let cf = FakeChannelFactory::new(runtime);
301             let mgr = AbstractChanMgr::new(cf);
302             let target = (413, '!');
303             let chan1 = mgr.get_or_launch(413, target).await.unwrap();
304             let chan2 = mgr.get_or_launch(413, target).await.unwrap();
305 
306             assert!(Arc::ptr_eq(&chan1, &chan2));
307 
308             let chan3 = mgr.get_nowait(&413).unwrap();
309             assert!(Arc::ptr_eq(&chan1, &chan3));
310         });
311     }
312 
313     #[test]
connect_one_fail()314     fn connect_one_fail() {
315         test_with_one_runtime!(|runtime| async {
316             let cf = FakeChannelFactory::new(runtime);
317             let mgr = AbstractChanMgr::new(cf);
318 
319             // This is set up to always fail.
320             let target = (999, '❌');
321             let res1 = mgr.get_or_launch(999, target).await;
322             assert!(matches!(res1, Err(Error::UnusableTarget(_))));
323 
324             let chan3 = mgr.get_nowait(&999);
325             assert!(chan3.is_none());
326         });
327     }
328 
329     #[test]
test_concurrent()330     fn test_concurrent() {
331         test_with_one_runtime!(|runtime| async {
332             let cf = FakeChannelFactory::new(runtime);
333             let mgr = AbstractChanMgr::new(cf);
334 
335             // TODO XXXX: figure out how to make these actually run
336             // concurrently. Right now it seems that they don't actually
337             // interact.
338             let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!(
339                 mgr.get_or_launch(3, (3, 'a')),
340                 mgr.get_or_launch(3, (3, 'b')),
341                 mgr.get_or_launch(44, (44, 'a')),
342                 mgr.get_or_launch(44, (44, 'b')),
343                 mgr.get_or_launch(86, (86, '❌')),
344                 mgr.get_or_launch(86, (86, '��')),
345             );
346             let ch3a = ch3a.unwrap();
347             let ch3b = ch3b.unwrap();
348             let ch44a = ch44a.unwrap();
349             let ch44b = ch44b.unwrap();
350             let err_a = ch86a.unwrap_err();
351             let err_b = ch86b.unwrap_err();
352 
353             assert!(Arc::ptr_eq(&ch3a, &ch3b));
354             assert!(Arc::ptr_eq(&ch44a, &ch44b));
355             assert!(!Arc::ptr_eq(&ch44a, &ch3a));
356 
357             assert!(matches!(err_a, Error::UnusableTarget(_)));
358             assert!(matches!(err_b, Error::UnusableTarget(_)));
359         });
360     }
361 
362     #[test]
unusable_entries()363     fn unusable_entries() {
364         test_with_one_runtime!(|runtime| async {
365             let cf = FakeChannelFactory::new(runtime);
366             let mgr = AbstractChanMgr::new(cf);
367 
368             let (ch3, ch4, ch5) = join!(
369                 mgr.get_or_launch(3, (3, 'a')),
370                 mgr.get_or_launch(4, (4, 'a')),
371                 mgr.get_or_launch(5, (5, 'a')),
372             );
373 
374             let ch3 = ch3.unwrap();
375             let _ch4 = ch4.unwrap();
376             let ch5 = ch5.unwrap();
377 
378             ch3.start_closing();
379             ch5.start_closing();
380 
381             let ch3_new = mgr.get_or_launch(3, (3, 'b')).await.unwrap();
382             assert!(!Arc::ptr_eq(&ch3, &ch3_new));
383             assert_eq!(ch3_new.mood, 'b');
384 
385             mgr.remove_unusable_entries().unwrap();
386 
387             assert!(mgr.get_nowait(&3).is_some());
388             assert!(mgr.get_nowait(&4).is_some());
389             assert!(mgr.get_nowait(&5).is_none());
390         });
391     }
392 }
393