1 //! UDP for IOCP
2 //!
3 //! Note that most of this module is quite similar to the TCP module, so if
4 //! something seems odd you may also want to try the docs over there.
5 
6 use std::fmt;
7 use std::io::prelude::*;
8 use std::io;
9 use std::mem;
10 use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
11 use std::sync::{Mutex, MutexGuard};
12 
13 #[allow(unused_imports)]
14 use net2::{UdpBuilder, UdpSocketExt};
15 use winapi::*;
16 use miow::iocp::CompletionStatus;
17 use miow::net::SocketAddrBuf;
18 use miow::net::UdpSocketExt as MiowUdpSocketExt;
19 
20 use {poll, Ready, Poll, PollOpt, Token};
21 use event::Evented;
22 use sys::windows::from_raw_arc::FromRawArc;
23 use sys::windows::selector::{Overlapped, ReadyBinding};
24 
25 pub struct UdpSocket {
26     imp: Imp,
27     registration: Mutex<Option<poll::Registration>>,
28 }
29 
30 #[derive(Clone)]
31 struct Imp {
32     inner: FromRawArc<Io>,
33 }
34 
35 struct Io {
36     read: Overlapped,
37     write: Overlapped,
38     socket: net::UdpSocket,
39     inner: Mutex<Inner>,
40 }
41 
42 struct Inner {
43     iocp: ReadyBinding,
44     read: State<Vec<u8>, Vec<u8>>,
45     write: State<Vec<u8>, (Vec<u8>, usize)>,
46     read_buf: SocketAddrBuf,
47 }
48 
49 enum State<T, U> {
50     Empty,
51     Pending(T),
52     Ready(U),
53     Error(io::Error),
54 }
55 
56 impl UdpSocket {
new(socket: net::UdpSocket) -> io::Result<UdpSocket>57     pub fn new(socket: net::UdpSocket) -> io::Result<UdpSocket> {
58         Ok(UdpSocket {
59             registration: Mutex::new(None),
60             imp: Imp {
61                 inner: FromRawArc::new(Io {
62                     read: Overlapped::new(recv_done),
63                     write: Overlapped::new(send_done),
64                     socket: socket,
65                     inner: Mutex::new(Inner {
66                         iocp: ReadyBinding::new(),
67                         read: State::Empty,
68                         write: State::Empty,
69                         read_buf: SocketAddrBuf::new(),
70                     }),
71                 }),
72             },
73         })
74     }
75 
local_addr(&self) -> io::Result<SocketAddr>76     pub fn local_addr(&self) -> io::Result<SocketAddr> {
77         self.imp.inner.socket.local_addr()
78     }
79 
try_clone(&self) -> io::Result<UdpSocket>80     pub fn try_clone(&self) -> io::Result<UdpSocket> {
81         self.imp.inner.socket.try_clone().and_then(UdpSocket::new)
82     }
83 
84     /// Note that unlike `TcpStream::write` this function will not attempt to
85     /// continue writing `buf` until its entirely written.
86     ///
87     /// TODO: This... may be wrong in the long run. We're reporting that we
88     ///       successfully wrote all of the bytes in `buf` but it's possible
89     ///       that we don't actually end up writing all of them!
send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize>90     pub fn send_to(&self, buf: &[u8], target: &SocketAddr)
91                    -> io::Result<usize> {
92         let mut me = self.inner();
93         let me = &mut *me;
94 
95         match me.write {
96             State::Empty => {}
97             _ => return Err(io::ErrorKind::WouldBlock.into()),
98         }
99 
100         if !me.iocp.registered() {
101             return Err(io::ErrorKind::WouldBlock.into())
102         }
103 
104         let interest = me.iocp.readiness();
105         me.iocp.set_readiness(interest - Ready::writable());
106 
107         let mut owned_buf = me.iocp.get_buffer(64 * 1024);
108         let amt = owned_buf.write(buf)?;
109         unsafe {
110             trace!("scheduling a send");
111             self.imp.inner.socket.send_to_overlapped(&owned_buf, target,
112                                                      self.imp.inner.write.as_mut_ptr())
113         }?;
114         me.write = State::Pending(owned_buf);
115         mem::forget(self.imp.clone());
116         Ok(amt)
117     }
118 
119     /// Note that unlike `TcpStream::write` this function will not attempt to
120     /// continue writing `buf` until its entirely written.
121     ///
122     /// TODO: This... may be wrong in the long run. We're reporting that we
123     ///       successfully wrote all of the bytes in `buf` but it's possible
124     ///       that we don't actually end up writing all of them!
send(&self, buf: &[u8]) -> io::Result<usize>125     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
126         let mut me = self.inner();
127         let me = &mut *me;
128 
129         match me.write {
130             State::Empty => {}
131             _ => return Err(io::ErrorKind::WouldBlock.into()),
132         }
133 
134         if !me.iocp.registered() {
135             return Err(io::ErrorKind::WouldBlock.into())
136         }
137 
138         let interest = me.iocp.readiness();
139         me.iocp.set_readiness(interest - Ready::writable());
140 
141         let mut owned_buf = me.iocp.get_buffer(64 * 1024);
142         let amt = owned_buf.write(buf)?;
143         unsafe {
144             trace!("scheduling a send");
145             self.imp.inner.socket.send_overlapped(&owned_buf, self.imp.inner.write.as_mut_ptr())
146 
147         }?;
148         me.write = State::Pending(owned_buf);
149         mem::forget(self.imp.clone());
150         Ok(amt)
151     }
152 
recv_from(&self, mut buf: &mut [u8]) -> io::Result<(usize, SocketAddr)>153     pub fn recv_from(&self, mut buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
154         let mut me = self.inner();
155         match mem::replace(&mut me.read, State::Empty) {
156             State::Empty => Err(io::ErrorKind::WouldBlock.into()),
157             State::Pending(b) => { me.read = State::Pending(b); Err(io::ErrorKind::WouldBlock.into()) }
158             State::Ready(data) => {
159                 // If we weren't provided enough space to receive the message
160                 // then don't actually read any data, just return an error.
161                 if buf.len() < data.len() {
162                     me.read = State::Ready(data);
163                     Err(io::Error::from_raw_os_error(WSAEMSGSIZE as i32))
164                 } else {
165                     let r = if let Some(addr) = me.read_buf.to_socket_addr() {
166                         buf.write(&data).unwrap();
167                         Ok((data.len(), addr))
168                     } else {
169                         Err(io::Error::new(io::ErrorKind::Other,
170                                            "failed to parse socket address"))
171                     };
172                     me.iocp.put_buffer(data);
173                     self.imp.schedule_read_from(&mut me);
174                     r
175                 }
176             }
177             State::Error(e) => {
178                 self.imp.schedule_read_from(&mut me);
179                 Err(e)
180             }
181         }
182     }
183 
recv(&self, buf: &mut [u8]) -> io::Result<usize>184     pub fn recv(&self, buf: &mut [u8])
185                      -> io::Result<usize> {
186         //Since recv_from can be used on connected sockets just call it and drop the address.
187         self.recv_from(buf).map(|(size,_)| size)
188     }
189 
connect(&self, addr: SocketAddr) -> io::Result<()>190     pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
191         self.imp.inner.socket.connect(addr)
192     }
193 
broadcast(&self) -> io::Result<bool>194     pub fn broadcast(&self) -> io::Result<bool> {
195         self.imp.inner.socket.broadcast()
196     }
197 
set_broadcast(&self, on: bool) -> io::Result<()>198     pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
199         self.imp.inner.socket.set_broadcast(on)
200     }
201 
multicast_loop_v4(&self) -> io::Result<bool>202     pub fn multicast_loop_v4(&self) -> io::Result<bool> {
203         self.imp.inner.socket.multicast_loop_v4()
204     }
205 
set_multicast_loop_v4(&self, on: bool) -> io::Result<()>206     pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
207         self.imp.inner.socket.set_multicast_loop_v4(on)
208     }
209 
multicast_ttl_v4(&self) -> io::Result<u32>210     pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
211         self.imp.inner.socket.multicast_ttl_v4()
212     }
213 
set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()>214     pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> {
215         self.imp.inner.socket.set_multicast_ttl_v4(ttl)
216     }
217 
multicast_loop_v6(&self) -> io::Result<bool>218     pub fn multicast_loop_v6(&self) -> io::Result<bool> {
219         self.imp.inner.socket.multicast_loop_v6()
220     }
221 
set_multicast_loop_v6(&self, on: bool) -> io::Result<()>222     pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
223         self.imp.inner.socket.set_multicast_loop_v6(on)
224     }
225 
ttl(&self) -> io::Result<u32>226     pub fn ttl(&self) -> io::Result<u32> {
227         self.imp.inner.socket.ttl()
228     }
229 
set_ttl(&self, ttl: u32) -> io::Result<()>230     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
231         self.imp.inner.socket.set_ttl(ttl)
232     }
233 
join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()>234     pub fn join_multicast_v4(&self,
235                              multiaddr: &Ipv4Addr,
236                              interface: &Ipv4Addr) -> io::Result<()> {
237         self.imp.inner.socket.join_multicast_v4(multiaddr, interface)
238     }
239 
join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()>240     pub fn join_multicast_v6(&self,
241                              multiaddr: &Ipv6Addr,
242                              interface: u32) -> io::Result<()> {
243         self.imp.inner.socket.join_multicast_v6(multiaddr, interface)
244     }
245 
leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()>246     pub fn leave_multicast_v4(&self,
247                               multiaddr: &Ipv4Addr,
248                               interface: &Ipv4Addr) -> io::Result<()> {
249         self.imp.inner.socket.leave_multicast_v4(multiaddr, interface)
250     }
251 
leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()>252     pub fn leave_multicast_v6(&self,
253                               multiaddr: &Ipv6Addr,
254                               interface: u32) -> io::Result<()> {
255         self.imp.inner.socket.leave_multicast_v6(multiaddr, interface)
256     }
257 
set_only_v6(&self, only_v6: bool) -> io::Result<()>258     pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
259         self.imp.inner.socket.set_only_v6(only_v6)
260     }
261 
only_v6(&self) -> io::Result<bool>262     pub fn only_v6(&self) -> io::Result<bool> {
263         self.imp.inner.socket.only_v6()
264     }
265 
take_error(&self) -> io::Result<Option<io::Error>>266     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
267         self.imp.inner.socket.take_error()
268     }
269 
inner(&self) -> MutexGuard<Inner>270     fn inner(&self) -> MutexGuard<Inner> {
271         self.imp.inner()
272     }
273 
post_register(&self, interest: Ready, me: &mut Inner)274     fn post_register(&self, interest: Ready, me: &mut Inner) {
275         if interest.is_readable() {
276             //We use recv_from here since it is well specified for both
277             //connected and non-connected sockets and we can discard the address
278             //when calling recv().
279             self.imp.schedule_read_from(me);
280         }
281         // See comments in TcpSocket::post_register for what's going on here
282         if interest.is_writable() {
283             if let State::Empty = me.write {
284                 self.imp.add_readiness(me, Ready::writable());
285             }
286         }
287     }
288 }
289 
290 impl Imp {
inner(&self) -> MutexGuard<Inner>291     fn inner(&self) -> MutexGuard<Inner> {
292         self.inner.inner.lock().unwrap()
293     }
294 
schedule_read_from(&self, me: &mut Inner)295     fn schedule_read_from(&self, me: &mut Inner) {
296         match me.read {
297             State::Empty => {}
298             _ => return,
299         }
300 
301         let interest = me.iocp.readiness();
302         me.iocp.set_readiness(interest - Ready::readable());
303 
304         let mut buf = me.iocp.get_buffer(64 * 1024);
305         let res = unsafe {
306             trace!("scheduling a read");
307             let cap = buf.capacity();
308             buf.set_len(cap);
309             self.inner.socket.recv_from_overlapped(&mut buf, &mut me.read_buf,
310                                                    self.inner.read.as_mut_ptr())
311         };
312         match res {
313             Ok(_) => {
314                 me.read = State::Pending(buf);
315                 mem::forget(self.clone());
316             }
317             Err(e) => {
318                 me.read = State::Error(e);
319                 self.add_readiness(me, Ready::readable());
320                 me.iocp.put_buffer(buf);
321             }
322         }
323     }
324 
325     // See comments in tcp::StreamImp::push
add_readiness(&self, me: &Inner, set: Ready)326     fn add_readiness(&self, me: &Inner, set: Ready) {
327         me.iocp.set_readiness(set | me.iocp.readiness());
328     }
329 }
330 
331 impl Evented for UdpSocket {
register(&self, poll: &Poll, token: Token, interest: Ready, opts: PollOpt) -> io::Result<()>332     fn register(&self, poll: &Poll, token: Token,
333                 interest: Ready, opts: PollOpt) -> io::Result<()> {
334         let mut me = self.inner();
335         me.iocp.register_socket(&self.imp.inner.socket,
336                                      poll, token, interest, opts,
337                                      &self.registration)?;
338         self.post_register(interest, &mut me);
339         Ok(())
340     }
341 
reregister(&self, poll: &Poll, token: Token, interest: Ready, opts: PollOpt) -> io::Result<()>342     fn reregister(&self, poll: &Poll, token: Token,
343                   interest: Ready, opts: PollOpt) -> io::Result<()> {
344         let mut me = self.inner();
345         me.iocp.reregister_socket(&self.imp.inner.socket,
346                                        poll, token, interest,
347                                        opts, &self.registration)?;
348         self.post_register(interest, &mut me);
349         Ok(())
350     }
351 
deregister(&self, poll: &Poll) -> io::Result<()>352     fn deregister(&self, poll: &Poll) -> io::Result<()> {
353         self.inner().iocp.deregister(&self.imp.inner.socket,
354                                      poll, &self.registration)
355     }
356 }
357 
358 impl fmt::Debug for UdpSocket {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result359     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
360         f.debug_struct("UdpSocket")
361             .finish()
362     }
363 }
364 
365 impl Drop for UdpSocket {
drop(&mut self)366     fn drop(&mut self) {
367         let inner = self.inner();
368 
369         // If we're still internally reading, we're no longer interested. Note
370         // though that we don't cancel any writes which may have been issued to
371         // preserve the same semantics as Unix.
372         unsafe {
373             match inner.read {
374                 State::Pending(_) => {
375                     drop(super::cancel(&self.imp.inner.socket,
376                                        &self.imp.inner.read));
377                 }
378                 State::Empty |
379                 State::Ready(_) |
380                 State::Error(_) => {}
381             }
382         }
383     }
384 }
385 
send_done(status: &OVERLAPPED_ENTRY)386 fn send_done(status: &OVERLAPPED_ENTRY) {
387     let status = CompletionStatus::from_entry(status);
388     trace!("finished a send {}", status.bytes_transferred());
389     let me2 = Imp {
390         inner: unsafe { overlapped2arc!(status.overlapped(), Io, write) },
391     };
392     let mut me = me2.inner();
393     me.write = State::Empty;
394     me2.add_readiness(&mut me, Ready::writable());
395 }
396 
recv_done(status: &OVERLAPPED_ENTRY)397 fn recv_done(status: &OVERLAPPED_ENTRY) {
398     let status = CompletionStatus::from_entry(status);
399     trace!("finished a recv {}", status.bytes_transferred());
400     let me2 = Imp {
401         inner: unsafe { overlapped2arc!(status.overlapped(), Io, read) },
402     };
403     let mut me = me2.inner();
404     let mut buf = match mem::replace(&mut me.read, State::Empty) {
405         State::Pending(buf) => buf,
406         _ => unreachable!(),
407     };
408     unsafe {
409         buf.set_len(status.bytes_transferred() as usize);
410     }
411     me.read = State::Ready(buf);
412     me2.add_readiness(&mut me, Ready::readable());
413 }
414