1 //! Abstract implementation of a channel manager 2 3 use crate::{Error, Result}; 4 5 use async_trait::async_trait; 6 use futures::channel::oneshot; 7 use futures::future::{FutureExt, Shared}; 8 use std::hash::Hash; 9 use std::sync::Arc; 10 11 mod map; 12 13 /// Trait to describe as much of a 14 /// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr` 15 /// needs to use. 16 pub(crate) trait AbstractChannel { 17 /// Identity type for the other side of the channel. 18 type Ident: Hash + Eq + Clone; 19 /// Return this channel's identity. ident(&self) -> &Self::Ident20 fn ident(&self) -> &Self::Ident; 21 /// Return true if this channel is usable. 22 /// 23 /// A channel might be unusable because it is closed, because it has 24 /// hit a bug, or for some other reason. We don't return unusable 25 /// channels back to the user. is_usable(&self) -> bool26 fn is_usable(&self) -> bool; 27 } 28 29 /// Trait to describe how channels are created. 30 #[async_trait] 31 pub(crate) trait ChannelFactory { 32 /// The type of channel that this factory can build. 33 type Channel: AbstractChannel; 34 /// Type that explains how to build a channel. 35 type BuildSpec; 36 37 /// Construct a new channel to the destination described at `target`. 38 /// 39 /// This function must take care of all timeouts, error detection, 40 /// and so on. 41 /// 42 /// It should not retry; that is handled at a higher level. build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<Self::Channel>>43 async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<Self::Channel>>; 44 } 45 46 /// A type- and network-agnostic implementation for 47 /// [`ChanMgr`](crate::ChanMgr). 48 /// 49 /// This type does the work of keeping track of open channels and 50 /// pending channel requests, launching requests as needed, waiting 51 /// for pending requests, and so forth. 52 /// 53 /// The actual job of launching connections is deferred to a ChannelFactory 54 /// type. 55 pub(crate) struct AbstractChanMgr<CF: ChannelFactory> { 56 /// A 'connector' object that we use to create channels. 57 connector: CF, 58 59 /// A map from ed25519 identity to channel, or to pending channel status. 60 channels: map::ChannelMap<CF::Channel>, 61 } 62 63 /// Type alias for a future that we wait on to see when a pending 64 /// channel is done or failed. 65 type Pending<C> = Shared<oneshot::Receiver<Result<Arc<C>>>>; 66 67 /// Type alias for the sender we notify when we complete a channel (or 68 /// fail to complete it). 69 type Sending<C> = oneshot::Sender<Result<Arc<C>>>; 70 71 impl<CF: ChannelFactory> AbstractChanMgr<CF> { 72 /// Make a new empty channel manager. new(connector: CF) -> Self73 pub(crate) fn new(connector: CF) -> Self { 74 AbstractChanMgr { 75 connector, 76 channels: map::ChannelMap::new(), 77 } 78 } 79 80 /// Remove every unusable entry from this channel manager. 81 #[cfg(test)] remove_unusable_entries(&self) -> Result<()>82 pub(crate) fn remove_unusable_entries(&self) -> Result<()> { 83 self.channels.remove_unusable() 84 } 85 86 /// Helper: return the objects used to inform pending tasks 87 /// about a newly open or failed channel. setup_launch<C>(&self) -> (map::ChannelState<C>, Sending<C>)88 fn setup_launch<C>(&self) -> (map::ChannelState<C>, Sending<C>) { 89 let (snd, rcv) = oneshot::channel(); 90 let shared = rcv.shared(); 91 (map::ChannelState::Building(shared), snd) 92 } 93 94 /// Get a channel whose identity is `ident`. 95 /// 96 /// If a usable channel exists with that identity, return it. 97 /// 98 /// If no such channel exists already, and none is in progress, 99 /// launch a new request using `target`, which must match `ident`. 100 /// 101 /// If no such channel exists already, but we have one that's in 102 /// progress, wait for it to succeed or fail. get_or_launch( &self, ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident, target: CF::BuildSpec, ) -> Result<Arc<CF::Channel>>103 pub(crate) async fn get_or_launch( 104 &self, 105 ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident, 106 target: CF::BuildSpec, 107 ) -> Result<Arc<CF::Channel>> { 108 use map::ChannelState::*; 109 110 /// Possible actions that we'll decide to take based on the 111 /// channel's initial state. 112 enum Action<C> { 113 /// We found no channel. We're going to launch a new one, 114 /// then tell everybody about it. 115 Launch(Sending<C>), 116 /// We found an in-progress attempt at making a channel. 117 /// We're going to wait for it to finish. 118 Wait(Pending<C>), 119 /// We found a usable channel. We're going to return it. 120 Return(Result<Arc<C>>), 121 } 122 /// How many times do we try? 123 const N_ATTEMPTS: usize = 2; 124 125 // XXXX It would be neat to use tor_retry instead, but it's 126 // too tied to anyhow right now. 127 let mut last_err = Err(Error::Internal("Error was never set!?")); 128 129 for _ in 0..N_ATTEMPTS { 130 // First, see what state we're in, and what we should do 131 // about it. 132 let action = self 133 .channels 134 .change_state(&ident, |oldstate| match oldstate { 135 Some(Open(ref ch)) => { 136 if ch.is_usable() { 137 // Good channel. Return it. 138 let action = Action::Return(Ok(Arc::clone(ch))); 139 (oldstate, action) 140 } else { 141 // Unusable channel. Move to the Building 142 // state and launch a new channel. 143 let (newstate, send) = self.setup_launch(); 144 let action = Action::Launch(send); 145 (Some(newstate), action) 146 } 147 } 148 Some(Building(ref pending)) => { 149 let action = Action::Wait(pending.clone()); 150 (oldstate, action) 151 } 152 Some(Poisoned(_)) => { 153 // We should never be able to see this state; this 154 // is a bug. 155 ( 156 None, 157 Action::Return(Err(Error::Internal("Found a poisoned entry"))), 158 ) 159 } 160 None => { 161 // No channel. Move to the Building 162 // state and launch a new channel. 163 let (newstate, send) = self.setup_launch(); 164 let action = Action::Launch(send); 165 (Some(newstate), action) 166 } 167 })?; 168 169 // Now we act based on the channel. 170 match action { 171 // Easy case: we have an error or a channel to return. 172 Action::Return(v) => { 173 return v; 174 } 175 // There's an in-progress channel. Wait for it. 176 Action::Wait(pend) => match pend.await { 177 Ok(Ok(chan)) => return Ok(chan), 178 Ok(Err(e)) => { 179 last_err = Err(e); 180 } 181 Err(_) => { 182 last_err = Err(Error::Internal("channel build task disappeared")); 183 } 184 }, 185 // We need to launch a channel. 186 Action::Launch(send) => match self.connector.build_channel(&target).await { 187 Ok(chan) => { 188 // The channel got built: remember it, tell the 189 // others, and return it. 190 self.channels 191 .replace(ident.clone(), Open(Arc::clone(&chan)))?; 192 // It's okay if all the receivers went away: 193 // that means that nobody was waiting for this channel. 194 let _ignore_err = send.send(Ok(Arc::clone(&chan))); 195 return Ok(chan); 196 } 197 Err(e) => { 198 // The channel failed. Make it non-pending, tell the 199 // others, and set the error. 200 self.channels.remove(&ident)?; 201 // (As above) 202 let _ignore_err = send.send(Err(e.clone())); 203 last_err = Err(e); 204 } 205 }, 206 } 207 } 208 209 last_err 210 } 211 212 /// Test only: return the current open usable channel with a given 213 /// `ident`, if any. 214 #[cfg(test)] get_nowait( &self, ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident, ) -> Option<Arc<CF::Channel>>215 pub(crate) fn get_nowait( 216 &self, 217 ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident, 218 ) -> Option<Arc<CF::Channel>> { 219 use map::ChannelState::*; 220 match self.channels.get(ident) { 221 Ok(Some(Open(ref ch))) if ch.is_usable() => Some(Arc::clone(ch)), 222 _ => None, 223 } 224 } 225 } 226 227 #[cfg(test)] 228 mod test { 229 #![allow(clippy::unwrap_used)] 230 use super::*; 231 use crate::Error; 232 233 use futures::join; 234 use std::sync::atomic::{AtomicBool, Ordering}; 235 use std::time::Duration; 236 237 use tor_rtcompat::{task::yield_now, test_with_one_runtime, Runtime}; 238 239 struct FakeChannelFactory<RT> { 240 runtime: RT, 241 } 242 243 #[derive(Debug)] 244 struct FakeChannel { 245 ident: u32, 246 mood: char, 247 closing: AtomicBool, 248 } 249 250 impl AbstractChannel for FakeChannel { 251 type Ident = u32; ident(&self) -> &u32252 fn ident(&self) -> &u32 { 253 &self.ident 254 } is_usable(&self) -> bool255 fn is_usable(&self) -> bool { 256 !self.closing.load(Ordering::SeqCst) 257 } 258 } 259 260 impl FakeChannel { start_closing(&self)261 fn start_closing(&self) { 262 self.closing.store(true, Ordering::SeqCst); 263 } 264 } 265 266 impl<RT: Runtime> FakeChannelFactory<RT> { new(runtime: RT) -> Self267 fn new(runtime: RT) -> Self { 268 FakeChannelFactory { runtime } 269 } 270 } 271 272 #[async_trait] 273 impl<RT: Runtime> ChannelFactory for FakeChannelFactory<RT> { 274 type Channel = FakeChannel; 275 type BuildSpec = (u32, char); 276 build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<FakeChannel>>277 async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<FakeChannel>> { 278 yield_now().await; 279 let (ident, mood) = *target; 280 match mood { 281 // "X" means never connect. 282 '❌' | '' => return Err(Error::UnusableTarget("emoji".into())), 283 // "zzz" means wait for 15 seconds then succeed. 284 '' => { 285 self.runtime.sleep(Duration::new(15, 0)).await; 286 } 287 _ => {} 288 } 289 Ok(Arc::new(FakeChannel { 290 ident, 291 mood, 292 closing: AtomicBool::new(false), 293 })) 294 } 295 } 296 297 #[test] connect_one_ok()298 fn connect_one_ok() { 299 test_with_one_runtime!(|runtime| async { 300 let cf = FakeChannelFactory::new(runtime); 301 let mgr = AbstractChanMgr::new(cf); 302 let target = (413, '!'); 303 let chan1 = mgr.get_or_launch(413, target).await.unwrap(); 304 let chan2 = mgr.get_or_launch(413, target).await.unwrap(); 305 306 assert!(Arc::ptr_eq(&chan1, &chan2)); 307 308 let chan3 = mgr.get_nowait(&413).unwrap(); 309 assert!(Arc::ptr_eq(&chan1, &chan3)); 310 }); 311 } 312 313 #[test] connect_one_fail()314 fn connect_one_fail() { 315 test_with_one_runtime!(|runtime| async { 316 let cf = FakeChannelFactory::new(runtime); 317 let mgr = AbstractChanMgr::new(cf); 318 319 // This is set up to always fail. 320 let target = (999, '❌'); 321 let res1 = mgr.get_or_launch(999, target).await; 322 assert!(matches!(res1, Err(Error::UnusableTarget(_)))); 323 324 let chan3 = mgr.get_nowait(&999); 325 assert!(chan3.is_none()); 326 }); 327 } 328 329 #[test] test_concurrent()330 fn test_concurrent() { 331 test_with_one_runtime!(|runtime| async { 332 let cf = FakeChannelFactory::new(runtime); 333 let mgr = AbstractChanMgr::new(cf); 334 335 // TODO XXXX: figure out how to make these actually run 336 // concurrently. Right now it seems that they don't actually 337 // interact. 338 let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!( 339 mgr.get_or_launch(3, (3, 'a')), 340 mgr.get_or_launch(3, (3, 'b')), 341 mgr.get_or_launch(44, (44, 'a')), 342 mgr.get_or_launch(44, (44, 'b')), 343 mgr.get_or_launch(86, (86, '❌')), 344 mgr.get_or_launch(86, (86, '')), 345 ); 346 let ch3a = ch3a.unwrap(); 347 let ch3b = ch3b.unwrap(); 348 let ch44a = ch44a.unwrap(); 349 let ch44b = ch44b.unwrap(); 350 let err_a = ch86a.unwrap_err(); 351 let err_b = ch86b.unwrap_err(); 352 353 assert!(Arc::ptr_eq(&ch3a, &ch3b)); 354 assert!(Arc::ptr_eq(&ch44a, &ch44b)); 355 assert!(!Arc::ptr_eq(&ch44a, &ch3a)); 356 357 assert!(matches!(err_a, Error::UnusableTarget(_))); 358 assert!(matches!(err_b, Error::UnusableTarget(_))); 359 }); 360 } 361 362 #[test] unusable_entries()363 fn unusable_entries() { 364 test_with_one_runtime!(|runtime| async { 365 let cf = FakeChannelFactory::new(runtime); 366 let mgr = AbstractChanMgr::new(cf); 367 368 let (ch3, ch4, ch5) = join!( 369 mgr.get_or_launch(3, (3, 'a')), 370 mgr.get_or_launch(4, (4, 'a')), 371 mgr.get_or_launch(5, (5, 'a')), 372 ); 373 374 let ch3 = ch3.unwrap(); 375 let _ch4 = ch4.unwrap(); 376 let ch5 = ch5.unwrap(); 377 378 ch3.start_closing(); 379 ch5.start_closing(); 380 381 let ch3_new = mgr.get_or_launch(3, (3, 'b')).await.unwrap(); 382 assert!(!Arc::ptr_eq(&ch3, &ch3_new)); 383 assert_eq!(ch3_new.mood, 'b'); 384 385 mgr.remove_unusable_entries().unwrap(); 386 387 assert!(mgr.get_nowait(&3).is_some()); 388 assert!(mgr.get_nowait(&4).is_some()); 389 assert!(mgr.get_nowait(&5).is_none()); 390 }); 391 } 392 } 393