1 use crate::{Authentication, Error, IntoTargetAddr, Result, TargetAddr, ToProxyAddrs};
2 use derefable::Derefable;
3 use futures::{
4     stream,
5     stream::Fuse,
6     task::{Context, Poll},
7     Stream, StreamExt,
8 };
9 use std::{
10     borrow::Borrow,
11     io,
12     net::{Ipv4Addr, Ipv6Addr, SocketAddr},
13     pin::Pin,
14 };
15 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16 use tokio::net::TcpStream;
17 
18 #[repr(u8)]
19 #[derive(Clone, Copy)]
20 enum Command {
21     Connect = 0x01,
22     Bind = 0x02,
23     #[allow(dead_code)]
24     Associate = 0x03,
25     #[cfg(feature = "tor")]
26     TorResolve = 0xF0,
27     #[cfg(feature = "tor")]
28     TorResolvePtr = 0xF1,
29 }
30 
31 /// A SOCKS5 client.
32 ///
33 /// For convenience, it can be dereferenced to `tokio_tcp::TcpStream`.
34 #[derive(Debug, Derefable)]
35 pub struct Socks5Stream {
36     #[deref(mutable)]
37     tcp: TcpStream,
38     target: TargetAddr<'static>,
39 }
40 
41 impl Socks5Stream {
42     /// Connects to a target server through a SOCKS5 proxy.
43     ///
44     /// # Error
45     ///
46     /// It propagates the error that occurs in the conversion from `T` to
47     /// `TargetAddr`.
connect<'t, P, T>(proxy: P, target: T) -> Result<Self> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,48     pub async fn connect<'t, P, T>(proxy: P, target: T) -> Result<Self>
49     where
50         P: ToProxyAddrs,
51         T: IntoTargetAddr<'t>,
52     {
53         Self::execute_command(proxy, target, Authentication::None, Command::Connect).await
54     }
55 
56     /// Connects to a target server through a SOCKS5 proxy using given username
57     /// and password.
58     ///
59     /// # Error
60     ///
61     /// It propagates the error that occurs in the conversion from `T` to
62     /// `TargetAddr`.
connect_with_password<'a, 't, P, T>( proxy: P, target: T, username: &'a str, password: &'a str, ) -> Result<Self> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,63     pub async fn connect_with_password<'a, 't, P, T>(
64         proxy: P,
65         target: T,
66         username: &'a str,
67         password: &'a str,
68     ) -> Result<Self>
69     where
70         P: ToProxyAddrs,
71         T: IntoTargetAddr<'t>,
72     {
73         Self::execute_command(
74             proxy,
75             target,
76             Authentication::Password { username, password },
77             Command::Connect,
78         )
79         .await
80     }
81 
validate_auth<'a>(auth: &Authentication<'a>) -> Result<()>82     fn validate_auth<'a>(auth: &Authentication<'a>) -> Result<()> {
83         match auth {
84             Authentication::Password { username, password } => {
85                 let username_len = username.as_bytes().len();
86                 if username_len < 1 || username_len > 255 {
87                     Err(Error::InvalidAuthValues("username length should between 1 to 255"))?
88                 }
89                 let password_len = password.as_bytes().len();
90                 if password_len < 1 || password_len > 255 {
91                     Err(Error::InvalidAuthValues("password length should between 1 to 255"))?
92                 }
93             }
94             Authentication::None => {}
95         }
96         Ok(())
97     }
98 
99     #[cfg(feature = "tor")]
tor_resolve<'t, P, T>(proxy: P, target: T) -> Result<TargetAddr<'static>> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,100     pub async fn tor_resolve<'t, P, T>(proxy: P, target: T) -> Result<TargetAddr<'static>>
101     where
102         P: ToProxyAddrs,
103         T: IntoTargetAddr<'t>,
104     {
105         let sock = Self::execute_command(proxy, target, Authentication::None, Command::TorResolve).await?;
106 
107         Ok(sock.target_addr().to_owned())
108     }
109 
110     #[cfg(feature = "tor")]
tor_resolve_ptr<'t, P, T>(proxy: P, target: T) -> Result<TargetAddr<'static>> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,111     pub async fn tor_resolve_ptr<'t, P, T>(proxy: P, target: T) -> Result<TargetAddr<'static>>
112     where
113         P: ToProxyAddrs,
114         T: IntoTargetAddr<'t>,
115     {
116         let sock = Self::execute_command(proxy, target, Authentication::None, Command::TorResolvePtr).await?;
117 
118         Ok(sock.target_addr().to_owned())
119     }
120 
execute_command<'a, 't, P, T>( proxy: P, target: T, auth: Authentication<'a>, command: Command, ) -> Result<Socks5Stream> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,121     async fn execute_command<'a, 't, P, T>(
122         proxy: P,
123         target: T,
124         auth: Authentication<'a>,
125         command: Command,
126     ) -> Result<Socks5Stream>
127     where
128         P: ToProxyAddrs,
129         T: IntoTargetAddr<'t>,
130     {
131         Self::validate_auth(&auth)?;
132 
133         let sock = SocksConnector::new(auth, command, proxy.to_proxy_addrs().fuse(), target.into_target_addr()?)
134             .execute()
135             .await?;
136 
137         Ok(sock)
138     }
139 
140     /// Consumes the `Socks5Stream`, returning the inner `tokio_tcp::TcpStream`.
into_inner(self) -> TcpStream141     pub fn into_inner(self) -> TcpStream {
142         self.tcp
143     }
144 
145     /// Returns the target address that the proxy server connects to.
target_addr(&self) -> TargetAddr<'_>146     pub fn target_addr(&self) -> TargetAddr<'_> {
147         match &self.target {
148             TargetAddr::Ip(addr) => TargetAddr::Ip(*addr),
149             TargetAddr::Domain(domain, port) => {
150                 let domain: &str = domain.borrow();
151                 TargetAddr::Domain(domain.into(), *port)
152             }
153         }
154     }
155 }
156 
157 /// A `Future` which resolves to a socket to the target server through proxy.
158 pub struct SocksConnector<'a, 't, S> {
159     auth: Authentication<'a>,
160     command: Command,
161     proxy: Fuse<S>,
162     target: TargetAddr<'t>,
163     buf: [u8; 513],
164     ptr: usize,
165     len: usize,
166 }
167 
168 impl<'a, 't, S> SocksConnector<'a, 't, S>
169 where
170     S: Stream<Item = Result<SocketAddr>> + Unpin,
171 {
new(auth: Authentication<'a>, command: Command, proxy: Fuse<S>, target: TargetAddr<'t>) -> Self172     fn new(auth: Authentication<'a>, command: Command, proxy: Fuse<S>, target: TargetAddr<'t>) -> Self {
173         SocksConnector {
174             auth,
175             command,
176             proxy,
177             target,
178             buf: [0; 513],
179             ptr: 0,
180             len: 0,
181         }
182     }
183 
184     /// Connect to the proxy server, authenticate and issue the SOCKS command
execute(&mut self) -> Result<Socks5Stream>185     pub async fn execute(&mut self) -> Result<Socks5Stream> {
186         let next_addr = self.proxy.select_next_some().await?;
187         let mut tcp = TcpStream::connect(next_addr)
188             .await
189             .map_err(|_| Error::ProxyServerUnreachable)?;
190 
191         self.authenticate(&mut tcp).await?;
192 
193         // Send request address that should be proxied
194         self.prepare_send_request();
195         tcp.write_all(&self.buf[self.ptr..self.len]).await?;
196 
197         let target = self.receive_reply(&mut tcp).await?;
198 
199         Ok(Socks5Stream { tcp, target })
200     }
201 
prepare_send_method_selection(&mut self)202     fn prepare_send_method_selection(&mut self) {
203         self.ptr = 0;
204         self.buf[0] = 0x05;
205         match self.auth {
206             Authentication::None => {
207                 self.buf[1..3].copy_from_slice(&[1, 0x00]);
208                 self.len = 3;
209             }
210             Authentication::Password { .. } => {
211                 self.buf[1..4].copy_from_slice(&[2, 0x00, 0x02]);
212                 self.len = 4;
213             }
214         }
215     }
216 
prepare_recv_method_selection(&mut self)217     fn prepare_recv_method_selection(&mut self) {
218         self.ptr = 0;
219         self.len = 2;
220     }
221 
prepare_send_password_auth(&mut self)222     fn prepare_send_password_auth(&mut self) {
223         if let Authentication::Password { username, password } = self.auth {
224             self.ptr = 0;
225             self.buf[0] = 0x01;
226             let username_bytes = username.as_bytes();
227             let username_len = username_bytes.len();
228             self.buf[1] = username_len as u8;
229             self.buf[2..(2 + username_len)].copy_from_slice(username_bytes);
230             let password_bytes = password.as_bytes();
231             let password_len = password_bytes.len();
232             self.len = 3 + username_len + password_len;
233             self.buf[(2 + username_len)] = password_len as u8;
234             self.buf[(3 + username_len)..self.len].copy_from_slice(password_bytes);
235         } else {
236             unreachable!()
237         }
238     }
239 
prepare_recv_password_auth(&mut self)240     fn prepare_recv_password_auth(&mut self) {
241         self.ptr = 0;
242         self.len = 2;
243     }
244 
prepare_send_request(&mut self)245     fn prepare_send_request(&mut self) {
246         self.ptr = 0;
247         self.buf[..3].copy_from_slice(&[0x05, self.command as u8, 0x00]);
248         match &self.target {
249             TargetAddr::Ip(SocketAddr::V4(addr)) => {
250                 self.buf[3] = 0x01;
251                 self.buf[4..8].copy_from_slice(&addr.ip().octets());
252                 self.buf[8..10].copy_from_slice(&addr.port().to_be_bytes());
253                 self.len = 10;
254             }
255             TargetAddr::Ip(SocketAddr::V6(addr)) => {
256                 self.buf[3] = 0x04;
257                 self.buf[4..20].copy_from_slice(&addr.ip().octets());
258                 self.buf[20..22].copy_from_slice(&addr.port().to_be_bytes());
259                 self.len = 22;
260             }
261             TargetAddr::Domain(domain, port) => {
262                 self.buf[3] = 0x03;
263                 let domain = domain.as_bytes();
264                 let len = domain.len();
265                 self.buf[4] = len as u8;
266                 self.buf[5..5 + len].copy_from_slice(domain);
267                 self.buf[(5 + len)..(7 + len)].copy_from_slice(&port.to_be_bytes());
268                 self.len = 7 + len;
269             }
270         }
271     }
272 
prepare_recv_reply(&mut self)273     fn prepare_recv_reply(&mut self) {
274         self.ptr = 0;
275         self.len = 4;
276     }
277 
password_authentication_protocol(&mut self, tcp: &mut TcpStream) -> Result<()>278     async fn password_authentication_protocol(&mut self, tcp: &mut TcpStream) -> Result<()> {
279         self.prepare_send_password_auth();
280         tcp.write_all(&self.buf[self.ptr..self.len]).await?;
281 
282         self.prepare_recv_password_auth();
283         tcp.read_exact(&mut self.buf[self.ptr..self.len]).await?;
284 
285         if self.buf[0] != 0x01 {
286             return Err(Error::InvalidResponseVersion);
287         }
288         if self.buf[1] != 0x00 {
289             return Err(Error::PasswordAuthFailure(self.buf[1]));
290         }
291 
292         Ok(())
293     }
294 
authenticate(&mut self, tcp: &mut TcpStream) -> Result<()>295     async fn authenticate(&mut self, tcp: &mut TcpStream) -> Result<()> {
296         // Write request to connect/authenticate
297         self.prepare_send_method_selection();
298         tcp.write_all(&self.buf[self.ptr..self.len]).await?;
299 
300         // Receive authentication method
301         self.prepare_recv_method_selection();
302         tcp.read_exact(&mut self.buf[self.ptr..self.len]).await?;
303         if self.buf[0] != 0x05 {
304             return Err(Error::InvalidResponseVersion);
305         }
306         match self.buf[1] {
307             0x00 => {
308                 // No auth
309             }
310             0x02 => {
311                 self.password_authentication_protocol(tcp).await?;
312             }
313             0xff => {
314                 return Err(Error::NoAcceptableAuthMethods);
315             }
316             m if m != self.auth.id() => return Err(Error::UnknownAuthMethod),
317             _ => unimplemented!(),
318         }
319 
320         Ok(())
321     }
322 
receive_reply(&mut self, tcp: &mut TcpStream) -> Result<TargetAddr<'static>>323     async fn receive_reply(&mut self, tcp: &mut TcpStream) -> Result<TargetAddr<'static>> {
324         self.prepare_recv_reply();
325         self.ptr += tcp.read_exact(&mut self.buf[self.ptr..self.len]).await?;
326         if self.buf[0] != 0x05 {
327             return Err(Error::InvalidResponseVersion);
328         }
329         if self.buf[2] != 0x00 {
330             return Err(Error::InvalidReservedByte);
331         }
332 
333         match self.buf[1] {
334             0x00 => {} // succeeded
335             0x01 => Err(Error::GeneralSocksServerFailure)?,
336             0x02 => Err(Error::ConnectionNotAllowedByRuleset)?,
337             0x03 => Err(Error::NetworkUnreachable)?,
338             0x04 => Err(Error::HostUnreachable)?,
339             0x05 => Err(Error::ConnectionRefused)?,
340             0x06 => Err(Error::TtlExpired)?,
341             0x07 => Err(Error::CommandNotSupported)?,
342             0x08 => Err(Error::AddressTypeNotSupported)?,
343             _ => Err(Error::UnknownAuthMethod)?,
344         }
345 
346         match self.buf[3] {
347             // IPv4
348             0x01 => {
349                 self.len = 10;
350             }
351             // IPv6
352             0x04 => {
353                 self.len = 22;
354             }
355             // Domain
356             0x03 => {
357                 self.len = 5;
358                 self.ptr += tcp.read_exact(&mut self.buf[self.ptr..self.len]).await?;
359                 self.len += self.buf[4] as usize + 2;
360             }
361             _ => Err(Error::UnknownAddressType)?,
362         }
363 
364         self.ptr += tcp.read_exact(&mut self.buf[self.ptr..self.len]).await?;
365         let target: TargetAddr<'static> = match self.buf[3] {
366             // IPv4
367             0x01 => {
368                 let mut ip = [0; 4];
369                 ip[..].copy_from_slice(&self.buf[4..8]);
370                 let ip = Ipv4Addr::from(ip);
371                 let port = u16::from_be_bytes([self.buf[8], self.buf[9]]);
372                 (ip, port).into_target_addr()?
373             }
374             // IPv6
375             0x04 => {
376                 let mut ip = [0; 16];
377                 ip[..].copy_from_slice(&self.buf[4..20]);
378                 let ip = Ipv6Addr::from(ip);
379                 let port = u16::from_be_bytes([self.buf[20], self.buf[21]]);
380                 (ip, port).into_target_addr()?
381             }
382             // Domain
383             0x03 => {
384                 let domain_bytes = (&self.buf[5..(self.len - 2)]).to_vec();
385                 let domain = String::from_utf8(domain_bytes)
386                     .map_err(|_| Error::InvalidTargetAddress("not a valid UTF-8 string"))?;
387                 let port = u16::from_be_bytes([self.buf[self.len - 2], self.buf[self.len - 1]]);
388                 TargetAddr::Domain(domain.into(), port)
389             }
390             _ => unreachable!(),
391         };
392 
393         Ok(target)
394     }
395 }
396 
397 /// A SOCKS5 BIND client.
398 ///
399 /// Once you get an instance of `Socks5Listener`, you should send the
400 /// `bind_addr` to the remote process via the primary connection. Then, call the
401 /// `accept` function and wait for the other end connecting to the rendezvous
402 /// address.
403 pub struct Socks5Listener {
404     inner: Socks5Stream,
405 }
406 
407 impl Socks5Listener {
408     /// Initiates a BIND request to the specified proxy.
409     ///
410     /// The proxy will filter incoming connections based on the value of
411     /// `target`.
412     ///
413     /// # Error
414     ///
415     /// It propagates the error that occurs in the conversion from `T` to
416     /// `TargetAddr`.
bind<'t, P, T>(proxy: P, target: T) -> Result<Socks5Listener> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,417     pub async fn bind<'t, P, T>(proxy: P, target: T) -> Result<Socks5Listener>
418     where
419         P: ToProxyAddrs,
420         T: IntoTargetAddr<'t>,
421     {
422         Self::bind_with_auth(Authentication::None, proxy, target).await
423     }
424 
425     /// Initiates a BIND request to the specified proxy using given username
426     /// and password.
427     ///
428     /// The proxy will filter incoming connections based on the value of
429     /// `target`.
430     ///
431     /// # Error
432     ///
433     /// It propagates the error that occurs in the conversion from `T` to
434     /// `TargetAddr`.
bind_with_password<'a, 't, P, T>( proxy: P, target: T, username: &'a str, password: &'a str, ) -> Result<Socks5Listener> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,435     pub async fn bind_with_password<'a, 't, P, T>(
436         proxy: P,
437         target: T,
438         username: &'a str,
439         password: &'a str,
440     ) -> Result<Socks5Listener>
441     where
442         P: ToProxyAddrs,
443         T: IntoTargetAddr<'t>,
444     {
445         Self::bind_with_auth(Authentication::Password { username, password }, proxy, target).await
446     }
447 
bind_with_auth<'t, P, T>(auth: Authentication<'_>, proxy: P, target: T) -> Result<Socks5Listener> where P: ToProxyAddrs, T: IntoTargetAddr<'t>,448     async fn bind_with_auth<'t, P, T>(auth: Authentication<'_>, proxy: P, target: T) -> Result<Socks5Listener>
449     where
450         P: ToProxyAddrs,
451         T: IntoTargetAddr<'t>,
452     {
453         let socket = SocksConnector::new(
454             auth,
455             Command::Bind,
456             proxy.to_proxy_addrs().fuse(),
457             target.into_target_addr()?,
458         )
459         .execute()
460         .await?;
461 
462         Ok(Socks5Listener { inner: socket })
463     }
464 
465     /// Returns the address of the proxy-side TCP listener.
466     ///
467     /// This should be forwarded to the remote process, which should open a
468     /// connection to it.
bind_addr(&self) -> TargetAddr469     pub fn bind_addr(&self) -> TargetAddr {
470         self.inner.target_addr()
471     }
472 
473     /// Consumes this listener, returning a `Future` which resolves to the
474     /// `Socks5Stream` connected to the target server through the proxy.
475     ///
476     /// The value of `bind_addr` should be forwarded to the remote process
477     /// before this method is called.
accept(mut self) -> Result<Socks5Stream>478     pub async fn accept(mut self) -> Result<Socks5Stream> {
479         let mut connector = SocksConnector {
480             auth: Authentication::None,
481             command: Command::Bind,
482             proxy: stream::empty().fuse(),
483             target: self.inner.target,
484             buf: [0; 513],
485             ptr: 0,
486             len: 0,
487         };
488 
489         let target = connector.receive_reply(&mut self.inner.tcp).await?;
490 
491         Ok(Socks5Stream {
492             tcp: self.inner.tcp,
493             target,
494         })
495     }
496 }
497 
498 impl AsyncRead for Socks5Stream {
prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool499     unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
500         AsyncRead::prepare_uninitialized_buffer(&self.tcp, buf)
501     }
502 
poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>>503     fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
504         AsyncRead::poll_read(Pin::new(&mut self.tcp), cx, buf)
505     }
506 }
507 
508 impl AsyncWrite for Socks5Stream {
poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>>509     fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
510         AsyncWrite::poll_write(Pin::new(&mut self.tcp), cx, buf)
511     }
512 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>513     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
514         AsyncWrite::poll_flush(Pin::new(&mut self.tcp), cx)
515     }
516 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>517     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
518         AsyncWrite::poll_shutdown(Pin::new(&mut self.tcp), cx)
519     }
520 }
521