1 use std::future::Future; 2 use std::io; 3 use std::pin::Pin; 4 use std::sync::atomic::{AtomicUsize, Ordering}; 5 use std::sync::{Arc, RwLock, Weak}; 6 use std::task::Context; 7 use std::task::Poll; 8 use std::time::{SystemTime, UNIX_EPOCH}; 9 10 use byteorder::{BigEndian, ByteOrder}; 11 use bytes::Bytes; 12 use futures_core::TryStream; 13 use futures_util::{future, ready, StreamExt, TryStreamExt}; 14 use once_cell::sync::OnceCell; 15 use thiserror::Error; 16 use tokio::sync::mpsc; 17 use tokio_stream::wrappers::UnboundedReceiverStream; 18 19 use crate::apresolve::apresolve; 20 use crate::audio_key::AudioKeyManager; 21 use crate::authentication::Credentials; 22 use crate::cache::Cache; 23 use crate::channel::ChannelManager; 24 use crate::config::SessionConfig; 25 use crate::connection::{self, AuthenticationError}; 26 use crate::mercury::MercuryManager; 27 28 #[derive(Debug, Error)] 29 pub enum SessionError { 30 #[error(transparent)] 31 AuthenticationError(#[from] AuthenticationError), 32 #[error("Cannot create session: {0}")] 33 IoError(#[from] io::Error), 34 } 35 36 struct SessionData { 37 country: String, 38 time_delta: i64, 39 canonical_username: String, 40 invalid: bool, 41 } 42 43 struct SessionInternal { 44 config: SessionConfig, 45 data: RwLock<SessionData>, 46 47 tx_connection: mpsc::UnboundedSender<(u8, Vec<u8>)>, 48 49 audio_key: OnceCell<AudioKeyManager>, 50 channel: OnceCell<ChannelManager>, 51 mercury: OnceCell<MercuryManager>, 52 cache: Option<Arc<Cache>>, 53 54 handle: tokio::runtime::Handle, 55 56 session_id: usize, 57 } 58 59 static SESSION_COUNTER: AtomicUsize = AtomicUsize::new(0); 60 61 #[derive(Clone)] 62 pub struct Session(Arc<SessionInternal>); 63 64 impl Session { connect( config: SessionConfig, credentials: Credentials, cache: Option<Cache>, ) -> Result<Session, SessionError>65 pub async fn connect( 66 config: SessionConfig, 67 credentials: Credentials, 68 cache: Option<Cache>, 69 ) -> Result<Session, SessionError> { 70 let ap = apresolve(config.proxy.as_ref(), config.ap_port).await; 71 72 info!("Connecting to AP \"{}\"", ap); 73 let mut conn = connection::connect(ap, config.proxy.as_ref()).await?; 74 75 let reusable_credentials = 76 connection::authenticate(&mut conn, credentials, &config.device_id).await?; 77 info!("Authenticated as \"{}\" !", reusable_credentials.username); 78 if let Some(cache) = &cache { 79 cache.save_credentials(&reusable_credentials); 80 } 81 82 let session = Session::create( 83 conn, 84 config, 85 cache, 86 reusable_credentials.username, 87 tokio::runtime::Handle::current(), 88 ); 89 90 Ok(session) 91 } 92 create( transport: connection::Transport, config: SessionConfig, cache: Option<Cache>, username: String, handle: tokio::runtime::Handle, ) -> Session93 fn create( 94 transport: connection::Transport, 95 config: SessionConfig, 96 cache: Option<Cache>, 97 username: String, 98 handle: tokio::runtime::Handle, 99 ) -> Session { 100 let (sink, stream) = transport.split(); 101 102 let (sender_tx, sender_rx) = mpsc::unbounded_channel(); 103 let session_id = SESSION_COUNTER.fetch_add(1, Ordering::Relaxed); 104 105 debug!("new Session[{}]", session_id); 106 107 let session = Session(Arc::new(SessionInternal { 108 config, 109 data: RwLock::new(SessionData { 110 country: String::new(), 111 canonical_username: username, 112 invalid: false, 113 time_delta: 0, 114 }), 115 tx_connection: sender_tx, 116 cache: cache.map(Arc::new), 117 audio_key: OnceCell::new(), 118 channel: OnceCell::new(), 119 mercury: OnceCell::new(), 120 handle, 121 session_id, 122 })); 123 124 let sender_task = UnboundedReceiverStream::new(sender_rx) 125 .map(Ok) 126 .forward(sink); 127 let receiver_task = DispatchTask(stream, session.weak()); 128 129 tokio::spawn(async move { 130 let result = future::try_join(sender_task, receiver_task).await; 131 132 if let Err(e) = result { 133 error!("{}", e); 134 } 135 }); 136 137 session 138 } 139 audio_key(&self) -> &AudioKeyManager140 pub fn audio_key(&self) -> &AudioKeyManager { 141 self.0 142 .audio_key 143 .get_or_init(|| AudioKeyManager::new(self.weak())) 144 } 145 channel(&self) -> &ChannelManager146 pub fn channel(&self) -> &ChannelManager { 147 self.0 148 .channel 149 .get_or_init(|| ChannelManager::new(self.weak())) 150 } 151 mercury(&self) -> &MercuryManager152 pub fn mercury(&self) -> &MercuryManager { 153 self.0 154 .mercury 155 .get_or_init(|| MercuryManager::new(self.weak())) 156 } 157 time_delta(&self) -> i64158 pub fn time_delta(&self) -> i64 { 159 self.0.data.read().unwrap().time_delta 160 } 161 spawn<T>(&self, task: T) where T: Future + Send + 'static, T::Output: Send + 'static,162 pub fn spawn<T>(&self, task: T) 163 where 164 T: Future + Send + 'static, 165 T::Output: Send + 'static, 166 { 167 self.0.handle.spawn(task); 168 } 169 debug_info(&self)170 fn debug_info(&self) { 171 debug!( 172 "Session[{}] strong={} weak={}", 173 self.0.session_id, 174 Arc::strong_count(&self.0), 175 Arc::weak_count(&self.0) 176 ); 177 } 178 179 #[allow(clippy::match_same_arms)] dispatch(&self, cmd: u8, data: Bytes)180 fn dispatch(&self, cmd: u8, data: Bytes) { 181 match cmd { 182 0x4 => { 183 let server_timestamp = BigEndian::read_u32(data.as_ref()) as i64; 184 let timestamp = match SystemTime::now().duration_since(UNIX_EPOCH) { 185 Ok(dur) => dur, 186 Err(err) => err.duration(), 187 } 188 .as_secs() as i64; 189 190 self.0.data.write().unwrap().time_delta = server_timestamp - timestamp; 191 192 self.debug_info(); 193 self.send_packet(0x49, vec![0, 0, 0, 0]); 194 } 195 0x4a => (), 196 0x1b => { 197 let country = String::from_utf8(data.as_ref().to_owned()).unwrap(); 198 info!("Country: {:?}", country); 199 self.0.data.write().unwrap().country = country; 200 } 201 202 0x9 | 0xa => self.channel().dispatch(cmd, data), 203 0xd | 0xe => self.audio_key().dispatch(cmd, data), 204 0xb2..=0xb6 => self.mercury().dispatch(cmd, data), 205 _ => (), 206 } 207 } 208 send_packet(&self, cmd: u8, data: Vec<u8>)209 pub fn send_packet(&self, cmd: u8, data: Vec<u8>) { 210 self.0.tx_connection.send((cmd, data)).unwrap(); 211 } 212 cache(&self) -> Option<&Arc<Cache>>213 pub fn cache(&self) -> Option<&Arc<Cache>> { 214 self.0.cache.as_ref() 215 } 216 config(&self) -> &SessionConfig217 fn config(&self) -> &SessionConfig { 218 &self.0.config 219 } 220 username(&self) -> String221 pub fn username(&self) -> String { 222 self.0.data.read().unwrap().canonical_username.clone() 223 } 224 country(&self) -> String225 pub fn country(&self) -> String { 226 self.0.data.read().unwrap().country.clone() 227 } 228 device_id(&self) -> &str229 pub fn device_id(&self) -> &str { 230 &self.config().device_id 231 } 232 weak(&self) -> SessionWeak233 fn weak(&self) -> SessionWeak { 234 SessionWeak(Arc::downgrade(&self.0)) 235 } 236 session_id(&self) -> usize237 pub fn session_id(&self) -> usize { 238 self.0.session_id 239 } 240 shutdown(&self)241 pub fn shutdown(&self) { 242 debug!("Invalidating session[{}]", self.0.session_id); 243 self.0.data.write().unwrap().invalid = true; 244 self.mercury().shutdown(); 245 self.channel().shutdown(); 246 } 247 is_invalid(&self) -> bool248 pub fn is_invalid(&self) -> bool { 249 self.0.data.read().unwrap().invalid 250 } 251 } 252 253 #[derive(Clone)] 254 pub struct SessionWeak(Weak<SessionInternal>); 255 256 impl SessionWeak { try_upgrade(&self) -> Option<Session>257 fn try_upgrade(&self) -> Option<Session> { 258 self.0.upgrade().map(Session) 259 } 260 upgrade(&self) -> Session261 pub(crate) fn upgrade(&self) -> Session { 262 self.try_upgrade().expect("Session died") 263 } 264 } 265 266 impl Drop for SessionInternal { drop(&mut self)267 fn drop(&mut self) { 268 debug!("drop Session[{}]", self.session_id); 269 } 270 } 271 272 struct DispatchTask<S>(S, SessionWeak) 273 where 274 S: TryStream<Ok = (u8, Bytes)> + Unpin; 275 276 impl<S> Future for DispatchTask<S> 277 where 278 S: TryStream<Ok = (u8, Bytes)> + Unpin, 279 <S as TryStream>::Ok: std::fmt::Debug, 280 { 281 type Output = Result<(), S::Error>; 282 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>283 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 284 let session = match self.1.try_upgrade() { 285 Some(session) => session, 286 None => return Poll::Ready(Ok(())), 287 }; 288 289 loop { 290 let (cmd, data) = match ready!(self.0.try_poll_next_unpin(cx)) { 291 Some(Ok(t)) => t, 292 None => { 293 warn!("Connection to server closed."); 294 session.shutdown(); 295 return Poll::Ready(Ok(())); 296 } 297 Some(Err(e)) => { 298 session.shutdown(); 299 return Poll::Ready(Err(e)); 300 } 301 }; 302 303 session.dispatch(cmd, data); 304 } 305 } 306 } 307 308 impl<S> Drop for DispatchTask<S> 309 where 310 S: TryStream<Ok = (u8, Bytes)> + Unpin, 311 { drop(&mut self)312 fn drop(&mut self) { 313 debug!("drop Dispatch"); 314 } 315 } 316