1 use std::future::Future;
2 use std::io;
3 use std::pin::Pin;
4 use std::sync::atomic::{AtomicUsize, Ordering};
5 use std::sync::{Arc, RwLock, Weak};
6 use std::task::Context;
7 use std::task::Poll;
8 use std::time::{SystemTime, UNIX_EPOCH};
9 
10 use byteorder::{BigEndian, ByteOrder};
11 use bytes::Bytes;
12 use futures_core::TryStream;
13 use futures_util::{future, ready, StreamExt, TryStreamExt};
14 use once_cell::sync::OnceCell;
15 use thiserror::Error;
16 use tokio::sync::mpsc;
17 use tokio_stream::wrappers::UnboundedReceiverStream;
18 
19 use crate::apresolve::apresolve;
20 use crate::audio_key::AudioKeyManager;
21 use crate::authentication::Credentials;
22 use crate::cache::Cache;
23 use crate::channel::ChannelManager;
24 use crate::config::SessionConfig;
25 use crate::connection::{self, AuthenticationError};
26 use crate::mercury::MercuryManager;
27 
28 #[derive(Debug, Error)]
29 pub enum SessionError {
30     #[error(transparent)]
31     AuthenticationError(#[from] AuthenticationError),
32     #[error("Cannot create session: {0}")]
33     IoError(#[from] io::Error),
34 }
35 
36 struct SessionData {
37     country: String,
38     time_delta: i64,
39     canonical_username: String,
40     invalid: bool,
41 }
42 
43 struct SessionInternal {
44     config: SessionConfig,
45     data: RwLock<SessionData>,
46 
47     tx_connection: mpsc::UnboundedSender<(u8, Vec<u8>)>,
48 
49     audio_key: OnceCell<AudioKeyManager>,
50     channel: OnceCell<ChannelManager>,
51     mercury: OnceCell<MercuryManager>,
52     cache: Option<Arc<Cache>>,
53 
54     handle: tokio::runtime::Handle,
55 
56     session_id: usize,
57 }
58 
59 static SESSION_COUNTER: AtomicUsize = AtomicUsize::new(0);
60 
61 #[derive(Clone)]
62 pub struct Session(Arc<SessionInternal>);
63 
64 impl Session {
connect( config: SessionConfig, credentials: Credentials, cache: Option<Cache>, ) -> Result<Session, SessionError>65     pub async fn connect(
66         config: SessionConfig,
67         credentials: Credentials,
68         cache: Option<Cache>,
69     ) -> Result<Session, SessionError> {
70         let ap = apresolve(config.proxy.as_ref(), config.ap_port).await;
71 
72         info!("Connecting to AP \"{}\"", ap);
73         let mut conn = connection::connect(ap, config.proxy.as_ref()).await?;
74 
75         let reusable_credentials =
76             connection::authenticate(&mut conn, credentials, &config.device_id).await?;
77         info!("Authenticated as \"{}\" !", reusable_credentials.username);
78         if let Some(cache) = &cache {
79             cache.save_credentials(&reusable_credentials);
80         }
81 
82         let session = Session::create(
83             conn,
84             config,
85             cache,
86             reusable_credentials.username,
87             tokio::runtime::Handle::current(),
88         );
89 
90         Ok(session)
91     }
92 
create( transport: connection::Transport, config: SessionConfig, cache: Option<Cache>, username: String, handle: tokio::runtime::Handle, ) -> Session93     fn create(
94         transport: connection::Transport,
95         config: SessionConfig,
96         cache: Option<Cache>,
97         username: String,
98         handle: tokio::runtime::Handle,
99     ) -> Session {
100         let (sink, stream) = transport.split();
101 
102         let (sender_tx, sender_rx) = mpsc::unbounded_channel();
103         let session_id = SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
104 
105         debug!("new Session[{}]", session_id);
106 
107         let session = Session(Arc::new(SessionInternal {
108             config,
109             data: RwLock::new(SessionData {
110                 country: String::new(),
111                 canonical_username: username,
112                 invalid: false,
113                 time_delta: 0,
114             }),
115             tx_connection: sender_tx,
116             cache: cache.map(Arc::new),
117             audio_key: OnceCell::new(),
118             channel: OnceCell::new(),
119             mercury: OnceCell::new(),
120             handle,
121             session_id,
122         }));
123 
124         let sender_task = UnboundedReceiverStream::new(sender_rx)
125             .map(Ok)
126             .forward(sink);
127         let receiver_task = DispatchTask(stream, session.weak());
128 
129         tokio::spawn(async move {
130             let result = future::try_join(sender_task, receiver_task).await;
131 
132             if let Err(e) = result {
133                 error!("{}", e);
134             }
135         });
136 
137         session
138     }
139 
audio_key(&self) -> &AudioKeyManager140     pub fn audio_key(&self) -> &AudioKeyManager {
141         self.0
142             .audio_key
143             .get_or_init(|| AudioKeyManager::new(self.weak()))
144     }
145 
channel(&self) -> &ChannelManager146     pub fn channel(&self) -> &ChannelManager {
147         self.0
148             .channel
149             .get_or_init(|| ChannelManager::new(self.weak()))
150     }
151 
mercury(&self) -> &MercuryManager152     pub fn mercury(&self) -> &MercuryManager {
153         self.0
154             .mercury
155             .get_or_init(|| MercuryManager::new(self.weak()))
156     }
157 
time_delta(&self) -> i64158     pub fn time_delta(&self) -> i64 {
159         self.0.data.read().unwrap().time_delta
160     }
161 
spawn<T>(&self, task: T) where T: Future + Send + 'static, T::Output: Send + 'static,162     pub fn spawn<T>(&self, task: T)
163     where
164         T: Future + Send + 'static,
165         T::Output: Send + 'static,
166     {
167         self.0.handle.spawn(task);
168     }
169 
debug_info(&self)170     fn debug_info(&self) {
171         debug!(
172             "Session[{}] strong={} weak={}",
173             self.0.session_id,
174             Arc::strong_count(&self.0),
175             Arc::weak_count(&self.0)
176         );
177     }
178 
179     #[allow(clippy::match_same_arms)]
dispatch(&self, cmd: u8, data: Bytes)180     fn dispatch(&self, cmd: u8, data: Bytes) {
181         match cmd {
182             0x4 => {
183                 let server_timestamp = BigEndian::read_u32(data.as_ref()) as i64;
184                 let timestamp = match SystemTime::now().duration_since(UNIX_EPOCH) {
185                     Ok(dur) => dur,
186                     Err(err) => err.duration(),
187                 }
188                 .as_secs() as i64;
189 
190                 self.0.data.write().unwrap().time_delta = server_timestamp - timestamp;
191 
192                 self.debug_info();
193                 self.send_packet(0x49, vec![0, 0, 0, 0]);
194             }
195             0x4a => (),
196             0x1b => {
197                 let country = String::from_utf8(data.as_ref().to_owned()).unwrap();
198                 info!("Country: {:?}", country);
199                 self.0.data.write().unwrap().country = country;
200             }
201 
202             0x9 | 0xa => self.channel().dispatch(cmd, data),
203             0xd | 0xe => self.audio_key().dispatch(cmd, data),
204             0xb2..=0xb6 => self.mercury().dispatch(cmd, data),
205             _ => (),
206         }
207     }
208 
send_packet(&self, cmd: u8, data: Vec<u8>)209     pub fn send_packet(&self, cmd: u8, data: Vec<u8>) {
210         self.0.tx_connection.send((cmd, data)).unwrap();
211     }
212 
cache(&self) -> Option<&Arc<Cache>>213     pub fn cache(&self) -> Option<&Arc<Cache>> {
214         self.0.cache.as_ref()
215     }
216 
config(&self) -> &SessionConfig217     fn config(&self) -> &SessionConfig {
218         &self.0.config
219     }
220 
username(&self) -> String221     pub fn username(&self) -> String {
222         self.0.data.read().unwrap().canonical_username.clone()
223     }
224 
country(&self) -> String225     pub fn country(&self) -> String {
226         self.0.data.read().unwrap().country.clone()
227     }
228 
device_id(&self) -> &str229     pub fn device_id(&self) -> &str {
230         &self.config().device_id
231     }
232 
weak(&self) -> SessionWeak233     fn weak(&self) -> SessionWeak {
234         SessionWeak(Arc::downgrade(&self.0))
235     }
236 
session_id(&self) -> usize237     pub fn session_id(&self) -> usize {
238         self.0.session_id
239     }
240 
shutdown(&self)241     pub fn shutdown(&self) {
242         debug!("Invalidating session[{}]", self.0.session_id);
243         self.0.data.write().unwrap().invalid = true;
244         self.mercury().shutdown();
245         self.channel().shutdown();
246     }
247 
is_invalid(&self) -> bool248     pub fn is_invalid(&self) -> bool {
249         self.0.data.read().unwrap().invalid
250     }
251 }
252 
253 #[derive(Clone)]
254 pub struct SessionWeak(Weak<SessionInternal>);
255 
256 impl SessionWeak {
try_upgrade(&self) -> Option<Session>257     fn try_upgrade(&self) -> Option<Session> {
258         self.0.upgrade().map(Session)
259     }
260 
upgrade(&self) -> Session261     pub(crate) fn upgrade(&self) -> Session {
262         self.try_upgrade().expect("Session died")
263     }
264 }
265 
266 impl Drop for SessionInternal {
drop(&mut self)267     fn drop(&mut self) {
268         debug!("drop Session[{}]", self.session_id);
269     }
270 }
271 
272 struct DispatchTask<S>(S, SessionWeak)
273 where
274     S: TryStream<Ok = (u8, Bytes)> + Unpin;
275 
276 impl<S> Future for DispatchTask<S>
277 where
278     S: TryStream<Ok = (u8, Bytes)> + Unpin,
279     <S as TryStream>::Ok: std::fmt::Debug,
280 {
281     type Output = Result<(), S::Error>;
282 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>283     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
284         let session = match self.1.try_upgrade() {
285             Some(session) => session,
286             None => return Poll::Ready(Ok(())),
287         };
288 
289         loop {
290             let (cmd, data) = match ready!(self.0.try_poll_next_unpin(cx)) {
291                 Some(Ok(t)) => t,
292                 None => {
293                     warn!("Connection to server closed.");
294                     session.shutdown();
295                     return Poll::Ready(Ok(()));
296                 }
297                 Some(Err(e)) => {
298                     session.shutdown();
299                     return Poll::Ready(Err(e));
300                 }
301             };
302 
303             session.dispatch(cmd, data);
304         }
305     }
306 }
307 
308 impl<S> Drop for DispatchTask<S>
309 where
310     S: TryStream<Ok = (u8, Bytes)> + Unpin,
311 {
drop(&mut self)312     fn drop(&mut self) {
313         debug!("drop Dispatch");
314     }
315 }
316