1 //! `async-std` integration.
2 use tungstenite::client::IntoClientRequest;
3 use tungstenite::handshake::client::{Request, Response};
4 use tungstenite::protocol::WebSocketConfig;
5 use tungstenite::Error;
6 
7 use async_std::net::TcpStream;
8 
9 use super::{domain, port, WebSocketStream};
10 
11 #[cfg(feature = "async-native-tls")]
12 use futures_io::{AsyncRead, AsyncWrite};
13 
14 #[cfg(feature = "async-native-tls")]
15 pub(crate) mod async_native_tls {
16     use async_native_tls::TlsConnector as AsyncTlsConnector;
17     use async_native_tls::TlsStream;
18     use real_async_native_tls as async_native_tls;
19 
20     use tungstenite::client::uri_mode;
21     use tungstenite::handshake::client::Request;
22     use tungstenite::stream::Mode;
23     use tungstenite::Error;
24 
25     use futures_io::{AsyncRead, AsyncWrite};
26 
27     use crate::stream::Stream as StreamSwitcher;
28     use crate::{
29         client_async_with_config, domain, IntoClientRequest, Response, WebSocketConfig,
30         WebSocketStream,
31     };
32 
33     /// A stream that might be protected with TLS.
34     pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
35 
36     pub type AutoStream<S> = MaybeTlsStream<S>;
37 
38     pub type Connector = AsyncTlsConnector;
39 
wrap_stream<S>( socket: S, domain: String, connector: Option<Connector>, mode: Mode, ) -> Result<AutoStream<S>, Error> where S: 'static + AsyncRead + AsyncWrite + Unpin,40     async fn wrap_stream<S>(
41         socket: S,
42         domain: String,
43         connector: Option<Connector>,
44         mode: Mode,
45     ) -> Result<AutoStream<S>, Error>
46     where
47         S: 'static + AsyncRead + AsyncWrite + Unpin,
48     {
49         match mode {
50             Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
51             Mode::Tls => {
52                 let stream = {
53                     let connector = if let Some(connector) = connector {
54                         connector
55                     } else {
56                         AsyncTlsConnector::new()
57                     };
58                     connector
59                         .connect(&domain, socket)
60                         .await
61                         .map_err(|err| Error::Tls(err.into()))?
62                 };
63                 Ok(StreamSwitcher::Tls(stream))
64             }
65         }
66     }
67 
68     /// Creates a WebSocket handshake from a request and a stream,
69     /// upgrading the stream to TLS if required and using the given
70     /// connector and WebSocket configuration.
client_async_tls_with_connector_and_config<R, S>( request: R, stream: S, connector: Option<AsyncTlsConnector>, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> where R: IntoClientRequest + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin, AutoStream<S>: Unpin,71     pub async fn client_async_tls_with_connector_and_config<R, S>(
72         request: R,
73         stream: S,
74         connector: Option<AsyncTlsConnector>,
75         config: Option<WebSocketConfig>,
76     ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
77     where
78         R: IntoClientRequest + Unpin,
79         S: 'static + AsyncRead + AsyncWrite + Unpin,
80         AutoStream<S>: Unpin,
81     {
82         let request: Request = request.into_client_request()?;
83 
84         let domain = domain(&request)?;
85 
86         // Make sure we check domain and mode first. URL must be valid.
87         let mode = uri_mode(request.uri())?;
88 
89         let stream = wrap_stream(stream, domain, connector, mode).await?;
90         client_async_with_config(request, stream, config).await
91     }
92 }
93 
94 #[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
95 pub(crate) mod dummy_tls {
96     use futures_io::{AsyncRead, AsyncWrite};
97 
98     use tungstenite::client::{uri_mode, IntoClientRequest};
99     use tungstenite::handshake::client::Request;
100     use tungstenite::stream::Mode;
101     use tungstenite::Error;
102 
103     use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
104 
105     pub type AutoStream<S> = S;
106     type Connector = ();
107 
wrap_stream<S>( socket: S, _domain: String, _connector: Option<()>, mode: Mode, ) -> Result<AutoStream<S>, Error> where S: 'static + AsyncRead + AsyncWrite + Unpin,108     async fn wrap_stream<S>(
109         socket: S,
110         _domain: String,
111         _connector: Option<()>,
112         mode: Mode,
113     ) -> Result<AutoStream<S>, Error>
114     where
115         S: 'static + AsyncRead + AsyncWrite + Unpin,
116     {
117         match mode {
118             Mode::Plain => Ok(socket),
119             Mode::Tls => Err(Error::Url(
120                 tungstenite::error::UrlError::TlsFeatureNotEnabled,
121             )),
122         }
123     }
124 
125     /// Creates a WebSocket handshake from a request and a stream,
126     /// upgrading the stream to TLS if required and using the given
127     /// connector and WebSocket configuration.
client_async_tls_with_connector_and_config<R, S>( request: R, stream: S, connector: Option<Connector>, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error> where R: IntoClientRequest + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin, AutoStream<S>: Unpin,128     pub async fn client_async_tls_with_connector_and_config<R, S>(
129         request: R,
130         stream: S,
131         connector: Option<Connector>,
132         config: Option<WebSocketConfig>,
133     ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
134     where
135         R: IntoClientRequest + Unpin,
136         S: 'static + AsyncRead + AsyncWrite + Unpin,
137         AutoStream<S>: Unpin,
138     {
139         let request: Request = request.into_client_request()?;
140 
141         let domain = domain(&request)?;
142 
143         // Make sure we check domain and mode first. URL must be valid.
144         let mode = uri_mode(request.uri())?;
145 
146         let stream = wrap_stream(stream, domain, connector, mode).await?;
147         client_async_with_config(request, stream, config).await
148     }
149 }
150 
151 #[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
152 pub use self::dummy_tls::client_async_tls_with_connector_and_config;
153 #[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
154 use self::dummy_tls::AutoStream;
155 
156 #[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
157 pub use crate::async_tls::client_async_tls_with_connector_and_config;
158 #[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
159 use crate::async_tls::AutoStream;
160 #[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
161 type Connector = real_async_tls::TlsConnector;
162 
163 #[cfg(feature = "async-native-tls")]
164 pub use self::async_native_tls::client_async_tls_with_connector_and_config;
165 #[cfg(feature = "async-native-tls")]
166 use self::async_native_tls::{AutoStream, Connector};
167 
168 /// Type alias for the stream type of the `client_async()` functions.
169 pub type ClientStream<S> = AutoStream<S>;
170 
171 #[cfg(feature = "async-native-tls")]
172 /// Creates a WebSocket handshake from a request and a stream,
173 /// upgrading the stream to TLS if required.
client_async_tls<R, S>( request: R, stream: S, ) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error> where R: IntoClientRequest + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin, AutoStream<S>: Unpin,174 pub async fn client_async_tls<R, S>(
175     request: R,
176     stream: S,
177 ) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
178 where
179     R: IntoClientRequest + Unpin,
180     S: 'static + AsyncRead + AsyncWrite + Unpin,
181     AutoStream<S>: Unpin,
182 {
183     client_async_tls_with_connector_and_config(request, stream, None, None).await
184 }
185 
186 #[cfg(feature = "async-native-tls")]
187 /// Creates a WebSocket handshake from a request and a stream,
188 /// upgrading the stream to TLS if required and using the given
189 /// WebSocket configuration.
client_async_tls_with_config<R, S>( request: R, stream: S, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error> where R: IntoClientRequest + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin, AutoStream<S>: Unpin,190 pub async fn client_async_tls_with_config<R, S>(
191     request: R,
192     stream: S,
193     config: Option<WebSocketConfig>,
194 ) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
195 where
196     R: IntoClientRequest + Unpin,
197     S: 'static + AsyncRead + AsyncWrite + Unpin,
198     AutoStream<S>: Unpin,
199 {
200     client_async_tls_with_connector_and_config(request, stream, None, config).await
201 }
202 
203 #[cfg(feature = "async-native-tls")]
204 /// Creates a WebSocket handshake from a request and a stream,
205 /// upgrading the stream to TLS if required and using the given
206 /// connector.
client_async_tls_with_connector<R, S>( request: R, stream: S, connector: Option<Connector>, ) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error> where R: IntoClientRequest + Unpin, S: 'static + AsyncRead + AsyncWrite + Unpin, AutoStream<S>: Unpin,207 pub async fn client_async_tls_with_connector<R, S>(
208     request: R,
209     stream: S,
210     connector: Option<Connector>,
211 ) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
212 where
213     R: IntoClientRequest + Unpin,
214     S: 'static + AsyncRead + AsyncWrite + Unpin,
215     AutoStream<S>: Unpin,
216 {
217     client_async_tls_with_connector_and_config(request, stream, connector, None).await
218 }
219 
220 /// Type alias for the stream type of the `connect_async()` functions.
221 pub type ConnectStream = ClientStream<TcpStream>;
222 
223 /// Connect to a given URL.
connect_async<R>( request: R, ) -> Result<(WebSocketStream<ConnectStream>, Response), Error> where R: IntoClientRequest + Unpin,224 pub async fn connect_async<R>(
225     request: R,
226 ) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
227 where
228     R: IntoClientRequest + Unpin,
229 {
230     connect_async_with_config(request, None).await
231 }
232 
233 /// Connect to a given URL with a given WebSocket configuration.
connect_async_with_config<R>( request: R, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<ConnectStream>, Response), Error> where R: IntoClientRequest + Unpin,234 pub async fn connect_async_with_config<R>(
235     request: R,
236     config: Option<WebSocketConfig>,
237 ) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
238 where
239     R: IntoClientRequest + Unpin,
240 {
241     let request: Request = request.into_client_request()?;
242 
243     let domain = domain(&request)?;
244     let port = port(&request)?;
245 
246     let try_socket = TcpStream::connect((domain.as_str(), port)).await;
247     let socket = try_socket.map_err(Error::Io)?;
248     client_async_tls_with_connector_and_config(request, socket, None, config).await
249 }
250 
251 #[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
252 /// Connect to a given URL using the provided TLS connector.
connect_async_with_tls_connector<R>( request: R, connector: Option<Connector>, ) -> Result<(WebSocketStream<ConnectStream>, Response), Error> where R: IntoClientRequest + Unpin,253 pub async fn connect_async_with_tls_connector<R>(
254     request: R,
255     connector: Option<Connector>,
256 ) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
257 where
258     R: IntoClientRequest + Unpin,
259 {
260     connect_async_with_tls_connector_and_config(request, connector, None).await
261 }
262 
263 #[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
264 /// Connect to a given URL using the provided TLS connector.
connect_async_with_tls_connector_and_config<R>( request: R, connector: Option<Connector>, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<ConnectStream>, Response), Error> where R: IntoClientRequest + Unpin,265 pub async fn connect_async_with_tls_connector_and_config<R>(
266     request: R,
267     connector: Option<Connector>,
268     config: Option<WebSocketConfig>,
269 ) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
270 where
271     R: IntoClientRequest + Unpin,
272 {
273     let request: Request = request.into_client_request()?;
274 
275     let domain = domain(&request)?;
276     let port = port(&request)?;
277 
278     let try_socket = TcpStream::connect((domain.as_str(), port)).await;
279     let socket = try_socket.map_err(Error::Io)?;
280     client_async_tls_with_connector_and_config(request, socket, connector, config).await
281 }
282