1 // Copyright 2015 The Rust Project Developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 use std::cmp::min;
10 use std::io::{self, IoSlice};
11 use std::marker::PhantomData;
12 use std::mem::{self, size_of, MaybeUninit};
13 use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14 use std::os::windows::prelude::*;
15 use std::sync::Once;
16 use std::time::{Duration, Instant};
17 use std::{ptr, slice};
18 
19 use winapi::ctypes::c_long;
20 use winapi::shared::in6addr::*;
21 use winapi::shared::inaddr::*;
22 use winapi::shared::minwindef::DWORD;
23 use winapi::shared::minwindef::ULONG;
24 use winapi::shared::mstcpip::{tcp_keepalive, SIO_KEEPALIVE_VALS};
25 use winapi::shared::ntdef::HANDLE;
26 use winapi::shared::ws2def;
27 use winapi::shared::ws2def::WSABUF;
28 use winapi::um::handleapi::SetHandleInformation;
29 use winapi::um::processthreadsapi::GetCurrentProcessId;
30 use winapi::um::winbase::{self, INFINITE};
31 use winapi::um::winsock2::{
32     self as sock, u_long, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND,
33     WSAPOLLFD,
34 };
35 
36 use crate::{RecvFlags, SockAddr, TcpKeepalive, Type};
37 
38 pub(crate) use winapi::ctypes::c_int;
39 
40 /// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
41 ///
42 /// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
43 /// value of the flag is defined by us.
44 pub(crate) const MSG_TRUNC: c_int = 0x01;
45 
46 // Used in `Domain`.
47 pub(crate) use winapi::shared::ws2def::{AF_INET, AF_INET6};
48 // Used in `Type`.
49 pub(crate) use winapi::shared::ws2def::{SOCK_DGRAM, SOCK_STREAM};
50 #[cfg(feature = "all")]
51 pub(crate) use winapi::shared::ws2def::{SOCK_RAW, SOCK_SEQPACKET};
52 // Used in `Protocol`.
53 pub(crate) const IPPROTO_ICMP: c_int = winapi::shared::ws2def::IPPROTO_ICMP as c_int;
54 pub(crate) const IPPROTO_ICMPV6: c_int = winapi::shared::ws2def::IPPROTO_ICMPV6 as c_int;
55 pub(crate) const IPPROTO_TCP: c_int = winapi::shared::ws2def::IPPROTO_TCP as c_int;
56 pub(crate) const IPPROTO_UDP: c_int = winapi::shared::ws2def::IPPROTO_UDP as c_int;
57 // Used in `SockAddr`.
58 pub(crate) use winapi::shared::ws2def::{
59     ADDRESS_FAMILY as sa_family_t, SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in,
60     SOCKADDR_STORAGE as sockaddr_storage,
61 };
62 pub(crate) use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH as sockaddr_in6;
63 pub(crate) use winapi::um::ws2tcpip::socklen_t;
64 // Used in `Socket`.
65 pub(crate) use winapi::shared::ws2def::{
66     IPPROTO_IP, SOL_SOCKET, SO_BROADCAST, SO_ERROR, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE,
67     SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF, SO_SNDTIMEO, SO_TYPE, TCP_NODELAY,
68 };
69 pub(crate) use winapi::shared::ws2ipdef::{
70     IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq, IPV6_MULTICAST_HOPS,
71     IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP,
72     IP_DROP_MEMBERSHIP, IP_MREQ as IpMreq, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
73     IP_TOS, IP_TTL,
74 };
75 pub(crate) use winapi::um::winsock2::{linger, MSG_OOB, MSG_PEEK};
76 pub(crate) const IPPROTO_IPV6: c_int = winapi::shared::ws2def::IPPROTO_IPV6 as c_int;
77 
78 /// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
79 ///
80 /// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
81 /// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
82 /// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
83 /// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
84 /// `getsockopt`.
85 pub(crate) type Bool = winapi::shared::ntdef::BOOLEAN;
86 
87 /// Maximum size of a buffer passed to system call like `recv` and `send`.
88 const MAX_BUF_LEN: usize = <c_int>::max_value() as usize;
89 
90 /// Helper macro to execute a system call that returns an `io::Result`.
91 macro_rules! syscall {
92     ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
93         #[allow(unused_unsafe)]
94         let res = unsafe { sock::$fn($($arg, )*) };
95         if $err_test(&res, &$err_value) {
96             Err(io::Error::last_os_error())
97         } else {
98             Ok(res)
99         }
100     }};
101 }
102 
103 impl_debug!(
104     crate::Domain,
105     ws2def::AF_INET,
106     ws2def::AF_INET6,
107     ws2def::AF_UNIX,
108     ws2def::AF_UNSPEC, // = 0.
109 );
110 
111 /// Windows only API.
112 impl Type {
113     /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
114     /// Trying to mimic `Type::cloexec` on windows.
115     const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
116 
117     /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
118     #[cfg(feature = "all")]
119     #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
no_inherit(self) -> Type120     pub const fn no_inherit(self) -> Type {
121         self._no_inherit()
122     }
123 
_no_inherit(self) -> Type124     pub(crate) const fn _no_inherit(self) -> Type {
125         Type(self.0 | Type::NO_INHERIT)
126     }
127 }
128 
129 impl_debug!(
130     crate::Type,
131     ws2def::SOCK_STREAM,
132     ws2def::SOCK_DGRAM,
133     ws2def::SOCK_RAW,
134     ws2def::SOCK_RDM,
135     ws2def::SOCK_SEQPACKET,
136 );
137 
138 impl_debug!(
139     crate::Protocol,
140     self::IPPROTO_ICMP,
141     self::IPPROTO_ICMPV6,
142     self::IPPROTO_TCP,
143     self::IPPROTO_UDP,
144 );
145 
146 impl std::fmt::Debug for RecvFlags {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result147     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148         f.debug_struct("RecvFlags")
149             .field("is_truncated", &self.is_truncated())
150             .finish()
151     }
152 }
153 
154 #[repr(transparent)]
155 pub struct MaybeUninitSlice<'a> {
156     vec: WSABUF,
157     _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
158 }
159 
160 impl<'a> MaybeUninitSlice<'a> {
new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a>161     pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
162         assert!(buf.len() <= ULONG::MAX as usize);
163         MaybeUninitSlice {
164             vec: WSABUF {
165                 len: buf.len() as ULONG,
166                 buf: buf.as_mut_ptr().cast(),
167             },
168             _lifetime: PhantomData,
169         }
170     }
171 
as_slice(&self) -> &[MaybeUninit<u8>]172     pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
173         unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
174     }
175 
as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>]176     pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
177         unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
178     }
179 }
180 
init()181 fn init() {
182     static INIT: Once = Once::new();
183 
184     INIT.call_once(|| {
185         // Initialize winsock through the standard library by just creating a
186         // dummy socket. Whether this is successful or not we drop the result as
187         // libstd will be sure to have initialized winsock.
188         let _ = net::UdpSocket::bind("127.0.0.1:34254");
189     });
190 }
191 
192 pub(crate) type Socket = sock::SOCKET;
193 
socket_from_raw(socket: Socket) -> crate::socket::Inner194 pub(crate) unsafe fn socket_from_raw(socket: Socket) -> crate::socket::Inner {
195     crate::socket::Inner::from_raw_socket(socket as RawSocket)
196 }
197 
socket_as_raw(socket: &crate::socket::Inner) -> Socket198 pub(crate) fn socket_as_raw(socket: &crate::socket::Inner) -> Socket {
199     socket.as_raw_socket() as Socket
200 }
201 
socket_into_raw(socket: crate::socket::Inner) -> Socket202 pub(crate) fn socket_into_raw(socket: crate::socket::Inner) -> Socket {
203     socket.into_raw_socket() as Socket
204 }
205 
socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket>206 pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
207     init();
208 
209     // Check if we set our custom flag.
210     let flags = if ty & Type::NO_INHERIT != 0 {
211         ty = ty & !Type::NO_INHERIT;
212         sock::WSA_FLAG_NO_HANDLE_INHERIT
213     } else {
214         0
215     };
216 
217     syscall!(
218         WSASocketW(
219             family,
220             ty,
221             protocol,
222             ptr::null_mut(),
223             0,
224             sock::WSA_FLAG_OVERLAPPED | flags,
225         ),
226         PartialEq::eq,
227         sock::INVALID_SOCKET
228     )
229 }
230 
bind(socket: Socket, addr: &SockAddr) -> io::Result<()>231 pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
232     syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
233 }
234 
connect(socket: Socket, addr: &SockAddr) -> io::Result<()>235 pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
236     syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
237 }
238 
poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()>239 pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
240     let start = Instant::now();
241 
242     let mut fd_array = WSAPOLLFD {
243         fd: socket.as_raw(),
244         events: POLLRDNORM | POLLWRNORM,
245         revents: 0,
246     };
247 
248     loop {
249         let elapsed = start.elapsed();
250         if elapsed >= timeout {
251             return Err(io::ErrorKind::TimedOut.into());
252         }
253 
254         let timeout = (timeout - elapsed).as_millis();
255         let timeout = clamp(timeout, 1, c_int::max_value() as u128) as c_int;
256 
257         match syscall!(
258             WSAPoll(&mut fd_array, 1, timeout),
259             PartialEq::eq,
260             sock::SOCKET_ERROR
261         ) {
262             Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
263             Ok(_) => {
264                 // Error or hang up indicates an error (or failure to connect).
265                 if (fd_array.revents & POLLERR) != 0 || (fd_array.revents & POLLHUP) != 0 {
266                     match socket.take_error() {
267                         Ok(Some(err)) => return Err(err),
268                         Ok(None) => {
269                             return Err(io::Error::new(
270                                 io::ErrorKind::Other,
271                                 "no error set after POLLHUP",
272                             ))
273                         }
274                         Err(err) => return Err(err),
275                     }
276                 }
277                 return Ok(());
278             }
279             // Got interrupted, try again.
280             Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
281             Err(err) => return Err(err),
282         }
283     }
284 }
285 
286 // TODO: use clamp from std lib, stable since 1.50.
clamp<T>(value: T, min: T, max: T) -> T where T: Ord,287 fn clamp<T>(value: T, min: T, max: T) -> T
288 where
289     T: Ord,
290 {
291     if value <= min {
292         min
293     } else if value >= max {
294         max
295     } else {
296         value
297     }
298 }
299 
listen(socket: Socket, backlog: c_int) -> io::Result<()>300 pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
301     syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
302 }
303 
accept(socket: Socket) -> io::Result<(Socket, SockAddr)>304 pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
305     // Safety: `accept` initialises the `SockAddr` for us.
306     unsafe {
307         SockAddr::init(|storage, len| {
308             syscall!(
309                 accept(socket, storage.cast(), len),
310                 PartialEq::eq,
311                 sock::INVALID_SOCKET
312             )
313         })
314     }
315 }
316 
getsockname(socket: Socket) -> io::Result<SockAddr>317 pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
318     // Safety: `getsockname` initialises the `SockAddr` for us.
319     unsafe {
320         SockAddr::init(|storage, len| {
321             syscall!(
322                 getsockname(socket, storage.cast(), len),
323                 PartialEq::eq,
324                 sock::SOCKET_ERROR
325             )
326         })
327     }
328     .map(|(_, addr)| addr)
329 }
330 
getpeername(socket: Socket) -> io::Result<SockAddr>331 pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
332     // Safety: `getpeername` initialises the `SockAddr` for us.
333     unsafe {
334         SockAddr::init(|storage, len| {
335             syscall!(
336                 getpeername(socket, storage.cast(), len),
337                 PartialEq::eq,
338                 sock::SOCKET_ERROR
339             )
340         })
341     }
342     .map(|(_, addr)| addr)
343 }
344 
try_clone(socket: Socket) -> io::Result<Socket>345 pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
346     let mut info: MaybeUninit<sock::WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
347     syscall!(
348         WSADuplicateSocketW(socket, GetCurrentProcessId(), info.as_mut_ptr()),
349         PartialEq::eq,
350         sock::SOCKET_ERROR
351     )?;
352     // Safety: `WSADuplicateSocketW` intialised `info` for us.
353     let mut info = unsafe { info.assume_init() };
354 
355     syscall!(
356         WSASocketW(
357             info.iAddressFamily,
358             info.iSocketType,
359             info.iProtocol,
360             &mut info,
361             0,
362             sock::WSA_FLAG_OVERLAPPED | sock::WSA_FLAG_NO_HANDLE_INHERIT,
363         ),
364         PartialEq::eq,
365         sock::INVALID_SOCKET
366     )
367 }
368 
set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()>369 pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
370     let mut nonblocking = nonblocking as u_long;
371     ioctlsocket(socket, sock::FIONBIO, &mut nonblocking)
372 }
373 
shutdown(socket: Socket, how: Shutdown) -> io::Result<()>374 pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
375     let how = match how {
376         Shutdown::Write => SD_SEND,
377         Shutdown::Read => SD_RECEIVE,
378         Shutdown::Both => SD_BOTH,
379     };
380     syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ())
381 }
382 
recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize>383 pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
384     let res = syscall!(
385         recv(
386             socket,
387             buf.as_mut_ptr().cast(),
388             min(buf.len(), MAX_BUF_LEN) as c_int,
389             flags,
390         ),
391         PartialEq::eq,
392         sock::SOCKET_ERROR
393     );
394     match res {
395         Ok(n) => Ok(n as usize),
396         Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
397         Err(err) => Err(err),
398     }
399 }
400 
recv_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)>401 pub(crate) fn recv_vectored(
402     socket: Socket,
403     bufs: &mut [crate::MaybeUninitSlice<'_>],
404     flags: c_int,
405 ) -> io::Result<(usize, RecvFlags)> {
406     let mut nread = 0;
407     let mut flags = flags as DWORD;
408     let res = syscall!(
409         WSARecv(
410             socket,
411             bufs.as_mut_ptr().cast(),
412             min(bufs.len(), DWORD::max_value() as usize) as DWORD,
413             &mut nread,
414             &mut flags,
415             ptr::null_mut(),
416             None,
417         ),
418         PartialEq::eq,
419         sock::SOCKET_ERROR
420     );
421     match res {
422         Ok(_) => Ok((nread as usize, RecvFlags(0))),
423         Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
424             Ok((0, RecvFlags(0)))
425         }
426         Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
427             Ok((nread as usize, RecvFlags(MSG_TRUNC)))
428         }
429         Err(err) => Err(err),
430     }
431 }
432 
recv_from( socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int, ) -> io::Result<(usize, SockAddr)>433 pub(crate) fn recv_from(
434     socket: Socket,
435     buf: &mut [MaybeUninit<u8>],
436     flags: c_int,
437 ) -> io::Result<(usize, SockAddr)> {
438     // Safety: `recvfrom` initialises the `SockAddr` for us.
439     unsafe {
440         SockAddr::init(|storage, addrlen| {
441             let res = syscall!(
442                 recvfrom(
443                     socket,
444                     buf.as_mut_ptr().cast(),
445                     min(buf.len(), MAX_BUF_LEN) as c_int,
446                     flags,
447                     storage.cast(),
448                     addrlen,
449                 ),
450                 PartialEq::eq,
451                 sock::SOCKET_ERROR
452             );
453             match res {
454                 Ok(n) => Ok(n as usize),
455                 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
456                 Err(err) => Err(err),
457             }
458         })
459     }
460 }
461 
recv_from_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags, SockAddr)>462 pub(crate) fn recv_from_vectored(
463     socket: Socket,
464     bufs: &mut [crate::MaybeUninitSlice<'_>],
465     flags: c_int,
466 ) -> io::Result<(usize, RecvFlags, SockAddr)> {
467     // Safety: `recvfrom` initialises the `SockAddr` for us.
468     unsafe {
469         SockAddr::init(|storage, addrlen| {
470             let mut nread = 0;
471             let mut flags = flags as DWORD;
472             let res = syscall!(
473                 WSARecvFrom(
474                     socket,
475                     bufs.as_mut_ptr().cast(),
476                     min(bufs.len(), DWORD::max_value() as usize) as DWORD,
477                     &mut nread,
478                     &mut flags,
479                     storage.cast(),
480                     addrlen,
481                     ptr::null_mut(),
482                     None,
483                 ),
484                 PartialEq::eq,
485                 sock::SOCKET_ERROR
486             );
487             match res {
488                 Ok(_) => Ok((nread as usize, RecvFlags(0))),
489                 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
490                     Ok((nread as usize, RecvFlags(0)))
491                 }
492                 Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
493                     Ok((nread as usize, RecvFlags(MSG_TRUNC)))
494                 }
495                 Err(err) => Err(err),
496             }
497         })
498     }
499     .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
500 }
501 
send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize>502 pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
503     syscall!(
504         send(
505             socket,
506             buf.as_ptr().cast(),
507             min(buf.len(), MAX_BUF_LEN) as c_int,
508             flags,
509         ),
510         PartialEq::eq,
511         sock::SOCKET_ERROR
512     )
513     .map(|n| n as usize)
514 }
515 
send_vectored( socket: Socket, bufs: &[IoSlice<'_>], flags: c_int, ) -> io::Result<usize>516 pub(crate) fn send_vectored(
517     socket: Socket,
518     bufs: &[IoSlice<'_>],
519     flags: c_int,
520 ) -> io::Result<usize> {
521     let mut nsent = 0;
522     syscall!(
523         WSASend(
524             socket,
525             // FIXME: From the `WSASend` docs [1]:
526             // > For a Winsock application, once the WSASend function is called,
527             // > the system owns these buffers and the application may not
528             // > access them.
529             //
530             // So what we're doing is actually UB as `bufs` needs to be `&mut
531             // [IoSlice<'_>]`.
532             //
533             // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
534             //
535             // NOTE: `send_to_vectored` has the same problem.
536             //
537             // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
538             bufs.as_ptr() as *mut _,
539             min(bufs.len(), DWORD::max_value() as usize) as DWORD,
540             &mut nsent,
541             flags as DWORD,
542             std::ptr::null_mut(),
543             None,
544         ),
545         PartialEq::eq,
546         sock::SOCKET_ERROR
547     )
548     .map(|_| nsent as usize)
549 }
550 
send_to( socket: Socket, buf: &[u8], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>551 pub(crate) fn send_to(
552     socket: Socket,
553     buf: &[u8],
554     addr: &SockAddr,
555     flags: c_int,
556 ) -> io::Result<usize> {
557     syscall!(
558         sendto(
559             socket,
560             buf.as_ptr().cast(),
561             min(buf.len(), MAX_BUF_LEN) as c_int,
562             flags,
563             addr.as_ptr(),
564             addr.len(),
565         ),
566         PartialEq::eq,
567         sock::SOCKET_ERROR
568     )
569     .map(|n| n as usize)
570 }
571 
send_to_vectored( socket: Socket, bufs: &[IoSlice<'_>], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>572 pub(crate) fn send_to_vectored(
573     socket: Socket,
574     bufs: &[IoSlice<'_>],
575     addr: &SockAddr,
576     flags: c_int,
577 ) -> io::Result<usize> {
578     let mut nsent = 0;
579     syscall!(
580         WSASendTo(
581             socket,
582             // FIXME: Same problem as in `send_vectored`.
583             bufs.as_ptr() as *mut _,
584             bufs.len().min(DWORD::MAX as usize) as DWORD,
585             &mut nsent,
586             flags as DWORD,
587             addr.as_ptr(),
588             addr.len(),
589             ptr::null_mut(),
590             None,
591         ),
592         PartialEq::eq,
593         sock::SOCKET_ERROR
594     )
595     .map(|_| nsent as usize)
596 }
597 
598 /// Wrapper around `getsockopt` to deal with platform specific timeouts.
timeout_opt(fd: Socket, lvl: c_int, name: c_int) -> io::Result<Option<Duration>>599 pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: c_int) -> io::Result<Option<Duration>> {
600     unsafe { getsockopt(fd, lvl, name).map(from_ms) }
601 }
602 
from_ms(duration: DWORD) -> Option<Duration>603 fn from_ms(duration: DWORD) -> Option<Duration> {
604     if duration == 0 {
605         None
606     } else {
607         let secs = duration / 1000;
608         let nsec = (duration % 1000) * 1000000;
609         Some(Duration::new(secs as u64, nsec as u32))
610     }
611 }
612 
613 /// Wrapper around `setsockopt` to deal with platform specific timeouts.
set_timeout_opt( fd: Socket, level: c_int, optname: c_int, duration: Option<Duration>, ) -> io::Result<()>614 pub(crate) fn set_timeout_opt(
615     fd: Socket,
616     level: c_int,
617     optname: c_int,
618     duration: Option<Duration>,
619 ) -> io::Result<()> {
620     let duration = into_ms(duration);
621     unsafe { setsockopt(fd, level, optname, duration) }
622 }
623 
into_ms(duration: Option<Duration>) -> DWORD624 fn into_ms(duration: Option<Duration>) -> DWORD {
625     // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
626     // timeouts in windows APIs are typically u32 milliseconds. To translate, we
627     // have two pieces to take care of:
628     //
629     // * Nanosecond precision is rounded up
630     // * Greater than u32::MAX milliseconds (50 days) is rounded up to
631     //   INFINITE (never time out).
632     duration
633         .map(|duration| min(duration.as_millis(), INFINITE as u128) as DWORD)
634         .unwrap_or(0)
635 }
636 
set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()>637 pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
638     let mut keepalive = tcp_keepalive {
639         onoff: 1,
640         keepalivetime: into_ms(keepalive.time),
641         keepaliveinterval: into_ms(keepalive.interval),
642     };
643     let mut out = 0;
644     syscall!(
645         WSAIoctl(
646             socket,
647             SIO_KEEPALIVE_VALS,
648             &mut keepalive as *mut _ as *mut _,
649             size_of::<tcp_keepalive>() as _,
650             ptr::null_mut(),
651             0,
652             &mut out,
653             ptr::null_mut(),
654             None,
655         ),
656         PartialEq::eq,
657         sock::SOCKET_ERROR
658     )
659     .map(|_| ())
660 }
661 
662 /// Caller must ensure `T` is the correct type for `level` and `optname`.
getsockopt<T>(socket: Socket, level: c_int, optname: c_int) -> io::Result<T>663 pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: c_int) -> io::Result<T> {
664     let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
665     let mut optlen = mem::size_of::<T>() as c_int;
666     syscall!(
667         getsockopt(
668             socket,
669             level,
670             optname,
671             optval.as_mut_ptr().cast(),
672             &mut optlen,
673         ),
674         PartialEq::eq,
675         sock::SOCKET_ERROR
676     )
677     .map(|_| {
678         debug_assert_eq!(optlen as usize, mem::size_of::<T>());
679         // Safety: `getsockopt` initialised `optval` for us.
680         optval.assume_init()
681     })
682 }
683 
684 /// Caller must ensure `T` is the correct type for `level` and `optname`.
setsockopt<T>( socket: Socket, level: c_int, optname: c_int, optval: T, ) -> io::Result<()>685 pub(crate) unsafe fn setsockopt<T>(
686     socket: Socket,
687     level: c_int,
688     optname: c_int,
689     optval: T,
690 ) -> io::Result<()> {
691     syscall!(
692         setsockopt(
693             socket,
694             level,
695             optname,
696             (&optval as *const T).cast(),
697             mem::size_of::<T>() as c_int,
698         ),
699         PartialEq::eq,
700         sock::SOCKET_ERROR
701     )
702     .map(|_| ())
703 }
704 
ioctlsocket(socket: Socket, cmd: c_long, payload: &mut u_long) -> io::Result<()>705 fn ioctlsocket(socket: Socket, cmd: c_long, payload: &mut u_long) -> io::Result<()> {
706     syscall!(
707         ioctlsocket(socket, cmd, payload),
708         PartialEq::eq,
709         sock::SOCKET_ERROR
710     )
711     .map(|_| ())
712 }
713 
to_in_addr(addr: &Ipv4Addr) -> IN_ADDR714 pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
715     let mut s_un: in_addr_S_un = unsafe { mem::zeroed() };
716     // `S_un` is stored as BE on all machines, and the array is in BE order. So
717     // the native endian conversion method is used so that it's never swapped.
718     unsafe { *(s_un.S_addr_mut()) = u32::from_ne_bytes(addr.octets()) };
719     IN_ADDR { S_un: s_un }
720 }
721 
from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr722 pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
723     Ipv4Addr::from(unsafe { *in_addr.S_un.S_addr() }.to_ne_bytes())
724 }
725 
to_in6_addr(addr: &Ipv6Addr) -> in6_addr726 pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> in6_addr {
727     let mut ret_addr: in6_addr_u = unsafe { mem::zeroed() };
728     unsafe { *(ret_addr.Byte_mut()) = addr.octets() };
729     let mut ret: in6_addr = unsafe { mem::zeroed() };
730     ret.u = ret_addr;
731     ret
732 }
733 
from_in6_addr(addr: in6_addr) -> Ipv6Addr734 pub(crate) fn from_in6_addr(addr: in6_addr) -> Ipv6Addr {
735     Ipv6Addr::from(*unsafe { addr.u.Byte() })
736 }
737 
738 /// Windows only API.
739 impl crate::Socket {
740     /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
741     #[cfg(feature = "all")]
742     #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
set_no_inherit(&self, no_inherit: bool) -> io::Result<()>743     pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
744         self._set_no_inherit(no_inherit)
745     }
746 
_set_no_inherit(&self, no_inherit: bool) -> io::Result<()>747     pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
748         // NOTE: can't use `syscall!` because it expects the function in the
749         // `sock::` path.
750         let res = unsafe {
751             SetHandleInformation(
752                 self.as_raw() as HANDLE,
753                 winbase::HANDLE_FLAG_INHERIT,
754                 !no_inherit as _,
755             )
756         };
757         if res == 0 {
758             // Zero means error.
759             Err(io::Error::last_os_error())
760         } else {
761             Ok(())
762         }
763     }
764 }
765 
766 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket767     fn as_raw_socket(&self) -> RawSocket {
768         self.as_raw() as RawSocket
769     }
770 }
771 
772 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket773     fn into_raw_socket(self) -> RawSocket {
774         self.into_raw() as RawSocket
775     }
776 }
777 
778 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket779     unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
780         crate::Socket::from_raw(socket as Socket)
781     }
782 }
783 
784 #[test]
in_addr_convertion()785 fn in_addr_convertion() {
786     let ip = Ipv4Addr::new(127, 0, 0, 1);
787     let raw = to_in_addr(&ip);
788     assert_eq!(unsafe { *raw.S_un.S_addr() }, 127 << 0 | 1 << 24);
789     assert_eq!(from_in_addr(raw), ip);
790 
791     let ip = Ipv4Addr::new(127, 34, 4, 12);
792     let raw = to_in_addr(&ip);
793     assert_eq!(
794         unsafe { *raw.S_un.S_addr() },
795         127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
796     );
797     assert_eq!(from_in_addr(raw), ip);
798 }
799 
800 #[test]
in6_addr_convertion()801 fn in6_addr_convertion() {
802     let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
803     let raw = to_in6_addr(&ip);
804     let want = [
805         0x2000u16.to_be(),
806         1u16.to_be(),
807         2u16.to_be(),
808         3u16.to_be(),
809         4u16.to_be(),
810         5u16.to_be(),
811         6u16.to_be(),
812         7u16.to_be(),
813     ];
814     assert_eq!(unsafe { *raw.u.Word() }, want);
815     assert_eq!(from_in6_addr(raw), ip);
816 }
817