1 // Copyright 2015 The Rust Project Developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 use std::cmp::min;
10 use std::io::{self, IoSlice};
11 use std::marker::PhantomData;
12 use std::mem::{self, size_of, MaybeUninit};
13 use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14 use std::os::windows::prelude::*;
15 use std::sync::Once;
16 use std::time::{Duration, Instant};
17 use std::{ptr, slice};
18 
19 use winapi::ctypes::c_long;
20 use winapi::shared::in6addr::*;
21 use winapi::shared::inaddr::*;
22 use winapi::shared::minwindef::DWORD;
23 use winapi::shared::minwindef::ULONG;
24 use winapi::shared::mstcpip::{tcp_keepalive, SIO_KEEPALIVE_VALS};
25 use winapi::shared::ntdef::HANDLE;
26 use winapi::shared::ws2def;
27 use winapi::shared::ws2def::WSABUF;
28 use winapi::um::handleapi::SetHandleInformation;
29 use winapi::um::processthreadsapi::GetCurrentProcessId;
30 use winapi::um::winbase::{self, INFINITE};
31 use winapi::um::winsock2::{
32     self as sock, u_long, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND,
33     WSAPOLLFD,
34 };
35 
36 use crate::{RecvFlags, SockAddr, TcpKeepalive, Type};
37 
38 pub(crate) use winapi::ctypes::c_int;
39 
40 /// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
41 ///
42 /// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
43 /// value of the flag is defined by us.
44 pub(crate) const MSG_TRUNC: c_int = 0x01;
45 
46 // Used in `Domain`.
47 pub(crate) use winapi::shared::ws2def::{AF_INET, AF_INET6};
48 // Used in `Type`.
49 pub(crate) use winapi::shared::ws2def::{SOCK_DGRAM, SOCK_STREAM};
50 #[cfg(feature = "all")]
51 pub(crate) use winapi::shared::ws2def::{SOCK_RAW, SOCK_SEQPACKET};
52 // Used in `Protocol`.
53 pub(crate) const IPPROTO_ICMP: c_int = winapi::shared::ws2def::IPPROTO_ICMP as c_int;
54 pub(crate) const IPPROTO_ICMPV6: c_int = winapi::shared::ws2def::IPPROTO_ICMPV6 as c_int;
55 pub(crate) const IPPROTO_TCP: c_int = winapi::shared::ws2def::IPPROTO_TCP as c_int;
56 pub(crate) const IPPROTO_UDP: c_int = winapi::shared::ws2def::IPPROTO_UDP as c_int;
57 // Used in `SockAddr`.
58 pub(crate) use winapi::shared::ws2def::{
59     ADDRESS_FAMILY as sa_family_t, SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in,
60     SOCKADDR_STORAGE as sockaddr_storage,
61 };
62 pub(crate) use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH as sockaddr_in6;
63 pub(crate) use winapi::um::ws2tcpip::socklen_t;
64 // Used in `Socket`.
65 pub(crate) use winapi::shared::ws2def::{
66     IPPROTO_IP, SOL_SOCKET, SO_BROADCAST, SO_ERROR, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE,
67     SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF, SO_SNDTIMEO, TCP_NODELAY,
68 };
69 pub(crate) use winapi::shared::ws2ipdef::{
70     IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq, IPV6_MULTICAST_HOPS,
71     IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP,
72     IP_DROP_MEMBERSHIP, IP_MREQ as IpMreq, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
73     IP_TTL,
74 };
75 pub(crate) use winapi::um::winsock2::{linger, MSG_OOB, MSG_PEEK};
76 pub(crate) const IPPROTO_IPV6: c_int = winapi::shared::ws2def::IPPROTO_IPV6 as c_int;
77 
78 /// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
79 ///
80 /// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
81 /// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
82 /// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
83 /// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
84 /// `getsockopt`.
85 pub(crate) type Bool = winapi::shared::ntdef::BOOLEAN;
86 
87 /// Maximum size of a buffer passed to system call like `recv` and `send`.
88 const MAX_BUF_LEN: usize = <c_int>::max_value() as usize;
89 
90 /// Helper macro to execute a system call that returns an `io::Result`.
91 macro_rules! syscall {
92     ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
93         #[allow(unused_unsafe)]
94         let res = unsafe { sock::$fn($($arg, )*) };
95         if $err_test(&res, &$err_value) {
96             Err(io::Error::last_os_error())
97         } else {
98             Ok(res)
99         }
100     }};
101 }
102 
103 impl_debug!(
104     crate::Domain,
105     ws2def::AF_INET,
106     ws2def::AF_INET6,
107     ws2def::AF_UNIX,
108     ws2def::AF_UNSPEC, // = 0.
109 );
110 
111 /// Windows only API.
112 impl Type {
113     /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
114     /// Trying to mimic `Type::cloexec` on windows.
115     const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
116 
117     /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
118     #[cfg(feature = "all")]
no_inherit(self) -> Type119     pub const fn no_inherit(self) -> Type {
120         self._no_inherit()
121     }
122 
_no_inherit(self) -> Type123     pub(crate) const fn _no_inherit(self) -> Type {
124         Type(self.0 | Type::NO_INHERIT)
125     }
126 }
127 
128 impl_debug!(
129     crate::Type,
130     ws2def::SOCK_STREAM,
131     ws2def::SOCK_DGRAM,
132     ws2def::SOCK_RAW,
133     ws2def::SOCK_RDM,
134     ws2def::SOCK_SEQPACKET,
135 );
136 
137 impl_debug!(
138     crate::Protocol,
139     self::IPPROTO_ICMP,
140     self::IPPROTO_ICMPV6,
141     self::IPPROTO_TCP,
142     self::IPPROTO_UDP,
143 );
144 
145 impl std::fmt::Debug for RecvFlags {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result146     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147         f.debug_struct("RecvFlags")
148             .field("is_truncated", &self.is_truncated())
149             .finish()
150     }
151 }
152 
153 #[repr(transparent)]
154 pub struct MaybeUninitSlice<'a> {
155     vec: WSABUF,
156     _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
157 }
158 
159 impl<'a> MaybeUninitSlice<'a> {
new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a>160     pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
161         assert!(buf.len() <= ULONG::MAX as usize);
162         MaybeUninitSlice {
163             vec: WSABUF {
164                 len: buf.len() as ULONG,
165                 buf: buf.as_mut_ptr().cast(),
166             },
167             _lifetime: PhantomData,
168         }
169     }
170 
as_slice(&self) -> &[MaybeUninit<u8>]171     pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
172         unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
173     }
174 
as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>]175     pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
176         unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
177     }
178 }
179 
init()180 fn init() {
181     static INIT: Once = Once::new();
182 
183     INIT.call_once(|| {
184         // Initialize winsock through the standard library by just creating a
185         // dummy socket. Whether this is successful or not we drop the result as
186         // libstd will be sure to have initialized winsock.
187         let _ = net::UdpSocket::bind("127.0.0.1:34254");
188     });
189 }
190 
191 pub(crate) type Socket = sock::SOCKET;
192 
socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket>193 pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
194     init();
195 
196     // Check if we set our custom flag.
197     let flags = if ty & Type::NO_INHERIT != 0 {
198         ty = ty & !Type::NO_INHERIT;
199         sock::WSA_FLAG_NO_HANDLE_INHERIT
200     } else {
201         0
202     };
203 
204     syscall!(
205         WSASocketW(
206             family,
207             ty,
208             protocol,
209             ptr::null_mut(),
210             0,
211             sock::WSA_FLAG_OVERLAPPED | flags,
212         ),
213         PartialEq::eq,
214         sock::INVALID_SOCKET
215     )
216 }
217 
bind(socket: Socket, addr: &SockAddr) -> io::Result<()>218 pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
219     syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
220 }
221 
connect(socket: Socket, addr: &SockAddr) -> io::Result<()>222 pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
223     syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
224 }
225 
poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()>226 pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
227     let start = Instant::now();
228 
229     let mut fd_array = WSAPOLLFD {
230         fd: socket.inner,
231         events: POLLRDNORM | POLLWRNORM,
232         revents: 0,
233     };
234 
235     loop {
236         let elapsed = start.elapsed();
237         if elapsed >= timeout {
238             return Err(io::ErrorKind::TimedOut.into());
239         }
240 
241         let timeout = (timeout - elapsed).as_millis();
242         let timeout = clamp(timeout, 1, c_int::max_value() as u128) as c_int;
243 
244         match syscall!(
245             WSAPoll(&mut fd_array, 1, timeout),
246             PartialEq::eq,
247             sock::SOCKET_ERROR
248         ) {
249             Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
250             Ok(_) => {
251                 // Error or hang up indicates an error (or failure to connect).
252                 if (fd_array.revents & POLLERR) != 0 || (fd_array.revents & POLLHUP) != 0 {
253                     match socket.take_error() {
254                         Ok(Some(err)) => return Err(err),
255                         Ok(None) => {
256                             return Err(io::Error::new(
257                                 io::ErrorKind::Other,
258                                 "no error set after POLLHUP",
259                             ))
260                         }
261                         Err(err) => return Err(err),
262                     }
263                 }
264                 return Ok(());
265             }
266             // Got interrupted, try again.
267             Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
268             Err(err) => return Err(err),
269         }
270     }
271 }
272 
273 // TODO: use clamp from std lib, stable since 1.50.
clamp<T>(value: T, min: T, max: T) -> T where T: Ord,274 fn clamp<T>(value: T, min: T, max: T) -> T
275 where
276     T: Ord,
277 {
278     if value <= min {
279         min
280     } else if value >= max {
281         max
282     } else {
283         value
284     }
285 }
286 
listen(socket: Socket, backlog: c_int) -> io::Result<()>287 pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
288     syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
289 }
290 
accept(socket: Socket) -> io::Result<(Socket, SockAddr)>291 pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
292     // Safety: `accept` initialises the `SockAddr` for us.
293     unsafe {
294         SockAddr::init(|storage, len| {
295             syscall!(
296                 accept(socket, storage.cast(), len),
297                 PartialEq::eq,
298                 sock::INVALID_SOCKET
299             )
300         })
301     }
302 }
303 
getsockname(socket: Socket) -> io::Result<SockAddr>304 pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
305     // Safety: `getsockname` initialises the `SockAddr` for us.
306     unsafe {
307         SockAddr::init(|storage, len| {
308             syscall!(
309                 getsockname(socket, storage.cast(), len),
310                 PartialEq::eq,
311                 sock::SOCKET_ERROR
312             )
313         })
314     }
315     .map(|(_, addr)| addr)
316 }
317 
getpeername(socket: Socket) -> io::Result<SockAddr>318 pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
319     // Safety: `getpeername` initialises the `SockAddr` for us.
320     unsafe {
321         SockAddr::init(|storage, len| {
322             syscall!(
323                 getpeername(socket, storage.cast(), len),
324                 PartialEq::eq,
325                 sock::SOCKET_ERROR
326             )
327         })
328     }
329     .map(|(_, addr)| addr)
330 }
331 
try_clone(socket: Socket) -> io::Result<Socket>332 pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
333     let mut info: MaybeUninit<sock::WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
334     syscall!(
335         WSADuplicateSocketW(socket, GetCurrentProcessId(), info.as_mut_ptr()),
336         PartialEq::eq,
337         sock::SOCKET_ERROR
338     )?;
339     // Safety: `WSADuplicateSocketW` intialised `info` for us.
340     let mut info = unsafe { info.assume_init() };
341 
342     syscall!(
343         WSASocketW(
344             info.iAddressFamily,
345             info.iSocketType,
346             info.iProtocol,
347             &mut info,
348             0,
349             sock::WSA_FLAG_OVERLAPPED | sock::WSA_FLAG_NO_HANDLE_INHERIT,
350         ),
351         PartialEq::eq,
352         sock::INVALID_SOCKET
353     )
354 }
355 
set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()>356 pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
357     let mut nonblocking = nonblocking as u_long;
358     ioctlsocket(socket, sock::FIONBIO, &mut nonblocking)
359 }
360 
shutdown(socket: Socket, how: Shutdown) -> io::Result<()>361 pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
362     let how = match how {
363         Shutdown::Write => SD_SEND,
364         Shutdown::Read => SD_RECEIVE,
365         Shutdown::Both => SD_BOTH,
366     };
367     syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ())
368 }
369 
recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize>370 pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
371     let res = syscall!(
372         recv(
373             socket,
374             buf.as_mut_ptr().cast(),
375             min(buf.len(), MAX_BUF_LEN) as c_int,
376             flags,
377         ),
378         PartialEq::eq,
379         sock::SOCKET_ERROR
380     );
381     match res {
382         Ok(n) => Ok(n as usize),
383         Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
384         Err(err) => Err(err),
385     }
386 }
387 
recv_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)>388 pub(crate) fn recv_vectored(
389     socket: Socket,
390     bufs: &mut [crate::MaybeUninitSlice<'_>],
391     flags: c_int,
392 ) -> io::Result<(usize, RecvFlags)> {
393     let mut nread = 0;
394     let mut flags = flags as DWORD;
395     let res = syscall!(
396         WSARecv(
397             socket,
398             bufs.as_mut_ptr().cast(),
399             min(bufs.len(), DWORD::max_value() as usize) as DWORD,
400             &mut nread,
401             &mut flags,
402             ptr::null_mut(),
403             None,
404         ),
405         PartialEq::eq,
406         sock::SOCKET_ERROR
407     );
408     match res {
409         Ok(_) => Ok((nread as usize, RecvFlags(0))),
410         Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
411             Ok((0, RecvFlags(0)))
412         }
413         Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
414             Ok((nread as usize, RecvFlags(MSG_TRUNC)))
415         }
416         Err(err) => Err(err),
417     }
418 }
419 
recv_from( socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int, ) -> io::Result<(usize, SockAddr)>420 pub(crate) fn recv_from(
421     socket: Socket,
422     buf: &mut [MaybeUninit<u8>],
423     flags: c_int,
424 ) -> io::Result<(usize, SockAddr)> {
425     // Safety: `recvfrom` initialises the `SockAddr` for us.
426     unsafe {
427         SockAddr::init(|storage, addrlen| {
428             let res = syscall!(
429                 recvfrom(
430                     socket,
431                     buf.as_mut_ptr().cast(),
432                     min(buf.len(), MAX_BUF_LEN) as c_int,
433                     flags,
434                     storage.cast(),
435                     addrlen,
436                 ),
437                 PartialEq::eq,
438                 sock::SOCKET_ERROR
439             );
440             match res {
441                 Ok(n) => Ok(n as usize),
442                 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
443                 Err(err) => Err(err),
444             }
445         })
446     }
447 }
448 
recv_from_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags, SockAddr)>449 pub(crate) fn recv_from_vectored(
450     socket: Socket,
451     bufs: &mut [crate::MaybeUninitSlice<'_>],
452     flags: c_int,
453 ) -> io::Result<(usize, RecvFlags, SockAddr)> {
454     // Safety: `recvfrom` initialises the `SockAddr` for us.
455     unsafe {
456         SockAddr::init(|storage, addrlen| {
457             let mut nread = 0;
458             let mut flags = flags as DWORD;
459             let res = syscall!(
460                 WSARecvFrom(
461                     socket,
462                     bufs.as_mut_ptr().cast(),
463                     min(bufs.len(), DWORD::max_value() as usize) as DWORD,
464                     &mut nread,
465                     &mut flags,
466                     storage.cast(),
467                     addrlen,
468                     ptr::null_mut(),
469                     None,
470                 ),
471                 PartialEq::eq,
472                 sock::SOCKET_ERROR
473             );
474             match res {
475                 Ok(_) => Ok((nread as usize, RecvFlags(0))),
476                 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
477                     Ok((nread as usize, RecvFlags(0)))
478                 }
479                 Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
480                     Ok((nread as usize, RecvFlags(MSG_TRUNC)))
481                 }
482                 Err(err) => Err(err),
483             }
484         })
485     }
486     .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
487 }
488 
send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize>489 pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
490     syscall!(
491         send(
492             socket,
493             buf.as_ptr().cast(),
494             min(buf.len(), MAX_BUF_LEN) as c_int,
495             flags,
496         ),
497         PartialEq::eq,
498         sock::SOCKET_ERROR
499     )
500     .map(|n| n as usize)
501 }
502 
send_vectored( socket: Socket, bufs: &[IoSlice<'_>], flags: c_int, ) -> io::Result<usize>503 pub(crate) fn send_vectored(
504     socket: Socket,
505     bufs: &[IoSlice<'_>],
506     flags: c_int,
507 ) -> io::Result<usize> {
508     let mut nsent = 0;
509     syscall!(
510         WSASend(
511             socket,
512             // FIXME: From the `WSASend` docs [1]:
513             // > For a Winsock application, once the WSASend function is called,
514             // > the system owns these buffers and the application may not
515             // > access them.
516             //
517             // So what we're doing is actually UB as `bufs` needs to be `&mut
518             // [IoSlice<'_>]`.
519             //
520             // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
521             //
522             // NOTE: `send_to_vectored` has the same problem.
523             //
524             // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
525             bufs.as_ptr() as *mut _,
526             min(bufs.len(), DWORD::max_value() as usize) as DWORD,
527             &mut nsent,
528             flags as DWORD,
529             std::ptr::null_mut(),
530             None,
531         ),
532         PartialEq::eq,
533         sock::SOCKET_ERROR
534     )
535     .map(|_| nsent as usize)
536 }
537 
send_to( socket: Socket, buf: &[u8], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>538 pub(crate) fn send_to(
539     socket: Socket,
540     buf: &[u8],
541     addr: &SockAddr,
542     flags: c_int,
543 ) -> io::Result<usize> {
544     syscall!(
545         sendto(
546             socket,
547             buf.as_ptr().cast(),
548             min(buf.len(), MAX_BUF_LEN) as c_int,
549             flags,
550             addr.as_ptr(),
551             addr.len(),
552         ),
553         PartialEq::eq,
554         sock::SOCKET_ERROR
555     )
556     .map(|n| n as usize)
557 }
558 
send_to_vectored( socket: Socket, bufs: &[IoSlice<'_>], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>559 pub(crate) fn send_to_vectored(
560     socket: Socket,
561     bufs: &[IoSlice<'_>],
562     addr: &SockAddr,
563     flags: c_int,
564 ) -> io::Result<usize> {
565     let mut nsent = 0;
566     syscall!(
567         WSASendTo(
568             socket,
569             // FIXME: Same problem as in `send_vectored`.
570             bufs.as_ptr() as *mut _,
571             bufs.len().min(DWORD::MAX as usize) as DWORD,
572             &mut nsent,
573             flags as DWORD,
574             addr.as_ptr(),
575             addr.len(),
576             ptr::null_mut(),
577             None,
578         ),
579         PartialEq::eq,
580         sock::SOCKET_ERROR
581     )
582     .map(|_| nsent as usize)
583 }
584 
585 /// Wrapper around `getsockopt` to deal with platform specific timeouts.
timeout_opt(fd: Socket, lvl: c_int, name: c_int) -> io::Result<Option<Duration>>586 pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: c_int) -> io::Result<Option<Duration>> {
587     unsafe { getsockopt(fd, lvl, name).map(from_ms) }
588 }
589 
from_ms(duration: DWORD) -> Option<Duration>590 fn from_ms(duration: DWORD) -> Option<Duration> {
591     if duration == 0 {
592         None
593     } else {
594         let secs = duration / 1000;
595         let nsec = (duration % 1000) * 1000000;
596         Some(Duration::new(secs as u64, nsec as u32))
597     }
598 }
599 
600 /// Wrapper around `setsockopt` to deal with platform specific timeouts.
set_timeout_opt( fd: Socket, level: c_int, optname: c_int, duration: Option<Duration>, ) -> io::Result<()>601 pub(crate) fn set_timeout_opt(
602     fd: Socket,
603     level: c_int,
604     optname: c_int,
605     duration: Option<Duration>,
606 ) -> io::Result<()> {
607     let duration = into_ms(duration);
608     unsafe { setsockopt(fd, level, optname, duration) }
609 }
610 
into_ms(duration: Option<Duration>) -> DWORD611 fn into_ms(duration: Option<Duration>) -> DWORD {
612     // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
613     // timeouts in windows APIs are typically u32 milliseconds. To translate, we
614     // have two pieces to take care of:
615     //
616     // * Nanosecond precision is rounded up
617     // * Greater than u32::MAX milliseconds (50 days) is rounded up to
618     //   INFINITE (never time out).
619     duration
620         .map(|duration| min(duration.as_millis(), INFINITE as u128) as DWORD)
621         .unwrap_or(0)
622 }
623 
set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()>624 pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
625     let mut keepalive = tcp_keepalive {
626         onoff: 1,
627         keepalivetime: into_ms(keepalive.time),
628         keepaliveinterval: into_ms(keepalive.interval),
629     };
630     let mut out = 0;
631     syscall!(
632         WSAIoctl(
633             socket,
634             SIO_KEEPALIVE_VALS,
635             &mut keepalive as *mut _ as *mut _,
636             size_of::<tcp_keepalive>() as _,
637             ptr::null_mut(),
638             0,
639             &mut out,
640             ptr::null_mut(),
641             None,
642         ),
643         PartialEq::eq,
644         sock::SOCKET_ERROR
645     )
646     .map(|_| ())
647 }
648 
649 /// Caller must ensure `T` is the correct type for `level` and `optname`.
getsockopt<T>(socket: Socket, level: c_int, optname: c_int) -> io::Result<T>650 pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: c_int) -> io::Result<T> {
651     let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
652     let mut optlen = mem::size_of::<T>() as c_int;
653     syscall!(
654         getsockopt(
655             socket,
656             level,
657             optname,
658             optval.as_mut_ptr().cast(),
659             &mut optlen,
660         ),
661         PartialEq::eq,
662         sock::SOCKET_ERROR
663     )
664     .map(|_| {
665         debug_assert_eq!(optlen as usize, mem::size_of::<T>());
666         // Safety: `getsockopt` initialised `optval` for us.
667         optval.assume_init()
668     })
669 }
670 
671 /// Caller must ensure `T` is the correct type for `level` and `optname`.
setsockopt<T>( socket: Socket, level: c_int, optname: c_int, optval: T, ) -> io::Result<()>672 pub(crate) unsafe fn setsockopt<T>(
673     socket: Socket,
674     level: c_int,
675     optname: c_int,
676     optval: T,
677 ) -> io::Result<()> {
678     syscall!(
679         setsockopt(
680             socket,
681             level,
682             optname,
683             (&optval as *const T).cast(),
684             mem::size_of::<T>() as c_int,
685         ),
686         PartialEq::eq,
687         sock::SOCKET_ERROR
688     )
689     .map(|_| ())
690 }
691 
ioctlsocket(socket: Socket, cmd: c_long, payload: &mut u_long) -> io::Result<()>692 fn ioctlsocket(socket: Socket, cmd: c_long, payload: &mut u_long) -> io::Result<()> {
693     syscall!(
694         ioctlsocket(socket, cmd, payload),
695         PartialEq::eq,
696         sock::SOCKET_ERROR
697     )
698     .map(|_| ())
699 }
700 
close(socket: Socket)701 pub(crate) fn close(socket: Socket) {
702     unsafe {
703         let _ = sock::closesocket(socket);
704     }
705 }
706 
to_in_addr(addr: &Ipv4Addr) -> IN_ADDR707 pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
708     let mut s_un: in_addr_S_un = unsafe { mem::zeroed() };
709     // `S_un` is stored as BE on all machines, and the array is in BE order. So
710     // the native endian conversion method is used so that it's never swapped.
711     unsafe { *(s_un.S_addr_mut()) = u32::from_ne_bytes(addr.octets()) };
712     IN_ADDR { S_un: s_un }
713 }
714 
from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr715 pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
716     Ipv4Addr::from(unsafe { *in_addr.S_un.S_addr() }.to_ne_bytes())
717 }
718 
to_in6_addr(addr: &Ipv6Addr) -> in6_addr719 pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> in6_addr {
720     let mut ret_addr: in6_addr_u = unsafe { mem::zeroed() };
721     unsafe { *(ret_addr.Byte_mut()) = addr.octets() };
722     let mut ret: in6_addr = unsafe { mem::zeroed() };
723     ret.u = ret_addr;
724     ret
725 }
726 
from_in6_addr(addr: in6_addr) -> Ipv6Addr727 pub(crate) fn from_in6_addr(addr: in6_addr) -> Ipv6Addr {
728     Ipv6Addr::from(*unsafe { addr.u.Byte() })
729 }
730 
731 /// Windows only API.
732 impl crate::Socket {
733     /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
734     #[cfg(feature = "all")]
set_no_inherit(&self, no_inherit: bool) -> io::Result<()>735     pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
736         self._set_no_inherit(no_inherit)
737     }
738 
_set_no_inherit(&self, no_inherit: bool) -> io::Result<()>739     pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
740         // NOTE: can't use `syscall!` because it expects the function in the
741         // `sock::` path.
742         let res = unsafe {
743             SetHandleInformation(
744                 self.inner as HANDLE,
745                 winbase::HANDLE_FLAG_INHERIT,
746                 !no_inherit as _,
747             )
748         };
749         if res == 0 {
750             // Zero means error.
751             Err(io::Error::last_os_error())
752         } else {
753             Ok(())
754         }
755     }
756 }
757 
758 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket759     fn as_raw_socket(&self) -> RawSocket {
760         self.inner as RawSocket
761     }
762 }
763 
764 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket765     fn into_raw_socket(self) -> RawSocket {
766         let socket = self.inner;
767         mem::forget(self);
768         socket as RawSocket
769     }
770 }
771 
772 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket773     unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
774         crate::Socket {
775             inner: socket as Socket,
776         }
777     }
778 }
779 
780 #[test]
in_addr_convertion()781 fn in_addr_convertion() {
782     let ip = Ipv4Addr::new(127, 0, 0, 1);
783     let raw = to_in_addr(&ip);
784     assert_eq!(unsafe { *raw.S_un.S_addr() }, 127 << 0 | 1 << 24);
785     assert_eq!(from_in_addr(raw), ip);
786 
787     let ip = Ipv4Addr::new(127, 34, 4, 12);
788     let raw = to_in_addr(&ip);
789     assert_eq!(
790         unsafe { *raw.S_un.S_addr() },
791         127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
792     );
793     assert_eq!(from_in_addr(raw), ip);
794 }
795 
796 #[test]
in6_addr_convertion()797 fn in6_addr_convertion() {
798     let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
799     let raw = to_in6_addr(&ip);
800     let want = [
801         0x2000u16.to_be(),
802         1u16.to_be(),
803         2u16.to_be(),
804         3u16.to_be(),
805         4u16.to_be(),
806         5u16.to_be(),
807         6u16.to_be(),
808         7u16.to_be(),
809     ];
810     assert_eq!(unsafe { *raw.u.Word() }, want);
811     assert_eq!(from_in6_addr(raw), ip);
812 }
813