1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::net::{TcpListener, TcpStream};
5 use tokio::sync::{mpsc, oneshot};
6 use tokio_test::assert_ok;
7 
8 use std::io;
9 use std::net::{IpAddr, SocketAddr};
10 
11 macro_rules! test_accept {
12     ($(($ident:ident, $target:expr),)*) => {
13         $(
14             #[tokio::test]
15             async fn $ident() {
16                 let listener = assert_ok!(TcpListener::bind($target).await);
17                 let addr = listener.local_addr().unwrap();
18 
19                 let (tx, rx) = oneshot::channel();
20 
21                 tokio::spawn(async move {
22                     let (socket, _) = assert_ok!(listener.accept().await);
23                     assert_ok!(tx.send(socket));
24                 });
25 
26                 let cli = assert_ok!(TcpStream::connect(&addr).await);
27                 let srv = assert_ok!(rx.await);
28 
29                 assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
30             }
31         )*
32     }
33 }
34 
35 test_accept! {
36     (ip_str, "127.0.0.1:0"),
37     (host_str, "localhost:0"),
38     (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
39     (str_port_tuple, ("127.0.0.1", 0)),
40     (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
41 }
42 
43 use std::pin::Pin;
44 use std::sync::{
45     atomic::{AtomicUsize, Ordering::SeqCst},
46     Arc,
47 };
48 use std::task::{Context, Poll};
49 use tokio_stream::{Stream, StreamExt};
50 
51 struct TrackPolls<'a> {
52     npolls: Arc<AtomicUsize>,
53     listener: &'a mut TcpListener,
54 }
55 
56 impl<'a> Stream for TrackPolls<'a> {
57     type Item = io::Result<(TcpStream, SocketAddr)>;
58 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>59     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60         self.npolls.fetch_add(1, SeqCst);
61         self.listener.poll_accept(cx).map(Some)
62     }
63 }
64 
65 #[tokio::test]
no_extra_poll()66 async fn no_extra_poll() {
67     let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
68     let addr = listener.local_addr().unwrap();
69 
70     let (tx, rx) = oneshot::channel();
71     let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
72 
73     tokio::spawn(async move {
74         let mut incoming = TrackPolls {
75             npolls: Arc::new(AtomicUsize::new(0)),
76             listener: &mut listener,
77         };
78         assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
79         while incoming.next().await.is_some() {
80             accepted_tx.send(()).unwrap();
81         }
82     });
83 
84     let npolls = assert_ok!(rx.await);
85     tokio::task::yield_now().await;
86 
87     // should have been polled exactly once: the initial poll
88     assert_eq!(npolls.load(SeqCst), 1);
89 
90     let _ = assert_ok!(TcpStream::connect(&addr).await);
91     accepted_rx.recv().await.unwrap();
92 
93     // should have been polled twice more: once to yield Some(), then once to yield Pending
94     assert_eq!(npolls.load(SeqCst), 1 + 2);
95 }
96 
97 #[tokio::test]
accept_many()98 async fn accept_many() {
99     use futures::future::poll_fn;
100     use std::future::Future;
101     use std::sync::atomic::AtomicBool;
102 
103     const N: usize = 50;
104 
105     let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
106     let listener = Arc::new(listener);
107     let addr = listener.local_addr().unwrap();
108     let connected = Arc::new(AtomicBool::new(false));
109 
110     let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
111     let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();
112 
113     for _ in 0..N {
114         let listener = listener.clone();
115         let connected = connected.clone();
116         let pending_tx = pending_tx.clone();
117         let notified_tx = notified_tx.clone();
118 
119         tokio::spawn(async move {
120             let accept = listener.accept();
121             tokio::pin!(accept);
122 
123             let mut polled = false;
124 
125             poll_fn(|cx| {
126                 if !polled {
127                     polled = true;
128                     assert!(Pin::new(&mut accept).poll(cx).is_pending());
129                     pending_tx.send(()).unwrap();
130                     Poll::Pending
131                 } else if connected.load(SeqCst) {
132                     notified_tx.send(()).unwrap();
133                     Poll::Ready(())
134                 } else {
135                     Poll::Pending
136                 }
137             })
138             .await;
139 
140             pending_tx.send(()).unwrap();
141         });
142     }
143 
144     // Wait for all tasks to have polled at least once
145     for _ in 0..N {
146         pending_rx.recv().await.unwrap();
147     }
148 
149     // Establish a TCP connection
150     connected.store(true, SeqCst);
151     let _sock = TcpStream::connect(addr).await.unwrap();
152 
153     // Wait for all notifications
154     for _ in 0..N {
155         notified_rx.recv().await.unwrap();
156     }
157 }
158