1 use super::afd::{self, Afd, AfdPollInfo};
2 use super::io_status_block::IoStatusBlock;
3 use super::Event;
4 use crate::sys::Events;
5 use crate::Interest;
6
7 use miow::iocp::{CompletionPort, CompletionStatus};
8 use miow::Overlapped;
9 use std::collections::VecDeque;
10 use std::marker::PhantomPinned;
11 use std::os::windows::io::RawSocket;
12 use std::pin::Pin;
13 #[cfg(debug_assertions)]
14 use std::sync::atomic::AtomicUsize;
15 use std::sync::atomic::{AtomicBool, Ordering};
16 use std::sync::{Arc, Mutex};
17 use std::time::Duration;
18 use std::{io, ptr};
19 use winapi::shared::ntdef::NT_SUCCESS;
20 use winapi::shared::ntdef::{HANDLE, PVOID};
21 use winapi::shared::ntstatus::STATUS_CANCELLED;
22 use winapi::shared::winerror::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, WAIT_TIMEOUT};
23 use winapi::um::minwinbase::OVERLAPPED;
24
25 /// Overlapped value to indicate a `Waker` event.
26 //
27 // Note: this must be null, `SelectorInner::feed_events` depends on it.
28 pub const WAKER_OVERLAPPED: *mut Overlapped = ptr::null_mut();
29
30 #[derive(Debug)]
31 struct AfdGroup {
32 cp: Arc<CompletionPort>,
33 afd_group: Mutex<Vec<Arc<Afd>>>,
34 }
35
36 impl AfdGroup {
new(cp: Arc<CompletionPort>) -> AfdGroup37 pub fn new(cp: Arc<CompletionPort>) -> AfdGroup {
38 AfdGroup {
39 afd_group: Mutex::new(Vec::new()),
40 cp,
41 }
42 }
43
release_unused_afd(&self)44 pub fn release_unused_afd(&self) {
45 let mut afd_group = self.afd_group.lock().unwrap();
46 afd_group.retain(|g| Arc::strong_count(&g) > 1);
47 }
48 }
49
50 cfg_net! {
51 const POLL_GROUP__MAX_GROUP_SIZE: usize = 32;
52
53 impl AfdGroup {
54 pub fn acquire(&self) -> io::Result<Arc<Afd>> {
55 let mut afd_group = self.afd_group.lock().unwrap();
56 if afd_group.len() == 0 {
57 self._alloc_afd_group(&mut afd_group)?;
58 } else {
59 // + 1 reference in Vec
60 if Arc::strong_count(afd_group.last().unwrap()) >= POLL_GROUP__MAX_GROUP_SIZE + 1 {
61 self._alloc_afd_group(&mut afd_group)?;
62 }
63 }
64
65 match afd_group.last() {
66 Some(arc) => Ok(arc.clone()),
67 None => unreachable!(
68 "Cannot acquire afd, {:#?}, afd_group: {:#?}",
69 self, afd_group
70 ),
71 }
72 }
73
74 fn _alloc_afd_group(&self, afd_group: &mut Vec<Arc<Afd>>) -> io::Result<()> {
75 let afd = Afd::new(&self.cp)?;
76 let arc = Arc::new(afd);
77 afd_group.push(arc);
78 Ok(())
79 }
80 }
81 }
82
83 #[derive(Debug)]
84 enum SockPollStatus {
85 Idle,
86 Pending,
87 Cancelled,
88 }
89
90 #[derive(Debug)]
91 pub struct SockState {
92 iosb: IoStatusBlock,
93 poll_info: AfdPollInfo,
94 afd: Arc<Afd>,
95
96 raw_socket: RawSocket,
97 base_socket: RawSocket,
98
99 user_evts: u32,
100 pending_evts: u32,
101
102 user_data: u64,
103
104 poll_status: SockPollStatus,
105 delete_pending: bool,
106
107 // last raw os error
108 error: Option<i32>,
109
110 pinned: PhantomPinned,
111 }
112
113 impl SockState {
update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()>114 fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
115 assert!(!self.delete_pending);
116
117 // make sure to reset previous error before a new update
118 self.error = None;
119
120 if let SockPollStatus::Pending = self.poll_status {
121 if (self.user_evts & afd::KNOWN_EVENTS & !self.pending_evts) == 0 {
122 /* All the events the user is interested in are already being monitored by
123 * the pending poll operation. It might spuriously complete because of an
124 * event that we're no longer interested in; when that happens we'll submit
125 * a new poll operation with the updated event mask. */
126 } else {
127 /* A poll operation is already pending, but it's not monitoring for all the
128 * events that the user is interested in. Therefore, cancel the pending
129 * poll operation; when we receive it's completion package, a new poll
130 * operation will be submitted with the correct event mask. */
131 if let Err(e) = self.cancel() {
132 self.error = e.raw_os_error();
133 return Err(e);
134 }
135 return Ok(());
136 }
137 } else if let SockPollStatus::Cancelled = self.poll_status {
138 /* The poll operation has already been cancelled, we're still waiting for
139 * it to return. For now, there's nothing that needs to be done. */
140 } else if let SockPollStatus::Idle = self.poll_status {
141 /* No poll operation is pending; start one. */
142 self.poll_info.exclusive = 0;
143 self.poll_info.number_of_handles = 1;
144 *unsafe { self.poll_info.timeout.QuadPart_mut() } = std::i64::MAX;
145 self.poll_info.handles[0].handle = self.base_socket as HANDLE;
146 self.poll_info.handles[0].status = 0;
147 self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;
148
149 // Increase the ref count as the memory will be used by the kernel.
150 let overlapped_ptr = into_overlapped(self_arc.clone());
151
152 let result = unsafe {
153 self.afd
154 .poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
155 };
156 if let Err(e) = result {
157 let code = e.raw_os_error().unwrap();
158 if code == ERROR_IO_PENDING as i32 {
159 /* Overlapped poll operation in progress; this is expected. */
160 } else {
161 // Since the operation failed it means the kernel won't be
162 // using the memory any more.
163 drop(from_overlapped(overlapped_ptr as *mut _));
164 if code == ERROR_INVALID_HANDLE as i32 {
165 /* Socket closed; it'll be dropped. */
166 self.mark_delete();
167 return Ok(());
168 } else {
169 self.error = e.raw_os_error();
170 return Err(e);
171 }
172 }
173 }
174
175 self.poll_status = SockPollStatus::Pending;
176 self.pending_evts = self.user_evts;
177 } else {
178 unreachable!("Invalid poll status during update, {:#?}", self)
179 }
180
181 Ok(())
182 }
183
cancel(&mut self) -> io::Result<()>184 fn cancel(&mut self) -> io::Result<()> {
185 match self.poll_status {
186 SockPollStatus::Pending => {}
187 _ => unreachable!("Invalid poll status during cancel, {:#?}", self),
188 };
189 unsafe {
190 self.afd.cancel(&mut *self.iosb)?;
191 }
192 self.poll_status = SockPollStatus::Cancelled;
193 self.pending_evts = 0;
194 Ok(())
195 }
196
197 // This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
feed_event(&mut self) -> Option<Event>198 fn feed_event(&mut self) -> Option<Event> {
199 self.poll_status = SockPollStatus::Idle;
200 self.pending_evts = 0;
201
202 let mut afd_events = 0;
203 // We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK.
204 unsafe {
205 if self.delete_pending {
206 return None;
207 } else if self.iosb.u.Status == STATUS_CANCELLED {
208 /* The poll request was cancelled by CancelIoEx. */
209 } else if !NT_SUCCESS(self.iosb.u.Status) {
210 /* The overlapped request itself failed in an unexpected way. */
211 afd_events = afd::POLL_CONNECT_FAIL;
212 } else if self.poll_info.number_of_handles < 1 {
213 /* This poll operation succeeded but didn't report any socket events. */
214 } else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 {
215 /* The poll operation reported that the socket was closed. */
216 self.mark_delete();
217 return None;
218 } else {
219 afd_events = self.poll_info.handles[0].events;
220 }
221 }
222
223 afd_events &= self.user_evts;
224
225 if afd_events == 0 {
226 return None;
227 }
228
229 // In mio, we have to simulate Edge-triggered behavior to match API usage.
230 // The strategy here is to intercept all read/write from user that could cause WouldBlock usage,
231 // then reregister the socket to reset the interests.
232
233 // Reset readable event
234 if (afd_events & interests_to_afd_flags(Interest::READABLE)) != 0 {
235 self.user_evts &= !(interests_to_afd_flags(Interest::READABLE));
236 }
237 // Reset writable event
238 if (afd_events & interests_to_afd_flags(Interest::WRITABLE)) != 0 {
239 self.user_evts &= !interests_to_afd_flags(Interest::WRITABLE);
240 }
241
242 Some(Event {
243 data: self.user_data,
244 flags: afd_events,
245 })
246 }
247
is_pending_deletion(&self) -> bool248 pub fn is_pending_deletion(&self) -> bool {
249 self.delete_pending
250 }
251
mark_delete(&mut self)252 pub fn mark_delete(&mut self) {
253 if !self.delete_pending {
254 if let SockPollStatus::Pending = self.poll_status {
255 drop(self.cancel());
256 }
257
258 self.delete_pending = true;
259 }
260 }
261
has_error(&self) -> bool262 fn has_error(&self) -> bool {
263 self.error.is_some()
264 }
265 }
266
267 cfg_net! {
268 impl SockState {
269 fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
270 Ok(SockState {
271 iosb: IoStatusBlock::zeroed(),
272 poll_info: AfdPollInfo::zeroed(),
273 afd,
274 raw_socket,
275 base_socket: get_base_socket(raw_socket)?,
276 user_evts: 0,
277 pending_evts: 0,
278 user_data: 0,
279 poll_status: SockPollStatus::Idle,
280 delete_pending: false,
281 error: None,
282 pinned: PhantomPinned,
283 })
284 }
285
286 /// True if need to be added on update queue, false otherwise.
287 fn set_event(&mut self, ev: Event) -> bool {
288 /* afd::POLL_CONNECT_FAIL and afd::POLL_ABORT are always reported, even when not requested by the caller. */
289 let events = ev.flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT;
290
291 self.user_evts = events;
292 self.user_data = ev.data;
293
294 (events & !self.pending_evts) != 0
295 }
296 }
297 }
298
299 impl Drop for SockState {
drop(&mut self)300 fn drop(&mut self) {
301 self.mark_delete();
302 }
303 }
304
305 /// Converts the pointer to a `SockState` into a raw pointer.
306 /// To revert see `from_overlapped`.
into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> PVOID307 fn into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> PVOID {
308 let overlapped_ptr: *const Mutex<SockState> =
309 unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
310 overlapped_ptr as *mut _
311 }
312
313 /// Convert a raw overlapped pointer into a reference to `SockState`.
314 /// Reverts `into_overlapped`.
from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>>315 fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>> {
316 let sock_ptr: *const Mutex<SockState> = ptr as *const _;
317 unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) }
318 }
319
320 /// Each Selector has a globally unique(ish) ID associated with it. This ID
321 /// gets tracked by `TcpStream`, `TcpListener`, etc... when they are first
322 /// registered with the `Selector`. If a type that is previously associated with
323 /// a `Selector` attempts to register itself with a different `Selector`, the
324 /// operation will return with an error. This matches windows behavior.
325 #[cfg(debug_assertions)]
326 static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
327
328 /// Windows implementaion of `sys::Selector`
329 ///
330 /// Edge-triggered event notification is simulated by resetting internal event flag of each socket state `SockState`
331 /// and setting all events back by intercepting all requests that could cause `io::ErrorKind::WouldBlock` happening.
332 ///
333 /// This selector is currently only support socket due to `Afd` driver is winsock2 specific.
334 #[derive(Debug)]
335 pub struct Selector {
336 #[cfg(debug_assertions)]
337 id: usize,
338
339 inner: Arc<SelectorInner>,
340 }
341
342 impl Selector {
new() -> io::Result<Selector>343 pub fn new() -> io::Result<Selector> {
344 SelectorInner::new().map(|inner| {
345 #[cfg(debug_assertions)]
346 let id = NEXT_ID.fetch_add(1, Ordering::Relaxed) + 1;
347 Selector {
348 #[cfg(debug_assertions)]
349 id,
350 inner: Arc::new(inner),
351 }
352 })
353 }
354
try_clone(&self) -> io::Result<Selector>355 pub fn try_clone(&self) -> io::Result<Selector> {
356 Ok(Selector {
357 #[cfg(debug_assertions)]
358 id: self.id,
359 inner: Arc::clone(&self.inner),
360 })
361 }
362
363 /// # Safety
364 ///
365 /// This requires a mutable reference to self because only a single thread
366 /// can poll IOCP at a time.
select(&mut self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>367 pub fn select(&mut self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
368 self.inner.select(events, timeout)
369 }
370
clone_port(&self) -> Arc<CompletionPort>371 pub(super) fn clone_port(&self) -> Arc<CompletionPort> {
372 self.inner.cp.clone()
373 }
374 }
375
376 cfg_net! {
377 use super::InternalState;
378 use crate::Token;
379
380 impl Selector {
381 pub(super) fn register(
382 &self,
383 socket: RawSocket,
384 token: Token,
385 interests: Interest,
386 ) -> io::Result<InternalState> {
387 SelectorInner::register(&self.inner, socket, token, interests)
388 }
389
390 pub(super) fn reregister(
391 &self,
392 state: Pin<Arc<Mutex<SockState>>>,
393 token: Token,
394 interests: Interest,
395 ) -> io::Result<()> {
396 self.inner.reregister(state, token, interests)
397 }
398
399 #[cfg(debug_assertions)]
400 pub fn id(&self) -> usize {
401 self.id
402 }
403 }
404 }
405
406 #[derive(Debug)]
407 pub struct SelectorInner {
408 cp: Arc<CompletionPort>,
409 update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
410 afd_group: AfdGroup,
411 is_polling: AtomicBool,
412 }
413
414 // We have ensured thread safety by introducing lock manually.
415 unsafe impl Sync for SelectorInner {}
416
417 impl SelectorInner {
new() -> io::Result<SelectorInner>418 pub fn new() -> io::Result<SelectorInner> {
419 CompletionPort::new(0).map(|cp| {
420 let cp = Arc::new(cp);
421 let cp_afd = Arc::clone(&cp);
422
423 SelectorInner {
424 cp,
425 update_queue: Mutex::new(VecDeque::new()),
426 afd_group: AfdGroup::new(cp_afd),
427 is_polling: AtomicBool::new(false),
428 }
429 })
430 }
431
432 /// # Safety
433 ///
434 /// May only be calling via `Selector::select`.
select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()>435 pub fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
436 events.clear();
437
438 if timeout.is_none() {
439 loop {
440 let len = self.select2(&mut events.statuses, &mut events.events, None)?;
441 if len == 0 {
442 continue;
443 }
444 return Ok(());
445 }
446 } else {
447 self.select2(&mut events.statuses, &mut events.events, timeout)?;
448 return Ok(());
449 }
450 }
451
select2( &self, statuses: &mut [CompletionStatus], events: &mut Vec<Event>, timeout: Option<Duration>, ) -> io::Result<usize>452 pub fn select2(
453 &self,
454 statuses: &mut [CompletionStatus],
455 events: &mut Vec<Event>,
456 timeout: Option<Duration>,
457 ) -> io::Result<usize> {
458 assert_eq!(self.is_polling.swap(true, Ordering::AcqRel), false);
459
460 unsafe { self.update_sockets_events() }?;
461
462 let result = self.cp.get_many(statuses, timeout);
463
464 self.is_polling.store(false, Ordering::Relaxed);
465
466 match result {
467 Ok(iocp_events) => Ok(unsafe { self.feed_events(events, iocp_events) }),
468 Err(ref e) if e.raw_os_error() == Some(WAIT_TIMEOUT as i32) => Ok(0),
469 Err(e) => Err(e),
470 }
471 }
472
update_sockets_events(&self) -> io::Result<()>473 unsafe fn update_sockets_events(&self) -> io::Result<()> {
474 let mut update_queue = self.update_queue.lock().unwrap();
475 for sock in update_queue.iter_mut() {
476 let mut sock_internal = sock.lock().unwrap();
477 if !sock_internal.is_pending_deletion() {
478 sock_internal.update(&sock)?;
479 }
480 }
481
482 // remove all sock which do not have error, they have afd op pending
483 update_queue.retain(|sock| sock.lock().unwrap().has_error());
484
485 self.afd_group.release_unused_afd();
486 Ok(())
487 }
488
489 // It returns processed count of iocp_events rather than the events itself.
feed_events( &self, events: &mut Vec<Event>, iocp_events: &[CompletionStatus], ) -> usize490 unsafe fn feed_events(
491 &self,
492 events: &mut Vec<Event>,
493 iocp_events: &[CompletionStatus],
494 ) -> usize {
495 let mut n = 0;
496 let mut update_queue = self.update_queue.lock().unwrap();
497 for iocp_event in iocp_events.iter() {
498 if iocp_event.overlapped().is_null() {
499 // `Waker` event, we'll add a readable event to match the other platforms.
500 events.push(Event {
501 flags: afd::POLL_RECEIVE,
502 data: iocp_event.token() as u64,
503 });
504 n += 1;
505 continue;
506 }
507
508 let sock_state = from_overlapped(iocp_event.overlapped());
509 let mut sock_guard = sock_state.lock().unwrap();
510 match sock_guard.feed_event() {
511 Some(e) => {
512 events.push(e);
513 n += 1;
514 }
515 None => {}
516 }
517
518 if !sock_guard.is_pending_deletion() {
519 update_queue.push_back(sock_state.clone());
520 }
521 }
522 self.afd_group.release_unused_afd();
523 n
524 }
525 }
526
527 cfg_net! {
528 use std::mem::size_of;
529 use std::ptr::null_mut;
530 use winapi::um::mswsock::SIO_BASE_HANDLE;
531 use winapi::um::winsock2::{WSAIoctl, SOCKET_ERROR};
532
533 impl SelectorInner {
534 fn register(
535 this: &Arc<Self>,
536 socket: RawSocket,
537 token: Token,
538 interests: Interest,
539 ) -> io::Result<InternalState> {
540 let flags = interests_to_afd_flags(interests);
541
542 let sock = {
543 let sock = this._alloc_sock_for_rawsocket(socket)?;
544 let event = Event {
545 flags,
546 data: token.0 as u64,
547 };
548 sock.lock().unwrap().set_event(event);
549 sock
550 };
551
552 let state = InternalState {
553 selector: this.clone(),
554 token,
555 interests,
556 sock_state: sock.clone(),
557 };
558
559 this.queue_state(sock);
560 unsafe { this.update_sockets_events_if_polling()?; }
561
562 Ok(state)
563 }
564
565 // Directly accessed in `IoSourceState::do_io`.
566 pub(super) fn reregister(
567 &self,
568 state: Pin<Arc<Mutex<SockState>>>,
569 token: Token,
570 interests: Interest,
571 ) -> io::Result<()> {
572 {
573 let event = Event {
574 flags: interests_to_afd_flags(interests),
575 data: token.0 as u64,
576 };
577
578 state.lock().unwrap().set_event(event);
579 }
580
581 // FIXME: a sock which has_error true should not be re-added to
582 // the update queue because it's already there.
583 self.queue_state(state);
584 unsafe { self.update_sockets_events_if_polling() }
585 }
586
587 /// This function is called by register() and reregister() to start an
588 /// IOCTL_AFD_POLL operation corresponding to the registered events, but
589 /// only if necessary.
590 ///
591 /// Since it is not possible to modify or synchronously cancel an AFD_POLL
592 /// operation, and there can be only one active AFD_POLL operation per
593 /// (socket, completion port) pair at any time, it is expensive to change
594 /// a socket's event registration after it has been submitted to the kernel.
595 ///
596 /// Therefore, if no other threads are polling when interest in a socket
597 /// event is (re)registered, the socket is added to the 'update queue', but
598 /// the actual syscall to start the IOCTL_AFD_POLL operation is deferred
599 /// until just before the GetQueuedCompletionStatusEx() syscall is made.
600 ///
601 /// However, when another thread is already blocked on
602 /// GetQueuedCompletionStatusEx() we tell the kernel about the registered
603 /// socket event(s) immediately.
604 unsafe fn update_sockets_events_if_polling(&self) -> io::Result<()> {
605 if self.is_polling.load(Ordering::Acquire) {
606 self.update_sockets_events()
607 } else {
608 Ok(())
609 }
610 }
611
612 fn queue_state(&self, sock_state: Pin<Arc<Mutex<SockState>>>) {
613 let mut update_queue = self.update_queue.lock().unwrap();
614 update_queue.push_back(sock_state);
615 }
616
617 fn _alloc_sock_for_rawsocket(
618 &self,
619 raw_socket: RawSocket,
620 ) -> io::Result<Pin<Arc<Mutex<SockState>>>> {
621 let afd = self.afd_group.acquire()?;
622 Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?)))
623 }
624 }
625
626 fn get_base_socket(raw_socket: RawSocket) -> io::Result<RawSocket> {
627 let mut base_socket: RawSocket = 0;
628 let mut bytes: u32 = 0;
629
630 unsafe {
631 if WSAIoctl(
632 raw_socket as usize,
633 SIO_BASE_HANDLE,
634 null_mut(),
635 0,
636 &mut base_socket as *mut _ as PVOID,
637 size_of::<RawSocket>() as u32,
638 &mut bytes,
639 null_mut(),
640 None,
641 ) == SOCKET_ERROR
642 {
643 Err(io::Error::last_os_error())
644 } else {
645 Ok(base_socket)
646 }
647 }
648 }
649 }
650
651 impl Drop for SelectorInner {
drop(&mut self)652 fn drop(&mut self) {
653 loop {
654 let events_num: usize;
655 let mut statuses: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024];
656
657 let result = self
658 .cp
659 .get_many(&mut statuses, Some(std::time::Duration::from_millis(0)));
660 match result {
661 Ok(iocp_events) => {
662 events_num = iocp_events.iter().len();
663 for iocp_event in iocp_events.iter() {
664 if !iocp_event.overlapped().is_null() {
665 // drain sock state to release memory of Arc reference
666 let _sock_state = from_overlapped(iocp_event.overlapped());
667 }
668 }
669 }
670
671 Err(_) => {
672 break;
673 }
674 }
675
676 if events_num == 0 {
677 // continue looping until all completion statuses have been drained
678 break;
679 }
680 }
681
682 self.afd_group.release_unused_afd();
683 }
684 }
685
interests_to_afd_flags(interests: Interest) -> u32686 fn interests_to_afd_flags(interests: Interest) -> u32 {
687 let mut flags = 0;
688
689 if interests.is_readable() {
690 // afd::POLL_DISCONNECT for is_read_hup()
691 flags |= afd::POLL_RECEIVE | afd::POLL_ACCEPT | afd::POLL_DISCONNECT;
692 }
693
694 if interests.is_writable() {
695 flags |= afd::POLL_SEND;
696 }
697
698 flags
699 }
700