1 // Copyright 2015 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10 
11 use std::cmp;
12 use std::fmt;
13 use std::io;
14 use std::io::{Read, Write};
15 use std::mem;
16 use std::net::Shutdown;
17 use std::net::{self, Ipv4Addr, Ipv6Addr};
18 use std::os::windows::prelude::*;
19 use std::ptr;
20 use std::sync::{Once, ONCE_INIT};
21 use std::time::Duration;
22 
23 use winapi::ctypes::{c_char, c_int, c_long, c_ulong};
24 use winapi::shared::in6addr::*;
25 use winapi::shared::inaddr::*;
26 use winapi::shared::minwindef::DWORD;
27 use winapi::shared::ntdef::{HANDLE, ULONG};
28 use winapi::shared::ws2def;
29 use winapi::shared::ws2def::*;
30 use winapi::shared::ws2ipdef::*;
31 use winapi::um::handleapi::SetHandleInformation;
32 use winapi::um::processthreadsapi::GetCurrentProcessId;
33 use winapi::um::winbase::INFINITE;
34 use winapi::um::winsock2 as sock;
35 
36 use crate::SockAddr;
37 
38 const HANDLE_FLAG_INHERIT: DWORD = 0x00000001;
39 const MSG_PEEK: c_int = 0x2;
40 const SD_BOTH: c_int = 2;
41 const SD_RECEIVE: c_int = 0;
42 const SD_SEND: c_int = 1;
43 const SIO_KEEPALIVE_VALS: DWORD = 0x98000004;
44 const WSA_FLAG_OVERLAPPED: DWORD = 0x01;
45 
46 pub const IPPROTO_ICMP: i32 = ws2def::IPPROTO_ICMP as i32;
47 pub const IPPROTO_ICMPV6: i32 = ws2def::IPPROTO_ICMPV6 as i32;
48 pub const IPPROTO_TCP: i32 = ws2def::IPPROTO_TCP as i32;
49 pub const IPPROTO_UDP: i32 = ws2def::IPPROTO_UDP as i32;
50 pub const SOCK_SEQPACKET: i32 = ws2def::SOCK_SEQPACKET as i32;
51 pub const SOCK_RAW: i32 = ws2def::SOCK_RAW as i32;
52 
53 #[repr(C)]
54 struct tcp_keepalive {
55     onoff: c_ulong,
56     keepalivetime: c_ulong,
57     keepaliveinterval: c_ulong,
58 }
59 
init()60 fn init() {
61     static INIT: Once = ONCE_INIT;
62 
63     INIT.call_once(|| {
64         // Initialize winsock through the standard library by just creating a
65         // dummy socket. Whether this is successful or not we drop the result as
66         // libstd will be sure to have initialized winsock.
67         let _ = net::UdpSocket::bind("127.0.0.1:34254");
68     });
69 }
70 
last_error() -> io::Error71 fn last_error() -> io::Error {
72     io::Error::from_raw_os_error(unsafe { sock::WSAGetLastError() })
73 }
74 
75 pub struct Socket {
76     socket: sock::SOCKET,
77 }
78 
79 impl Socket {
new(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket>80     pub fn new(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket> {
81         init();
82         unsafe {
83             let socket = match sock::WSASocketW(
84                 family,
85                 ty,
86                 protocol,
87                 ptr::null_mut(),
88                 0,
89                 WSA_FLAG_OVERLAPPED,
90             ) {
91                 sock::INVALID_SOCKET => return Err(last_error()),
92                 socket => socket,
93             };
94             let socket = Socket::from_raw_socket(socket as RawSocket);
95             socket.set_no_inherit()?;
96             Ok(socket)
97         }
98     }
99 
bind(&self, addr: &SockAddr) -> io::Result<()>100     pub fn bind(&self, addr: &SockAddr) -> io::Result<()> {
101         unsafe {
102             if sock::bind(self.socket, addr.as_ptr(), addr.len()) == 0 {
103                 Ok(())
104             } else {
105                 Err(last_error())
106             }
107         }
108     }
109 
listen(&self, backlog: i32) -> io::Result<()>110     pub fn listen(&self, backlog: i32) -> io::Result<()> {
111         unsafe {
112             if sock::listen(self.socket, backlog) == 0 {
113                 Ok(())
114             } else {
115                 Err(last_error())
116             }
117         }
118     }
119 
connect(&self, addr: &SockAddr) -> io::Result<()>120     pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
121         unsafe {
122             if sock::connect(self.socket, addr.as_ptr(), addr.len()) == 0 {
123                 Ok(())
124             } else {
125                 Err(last_error())
126             }
127         }
128     }
129 
connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()>130     pub fn connect_timeout(&self, addr: &SockAddr, timeout: Duration) -> io::Result<()> {
131         self.set_nonblocking(true)?;
132         let r = self.connect(addr);
133         self.set_nonblocking(false)?;
134 
135         match r {
136             Ok(()) => return Ok(()),
137             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
138             Err(e) => return Err(e),
139         }
140 
141         if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
142             return Err(io::Error::new(
143                 io::ErrorKind::InvalidInput,
144                 "cannot set a 0 duration timeout",
145             ));
146         }
147 
148         let mut timeout = sock::timeval {
149             tv_sec: timeout.as_secs() as c_long,
150             tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
151         };
152         if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
153             timeout.tv_usec = 1;
154         }
155 
156         let fds = unsafe {
157             let mut fds = mem::zeroed::<sock::fd_set>();
158             fds.fd_count = 1;
159             fds.fd_array[0] = self.socket;
160             fds
161         };
162 
163         let mut writefds = fds;
164         let mut errorfds = fds;
165 
166         match unsafe { sock::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout) } {
167             sock::SOCKET_ERROR => return Err(io::Error::last_os_error()),
168             0 => {
169                 return Err(io::Error::new(
170                     io::ErrorKind::TimedOut,
171                     "connection timed out",
172                 ))
173             }
174             _ => {
175                 if writefds.fd_count != 1 {
176                     if let Some(e) = self.take_error()? {
177                         return Err(e);
178                     }
179                 }
180                 Ok(())
181             }
182         }
183     }
184 
local_addr(&self) -> io::Result<SockAddr>185     pub fn local_addr(&self) -> io::Result<SockAddr> {
186         unsafe {
187             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
188             let mut len = mem::size_of_val(&storage) as c_int;
189             if sock::getsockname(self.socket, &mut storage as *mut _ as *mut _, &mut len) != 0 {
190                 return Err(last_error());
191             }
192             Ok(SockAddr::from_raw_parts(
193                 &storage as *const _ as *const _,
194                 len,
195             ))
196         }
197     }
198 
peer_addr(&self) -> io::Result<SockAddr>199     pub fn peer_addr(&self) -> io::Result<SockAddr> {
200         unsafe {
201             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
202             let mut len = mem::size_of_val(&storage) as c_int;
203             if sock::getpeername(self.socket, &mut storage as *mut _ as *mut _, &mut len) != 0 {
204                 return Err(last_error());
205             }
206             Ok(SockAddr::from_raw_parts(
207                 &storage as *const _ as *const _,
208                 len,
209             ))
210         }
211     }
212 
try_clone(&self) -> io::Result<Socket>213     pub fn try_clone(&self) -> io::Result<Socket> {
214         unsafe {
215             let mut info: sock::WSAPROTOCOL_INFOW = mem::zeroed();
216             let r = sock::WSADuplicateSocketW(self.socket, GetCurrentProcessId(), &mut info);
217             if r != 0 {
218                 return Err(io::Error::last_os_error());
219             }
220             let socket = sock::WSASocketW(
221                 info.iAddressFamily,
222                 info.iSocketType,
223                 info.iProtocol,
224                 &mut info,
225                 0,
226                 WSA_FLAG_OVERLAPPED,
227             );
228             let socket = match socket {
229                 sock::INVALID_SOCKET => return Err(last_error()),
230                 n => Socket::from_raw_socket(n as RawSocket),
231             };
232             socket.set_no_inherit()?;
233             Ok(socket)
234         }
235     }
236 
accept(&self) -> io::Result<(Socket, SockAddr)>237     pub fn accept(&self) -> io::Result<(Socket, SockAddr)> {
238         unsafe {
239             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
240             let mut len = mem::size_of_val(&storage) as c_int;
241             let socket = { sock::accept(self.socket, &mut storage as *mut _ as *mut _, &mut len) };
242             let socket = match socket {
243                 sock::INVALID_SOCKET => return Err(last_error()),
244                 socket => Socket::from_raw_socket(socket as RawSocket),
245             };
246             socket.set_no_inherit()?;
247             let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, len);
248             Ok((socket, addr))
249         }
250     }
251 
take_error(&self) -> io::Result<Option<io::Error>>252     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
253         unsafe {
254             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_ERROR)?;
255             if raw == 0 {
256                 Ok(None)
257             } else {
258                 Ok(Some(io::Error::from_raw_os_error(raw as i32)))
259             }
260         }
261     }
262 
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>263     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
264         unsafe {
265             let mut nonblocking = nonblocking as c_ulong;
266             let r = sock::ioctlsocket(self.socket, sock::FIONBIO as c_int, &mut nonblocking);
267             if r == 0 {
268                 Ok(())
269             } else {
270                 Err(io::Error::last_os_error())
271             }
272         }
273     }
274 
shutdown(&self, how: Shutdown) -> io::Result<()>275     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
276         let how = match how {
277             Shutdown::Write => SD_SEND,
278             Shutdown::Read => SD_RECEIVE,
279             Shutdown::Both => SD_BOTH,
280         };
281         if unsafe { sock::shutdown(self.socket, how) == 0 } {
282             Ok(())
283         } else {
284             Err(last_error())
285         }
286     }
287 
recv(&self, buf: &mut [u8]) -> io::Result<usize>288     pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
289         unsafe {
290             let n = {
291                 sock::recv(
292                     self.socket,
293                     buf.as_mut_ptr() as *mut c_char,
294                     clamp(buf.len()),
295                     0,
296                 )
297             };
298             match n {
299                 sock::SOCKET_ERROR if sock::WSAGetLastError() == sock::WSAESHUTDOWN as i32 => Ok(0),
300                 sock::SOCKET_ERROR => Err(last_error()),
301                 n => Ok(n as usize),
302             }
303         }
304     }
305 
peek(&self, buf: &mut [u8]) -> io::Result<usize>306     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
307         unsafe {
308             let n = {
309                 sock::recv(
310                     self.socket,
311                     buf.as_mut_ptr() as *mut c_char,
312                     clamp(buf.len()),
313                     MSG_PEEK,
314                 )
315             };
316             match n {
317                 sock::SOCKET_ERROR if sock::WSAGetLastError() == sock::WSAESHUTDOWN as i32 => Ok(0),
318                 sock::SOCKET_ERROR => Err(last_error()),
319                 n => Ok(n as usize),
320             }
321         }
322     }
323 
recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)>324     pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
325         self.recvfrom(buf, 0)
326     }
327 
peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)>328     pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
329         self.recvfrom(buf, MSG_PEEK)
330     }
331 
recvfrom(&self, buf: &mut [u8], flags: c_int) -> io::Result<(usize, SockAddr)>332     fn recvfrom(&self, buf: &mut [u8], flags: c_int) -> io::Result<(usize, SockAddr)> {
333         unsafe {
334             let mut storage: SOCKADDR_STORAGE = mem::zeroed();
335             let mut addrlen = mem::size_of_val(&storage) as c_int;
336 
337             let n = {
338                 sock::recvfrom(
339                     self.socket,
340                     buf.as_mut_ptr() as *mut c_char,
341                     clamp(buf.len()),
342                     flags,
343                     &mut storage as *mut _ as *mut _,
344                     &mut addrlen,
345                 )
346             };
347             let n = match n {
348                 sock::SOCKET_ERROR if sock::WSAGetLastError() == sock::WSAESHUTDOWN as i32 => 0,
349                 sock::SOCKET_ERROR => return Err(last_error()),
350                 n => n as usize,
351             };
352             let addr = SockAddr::from_raw_parts(&storage as *const _ as *const _, addrlen);
353             Ok((n, addr))
354         }
355     }
356 
send(&self, buf: &[u8]) -> io::Result<usize>357     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
358         unsafe {
359             let n = {
360                 sock::send(
361                     self.socket,
362                     buf.as_ptr() as *const c_char,
363                     clamp(buf.len()),
364                     0,
365                 )
366             };
367             if n == sock::SOCKET_ERROR {
368                 Err(last_error())
369             } else {
370                 Ok(n as usize)
371             }
372         }
373     }
374 
send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize>375     pub fn send_to(&self, buf: &[u8], addr: &SockAddr) -> io::Result<usize> {
376         unsafe {
377             let n = {
378                 sock::sendto(
379                     self.socket,
380                     buf.as_ptr() as *const c_char,
381                     clamp(buf.len()),
382                     0,
383                     addr.as_ptr(),
384                     addr.len(),
385                 )
386             };
387             if n == sock::SOCKET_ERROR {
388                 Err(last_error())
389             } else {
390                 Ok(n as usize)
391             }
392         }
393     }
394 
395     // ================================================
396 
ttl(&self) -> io::Result<u32>397     pub fn ttl(&self) -> io::Result<u32> {
398         unsafe {
399             let raw: c_int = self.getsockopt(IPPROTO_IP, IP_TTL)?;
400             Ok(raw as u32)
401         }
402     }
403 
set_ttl(&self, ttl: u32) -> io::Result<()>404     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
405         unsafe { self.setsockopt(IPPROTO_IP, IP_TTL, ttl as c_int) }
406     }
407 
unicast_hops_v6(&self) -> io::Result<u32>408     pub fn unicast_hops_v6(&self) -> io::Result<u32> {
409         unsafe {
410             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_UNICAST_HOPS)?;
411             Ok(raw as u32)
412         }
413     }
414 
set_unicast_hops_v6(&self, hops: u32) -> io::Result<()>415     pub fn set_unicast_hops_v6(&self, hops: u32) -> io::Result<()> {
416         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_UNICAST_HOPS, hops as c_int) }
417     }
418 
only_v6(&self) -> io::Result<bool>419     pub fn only_v6(&self) -> io::Result<bool> {
420         unsafe {
421             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_V6ONLY)?;
422             Ok(raw != 0)
423         }
424     }
425 
set_only_v6(&self, only_v6: bool) -> io::Result<()>426     pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
427         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_V6ONLY, only_v6 as c_int) }
428     }
429 
read_timeout(&self) -> io::Result<Option<Duration>>430     pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
431         unsafe { Ok(ms2dur(self.getsockopt(SOL_SOCKET, SO_RCVTIMEO)?)) }
432     }
433 
set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>434     pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
435         unsafe { self.setsockopt(SOL_SOCKET, SO_RCVTIMEO, dur2ms(dur)?) }
436     }
437 
write_timeout(&self) -> io::Result<Option<Duration>>438     pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
439         unsafe { Ok(ms2dur(self.getsockopt(SOL_SOCKET, SO_SNDTIMEO)?)) }
440     }
441 
set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>442     pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
443         unsafe { self.setsockopt(SOL_SOCKET, SO_SNDTIMEO, dur2ms(dur)?) }
444     }
445 
nodelay(&self) -> io::Result<bool>446     pub fn nodelay(&self) -> io::Result<bool> {
447         unsafe {
448             let raw: c_char = self.getsockopt(IPPROTO_TCP, TCP_NODELAY)?;
449             Ok(raw != 0)
450         }
451     }
452 
set_nodelay(&self, nodelay: bool) -> io::Result<()>453     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
454         unsafe { self.setsockopt(IPPROTO_TCP, TCP_NODELAY, nodelay as c_char) }
455     }
456 
broadcast(&self) -> io::Result<bool>457     pub fn broadcast(&self) -> io::Result<bool> {
458         unsafe {
459             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_BROADCAST)?;
460             Ok(raw != 0)
461         }
462     }
463 
set_broadcast(&self, broadcast: bool) -> io::Result<()>464     pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
465         unsafe { self.setsockopt(SOL_SOCKET, SO_BROADCAST, broadcast as c_int) }
466     }
467 
multicast_loop_v4(&self) -> io::Result<bool>468     pub fn multicast_loop_v4(&self) -> io::Result<bool> {
469         unsafe {
470             let raw: c_int = self.getsockopt(IPPROTO_IP, IP_MULTICAST_LOOP)?;
471             Ok(raw != 0)
472         }
473     }
474 
set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()>475     pub fn set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()> {
476         unsafe { self.setsockopt(IPPROTO_IP, IP_MULTICAST_LOOP, multicast_loop_v4 as c_int) }
477     }
478 
multicast_ttl_v4(&self) -> io::Result<u32>479     pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
480         unsafe {
481             let raw: c_int = self.getsockopt(IPPROTO_IP, IP_MULTICAST_TTL)?;
482             Ok(raw as u32)
483         }
484     }
485 
set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()>486     pub fn set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()> {
487         unsafe { self.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, multicast_ttl_v4 as c_int) }
488     }
489 
multicast_hops_v6(&self) -> io::Result<u32>490     pub fn multicast_hops_v6(&self) -> io::Result<u32> {
491         unsafe {
492             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_HOPS)?;
493             Ok(raw as u32)
494         }
495     }
496 
set_multicast_hops_v6(&self, hops: u32) -> io::Result<()>497     pub fn set_multicast_hops_v6(&self, hops: u32) -> io::Result<()> {
498         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_HOPS, hops as c_int) }
499     }
500 
multicast_if_v4(&self) -> io::Result<Ipv4Addr>501     pub fn multicast_if_v4(&self) -> io::Result<Ipv4Addr> {
502         unsafe {
503             let imr_interface: IN_ADDR = self.getsockopt(IPPROTO_IP, IP_MULTICAST_IF)?;
504             Ok(from_s_addr(imr_interface.S_un))
505         }
506     }
507 
set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()>508     pub fn set_multicast_if_v4(&self, interface: &Ipv4Addr) -> io::Result<()> {
509         let interface = to_s_addr(interface);
510         let imr_interface = IN_ADDR { S_un: interface };
511 
512         unsafe { self.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, imr_interface) }
513     }
514 
multicast_if_v6(&self) -> io::Result<u32>515     pub fn multicast_if_v6(&self) -> io::Result<u32> {
516         unsafe {
517             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_IF)?;
518             Ok(raw as u32)
519         }
520     }
521 
set_multicast_if_v6(&self, interface: u32) -> io::Result<()>522     pub fn set_multicast_if_v6(&self, interface: u32) -> io::Result<()> {
523         unsafe { self.setsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_IF, interface as c_int) }
524     }
525 
multicast_loop_v6(&self) -> io::Result<bool>526     pub fn multicast_loop_v6(&self) -> io::Result<bool> {
527         unsafe {
528             let raw: c_int = self.getsockopt(IPPROTO_IPV6 as c_int, IPV6_MULTICAST_LOOP)?;
529             Ok(raw != 0)
530         }
531     }
532 
set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()>533     pub fn set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()> {
534         unsafe {
535             self.setsockopt(
536                 IPPROTO_IPV6 as c_int,
537                 IPV6_MULTICAST_LOOP,
538                 multicast_loop_v6 as c_int,
539             )
540         }
541     }
542 
join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()>543     pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
544         let multiaddr = to_s_addr(multiaddr);
545         let interface = to_s_addr(interface);
546         let mreq = IP_MREQ {
547             imr_multiaddr: IN_ADDR { S_un: multiaddr },
548             imr_interface: IN_ADDR { S_un: interface },
549         };
550         unsafe { self.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq) }
551     }
552 
join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()>553     pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
554         let multiaddr = to_in6_addr(multiaddr);
555         let mreq = IPV6_MREQ {
556             ipv6mr_multiaddr: multiaddr,
557             ipv6mr_interface: interface,
558         };
559         unsafe { self.setsockopt(IPPROTO_IP, IPV6_ADD_MEMBERSHIP, mreq) }
560     }
561 
leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()>562     pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
563         let multiaddr = to_s_addr(multiaddr);
564         let interface = to_s_addr(interface);
565         let mreq = IP_MREQ {
566             imr_multiaddr: IN_ADDR { S_un: multiaddr },
567             imr_interface: IN_ADDR { S_un: interface },
568         };
569         unsafe { self.setsockopt(IPPROTO_IP, IP_DROP_MEMBERSHIP, mreq) }
570     }
571 
leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()>572     pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
573         let multiaddr = to_in6_addr(multiaddr);
574         let mreq = IPV6_MREQ {
575             ipv6mr_multiaddr: multiaddr,
576             ipv6mr_interface: interface,
577         };
578         unsafe { self.setsockopt(IPPROTO_IP, IPV6_DROP_MEMBERSHIP, mreq) }
579     }
580 
linger(&self) -> io::Result<Option<Duration>>581     pub fn linger(&self) -> io::Result<Option<Duration>> {
582         unsafe { Ok(linger2dur(self.getsockopt(SOL_SOCKET, SO_LINGER)?)) }
583     }
584 
set_linger(&self, dur: Option<Duration>) -> io::Result<()>585     pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
586         unsafe { self.setsockopt(SOL_SOCKET, SO_LINGER, dur2linger(dur)) }
587     }
588 
set_reuse_address(&self, reuse: bool) -> io::Result<()>589     pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
590         unsafe { self.setsockopt(SOL_SOCKET, SO_REUSEADDR, reuse as c_int) }
591     }
592 
reuse_address(&self) -> io::Result<bool>593     pub fn reuse_address(&self) -> io::Result<bool> {
594         unsafe {
595             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_REUSEADDR)?;
596             Ok(raw != 0)
597         }
598     }
599 
recv_buffer_size(&self) -> io::Result<usize>600     pub fn recv_buffer_size(&self) -> io::Result<usize> {
601         unsafe {
602             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_RCVBUF)?;
603             Ok(raw as usize)
604         }
605     }
606 
set_recv_buffer_size(&self, size: usize) -> io::Result<()>607     pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
608         unsafe {
609             // TODO: casting usize to a c_int should be a checked cast
610             self.setsockopt(SOL_SOCKET, SO_RCVBUF, size as c_int)
611         }
612     }
613 
send_buffer_size(&self) -> io::Result<usize>614     pub fn send_buffer_size(&self) -> io::Result<usize> {
615         unsafe {
616             let raw: c_int = self.getsockopt(SOL_SOCKET, SO_SNDBUF)?;
617             Ok(raw as usize)
618         }
619     }
620 
set_send_buffer_size(&self, size: usize) -> io::Result<()>621     pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
622         unsafe {
623             // TODO: casting usize to a c_int should be a checked cast
624             self.setsockopt(SOL_SOCKET, SO_SNDBUF, size as c_int)
625         }
626     }
627 
keepalive(&self) -> io::Result<Option<Duration>>628     pub fn keepalive(&self) -> io::Result<Option<Duration>> {
629         let mut ka = tcp_keepalive {
630             onoff: 0,
631             keepalivetime: 0,
632             keepaliveinterval: 0,
633         };
634         let n = unsafe {
635             sock::WSAIoctl(
636                 self.socket,
637                 SIO_KEEPALIVE_VALS,
638                 0 as *mut _,
639                 0,
640                 &mut ka as *mut _ as *mut _,
641                 mem::size_of_val(&ka) as DWORD,
642                 0 as *mut _,
643                 0 as *mut _,
644                 None,
645             )
646         };
647         if n == 0 {
648             Ok(if ka.onoff == 0 {
649                 None
650             } else if ka.keepaliveinterval == 0 {
651                 None
652             } else {
653                 let seconds = ka.keepaliveinterval / 1000;
654                 let nanos = (ka.keepaliveinterval % 1000) * 1_000_000;
655                 Some(Duration::new(seconds as u64, nanos as u32))
656             })
657         } else {
658             Err(last_error())
659         }
660     }
661 
set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()>662     pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
663         let ms = dur2ms(keepalive)?;
664         // TODO: checked casts here
665         let ka = tcp_keepalive {
666             onoff: keepalive.is_some() as c_ulong,
667             keepalivetime: ms as c_ulong,
668             keepaliveinterval: ms as c_ulong,
669         };
670         let mut out = 0;
671         let n = unsafe {
672             sock::WSAIoctl(
673                 self.socket,
674                 SIO_KEEPALIVE_VALS,
675                 &ka as *const _ as *mut _,
676                 mem::size_of_val(&ka) as DWORD,
677                 0 as *mut _,
678                 0,
679                 &mut out,
680                 0 as *mut _,
681                 None,
682             )
683         };
684         if n == 0 {
685             Ok(())
686         } else {
687             Err(last_error())
688         }
689     }
690 
setsockopt<T>(&self, opt: c_int, val: c_int, payload: T) -> io::Result<()> where T: Copy,691     unsafe fn setsockopt<T>(&self, opt: c_int, val: c_int, payload: T) -> io::Result<()>
692     where
693         T: Copy,
694     {
695         let payload = &payload as *const T as *const c_char;
696         if sock::setsockopt(self.socket, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
697             Ok(())
698         } else {
699             Err(last_error())
700         }
701     }
702 
getsockopt<T: Copy>(&self, opt: c_int, val: c_int) -> io::Result<T>703     unsafe fn getsockopt<T: Copy>(&self, opt: c_int, val: c_int) -> io::Result<T> {
704         let mut slot: T = mem::zeroed();
705         let mut len = mem::size_of::<T>() as c_int;
706         if sock::getsockopt(
707             self.socket,
708             opt,
709             val,
710             &mut slot as *mut _ as *mut _,
711             &mut len,
712         ) == 0
713         {
714             assert_eq!(len as usize, mem::size_of::<T>());
715             Ok(slot)
716         } else {
717             Err(last_error())
718         }
719     }
720 
set_no_inherit(&self) -> io::Result<()>721     fn set_no_inherit(&self) -> io::Result<()> {
722         unsafe {
723             let r = SetHandleInformation(self.socket as HANDLE, HANDLE_FLAG_INHERIT, 0);
724             if r == 0 {
725                 Err(io::Error::last_os_error())
726             } else {
727                 Ok(())
728             }
729         }
730     }
731 }
732 
733 impl Read for Socket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>734     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
735         <&Socket>::read(&mut &*self, buf)
736     }
737 }
738 
739 impl<'a> Read for &'a Socket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>740     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
741         self.recv(buf)
742     }
743 }
744 
745 impl Write for Socket {
write(&mut self, buf: &[u8]) -> io::Result<usize>746     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
747         <&Socket>::write(&mut &*self, buf)
748     }
749 
flush(&mut self) -> io::Result<()>750     fn flush(&mut self) -> io::Result<()> {
751         <&Socket>::flush(&mut &*self)
752     }
753 }
754 
755 impl<'a> Write for &'a Socket {
write(&mut self, buf: &[u8]) -> io::Result<usize>756     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
757         self.send(buf)
758     }
759 
flush(&mut self) -> io::Result<()>760     fn flush(&mut self) -> io::Result<()> {
761         Ok(())
762     }
763 }
764 
765 impl fmt::Debug for Socket {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result766     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
767         let mut f = f.debug_struct("Socket");
768         f.field("socket", &self.socket);
769         if let Ok(addr) = self.local_addr() {
770             f.field("local_addr", &addr);
771         }
772         if let Ok(addr) = self.peer_addr() {
773             f.field("peer_addr", &addr);
774         }
775         f.finish()
776     }
777 }
778 
779 impl AsRawSocket for Socket {
as_raw_socket(&self) -> RawSocket780     fn as_raw_socket(&self) -> RawSocket {
781         self.socket as RawSocket
782     }
783 }
784 
785 impl IntoRawSocket for Socket {
into_raw_socket(self) -> RawSocket786     fn into_raw_socket(self) -> RawSocket {
787         let socket = self.socket;
788         mem::forget(self);
789         socket as RawSocket
790     }
791 }
792 
793 impl FromRawSocket for Socket {
from_raw_socket(socket: RawSocket) -> Socket794     unsafe fn from_raw_socket(socket: RawSocket) -> Socket {
795         Socket {
796             socket: socket as sock::SOCKET,
797         }
798     }
799 }
800 
801 impl AsRawSocket for crate::Socket {
as_raw_socket(&self) -> RawSocket802     fn as_raw_socket(&self) -> RawSocket {
803         self.inner.as_raw_socket()
804     }
805 }
806 
807 impl IntoRawSocket for crate::Socket {
into_raw_socket(self) -> RawSocket808     fn into_raw_socket(self) -> RawSocket {
809         self.inner.into_raw_socket()
810     }
811 }
812 
813 impl FromRawSocket for crate::Socket {
from_raw_socket(socket: RawSocket) -> crate::Socket814     unsafe fn from_raw_socket(socket: RawSocket) -> crate::Socket {
815         crate::Socket {
816             inner: Socket::from_raw_socket(socket),
817         }
818     }
819 }
820 
821 impl Drop for Socket {
drop(&mut self)822     fn drop(&mut self) {
823         unsafe {
824             let _ = sock::closesocket(self.socket);
825         }
826     }
827 }
828 
829 impl From<Socket> for net::TcpStream {
from(socket: Socket) -> net::TcpStream830     fn from(socket: Socket) -> net::TcpStream {
831         unsafe { net::TcpStream::from_raw_socket(socket.into_raw_socket()) }
832     }
833 }
834 
835 impl From<Socket> for net::TcpListener {
from(socket: Socket) -> net::TcpListener836     fn from(socket: Socket) -> net::TcpListener {
837         unsafe { net::TcpListener::from_raw_socket(socket.into_raw_socket()) }
838     }
839 }
840 
841 impl From<Socket> for net::UdpSocket {
from(socket: Socket) -> net::UdpSocket842     fn from(socket: Socket) -> net::UdpSocket {
843         unsafe { net::UdpSocket::from_raw_socket(socket.into_raw_socket()) }
844     }
845 }
846 
847 impl From<net::TcpStream> for Socket {
from(socket: net::TcpStream) -> Socket848     fn from(socket: net::TcpStream) -> Socket {
849         unsafe { Socket::from_raw_socket(socket.into_raw_socket()) }
850     }
851 }
852 
853 impl From<net::TcpListener> for Socket {
from(socket: net::TcpListener) -> Socket854     fn from(socket: net::TcpListener) -> Socket {
855         unsafe { Socket::from_raw_socket(socket.into_raw_socket()) }
856     }
857 }
858 
859 impl From<net::UdpSocket> for Socket {
from(socket: net::UdpSocket) -> Socket860     fn from(socket: net::UdpSocket) -> Socket {
861         unsafe { Socket::from_raw_socket(socket.into_raw_socket()) }
862     }
863 }
864 
clamp(input: usize) -> c_int865 fn clamp(input: usize) -> c_int {
866     cmp::min(input, <c_int>::max_value() as usize) as c_int
867 }
868 
dur2ms(dur: Option<Duration>) -> io::Result<DWORD>869 fn dur2ms(dur: Option<Duration>) -> io::Result<DWORD> {
870     match dur {
871         Some(dur) => {
872             // Note that a duration is a (u64, u32) (seconds, nanoseconds)
873             // pair, and the timeouts in windows APIs are typically u32
874             // milliseconds. To translate, we have two pieces to take care of:
875             //
876             // * Nanosecond precision is rounded up
877             // * Greater than u32::MAX milliseconds (50 days) is rounded up to
878             //   INFINITE (never time out).
879             let ms = dur
880                 .as_secs()
881                 .checked_mul(1000)
882                 .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
883                 .and_then(|ms| {
884                     ms.checked_add(if dur.subsec_nanos() % 1_000_000 > 0 {
885                         1
886                     } else {
887                         0
888                     })
889                 })
890                 .map(|ms| {
891                     if ms > <DWORD>::max_value() as u64 {
892                         INFINITE
893                     } else {
894                         ms as DWORD
895                     }
896                 })
897                 .unwrap_or(INFINITE);
898             if ms == 0 {
899                 return Err(io::Error::new(
900                     io::ErrorKind::InvalidInput,
901                     "cannot set a 0 duration timeout",
902                 ));
903             }
904             Ok(ms)
905         }
906         None => Ok(0),
907     }
908 }
909 
ms2dur(raw: DWORD) -> Option<Duration>910 fn ms2dur(raw: DWORD) -> Option<Duration> {
911     if raw == 0 {
912         None
913     } else {
914         let secs = raw / 1000;
915         let nsec = (raw % 1000) * 1000000;
916         Some(Duration::new(secs as u64, nsec as u32))
917     }
918 }
919 
to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un920 fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
921     let octets = addr.octets();
922     let res = crate::hton(
923         ((octets[0] as ULONG) << 24)
924             | ((octets[1] as ULONG) << 16)
925             | ((octets[2] as ULONG) << 8)
926             | ((octets[3] as ULONG) << 0),
927     );
928     let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
929     unsafe { *(new_addr.S_addr_mut()) = res };
930     new_addr
931 }
932 
from_s_addr(in_addr: in_addr_S_un) -> Ipv4Addr933 fn from_s_addr(in_addr: in_addr_S_un) -> Ipv4Addr {
934     let h_addr = crate::ntoh(unsafe { *in_addr.S_addr() });
935 
936     let a: u8 = (h_addr >> 24) as u8;
937     let b: u8 = (h_addr >> 16) as u8;
938     let c: u8 = (h_addr >> 8) as u8;
939     let d: u8 = (h_addr >> 0) as u8;
940 
941     Ipv4Addr::new(a, b, c, d)
942 }
943 
to_in6_addr(addr: &Ipv6Addr) -> in6_addr944 fn to_in6_addr(addr: &Ipv6Addr) -> in6_addr {
945     let mut ret_addr: in6_addr_u = unsafe { mem::zeroed() };
946     unsafe { *(ret_addr.Byte_mut()) = addr.octets() };
947     let mut ret: in6_addr = unsafe { mem::zeroed() };
948     ret.u = ret_addr;
949     ret
950 }
951 
linger2dur(linger_opt: sock::linger) -> Option<Duration>952 fn linger2dur(linger_opt: sock::linger) -> Option<Duration> {
953     if linger_opt.l_onoff == 0 {
954         None
955     } else {
956         Some(Duration::from_secs(linger_opt.l_linger as u64))
957     }
958 }
959 
dur2linger(dur: Option<Duration>) -> sock::linger960 fn dur2linger(dur: Option<Duration>) -> sock::linger {
961     match dur {
962         Some(d) => sock::linger {
963             l_onoff: 1,
964             l_linger: d.as_secs() as u16,
965         },
966         None => sock::linger {
967             l_onoff: 0,
968             l_linger: 0,
969         },
970     }
971 }
972 
973 #[test]
test_ip()974 fn test_ip() {
975     let ip = Ipv4Addr::new(127, 0, 0, 1);
976     assert_eq!(ip, from_s_addr(to_s_addr(&ip)));
977 }
978