1 use std::io;
2 use std::convert::TryInto;
3 use std::mem::size_of;
4 use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5 use std::time::Duration;
6 use std::ptr;
7 use std::os::windows::io::FromRawSocket;
8 use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
9 
10 use winapi::ctypes::{c_char, c_int, c_ushort, c_ulong};
11 use winapi::shared::ws2def::{SOCKADDR_STORAGE, AF_INET, AF_INET6, SOCKADDR_IN};
12 use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH;
13 use winapi::shared::mstcpip;
14 
15 use winapi::shared::minwindef::{BOOL, TRUE, FALSE, DWORD, LPVOID, LPDWORD};
16 use winapi::um::winsock2::{
17     self, closesocket, linger, setsockopt, getsockopt, getsockname, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR,
18     SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, SO_RCVBUF, SO_SNDBUF, SO_KEEPALIVE, WSAIoctl, LPWSAOVERLAPPED,
19 };
20 
21 use crate::sys::windows::net::{init, new_socket, socket_addr};
22 use crate::net::TcpKeepalive;
23 
24 pub(crate) type TcpSocket = SOCKET;
25 
new_v4_socket() -> io::Result<TcpSocket>26 pub(crate) fn new_v4_socket() -> io::Result<TcpSocket> {
27     init();
28     new_socket(PF_INET, SOCK_STREAM)
29 }
30 
new_v6_socket() -> io::Result<TcpSocket>31 pub(crate) fn new_v6_socket() -> io::Result<TcpSocket> {
32     init();
33     new_socket(PF_INET6, SOCK_STREAM)
34 }
35 
bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()>36 pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> {
37     use winsock2::bind;
38 
39     let (raw_addr, raw_addr_length) = socket_addr(&addr);
40     syscall!(
41         bind(socket, raw_addr.as_ptr(), raw_addr_length),
42         PartialEq::eq,
43         SOCKET_ERROR
44     )?;
45     Ok(())
46 }
47 
connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::TcpStream>48 pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::TcpStream> {
49     use winsock2::connect;
50 
51     let (raw_addr, raw_addr_length) = socket_addr(&addr);
52 
53     let res = syscall!(
54         connect(socket, raw_addr.as_ptr(), raw_addr_length),
55         PartialEq::eq,
56         SOCKET_ERROR
57     );
58 
59     match res {
60         Err(err) if err.kind() != io::ErrorKind::WouldBlock => {
61             Err(err)
62         }
63         _ => {
64             Ok(unsafe { net::TcpStream::from_raw_socket(socket as StdSocket) })
65         }
66     }
67 }
68 
listen(socket: TcpSocket, backlog: u32) -> io::Result<net::TcpListener>69 pub(crate) fn listen(socket: TcpSocket, backlog: u32) -> io::Result<net::TcpListener> {
70     use winsock2::listen;
71     use std::convert::TryInto;
72 
73     let backlog = backlog.try_into().unwrap_or(i32::max_value());
74     syscall!(listen(socket, backlog), PartialEq::eq, SOCKET_ERROR)?;
75     Ok(unsafe { net::TcpListener::from_raw_socket(socket as StdSocket) })
76 }
77 
close(socket: TcpSocket)78 pub(crate) fn close(socket: TcpSocket) {
79     let _ = unsafe { closesocket(socket) };
80 }
81 
set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<()>82 pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<()> {
83     let val: BOOL = if reuseaddr { TRUE } else { FALSE };
84 
85     match unsafe { setsockopt(
86         socket,
87         SOL_SOCKET,
88         SO_REUSEADDR,
89         &val as *const _ as *const c_char,
90         size_of::<BOOL>() as c_int,
91     ) } {
92         SOCKET_ERROR => Err(io::Error::last_os_error()),
93         _ => Ok(()),
94     }
95 }
96 
get_reuseaddr(socket: TcpSocket) -> io::Result<bool>97 pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result<bool> {
98     let mut optval: c_char = 0;
99     let mut optlen = size_of::<BOOL>() as c_int;
100 
101     match unsafe { getsockopt(
102         socket,
103         SOL_SOCKET,
104         SO_REUSEADDR,
105         &mut optval as *mut _ as *mut _,
106         &mut optlen,
107     ) } {
108         SOCKET_ERROR => Err(io::Error::last_os_error()),
109         _ => Ok(optval != 0),
110     }
111 }
112 
get_localaddr(socket: TcpSocket) -> io::Result<SocketAddr>113 pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result<SocketAddr> {
114     let mut storage: SOCKADDR_STORAGE = unsafe { std::mem::zeroed() };
115     let mut length = std::mem::size_of_val(&storage) as c_int;
116 
117     match unsafe { getsockname(
118         socket,
119         &mut storage as *mut _ as *mut _,
120         &mut length
121     ) } {
122         SOCKET_ERROR => Err(io::Error::last_os_error()),
123         _ => {
124             if storage.ss_family as c_int == AF_INET {
125                 // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
126                 let addr: &SOCKADDR_IN = unsafe { &*(&storage as *const _ as *const SOCKADDR_IN) };
127                 let ip_bytes = unsafe { addr.sin_addr.S_un.S_un_b() };
128                 let ip = Ipv4Addr::from([ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4]);
129                 let port = u16::from_be(addr.sin_port);
130                 Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
131             } else if storage.ss_family as c_int == AF_INET6 {
132                 // Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6.
133                 let addr: &SOCKADDR_IN6_LH = unsafe { &*(&storage as *const _ as *const SOCKADDR_IN6_LH) };
134                 let ip = Ipv6Addr::from(*unsafe { addr.sin6_addr.u.Byte() });
135                 let port = u16::from_be(addr.sin6_port);
136                 let scope_id = unsafe { *addr.u.sin6_scope_id() };
137                 Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, scope_id)))
138             } else {
139                 Err(std::io::ErrorKind::InvalidInput.into())
140             }
141         },
142     }
143 }
144 
set_linger(socket: TcpSocket, dur: Option<Duration>) -> io::Result<()>145 pub(crate) fn set_linger(socket: TcpSocket, dur: Option<Duration>) -> io::Result<()> {
146     let val: linger = linger {
147         l_onoff: if dur.is_some() { 1 } else { 0 },
148         l_linger: dur.map(|dur| dur.as_secs() as c_ushort).unwrap_or_default(),
149     };
150 
151     match unsafe { setsockopt(
152         socket,
153         SOL_SOCKET,
154         SO_LINGER,
155         &val as *const _ as *const c_char,
156         size_of::<linger>() as c_int,
157     ) } {
158         SOCKET_ERROR => Err(io::Error::last_os_error()),
159         _ => Ok(()),
160     }
161 }
162 
get_linger(socket: TcpSocket) -> io::Result<Option<Duration>>163 pub(crate) fn get_linger(socket: TcpSocket) -> io::Result<Option<Duration>> {
164     let mut val: linger = unsafe { std::mem::zeroed() };
165     let mut len = size_of::<linger>() as c_int;
166 
167     match unsafe { getsockopt(
168         socket,
169         SOL_SOCKET,
170         SO_LINGER,
171         &mut val as *mut _ as *mut _,
172         &mut len,
173     ) } {
174         SOCKET_ERROR => Err(io::Error::last_os_error()),
175         _ => {
176             if val.l_onoff == 0 {
177                 Ok(None)
178             } else {
179                 Ok(Some(Duration::from_secs(val.l_linger as u64)))
180             }
181         },
182     }
183 }
184 
185 
set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()>186 pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> {
187     let size = size.try_into().ok().unwrap_or_else(i32::max_value);
188     match unsafe { setsockopt(
189         socket,
190         SOL_SOCKET,
191         SO_RCVBUF,
192         &size as *const _ as *const c_char,
193         size_of::<c_int>() as c_int
194     ) } {
195         SOCKET_ERROR => Err(io::Error::last_os_error()),
196         _ => Ok(()),
197     }
198 }
199 
get_recv_buffer_size(socket: TcpSocket) -> io::Result<u32>200 pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result<u32> {
201     let mut optval: c_int = 0;
202     let mut optlen = size_of::<c_int>() as c_int;
203     match unsafe { getsockopt(
204         socket,
205         SOL_SOCKET,
206         SO_RCVBUF,
207         &mut optval as *mut _ as *mut _,
208         &mut optlen as *mut _,
209     ) } {
210         SOCKET_ERROR => Err(io::Error::last_os_error()),
211         _ => Ok(optval as u32),
212     }
213 }
214 
set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()>215 pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> {
216     let size = size.try_into().ok().unwrap_or_else(i32::max_value);
217     match unsafe { setsockopt(
218         socket,
219         SOL_SOCKET,
220         SO_SNDBUF,
221         &size as *const _ as *const c_char,
222         size_of::<c_int>() as c_int
223     ) } {
224         SOCKET_ERROR => Err(io::Error::last_os_error()),
225         _ => Ok(()),
226     }
227 }
228 
get_send_buffer_size(socket: TcpSocket) -> io::Result<u32>229 pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result<u32> {
230     let mut optval: c_int = 0;
231     let mut optlen = size_of::<c_int>() as c_int;
232     match unsafe { getsockopt(
233         socket,
234         SOL_SOCKET,
235         SO_SNDBUF,
236         &mut optval as *mut _ as *mut _,
237         &mut optlen as *mut _,
238     ) } {
239         SOCKET_ERROR => Err(io::Error::last_os_error()),
240         _ => Ok(optval as u32),
241     }
242 }
243 
set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()>244 pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> {
245     let val: BOOL = if keepalive { TRUE } else { FALSE };
246     match unsafe { setsockopt(
247         socket,
248         SOL_SOCKET,
249         SO_KEEPALIVE,
250         &val as *const _ as *const c_char,
251         size_of::<BOOL>() as c_int
252     ) } {
253         SOCKET_ERROR => Err(io::Error::last_os_error()),
254         _ => Ok(()),
255     }
256 }
257 
get_keepalive(socket: TcpSocket) -> io::Result<bool>258 pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result<bool> {
259     let mut optval: c_char = 0;
260     let mut optlen = size_of::<BOOL>() as c_int;
261 
262     match unsafe { getsockopt(
263         socket,
264         SOL_SOCKET,
265         SO_KEEPALIVE,
266         &mut optval as *mut _ as *mut _,
267         &mut optlen,
268     ) } {
269         SOCKET_ERROR => Err(io::Error::last_os_error()),
270         _ => Ok(optval != FALSE as c_char),
271     }
272 }
273 
set_keepalive_params(socket: TcpSocket, keepalive: TcpKeepalive) -> io::Result<()>274 pub(crate) fn set_keepalive_params(socket: TcpSocket, keepalive: TcpKeepalive) -> io::Result<()> {
275     /// Windows configures keepalive time/interval in a u32 of milliseconds.
276     fn dur_to_ulong_ms(dur: Duration) -> c_ulong {
277         dur.as_millis().try_into().ok().unwrap_or_else(u32::max_value)
278     }
279 
280     // If any of the fields on the `tcp_keepalive` struct were not provided by
281     // the user, just leaving them zero will clobber any existing value.
282     // Unfortunately, we can't access the current value, so we will use the
283     // defaults if a value for the time or interval was not not provided.
284     let time = keepalive.time.unwrap_or_else(|| {
285         // The default value is two hours, as per
286         // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals
287         let two_hours = 2 * 60 * 60;
288         Duration::from_secs(two_hours)
289     });
290 
291     let interval = keepalive.interval.unwrap_or_else(|| {
292         // The default value is one second, as per
293         // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals
294         Duration::from_secs(1)
295     });
296 
297     let mut keepalive = mstcpip::tcp_keepalive {
298         // Enable keepalive
299         onoff: 1,
300         keepalivetime: dur_to_ulong_ms(time),
301         keepaliveinterval: dur_to_ulong_ms(interval),
302     };
303 
304     let mut out = 0;
305     match unsafe { WSAIoctl(
306         socket,
307         mstcpip::SIO_KEEPALIVE_VALS,
308         &mut keepalive as *mut _ as LPVOID,
309         size_of::<mstcpip::tcp_keepalive>() as DWORD,
310         ptr::null_mut() as LPVOID,
311         0 as DWORD,
312         &mut out as *mut _ as LPDWORD,
313         0 as LPWSAOVERLAPPED,
314         None,
315     ) } {
316         0 => Ok(()),
317         _ => Err(io::Error::last_os_error())
318     }
319 }
320 
accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)>321 pub(crate) fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> {
322     // The non-blocking state of `listener` is inherited. See
323     // https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#remarks.
324     listener.accept()
325 }
326