1 // Copyright 2015 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10 
11 use std::cmp;
12 use std::fmt;
13 use std::io;
14 use std::io::{Read, Write};
15 use std::mem;
16 use std::net::Shutdown;
17 use std::net::{self, Ipv4Addr, Ipv6Addr};
18 use std::os::windows::prelude::*;
19 use std::ptr;
20 use std::sync::Once;
21 use std::time::Duration;
22 
23 use winapi::ctypes::{c_char, c_long, c_ulong};
24 use winapi::shared::in6addr::*;
25 use winapi::shared::inaddr::*;
26 use winapi::shared::minwindef::DWORD;
27 use winapi::shared::ntdef::{HANDLE, ULONG};
28 use winapi::shared::ws2def::{self, *};
29 use winapi::shared::ws2ipdef::*;
30 use winapi::um::handleapi::SetHandleInformation;
31 use winapi::um::processthreadsapi::GetCurrentProcessId;
32 use winapi::um::winbase::INFINITE;
33 use winapi::um::winsock2 as sock;
34 
35 use crate::SockAddr;
36 
37 const HANDLE_FLAG_INHERIT: DWORD = 0x00000001;
38 const MSG_PEEK: c_int = 0x2;
39 const SD_BOTH: c_int = 2;
40 const SD_RECEIVE: c_int = 0;
41 const SD_SEND: c_int = 1;
42 const SIO_KEEPALIVE_VALS: DWORD = 0x98000004;
43 const WSA_FLAG_OVERLAPPED: DWORD = 0x01;
44 
45 pub use winapi::ctypes::c_int;
46 
47 // Used in `Domain`.
48 pub(crate) use winapi::shared::ws2def::{AF_INET, AF_INET6};
49 // Used in `Type`.
50 pub(crate) use winapi::shared::ws2def::{SOCK_DGRAM, SOCK_RAW, SOCK_SEQPACKET, SOCK_STREAM};
51 // Used in `Protocol`.
52 pub(crate) const IPPROTO_ICMP: c_int = winapi::shared::ws2def::IPPROTO_ICMP as c_int;
53 pub(crate) const IPPROTO_ICMPV6: c_int = winapi::shared::ws2def::IPPROTO_ICMPV6 as c_int;
54 pub(crate) const IPPROTO_TCP: c_int = winapi::shared::ws2def::IPPROTO_TCP as c_int;
55 pub(crate) const IPPROTO_UDP: c_int = winapi::shared::ws2def::IPPROTO_UDP as c_int;
56 
57 impl_debug!(
58     crate::Domain,
59     ws2def::AF_INET,
60     ws2def::AF_INET6,
61     ws2def::AF_UNIX,
62     ws2def::AF_UNSPEC, // = 0.
63 );
64 
65 impl_debug!(
66     crate::Type,
67     ws2def::SOCK_STREAM,
68     ws2def::SOCK_DGRAM,
69     ws2def::SOCK_RAW,
70     ws2def::SOCK_RDM,
71     ws2def::SOCK_SEQPACKET,
72 );
73 
74 impl_debug!(
75     crate::Protocol,
76     self::IPPROTO_ICMP,
77     self::IPPROTO_ICMPV6,
78     self::IPPROTO_TCP,
79     self::IPPROTO_UDP,
80 );
81 
82 #[repr(C)]
83 struct tcp_keepalive {
84     onoff: c_ulong,
85     keepalivetime: c_ulong,
86     keepaliveinterval: c_ulong,
87 }
88 
init()89 fn init() {
90     static INIT: Once = Once::new();
91 
92     INIT.call_once(|| {
93         // Initialize winsock through the standard library by just creating a
94         // dummy socket. Whether this is successful or not we drop the result as
95         // libstd will be sure to have initialized winsock.
96         let _ = net::UdpSocket::bind("127.0.0.1:34254");
97     });
98 }
99 
last_error() -> io::Error100 fn last_error() -> io::Error {
101     io::Error::from_raw_os_error(unsafe { sock::WSAGetLastError() })
102 }
103 
104 pub struct Socket {
105     socket: sock::SOCKET,
106 }
107 
108 impl Socket {
new(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket>109     pub fn new(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket> {
110         init();
111         unsafe {
112             let socket = match sock::WSASocketW(
113                 family,
114                 ty,
115                 protocol,
116                 ptr::null_mut(),
117                 0,
118                 WSA_FLAG_OVERLAPPED,
119             ) {
120                 sock::INVALID_SOCKET => return Err(last_error()),
121                 socket => socket,
122             };
123             let socket = Socket::from_raw_socket(socket as RawSocket);
124             socket.set_no_inherit()?;
125             Ok(socket)
126         }
127     }
128 
bind(&self, addr: &SockAddr) -> io::Result<()>129     pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
130         unsafe {
131             if sock::bind(self.socket, addr.as_ptr(), addr.len()) == 0 {
132                 Ok(())
133             } else {
134                 Err(last_error())
135             }
136         }
137     }
138 
listen(&self, backlog: i32) -> io::Result<()>139     pub fn listen(&self, backlog: i32) -> io::Result<()> {
140         unsafe {
141             if sock::listen(self.socket, backlog) == 0 {
142                 Ok(())
143             } else {
144                 Err(last_error())
145             }
146         }
147     }
148 
connect(&self, addr: &SockAddr) -> io::Result<()>149     pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
150         unsafe {
151             if sock::connect(self.socket, addr.as_ptr(), addr.len()) == 0 {
152                 Ok(())
153             } else {
154                 Err(last_error())
155             }
156         }
157     }
158 
connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()>159     pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> {
160         self.set_nonblocking(true)?;
161         let r = self.connect(addr);
162         self.set_nonblocking(false)?;
163 
164         match r {
165             Ok(()) => return Ok(()),
166             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
167             Err(e) => return Err(e),
168         }
169 
170         if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
171             return Err(io::Error::new(
172                 io::ErrorKind::InvalidInput,
173                 "cannot set a 0 duration timeout",
174             ));
175         }
176 
177         let mut timeout = sock::timeval {
178             tv_sec: timeout.as_secs() as c_long,
179             tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
180         };
181         if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
182             timeout.tv_usec = 1;
183         }
184 
185         let fds = unsafe {
186             let mut fds = mem::zeroed::<sock::fd_set>();
187             fds.fd_count = 1;
188             fds.fd_array[0] = self.socket;
189             fds
190         };
191 
192         let mut writefds = fds;
193         let mut errorfds = fds;
194 
195         match unsafe { sock::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout) } {
196             sock::SOCKET_ERROR => return Err(io::Error::last_os_error()),
197             0 => {
198                 return Err(io::Error::new(
199                     io::ErrorKind::TimedOut,
200                     "connection timed out",
201                 ))
202             }
203             _ => {
204                 if writefds.fd_count != 1 {
205                     if let Some(e) = self.take_error()? {
206                         return Err(e);
207                     }
208                 }
209                 Ok(())
210             }
211         }
212     }
213 
local_addr(&self) -> io::Result<SockAddr>214     pub fn local_addr(&self) -> io::Result<SockAddr> {
215         unsafe {
216             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
217             let mut len = mem::size_of_val(&storage) as c_int;
218             if sock::getsockname(self.socket, &mut storage as *mut _ as *mut _, &mut len) != 0 {
219                 return Err(last_error());
220             }
221             Ok(SockAddr::from_raw_parts(
222                 &storage as *const _ as *const _,
223                 len,
224             ))
225         }
226     }
227 
peer_addr(&self) -> io::Result<SockAddr>228     pub fn peer_addr(&self) -> io::Result<SockAddr> {
229         unsafe {
230             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
231             let mut len = mem::size_of_val(&storage) as c_int;
232             if sock::getpeername(self.socket, &mut storage as *mut _ as *mut _, &mut len) != 0 {
233                 return Err(last_error());
234             }
235             Ok(SockAddr::from_raw_parts(
236                 &storage as *const _ as *const _,
237                 len,
238             ))
239         }
240     }
241 
try_clone(&self) -> io::Result<Socket>242     pub fn try_clone(&self) -> io::Result<Socket> {
243         unsafe {
244             let mut info: sock::WSAPROTOCOL_INFOW = mem::zeroed();
245             let r = sock::WSADuplicateSocketW(self.socket, GetCurrentProcessId(), &mut info);
246             if r != 0 {
247                 return Err(io::Error::last_os_error());
248             }
249             let socket = sock::WSASocketW(
250                 info.iAddressFamily,
251                 info.iSocketType,
252                 info.iProtocol,
253                 &mut info,
254                 0,
255                 WSA_FLAG_OVERLAPPED,
256             );
257             let socket = match socket {
258                 sock::INVALID_SOCKET => return Err(last_error()),
259                 n => Socket::from_raw_socket(n as RawSocket),
260             };
261             socket.set_no_inherit()?;
262             Ok(socket)
263         }
264     }
265 
accept(&self) -> io::Result<(Socket, SockAddr)>266     pub fn accept(&self) -> io::Result<(Socket, SockAddr)> {
267         unsafe {
268             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
269             let mut len = mem::size_of_val(&storage) as c_int;
270             let socket = { sock::accept(self.socket, &mut storage as *mut _ as *mut _, &mut len) };
271             let socket = match socket {
272                 sock::INVALID_SOCKET => return Err(last_error()),
273                 socket => Socket::from_raw_socket(socket as RawSocket),
274             };
275             socket.set_no_inherit()?;
276             let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, len);
277             Ok((socket, addr))
278         }
279     }
280 
take_error(&self) -> io::Result<Option<io::Error>>281     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
282         unsafe {
283             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_ERROR)?;
284             if raw == 0 {
285                 Ok(None)
286             } else {
287                 Ok(Some(io::Error::from_raw_os_error(raw as i32)))
288             }
289         }
290     }
291 
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>292     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
293         unsafe {
294             let mut nonblocking = nonblocking as c_ulong;
295             let r = sock::ioctlsocket(self.socket, sock::FIONBIO as c_int, &mut nonblocking);
296             if r == 0 {
297                 Ok(())
298             } else {
299                 Err(io::Error::last_os_error())
300             }
301         }
302     }
303 
shutdown(&self, how: Shutdown) -> io::Result<()>304     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
305         let how = match how {
306             Shutdown::Write => SD_SEND,
307             Shutdown::Read => SD_RECEIVE,
308             Shutdown::Both => SD_BOTH,
309         };
310         if unsafe { sock::shutdown(self.socket, how) == 0 } {
311             Ok(())
312         } else {
313             Err(last_error())
314         }
315     }
316 
recv(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize>317     pub fn recv(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
318         unsafe {
319             let n = {
320                 sock::recv(
321                     self.socket,
322                     buf.as_mut_ptr() as *mut c_char,
323                     clamp(buf.len()),
324                     flags,
325                 )
326             };
327             match n {
328                 sock::SOCKET_ERROR if sock::WSAGetLastError() == sock::WSAESHUTDOWN as i32 => Ok(0),
329                 sock::SOCKET_ERROR => Err(last_error()),
330                 n => Ok(n as usize),
331             }
332         }
333     }
334 
peek(&self, buf: &mut [u8]) -> io::Result<usize>335     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
336         unsafe {
337             let n = {
338                 sock::recv(
339                     self.socket,
340                     buf.as_mut_ptr() as *mut c_char,
341                     clamp(buf.len()),
342                     MSG_PEEK,
343                 )
344             };
345             match n {
346                 sock::SOCKET_ERROR if sock::WSAGetLastError() == sock::WSAESHUTDOWN as i32 => Ok(0),
347                 sock::SOCKET_ERROR => Err(last_error()),
348                 n => Ok(n as usize),
349             }
350         }
351     }
352 
peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)>353     pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
354         self.recv_from(buf, MSG_PEEK)
355     }
356 
recv_from(&self, buf: &mut [u8], flags: c_int) -> io::Result<(usize, SockAddr)>357     pub fn recv_from(&self, buf: &mut [u8], flags: c_int) -> io::Result<(usize, SockAddr)> {
358         unsafe {
359             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
360             let mut addrlen = mem::size_of_val(&storage) as c_int;
361 
362             let n = {
363                 sock::recvfrom(
364                     self.socket,
365                     buf.as_mut_ptr() as *mut c_char,
366                     clamp(buf.len()),
367                     flags,
368                     &mut storage as *mut _ as *mut _,
369                     &mut addrlen,
370                 )
371             };
372             let n = match n {
373                 sock::SOCKET_ERROR if sock::WSAGetLastError() == sock::WSAESHUTDOWN as i32 => 0,
374                 sock::SOCKET_ERROR => return Err(last_error()),
375                 n => n as usize,
376             };
377             let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, addrlen);
378             Ok((n, addr))
379         }
380     }
381 
send(&self, buf: &[u8], flags: c_int) -> io::Result<usize>382     pub fn send(&self, buf: &[u8], flags: c_int) -> io::Result<usize> {
383         unsafe {
384             let n = {
385                 sock::send(
386                     self.socket,
387                     buf.as_ptr() as *const c_char,
388                     clamp(buf.len()),
389                     flags,
390                 )
391             };
392             if n == sock::SOCKET_ERROR {
393                 Err(last_error())
394             } else {
395                 Ok(n as usize)
396             }
397         }
398     }
399 
send_to(&self, buf: &[u8], flags: c_int, addr: &SockAddr) -> io::Result<usize>400     pub fn send_to(&self, buf: &[u8], flags: c_int, addr: &SockAddr) -> io::Result<usize> {
401         unsafe {
402             let n = {
403                 sock::sendto(
404                     self.socket,
405                     buf.as_ptr() as *const c_char,
406                     clamp(buf.len()),
407                     flags,
408                     addr.as_ptr(),
409                     addr.len(),
410                 )
411             };
412             if n == sock::SOCKET_ERROR {
413                 Err(last_error())
414             } else {
415                 Ok(n as usize)
416             }
417         }
418     }
419 
420     // ================================================
421 
ttl(&self) -> io::Result<u32>422     pub fn ttl(&self) -> io::Result<u32> {
423         unsafe {
424             let raw: c_int = self.getsockopt(IPPROTO_IP, IP_TTL)?;
425             Ok(raw as u32)
426         }
427     }
428 
set_ttl(&self, ttl: u32) -> io::Result<()>429     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
430         unsafe { self.setsockopt(IPPROTO_IP, IP_TTL, ttl as c_int) }
431     }
432 
unicast_hops_v6(&self) -> io::Result<u32>433     pub fn unicast_hops_v6(&self) -> io::Result<u32> {
434         unsafe {
435             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_UNICAST_HOPS)?;
436             Ok(raw as u32)
437         }
438     }
439 
set_unicast_hops_v6(&self, hops: u32) -> io::Result<()>440     pub fn set_unicast_hops_v6(&self, hops: u32) -> io::Result<()> {
441         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_UNICAST_HOPS, hops as c_int) }
442     }
443 
only_v6(&self) -> io::Result<bool>444     pub fn only_v6(&self) -> io::Result<bool> {
445         unsafe {
446             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_V6ONLY)?;
447             Ok(raw != 0)
448         }
449     }
450 
set_only_v6(&self, only_v6: bool) -> io::Result<()>451     pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
452         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_V6ONLY, only_v6 as c_int) }
453     }
454 
read_timeout(&self) -> io::Result<Option<Duration>>455     pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
456         unsafe { Ok(ms2dur(self.getsockopt(SOL_SOCKET, SO_RCVTIMEO)?)) }
457     }
458 
set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>459     pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
460         unsafe { self.setsockopt(SOL_SOCKET, SO_RCVTIMEO, dur2ms(dur)?) }
461     }
462 
write_timeout(&self) -> io::Result<Option<Duration>>463     pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
464         unsafe { Ok(ms2dur(self.getsockopt(SOL_SOCKET, SO_SNDTIMEO)?)) }
465     }
466 
set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>467     pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
468         unsafe { self.setsockopt(SOL_SOCKET, SO_SNDTIMEO, dur2ms(dur)?) }
469     }
470 
nodelay(&self) -> io::Result<bool>471     pub fn nodelay(&self) -> io::Result<bool> {
472         unsafe {
473             let raw: c_char = self.getsockopt(IPPROTO_TCP, TCP_NODELAY)?;
474             Ok(raw != 0)
475         }
476     }
477 
set_nodelay(&self, nodelay: bool) -> io::Result<()>478     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
479         unsafe { self.setsockopt(IPPROTO_TCP, TCP_NODELAY, nodelay as c_char) }
480     }
481 
broadcast(&self) -> io::Result<bool>482     pub fn broadcast(&self) -> io::Result<bool> {
483         unsafe {
484             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_BROADCAST)?;
485             Ok(raw != 0)
486         }
487     }
488 
set_broadcast(&self, broadcast: bool) -> io::Result<()>489     pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
490         unsafe { self.setsockopt(SOL_SOCKET, SO_BROADCAST, broadcast as c_int) }
491     }
492 
multicast_loop_v4(&self) -> io::Result<bool>493     pub fn multicast_loop_v4(&self) -> io::Result<bool> {
494         unsafe {
495             let raw: c_int = self.getsockopt(IPPROTO_IP, IP_MULTICAST_LOOP)?;
496             Ok(raw != 0)
497         }
498     }
499 
set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()>500     pub fn set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()> {
501         unsafe { self.setsockopt(IPPROTO_IP, IP_MULTICAST_LOOP, multicast_loop_v4 as c_int) }
502     }
503 
multicast_ttl_v4(&self) -> io::Result<u32>504     pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
505         unsafe {
506             let raw: c_int = self.getsockopt(IPPROTO_IP, IP_MULTICAST_TTL)?;
507             Ok(raw as u32)
508         }
509     }
510 
set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()>511     pub fn set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()> {
512         unsafe { self.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, multicast_ttl_v4 as c_int) }
513     }
514 
multicast_hops_v6(&self) -> io::Result<u32>515     pub fn multicast_hops_v6(&self) -> io::Result<u32> {
516         unsafe {
517             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_HOPS)?;
518             Ok(raw as u32)
519         }
520     }
521 
set_multicast_hops_v6(&self, hops: u32) -> io::Result<()>522     pub fn set_multicast_hops_v6(&self, hops: u32) -> io::Result<()> {
523         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_HOPS, hops as c_int) }
524     }
525 
multicast_if_v4(&self) -> io::Result<Ipv4Addr>526     pub fn multicast_if_v4(&self) -> io::Result<Ipv4Addr> {
527         unsafe {
528             let imr_interface: IN_ADDR = self.getsockopt(IPPROTO_IP, IP_MULTICAST_IF)?;
529             Ok(from_s_addr(imr_interface.S_un))
530         }
531     }
532 
set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()>533     pub fn set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()> {
534         let interface = to_s_addr(interface);
535         let imr_interface = IN_ADDR { S_un: interface };
536 
537         unsafe { self.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, imr_interface) }
538     }
539 
multicast_if_v6(&self) -> io::Result<u32>540     pub fn multicast_if_v6(&self) -> io::Result<u32> {
541         unsafe {
542             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_IF)?;
543             Ok(raw as u32)
544         }
545     }
546 
set_multicast_if_v6(&self, interface: u32) -> io::Result<()>547     pub fn set_multicast_if_v6(&self, interface: u32) -> io::Result<()> {
548         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_IF, interface as c_int) }
549     }
550 
multicast_loop_v6(&self) -> io::Result<bool>551     pub fn multicast_loop_v6(&self) -> io::Result<bool> {
552         unsafe {
553             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_LOOP)?;
554             Ok(raw != 0)
555         }
556     }
557 
set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()>558     pub fn set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()> {
559         unsafe {
560             self.setsockopt(
561                 IPPROTO_IPV6 as c_int,
562                 IPV6_MULTICAST_LOOP,
563                 multicast_loop_v6 as c_int,
564             )
565         }
566     }
567 
join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()>568     pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
569         let multiaddr = to_s_addr(multiaddr);
570         let interface = to_s_addr(interface);
571         let mreq = IP_MREQ {
572             imr_multiaddr: IN_ADDR { S_un: multiaddr },
573             imr_interface: IN_ADDR { S_un: interface },
574         };
575         unsafe { self.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq) }
576     }
577 
join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()>578     pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
579         let multiaddr = to_in6_addr(multiaddr);
580         let mreq = IPV6_MREQ {
581             ipv6mr_multiaddr: multiaddr,
582             ipv6mr_interface: interface,
583         };
584         unsafe { self.setsockopt(IPPROTO_IP, IPV6_ADD_MEMBERSHIP, mreq) }
585     }
586 
leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()>587     pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
588         let multiaddr = to_s_addr(multiaddr);
589         let interface = to_s_addr(interface);
590         let mreq = IP_MREQ {
591             imr_multiaddr: IN_ADDR { S_un: multiaddr },
592             imr_interface: IN_ADDR { S_un: interface },
593         };
594         unsafe { self.setsockopt(IPPROTO_IP, IP_DROP_MEMBERSHIP, mreq) }
595     }
596 
leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()>597     pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
598         let multiaddr = to_in6_addr(multiaddr);
599         let mreq = IPV6_MREQ {
600             ipv6mr_multiaddr: multiaddr,
601             ipv6mr_interface: interface,
602         };
603         unsafe { self.setsockopt(IPPROTO_IP, IPV6_DROP_MEMBERSHIP, mreq) }
604     }
605 
linger(&self) -> io::Result<Option<Duration>>606     pub fn linger(&self) -> io::Result<Option<Duration>> {
607         unsafe { Ok(linger2dur(self.getsockopt(SOL_SOCKET, SO_LINGER)?)) }
608     }
609 
set_linger(&self, dur: Option<Duration>) -> io::Result<()>610     pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
611         unsafe { self.setsockopt(SOL_SOCKET, SO_LINGER, dur2linger(dur)) }
612     }
613 
set_reuse_address(&self, reuse: bool) -> io::Result<()>614     pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
615         unsafe { self.setsockopt(SOL_SOCKET, SO_REUSEADDR, reuse as c_int) }
616     }
617 
reuse_address(&self) -> io::Result<bool>618     pub fn reuse_address(&self) -> io::Result<bool> {
619         unsafe {
620             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_REUSEADDR)?;
621             Ok(raw != 0)
622         }
623     }
624 
recv_buffer_size(&self) -> io::Result<usize>625     pub fn recv_buffer_size(&self) -> io::Result<usize> {
626         unsafe {
627             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_RCVBUF)?;
628             Ok(raw as usize)
629         }
630     }
631 
set_recv_buffer_size(&self, size: usize) -> io::Result<()>632     pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
633         unsafe {
634             // TODO: casting usize to a c_int should be a checked cast
635             self.setsockopt(SOL_SOCKET, SO_RCVBUF, size as c_int)
636         }
637     }
638 
send_buffer_size(&self) -> io::Result<usize>639     pub fn send_buffer_size(&self) -> io::Result<usize> {
640         unsafe {
641             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_SNDBUF)?;
642             Ok(raw as usize)
643         }
644     }
645 
set_send_buffer_size(&self, size: usize) -> io::Result<()>646     pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
647         unsafe {
648             // TODO: casting usize to a c_int should be a checked cast
649             self.setsockopt(SOL_SOCKET, SO_SNDBUF, size as c_int)
650         }
651     }
652 
keepalive(&self) -> io::Result<Option<Duration>>653     pub fn keepalive(&self) -> io::Result<Option<Duration>> {
654         let mut ka = tcp_keepalive {
655             onoff: 0,
656             keepalivetime: 0,
657             keepaliveinterval: 0,
658         };
659         let n = unsafe {
660             sock::WSAIoctl(
661                 self.socket,
662                 SIO_KEEPALIVE_VALS,
663                 0 as *mut _,
664                 0,
665                 &mut ka as *mut _ as *mut _,
666                 mem::size_of_val(&ka) as DWORD,
667                 0 as *mut _,
668                 0 as *mut _,
669                 None,
670             )
671         };
672         if n == 0 {
673             Ok(if ka.onoff == 0 {
674                 None
675             } else if ka.keepaliveinterval == 0 {
676                 None
677             } else {
678                 let seconds = ka.keepaliveinterval / 1000;
679                 let nanos = (ka.keepaliveinterval % 1000) * 1_000_000;
680                 Some(Duration::new(seconds as u64, nanos as u32))
681             })
682         } else {
683             Err(last_error())
684         }
685     }
686 
set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()>687     pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
688         let ms = dur2ms(keepalive)?;
689         // TODO: checked casts here
690         let ka = tcp_keepalive {
691             onoff: keepalive.is_some() as c_ulong,
692             keepalivetime: ms as c_ulong,
693             keepaliveinterval: ms as c_ulong,
694         };
695         let mut out = 0;
696         let n = unsafe {
697             sock::WSAIoctl(
698                 self.socket,
699                 SIO_KEEPALIVE_VALS,
700                 &ka as *const _ as *mut _,
701                 mem::size_of_val(&ka) as DWORD,
702                 0 as *mut _,
703                 0,
704                 &mut out,
705                 0 as *mut _,
706                 None,
707             )
708         };
709         if n == 0 {
710             Ok(())
711         } else {
712             Err(last_error())
713         }
714     }
715 
out_of_band_inline(&self) -> io::Result<bool>716     pub fn out_of_band_inline(&self) -> io::Result<bool> {
717         unsafe {
718             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_OOBINLINE)?;
719             Ok(raw != 0)
720         }
721     }
722 
set_out_of_band_inline(&self, oob_inline: bool) -> io::Result<()>723     pub fn set_out_of_band_inline(&self, oob_inline: bool) -> io::Result<()> {
724         unsafe { self.setsockopt(SOL_SOCKET, SO_OOBINLINE, oob_inline as c_int) }
725     }
726 
setsockopt<T>(&self, opt: c_int, val: c_int, payload: T) -> io::Result<()> where T: Copy,727     unsafe fn setsockopt<T>(&self, opt: c_int, val: c_int, payload: T) -> io::Result<()>
728     where
729         T: Copy,
730     {
731         let payload = &payload as *const T as *const c_char;
732         if sock::setsockopt(self.socket, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
733             Ok(())
734         } else {
735             Err(last_error())
736         }
737     }
738 
getsockopt<T: Copy>(&self, opt: c_int, val: c_int) -> io::Result<T>739     unsafe fn getsockopt<T: Copy>(&self, opt: c_int, val: c_int) -> io::Result<T> {
740         let mut slot: T = mem::zeroed();
741         let mut len = mem::size_of::<T>() as c_int;
742         if sock::getsockopt(
743             self.socket,
744             opt,
745             val,
746             &mut slot as *mut _ as *mut _,
747             &mut len,
748         ) == 0
749         {
750             assert_eq!(len as usize, mem::size_of::<T>());
751             Ok(slot)
752         } else {
753             Err(last_error())
754         }
755     }
756 
set_no_inherit(&self) -> io::Result<()>757     fn set_no_inherit(&self) -> io::Result<()> {
758         unsafe {
759             let r = SetHandleInformation(self.socket as HANDLE, HANDLE_FLAG_INHERIT, 0);
760             if r == 0 {
761                 Err(io::Error::last_os_error())
762             } else {
763                 Ok(())
764             }
765         }
766     }
767 }
768 
769 impl Read for Socket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>770     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
771         <&Socket>::read(&mut &*self, buf)
772     }
773 }
774 
775 impl<'a> Read for &'a Socket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>776     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
777         self.recv(buf, 0)
778     }
779 }
780 
781 impl Write for Socket {
write(&mut self, buf: &[u8]) -> io::Result<usize>782     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
783         <&Socket>::write(&mut &*self, buf)
784     }
785 
flush(&mut self) -> io::Result<()>786     fn flush(&mut self) -> io::Result<()> {
787         <&Socket>::flush(&mut &*self)
788     }
789 }
790 
791 impl<'a> Write for &'a Socket {
write(&mut self, buf: &[u8]) -> io::Result<usize>792     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
793         self.send(buf, 0)
794     }
795 
flush(&mut self) -> io::Result<()>796     fn flush(&mut self) -> io::Result<()> {
797         Ok(())
798     }
799 }
800 
801 impl fmt::Debug for Socket {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result802     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
803         let mut f = f.debug_struct("Socket");
804         f.field("socket", &self.socket);
805         if let Ok(addr) = self.local_addr() {
806             f.field("local_addr", &addr);
807         }
808         if let Ok(addr) = self.peer_addr() {
809             f.field("peer_addr", &addr);
810         }
811         f.finish()
812     }
813 }
814 
815 impl AsRawSocket for Socket {
as_raw_socket(&self) -> RawSocket816     fn as_raw_socket(&self) -> RawSocket {
817         self.socket as RawSocket
818     }
819 }
820 
821 impl IntoRawSocket for Socket {
into_raw_socket(self) -> RawSocket822     fn into_raw_socket(self) -> RawSocket {
823         let socket = self.socket;
824         mem::forget(self);
825         socket as RawSocket
826     }
827 }
828 
829 impl FromRawSocket for Socket {
from_raw_socket(socket: RawSocket) -> Socket830     unsafe fn from_raw_socket(socket: RawSocket) -> Socket {
831         Socket {
832             socket: socket as sock::SOCKET,
833         }
834     }
835 }
836 
837 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket838     fn as_raw_socket(&self) -> RawSocket {
839         self.inner.as_raw_socket()
840     }
841 }
842 
843 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket844     fn into_raw_socket(self) -> RawSocket {
845         self.inner.into_raw_socket()
846     }
847 }
848 
849 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket850     unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
851         crate::Socket {
852             inner: Socket::from_raw_socket(socket),
853         }
854     }
855 }
856 
857 impl Drop for Socket {
drop(&mut self)858     fn drop(&mut self) {
859         unsafe {
860             let _ = sock::closesocket(self.socket);
861         }
862     }
863 }
864 
865 impl From<Socket> for net::TcpStream {
from(socket: Socket) -> net::TcpStream866     fn from(socket: Socket) -> net::TcpStream {
867         unsafe { net::TcpStream::from_raw_socket(socket.into_raw_socket()) }
868     }
869 }
870 
871 impl From<Socket> for net::TcpListener {
from(socket: Socket) -> net::TcpListener872     fn from(socket: Socket) -> net::TcpListener {
873         unsafe { net::TcpListener::from_raw_socket(socket.into_raw_socket()) }
874     }
875 }
876 
877 impl From<Socket> for net::UdpSocket {
from(socket: Socket) -> net::UdpSocket878     fn from(socket: Socket) -> net::UdpSocket {
879         unsafe { net::UdpSocket::from_raw_socket(socket.into_raw_socket()) }
880     }
881 }
882 
883 impl From<net::TcpStream> for Socket {
from(socket: net::TcpStream) -> Socket884     fn from(socket: net::TcpStream) -> Socket {
885         unsafe { Socket::from_raw_socket(socket.into_raw_socket()) }
886     }
887 }
888 
889 impl From<net::TcpListener> for Socket {
from(socket: net::TcpListener) -> Socket890     fn from(socket: net::TcpListener) -> Socket {
891         unsafe { Socket::from_raw_socket(socket.into_raw_socket()) }
892     }
893 }
894 
895 impl From<net::UdpSocket> for Socket {
from(socket: net::UdpSocket) -> Socket896     fn from(socket: net::UdpSocket) -> Socket {
897         unsafe { Socket::from_raw_socket(socket.into_raw_socket()) }
898     }
899 }
900 
clamp(input: usize) -> c_int901 fn clamp(input: usize) -> c_int {
902     cmp::min(input, <c_int>::max_value() as usize) as c_int
903 }
904 
dur2ms(dur: Option<Duration>) -> io::Result<DWORD>905 fn dur2ms(dur: Option<Duration>) -> io::Result<DWORD> {
906     match dur {
907         Some(dur) => {
908             // Note that a duration is a (u64, u32) (seconds, nanoseconds)
909             // pair, and the timeouts in windows APIs are typically u32
910             // milliseconds. To translate, we have two pieces to take care of:
911             //
912             // * Nanosecond precision is rounded up
913             // * Greater than u32::MAX milliseconds (50 days) is rounded up to
914             //   INFINITE (never time out).
915             let ms = dur
916                 .as_secs()
917                 .checked_mul(1000)
918                 .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
919                 .and_then(|ms| {
920                     ms.checked_add(if dur.subsec_nanos() % 1_000_000 > 0 {
921                         1
922                     } else {
923                         0
924                     })
925                 })
926                 .map(|ms| {
927                     if ms > <DWORD>::max_value() as u64 {
928                         INFINITE
929                     } else {
930                         ms as DWORD
931                     }
932                 })
933                 .unwrap_or(INFINITE);
934             if ms == 0 {
935                 return Err(io::Error::new(
936                     io::ErrorKind::InvalidInput,
937                     "cannot set a 0 duration timeout",
938                 ));
939             }
940             Ok(ms)
941         }
942         None => Ok(0),
943     }
944 }
945 
ms2dur(raw: DWORD) -> Option<Duration>946 fn ms2dur(raw: DWORD) -> Option<Duration> {
947     if raw == 0 {
948         None
949     } else {
950         let secs = raw / 1000;
951         let nsec = (raw % 1000) * 1000000;
952         Some(Duration::new(secs as u64, nsec as u32))
953     }
954 }
955 
to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un956 fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
957     let octets = addr.octets();
958     let res = crate::hton(
959         ((octets[0] as ULONG) << 24)
960             | ((octets[1] as ULONG) << 16)
961             | ((octets[2] as ULONG) << 8)
962             | ((octets[3] as ULONG) << 0),
963     );
964     let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
965     unsafe { *(new_addr.S_addr_mut()) = res };
966     new_addr
967 }
968 
from_s_addr(in_addr: in_addr_S_un) -> Ipv4Addr969 fn from_s_addr(in_addr: in_addr_S_un) -> Ipv4Addr {
970     let h_addr = crate::ntoh(unsafe { *in_addr.S_addr() });
971 
972     let a: u8 = (h_addr >> 24) as u8;
973     let b: u8 = (h_addr >> 16) as u8;
974     let c: u8 = (h_addr >> 8) as u8;
975     let d: u8 = (h_addr >> 0) as u8;
976 
977     Ipv4Addr::new(a, b, c, d)
978 }
979 
to_in6_addr(addr: &Ipv6Addr) -> in6_addr980 fn to_in6_addr(addr: &Ipv6Addr) -> in6_addr {
981     let mut ret_addr: in6_addr_u = unsafe { mem::zeroed() };
982     unsafe { *(ret_addr.Byte_mut()) = addr.octets() };
983     let mut ret: in6_addr = unsafe { mem::zeroed() };
984     ret.u = ret_addr;
985     ret
986 }
987 
linger2dur(linger_opt: sock::linger) -> Option<Duration>988 fn linger2dur(linger_opt: sock::linger) -> Option<Duration> {
989     if linger_opt.l_onoff == 0 {
990         None
991     } else {
992         Some(Duration::from_secs(linger_opt.l_linger as u64))
993     }
994 }
995 
dur2linger(dur: Option<Duration>) -> sock::linger996 fn dur2linger(dur: Option<Duration>) -> sock::linger {
997     match dur {
998         Some(d) => sock::linger {
999             l_onoff: 1,
1000             l_linger: d.as_secs() as u16,
1001         },
1002         None => sock::linger {
1003             l_onoff: 0,
1004             l_linger: 0,
1005         },
1006     }
1007 }
1008 
1009 #[test]
test_ip()1010 fn test_ip() {
1011     let ip = Ipv4Addr::new(127, 0, 0, 1);
1012     assert_eq!(ip, from_s_addr(to_s_addr(&ip)));
1013 }
1014 
1015 #[test]
test_out_of_band_inline()1016 fn test_out_of_band_inline() {
1017     let tcp = Socket::new(AF_INET, SOCK_STREAM, 0).unwrap();
1018     assert_eq!(tcp.out_of_band_inline().unwrap(), false);
1019 
1020     tcp.set_out_of_band_inline(true).unwrap();
1021     assert_eq!(tcp.out_of_band_inline().unwrap(), true);
1022 }
1023