1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::net::{TcpListener, TcpStream};
5 use tokio::sync::{mpsc, oneshot};
6 use tokio_test::assert_ok;
7 
8 use std::net::{IpAddr, SocketAddr};
9 
10 macro_rules! test_accept {
11     ($(($ident:ident, $target:expr),)*) => {
12         $(
13             #[tokio::test]
14             async fn $ident() {
15                 let mut listener = assert_ok!(TcpListener::bind($target).await);
16                 let addr = listener.local_addr().unwrap();
17 
18                 let (tx, rx) = oneshot::channel();
19 
20                 tokio::spawn(async move {
21                     let (socket, _) = assert_ok!(listener.accept().await);
22                     assert_ok!(tx.send(socket));
23                 });
24 
25                 let cli = assert_ok!(TcpStream::connect(&addr).await);
26                 let srv = assert_ok!(rx.await);
27 
28                 assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
29             }
30         )*
31     }
32 }
33 
34 test_accept! {
35     (ip_str, "127.0.0.1:0"),
36     (host_str, "localhost:0"),
37     (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
38     (str_port_tuple, ("127.0.0.1", 0)),
39     (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
40 }
41 
42 use pin_project_lite::pin_project;
43 use std::pin::Pin;
44 use std::sync::{
45     atomic::{AtomicUsize, Ordering::SeqCst},
46     Arc,
47 };
48 use std::task::{Context, Poll};
49 use tokio::stream::{Stream, StreamExt};
50 
51 pin_project! {
52     struct TrackPolls<S> {
53         npolls: Arc<AtomicUsize>,
54         #[pin]
55         s: S,
56     }
57 }
58 
59 impl<S> Stream for TrackPolls<S>
60 where
61     S: Stream,
62 {
63     type Item = S::Item;
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>64     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
65         let this = self.project();
66         this.npolls.fetch_add(1, SeqCst);
67         this.s.poll_next(cx)
68     }
69 }
70 
71 #[tokio::test]
no_extra_poll()72 async fn no_extra_poll() {
73     let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
74     let addr = listener.local_addr().unwrap();
75 
76     let (tx, rx) = oneshot::channel();
77     let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
78 
79     tokio::spawn(async move {
80         let mut incoming = TrackPolls {
81             npolls: Arc::new(AtomicUsize::new(0)),
82             s: listener.incoming(),
83         };
84         assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
85         while incoming.next().await.is_some() {
86             accepted_tx.send(()).unwrap();
87         }
88     });
89 
90     let npolls = assert_ok!(rx.await);
91     tokio::task::yield_now().await;
92 
93     // should have been polled exactly once: the initial poll
94     assert_eq!(npolls.load(SeqCst), 1);
95 
96     let _ = assert_ok!(TcpStream::connect(&addr).await);
97     accepted_rx.next().await.unwrap();
98 
99     // should have been polled twice more: once to yield Some(), then once to yield Pending
100     assert_eq!(npolls.load(SeqCst), 1 + 2);
101 }
102