1 // Copyright 2015 The Rust Project Developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 use std::cmp::min;
10 use std::io::{self, IoSlice};
11 use std::marker::PhantomData;
12 use std::mem::{self, size_of, MaybeUninit};
13 use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
14 use std::os::windows::prelude::*;
15 use std::sync::Once;
16 use std::time::{Duration, Instant};
17 use std::{ptr, slice};
18
19 use winapi::ctypes::c_long;
20 use winapi::shared::in6addr::*;
21 use winapi::shared::inaddr::*;
22 use winapi::shared::minwindef::DWORD;
23 use winapi::shared::minwindef::ULONG;
24 use winapi::shared::mstcpip::{tcp_keepalive, SIO_KEEPALIVE_VALS};
25 use winapi::shared::ntdef::HANDLE;
26 use winapi::shared::ws2def;
27 use winapi::shared::ws2def::WSABUF;
28 use winapi::um::handleapi::SetHandleInformation;
29 use winapi::um::processthreadsapi::GetCurrentProcessId;
30 use winapi::um::winbase::{self, INFINITE};
31 use winapi::um::winsock2::{
32 self as sock, u_long, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND,
33 WSAPOLLFD,
34 };
35
36 use crate::{RecvFlags, SockAddr, TcpKeepalive, Type};
37
38 pub(crate) use winapi::ctypes::c_int;
39
40 /// Fake MSG_TRUNC flag for the [`RecvFlags`] struct.
41 ///
42 /// The flag is enabled when a `WSARecv[From]` call returns `WSAEMSGSIZE`. The
43 /// value of the flag is defined by us.
44 pub(crate) const MSG_TRUNC: c_int = 0x01;
45
46 // Used in `Domain`.
47 pub(crate) use winapi::shared::ws2def::{AF_INET, AF_INET6};
48 // Used in `Type`.
49 pub(crate) use winapi::shared::ws2def::{SOCK_DGRAM, SOCK_STREAM};
50 #[cfg(feature = "all")]
51 pub(crate) use winapi::shared::ws2def::{SOCK_RAW, SOCK_SEQPACKET};
52 // Used in `Protocol`.
53 pub(crate) const IPPROTO_ICMP: c_int = winapi::shared::ws2def::IPPROTO_ICMP as c_int;
54 pub(crate) const IPPROTO_ICMPV6: c_int = winapi::shared::ws2def::IPPROTO_ICMPV6 as c_int;
55 pub(crate) const IPPROTO_TCP: c_int = winapi::shared::ws2def::IPPROTO_TCP as c_int;
56 pub(crate) const IPPROTO_UDP: c_int = winapi::shared::ws2def::IPPROTO_UDP as c_int;
57 // Used in `SockAddr`.
58 pub(crate) use winapi::shared::ws2def::{
59 ADDRESS_FAMILY as sa_family_t, SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in,
60 SOCKADDR_STORAGE as sockaddr_storage,
61 };
62 pub(crate) use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH as sockaddr_in6;
63 pub(crate) use winapi::um::ws2tcpip::socklen_t;
64 // Used in `Socket`.
65 pub(crate) use winapi::shared::ws2def::{
66 IPPROTO_IP, SOL_SOCKET, SO_BROADCAST, SO_ERROR, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE,
67 SO_RCVBUF, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF, SO_SNDTIMEO, TCP_NODELAY,
68 };
69 pub(crate) use winapi::shared::ws2ipdef::{
70 IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MREQ as Ipv6Mreq, IPV6_MULTICAST_HOPS,
71 IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP,
72 IP_DROP_MEMBERSHIP, IP_MREQ as IpMreq, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL,
73 IP_TTL,
74 };
75 pub(crate) use winapi::um::winsock2::{linger, MSG_OOB, MSG_PEEK};
76 pub(crate) const IPPROTO_IPV6: c_int = winapi::shared::ws2def::IPPROTO_IPV6 as c_int;
77
78 /// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option.
79 ///
80 /// NOTE: <https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-getsockopt>
81 /// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a
82 /// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to
83 /// be false (or misleading) as a `BOOLEAN` (`c_uchar`, 1 byte) is returned by
84 /// `getsockopt`.
85 pub(crate) type Bool = winapi::shared::ntdef::BOOLEAN;
86
87 /// Maximum size of a buffer passed to system call like `recv` and `send`.
88 const MAX_BUF_LEN: usize = <c_int>::max_value() as usize;
89
90 /// Helper macro to execute a system call that returns an `io::Result`.
91 macro_rules! syscall {
92 ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
93 #[allow(unused_unsafe)]
94 let res = unsafe { sock::$fn($($arg, )*) };
95 if $err_test(&res, &$err_value) {
96 Err(io::Error::last_os_error())
97 } else {
98 Ok(res)
99 }
100 }};
101 }
102
103 impl_debug!(
104 crate::Domain,
105 ws2def::AF_INET,
106 ws2def::AF_INET6,
107 ws2def::AF_UNIX,
108 ws2def::AF_UNSPEC, // = 0.
109 );
110
111 /// Windows only API.
112 impl Type {
113 /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
114 /// Trying to mimic `Type::cloexec` on windows.
115 const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
116
117 /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
118 #[cfg(feature = "all")]
no_inherit(self) -> Type119 pub const fn no_inherit(self) -> Type {
120 self._no_inherit()
121 }
122
_no_inherit(self) -> Type123 pub(crate) const fn _no_inherit(self) -> Type {
124 Type(self.0 | Type::NO_INHERIT)
125 }
126 }
127
128 impl_debug!(
129 crate::Type,
130 ws2def::SOCK_STREAM,
131 ws2def::SOCK_DGRAM,
132 ws2def::SOCK_RAW,
133 ws2def::SOCK_RDM,
134 ws2def::SOCK_SEQPACKET,
135 );
136
137 impl_debug!(
138 crate::Protocol,
139 self::IPPROTO_ICMP,
140 self::IPPROTO_ICMPV6,
141 self::IPPROTO_TCP,
142 self::IPPROTO_UDP,
143 );
144
145 impl std::fmt::Debug for RecvFlags {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("RecvFlags")
148 .field("is_truncated", &self.is_truncated())
149 .finish()
150 }
151 }
152
153 #[repr(transparent)]
154 pub struct MaybeUninitSlice<'a> {
155 vec: WSABUF,
156 _lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
157 }
158
159 impl<'a> MaybeUninitSlice<'a> {
new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a>160 pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
161 assert!(buf.len() <= ULONG::MAX as usize);
162 MaybeUninitSlice {
163 vec: WSABUF {
164 len: buf.len() as ULONG,
165 buf: buf.as_mut_ptr().cast(),
166 },
167 _lifetime: PhantomData,
168 }
169 }
170
as_slice(&self) -> &[MaybeUninit<u8>]171 pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
172 unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
173 }
174
as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>]175 pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
176 unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
177 }
178 }
179
init()180 fn init() {
181 static INIT: Once = Once::new();
182
183 INIT.call_once(|| {
184 // Initialize winsock through the standard library by just creating a
185 // dummy socket. Whether this is successful or not we drop the result as
186 // libstd will be sure to have initialized winsock.
187 let _ = net::UdpSocket::bind("127.0.0.1:34254");
188 });
189 }
190
191 pub(crate) type Socket = sock::SOCKET;
192
socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket>193 pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result<Socket> {
194 init();
195
196 // Check if we set our custom flag.
197 let flags = if ty & Type::NO_INHERIT != 0 {
198 ty = ty & !Type::NO_INHERIT;
199 sock::WSA_FLAG_NO_HANDLE_INHERIT
200 } else {
201 0
202 };
203
204 syscall!(
205 WSASocketW(
206 family,
207 ty,
208 protocol,
209 ptr::null_mut(),
210 0,
211 sock::WSA_FLAG_OVERLAPPED | flags,
212 ),
213 PartialEq::eq,
214 sock::INVALID_SOCKET
215 )
216 }
217
bind(socket: Socket, addr: &SockAddr) -> io::Result<()>218 pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
219 syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
220 }
221
connect(socket: Socket, addr: &SockAddr) -> io::Result<()>222 pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
223 syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
224 }
225
poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()>226 pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
227 let start = Instant::now();
228
229 let mut fd_array = WSAPOLLFD {
230 fd: socket.inner,
231 events: POLLRDNORM | POLLWRNORM,
232 revents: 0,
233 };
234
235 loop {
236 let elapsed = start.elapsed();
237 if elapsed >= timeout {
238 return Err(io::ErrorKind::TimedOut.into());
239 }
240
241 let timeout = (timeout - elapsed).as_millis();
242 let timeout = clamp(timeout, 1, c_int::max_value() as u128) as c_int;
243
244 match syscall!(
245 WSAPoll(&mut fd_array, 1, timeout),
246 PartialEq::eq,
247 sock::SOCKET_ERROR
248 ) {
249 Ok(0) => return Err(io::ErrorKind::TimedOut.into()),
250 Ok(_) => {
251 // Error or hang up indicates an error (or failure to connect).
252 if (fd_array.revents & POLLERR) != 0 || (fd_array.revents & POLLHUP) != 0 {
253 match socket.take_error() {
254 Ok(Some(err)) => return Err(err),
255 Ok(None) => {
256 return Err(io::Error::new(
257 io::ErrorKind::Other,
258 "no error set after POLLHUP",
259 ))
260 }
261 Err(err) => return Err(err),
262 }
263 }
264 return Ok(());
265 }
266 // Got interrupted, try again.
267 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
268 Err(err) => return Err(err),
269 }
270 }
271 }
272
273 // TODO: use clamp from std lib, stable since 1.50.
clamp<T>(value: T, min: T, max: T) -> T where T: Ord,274 fn clamp<T>(value: T, min: T, max: T) -> T
275 where
276 T: Ord,
277 {
278 if value <= min {
279 min
280 } else if value >= max {
281 max
282 } else {
283 value
284 }
285 }
286
listen(socket: Socket, backlog: c_int) -> io::Result<()>287 pub(crate) fn listen(socket: Socket, backlog: c_int) -> io::Result<()> {
288 syscall!(listen(socket, backlog), PartialEq::ne, 0).map(|_| ())
289 }
290
accept(socket: Socket) -> io::Result<(Socket, SockAddr)>291 pub(crate) fn accept(socket: Socket) -> io::Result<(Socket, SockAddr)> {
292 // Safety: `accept` initialises the `SockAddr` for us.
293 unsafe {
294 SockAddr::init(|storage, len| {
295 syscall!(
296 accept(socket, storage.cast(), len),
297 PartialEq::eq,
298 sock::INVALID_SOCKET
299 )
300 })
301 }
302 }
303
getsockname(socket: Socket) -> io::Result<SockAddr>304 pub(crate) fn getsockname(socket: Socket) -> io::Result<SockAddr> {
305 // Safety: `getsockname` initialises the `SockAddr` for us.
306 unsafe {
307 SockAddr::init(|storage, len| {
308 syscall!(
309 getsockname(socket, storage.cast(), len),
310 PartialEq::eq,
311 sock::SOCKET_ERROR
312 )
313 })
314 }
315 .map(|(_, addr)| addr)
316 }
317
getpeername(socket: Socket) -> io::Result<SockAddr>318 pub(crate) fn getpeername(socket: Socket) -> io::Result<SockAddr> {
319 // Safety: `getpeername` initialises the `SockAddr` for us.
320 unsafe {
321 SockAddr::init(|storage, len| {
322 syscall!(
323 getpeername(socket, storage.cast(), len),
324 PartialEq::eq,
325 sock::SOCKET_ERROR
326 )
327 })
328 }
329 .map(|(_, addr)| addr)
330 }
331
try_clone(socket: Socket) -> io::Result<Socket>332 pub(crate) fn try_clone(socket: Socket) -> io::Result<Socket> {
333 let mut info: MaybeUninit<sock::WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
334 syscall!(
335 WSADuplicateSocketW(socket, GetCurrentProcessId(), info.as_mut_ptr()),
336 PartialEq::eq,
337 sock::SOCKET_ERROR
338 )?;
339 // Safety: `WSADuplicateSocketW` intialised `info` for us.
340 let mut info = unsafe { info.assume_init() };
341
342 syscall!(
343 WSASocketW(
344 info.iAddressFamily,
345 info.iSocketType,
346 info.iProtocol,
347 &mut info,
348 0,
349 sock::WSA_FLAG_OVERLAPPED | sock::WSA_FLAG_NO_HANDLE_INHERIT,
350 ),
351 PartialEq::eq,
352 sock::INVALID_SOCKET
353 )
354 }
355
set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()>356 pub(crate) fn set_nonblocking(socket: Socket, nonblocking: bool) -> io::Result<()> {
357 let mut nonblocking = nonblocking as u_long;
358 ioctlsocket(socket, sock::FIONBIO, &mut nonblocking)
359 }
360
shutdown(socket: Socket, how: Shutdown) -> io::Result<()>361 pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
362 let how = match how {
363 Shutdown::Write => SD_SEND,
364 Shutdown::Read => SD_RECEIVE,
365 Shutdown::Both => SD_BOTH,
366 };
367 syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ())
368 }
369
recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize>370 pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
371 let res = syscall!(
372 recv(
373 socket,
374 buf.as_mut_ptr().cast(),
375 min(buf.len(), MAX_BUF_LEN) as c_int,
376 flags,
377 ),
378 PartialEq::eq,
379 sock::SOCKET_ERROR
380 );
381 match res {
382 Ok(n) => Ok(n as usize),
383 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
384 Err(err) => Err(err),
385 }
386 }
387
recv_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)>388 pub(crate) fn recv_vectored(
389 socket: Socket,
390 bufs: &mut [crate::MaybeUninitSlice<'_>],
391 flags: c_int,
392 ) -> io::Result<(usize, RecvFlags)> {
393 let mut nread = 0;
394 let mut flags = flags as DWORD;
395 let res = syscall!(
396 WSARecv(
397 socket,
398 bufs.as_mut_ptr().cast(),
399 min(bufs.len(), DWORD::max_value() as usize) as DWORD,
400 &mut nread,
401 &mut flags,
402 ptr::null_mut(),
403 None,
404 ),
405 PartialEq::eq,
406 sock::SOCKET_ERROR
407 );
408 match res {
409 Ok(_) => Ok((nread as usize, RecvFlags(0))),
410 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
411 Ok((0, RecvFlags(0)))
412 }
413 Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
414 Ok((nread as usize, RecvFlags(MSG_TRUNC)))
415 }
416 Err(err) => Err(err),
417 }
418 }
419
recv_from( socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int, ) -> io::Result<(usize, SockAddr)>420 pub(crate) fn recv_from(
421 socket: Socket,
422 buf: &mut [MaybeUninit<u8>],
423 flags: c_int,
424 ) -> io::Result<(usize, SockAddr)> {
425 // Safety: `recvfrom` initialises the `SockAddr` for us.
426 unsafe {
427 SockAddr::init(|storage, addrlen| {
428 let res = syscall!(
429 recvfrom(
430 socket,
431 buf.as_mut_ptr().cast(),
432 min(buf.len(), MAX_BUF_LEN) as c_int,
433 flags,
434 storage.cast(),
435 addrlen,
436 ),
437 PartialEq::eq,
438 sock::SOCKET_ERROR
439 );
440 match res {
441 Ok(n) => Ok(n as usize),
442 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0),
443 Err(err) => Err(err),
444 }
445 })
446 }
447 }
448
recv_from_vectored( socket: Socket, bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags, SockAddr)>449 pub(crate) fn recv_from_vectored(
450 socket: Socket,
451 bufs: &mut [crate::MaybeUninitSlice<'_>],
452 flags: c_int,
453 ) -> io::Result<(usize, RecvFlags, SockAddr)> {
454 // Safety: `recvfrom` initialises the `SockAddr` for us.
455 unsafe {
456 SockAddr::init(|storage, addrlen| {
457 let mut nread = 0;
458 let mut flags = flags as DWORD;
459 let res = syscall!(
460 WSARecvFrom(
461 socket,
462 bufs.as_mut_ptr().cast(),
463 min(bufs.len(), DWORD::max_value() as usize) as DWORD,
464 &mut nread,
465 &mut flags,
466 storage.cast(),
467 addrlen,
468 ptr::null_mut(),
469 None,
470 ),
471 PartialEq::eq,
472 sock::SOCKET_ERROR
473 );
474 match res {
475 Ok(_) => Ok((nread as usize, RecvFlags(0))),
476 Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
477 Ok((nread as usize, RecvFlags(0)))
478 }
479 Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
480 Ok((nread as usize, RecvFlags(MSG_TRUNC)))
481 }
482 Err(err) => Err(err),
483 }
484 })
485 }
486 .map(|((n, recv_flags), addr)| (n, recv_flags, addr))
487 }
488
send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize>489 pub(crate) fn send(socket: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
490 syscall!(
491 send(
492 socket,
493 buf.as_ptr().cast(),
494 min(buf.len(), MAX_BUF_LEN) as c_int,
495 flags,
496 ),
497 PartialEq::eq,
498 sock::SOCKET_ERROR
499 )
500 .map(|n| n as usize)
501 }
502
send_vectored( socket: Socket, bufs: &[IoSlice<'_>], flags: c_int, ) -> io::Result<usize>503 pub(crate) fn send_vectored(
504 socket: Socket,
505 bufs: &[IoSlice<'_>],
506 flags: c_int,
507 ) -> io::Result<usize> {
508 let mut nsent = 0;
509 syscall!(
510 WSASend(
511 socket,
512 // FIXME: From the `WSASend` docs [1]:
513 // > For a Winsock application, once the WSASend function is called,
514 // > the system owns these buffers and the application may not
515 // > access them.
516 //
517 // So what we're doing is actually UB as `bufs` needs to be `&mut
518 // [IoSlice<'_>]`.
519 //
520 // Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
521 //
522 // NOTE: `send_to_vectored` has the same problem.
523 //
524 // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
525 bufs.as_ptr() as *mut _,
526 min(bufs.len(), DWORD::max_value() as usize) as DWORD,
527 &mut nsent,
528 flags as DWORD,
529 std::ptr::null_mut(),
530 None,
531 ),
532 PartialEq::eq,
533 sock::SOCKET_ERROR
534 )
535 .map(|_| nsent as usize)
536 }
537
send_to( socket: Socket, buf: &[u8], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>538 pub(crate) fn send_to(
539 socket: Socket,
540 buf: &[u8],
541 addr: &SockAddr,
542 flags: c_int,
543 ) -> io::Result<usize> {
544 syscall!(
545 sendto(
546 socket,
547 buf.as_ptr().cast(),
548 min(buf.len(), MAX_BUF_LEN) as c_int,
549 flags,
550 addr.as_ptr(),
551 addr.len(),
552 ),
553 PartialEq::eq,
554 sock::SOCKET_ERROR
555 )
556 .map(|n| n as usize)
557 }
558
send_to_vectored( socket: Socket, bufs: &[IoSlice<'_>], addr: &SockAddr, flags: c_int, ) -> io::Result<usize>559 pub(crate) fn send_to_vectored(
560 socket: Socket,
561 bufs: &[IoSlice<'_>],
562 addr: &SockAddr,
563 flags: c_int,
564 ) -> io::Result<usize> {
565 let mut nsent = 0;
566 syscall!(
567 WSASendTo(
568 socket,
569 // FIXME: Same problem as in `send_vectored`.
570 bufs.as_ptr() as *mut _,
571 bufs.len().min(DWORD::MAX as usize) as DWORD,
572 &mut nsent,
573 flags as DWORD,
574 addr.as_ptr(),
575 addr.len(),
576 ptr::null_mut(),
577 None,
578 ),
579 PartialEq::eq,
580 sock::SOCKET_ERROR
581 )
582 .map(|_| nsent as usize)
583 }
584
585 /// Wrapper around `getsockopt` to deal with platform specific timeouts.
timeout_opt(fd: Socket, lvl: c_int, name: c_int) -> io::Result<Option<Duration>>586 pub(crate) fn timeout_opt(fd: Socket, lvl: c_int, name: c_int) -> io::Result<Option<Duration>> {
587 unsafe { getsockopt(fd, lvl, name).map(from_ms) }
588 }
589
from_ms(duration: DWORD) -> Option<Duration>590 fn from_ms(duration: DWORD) -> Option<Duration> {
591 if duration == 0 {
592 None
593 } else {
594 let secs = duration / 1000;
595 let nsec = (duration % 1000) * 1000000;
596 Some(Duration::new(secs as u64, nsec as u32))
597 }
598 }
599
600 /// Wrapper around `setsockopt` to deal with platform specific timeouts.
set_timeout_opt( fd: Socket, level: c_int, optname: c_int, duration: Option<Duration>, ) -> io::Result<()>601 pub(crate) fn set_timeout_opt(
602 fd: Socket,
603 level: c_int,
604 optname: c_int,
605 duration: Option<Duration>,
606 ) -> io::Result<()> {
607 let duration = into_ms(duration);
608 unsafe { setsockopt(fd, level, optname, duration) }
609 }
610
into_ms(duration: Option<Duration>) -> DWORD611 fn into_ms(duration: Option<Duration>) -> DWORD {
612 // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
613 // timeouts in windows APIs are typically u32 milliseconds. To translate, we
614 // have two pieces to take care of:
615 //
616 // * Nanosecond precision is rounded up
617 // * Greater than u32::MAX milliseconds (50 days) is rounded up to
618 // INFINITE (never time out).
619 duration
620 .map(|duration| min(duration.as_millis(), INFINITE as u128) as DWORD)
621 .unwrap_or(0)
622 }
623
set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()>624 pub(crate) fn set_tcp_keepalive(socket: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
625 let mut keepalive = tcp_keepalive {
626 onoff: 1,
627 keepalivetime: into_ms(keepalive.time),
628 keepaliveinterval: into_ms(keepalive.interval),
629 };
630 let mut out = 0;
631 syscall!(
632 WSAIoctl(
633 socket,
634 SIO_KEEPALIVE_VALS,
635 &mut keepalive as *mut _ as *mut _,
636 size_of::<tcp_keepalive>() as _,
637 ptr::null_mut(),
638 0,
639 &mut out,
640 ptr::null_mut(),
641 None,
642 ),
643 PartialEq::eq,
644 sock::SOCKET_ERROR
645 )
646 .map(|_| ())
647 }
648
649 /// Caller must ensure `T` is the correct type for `level` and `optname`.
getsockopt<T>(socket: Socket, level: c_int, optname: c_int) -> io::Result<T>650 pub(crate) unsafe fn getsockopt<T>(socket: Socket, level: c_int, optname: c_int) -> io::Result<T> {
651 let mut optval: MaybeUninit<T> = MaybeUninit::uninit();
652 let mut optlen = mem::size_of::<T>() as c_int;
653 syscall!(
654 getsockopt(
655 socket,
656 level,
657 optname,
658 optval.as_mut_ptr().cast(),
659 &mut optlen,
660 ),
661 PartialEq::eq,
662 sock::SOCKET_ERROR
663 )
664 .map(|_| {
665 debug_assert_eq!(optlen as usize, mem::size_of::<T>());
666 // Safety: `getsockopt` initialised `optval` for us.
667 optval.assume_init()
668 })
669 }
670
671 /// Caller must ensure `T` is the correct type for `level` and `optname`.
setsockopt<T>( socket: Socket, level: c_int, optname: c_int, optval: T, ) -> io::Result<()>672 pub(crate) unsafe fn setsockopt<T>(
673 socket: Socket,
674 level: c_int,
675 optname: c_int,
676 optval: T,
677 ) -> io::Result<()> {
678 syscall!(
679 setsockopt(
680 socket,
681 level,
682 optname,
683 (&optval as *const T).cast(),
684 mem::size_of::<T>() as c_int,
685 ),
686 PartialEq::eq,
687 sock::SOCKET_ERROR
688 )
689 .map(|_| ())
690 }
691
ioctlsocket(socket: Socket, cmd: c_long, payload: &mut u_long) -> io::Result<()>692 fn ioctlsocket(socket: Socket, cmd: c_long, payload: &mut u_long) -> io::Result<()> {
693 syscall!(
694 ioctlsocket(socket, cmd, payload),
695 PartialEq::eq,
696 sock::SOCKET_ERROR
697 )
698 .map(|_| ())
699 }
700
close(socket: Socket)701 pub(crate) fn close(socket: Socket) {
702 unsafe {
703 let _ = sock::closesocket(socket);
704 }
705 }
706
to_in_addr(addr: &Ipv4Addr) -> IN_ADDR707 pub(crate) fn to_in_addr(addr: &Ipv4Addr) -> IN_ADDR {
708 let mut s_un: in_addr_S_un = unsafe { mem::zeroed() };
709 // `S_un` is stored as BE on all machines, and the array is in BE order. So
710 // the native endian conversion method is used so that it's never swapped.
711 unsafe { *(s_un.S_addr_mut()) = u32::from_ne_bytes(addr.octets()) };
712 IN_ADDR { S_un: s_un }
713 }
714
from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr715 pub(crate) fn from_in_addr(in_addr: IN_ADDR) -> Ipv4Addr {
716 Ipv4Addr::from(unsafe { *in_addr.S_un.S_addr() }.to_ne_bytes())
717 }
718
to_in6_addr(addr: &Ipv6Addr) -> in6_addr719 pub(crate) fn to_in6_addr(addr: &Ipv6Addr) -> in6_addr {
720 let mut ret_addr: in6_addr_u = unsafe { mem::zeroed() };
721 unsafe { *(ret_addr.Byte_mut()) = addr.octets() };
722 let mut ret: in6_addr = unsafe { mem::zeroed() };
723 ret.u = ret_addr;
724 ret
725 }
726
from_in6_addr(addr: in6_addr) -> Ipv6Addr727 pub(crate) fn from_in6_addr(addr: in6_addr) -> Ipv6Addr {
728 Ipv6Addr::from(*unsafe { addr.u.Byte() })
729 }
730
731 /// Windows only API.
732 impl crate::Socket {
733 /// Sets `HANDLE_FLAG_INHERIT` using `SetHandleInformation`.
734 #[cfg(feature = "all")]
set_no_inherit(&self, no_inherit: bool) -> io::Result<()>735 pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
736 self._set_no_inherit(no_inherit)
737 }
738
_set_no_inherit(&self, no_inherit: bool) -> io::Result<()>739 pub(crate) fn _set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
740 // NOTE: can't use `syscall!` because it expects the function in the
741 // `sock::` path.
742 let res = unsafe {
743 SetHandleInformation(
744 self.inner as HANDLE,
745 winbase::HANDLE_FLAG_INHERIT,
746 !no_inherit as _,
747 )
748 };
749 if res == 0 {
750 // Zero means error.
751 Err(io::Error::last_os_error())
752 } else {
753 Ok(())
754 }
755 }
756 }
757
758 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket759 fn as_raw_socket(&self) -> RawSocket {
760 self.inner as RawSocket
761 }
762 }
763
764 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket765 fn into_raw_socket(self) -> RawSocket {
766 let socket = self.inner;
767 mem::forget(self);
768 socket as RawSocket
769 }
770 }
771
772 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket773 unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
774 crate::Socket {
775 inner: socket as Socket,
776 }
777 }
778 }
779
780 #[test]
in_addr_convertion()781 fn in_addr_convertion() {
782 let ip = Ipv4Addr::new(127, 0, 0, 1);
783 let raw = to_in_addr(&ip);
784 assert_eq!(unsafe { *raw.S_un.S_addr() }, 127 << 0 | 1 << 24);
785 assert_eq!(from_in_addr(raw), ip);
786
787 let ip = Ipv4Addr::new(127, 34, 4, 12);
788 let raw = to_in_addr(&ip);
789 assert_eq!(
790 unsafe { *raw.S_un.S_addr() },
791 127 << 0 | 34 << 8 | 4 << 16 | 12 << 24
792 );
793 assert_eq!(from_in_addr(raw), ip);
794 }
795
796 #[test]
in6_addr_convertion()797 fn in6_addr_convertion() {
798 let ip = Ipv6Addr::new(0x2000, 1, 2, 3, 4, 5, 6, 7);
799 let raw = to_in6_addr(&ip);
800 let want = [
801 0x2000u16.to_be(),
802 1u16.to_be(),
803 2u16.to_be(),
804 3u16.to_be(),
805 4u16.to_be(),
806 5u16.to_be(),
807 6u16.to_be(),
808 7u16.to_be(),
809 ];
810 assert_eq!(unsafe { *raw.u.Word() }, want);
811 assert_eq!(from_in6_addr(raw), ip);
812 }
813