1 use crate::codec::{Decoder, Encoder};
2 
3 use futures_core::Stream;
4 use tokio::{io::ReadBuf, net::UdpSocket};
5 
6 use bytes::{BufMut, BytesMut};
7 use futures_core::ready;
8 use futures_sink::Sink;
9 use std::pin::Pin;
10 use std::task::{Context, Poll};
11 use std::{
12     borrow::Borrow,
13     net::{Ipv4Addr, SocketAddr, SocketAddrV4},
14 };
15 use std::{io, mem::MaybeUninit};
16 
17 /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
18 /// the `Encoder` and `Decoder` traits to encode and decode frames.
19 ///
20 /// Raw UDP sockets work with datagrams, but higher-level code usually wants to
21 /// batch these into meaningful chunks, called "frames". This method layers
22 /// framing on top of this socket by using the `Encoder` and `Decoder` traits to
23 /// handle encoding and decoding of messages frames. Note that the incoming and
24 /// outgoing frame types may be distinct.
25 ///
26 /// This function returns a *single* object that is both [`Stream`] and [`Sink`];
27 /// grouping this into a single object is often useful for layering things which
28 /// require both read and write access to the underlying object.
29 ///
30 /// If you want to work more directly with the streams and sink, consider
31 /// calling [`split`] on the `UdpFramed` returned by this method, which will break
32 /// them into separate objects, allowing them to interact more easily.
33 ///
34 /// [`Stream`]: futures_core::Stream
35 /// [`Sink`]: futures_sink::Sink
36 /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
37 #[must_use = "sinks do nothing unless polled"]
38 #[derive(Debug)]
39 pub struct UdpFramed<C, T = UdpSocket> {
40     socket: T,
41     codec: C,
42     rd: BytesMut,
43     wr: BytesMut,
44     out_addr: SocketAddr,
45     flushed: bool,
46     is_readable: bool,
47     current_addr: Option<SocketAddr>,
48 }
49 
50 const INITIAL_RD_CAPACITY: usize = 64 * 1024;
51 const INITIAL_WR_CAPACITY: usize = 8 * 1024;
52 
53 impl<C, T> Unpin for UdpFramed<C, T> {}
54 
55 impl<C, T> Stream for UdpFramed<C, T>
56 where
57     T: Borrow<UdpSocket>,
58     C: Decoder,
59 {
60     type Item = Result<(C::Item, SocketAddr), C::Error>;
61 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>62     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63         let pin = self.get_mut();
64 
65         pin.rd.reserve(INITIAL_RD_CAPACITY);
66 
67         loop {
68             // Are there still bytes left in the read buffer to decode?
69             if pin.is_readable {
70                 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
71                     let current_addr = pin
72                         .current_addr
73                         .expect("will always be set before this line is called");
74 
75                     return Poll::Ready(Some(Ok((frame, current_addr))));
76                 }
77 
78                 // if this line has been reached then decode has returned `None`.
79                 pin.is_readable = false;
80                 pin.rd.clear();
81             }
82 
83             // We're out of data. Try and fetch more data to decode
84             let addr = unsafe {
85                 // Convert `&mut [MaybeUnit<u8>]` to `&mut [u8]` because we will be
86                 // writing to it via `poll_recv_from` and therefore initializing the memory.
87                 let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]);
88                 let mut read = ReadBuf::uninit(buf);
89                 let ptr = read.filled().as_ptr();
90                 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
91 
92                 assert_eq!(ptr, read.filled().as_ptr());
93                 let addr = res?;
94                 pin.rd.advance_mut(read.filled().len());
95                 addr
96             };
97 
98             pin.current_addr = Some(addr);
99             pin.is_readable = true;
100         }
101     }
102 }
103 
104 impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
105 where
106     T: Borrow<UdpSocket>,
107     C: Encoder<I>,
108 {
109     type Error = C::Error;
110 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>111     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112         if !self.flushed {
113             match self.poll_flush(cx)? {
114                 Poll::Ready(()) => {}
115                 Poll::Pending => return Poll::Pending,
116             }
117         }
118 
119         Poll::Ready(Ok(()))
120     }
121 
start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error>122     fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
123         let (frame, out_addr) = item;
124 
125         let pin = self.get_mut();
126 
127         pin.codec.encode(frame, &mut pin.wr)?;
128         pin.out_addr = out_addr;
129         pin.flushed = false;
130 
131         Ok(())
132     }
133 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>134     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135         if self.flushed {
136             return Poll::Ready(Ok(()));
137         }
138 
139         let Self {
140             ref socket,
141             ref mut out_addr,
142             ref mut wr,
143             ..
144         } = *self;
145 
146         let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
147 
148         let wrote_all = n == self.wr.len();
149         self.wr.clear();
150         self.flushed = true;
151 
152         let res = if wrote_all {
153             Ok(())
154         } else {
155             Err(io::Error::new(
156                 io::ErrorKind::Other,
157                 "failed to write entire datagram to socket",
158             )
159             .into())
160         };
161 
162         Poll::Ready(res)
163     }
164 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>165     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166         ready!(self.poll_flush(cx))?;
167         Poll::Ready(Ok(()))
168     }
169 }
170 
171 impl<C, T> UdpFramed<C, T>
172 where
173     T: Borrow<UdpSocket>,
174 {
175     /// Create a new `UdpFramed` backed by the given socket and codec.
176     ///
177     /// See struct level documentation for more details.
new(socket: T, codec: C) -> UdpFramed<C, T>178     pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
179         Self {
180             socket,
181             codec,
182             out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
183             rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
184             wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
185             flushed: true,
186             is_readable: false,
187             current_addr: None,
188         }
189     }
190 
191     /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
192     ///
193     /// # Note
194     ///
195     /// Care should be taken to not tamper with the underlying stream of data
196     /// coming in as it may corrupt the stream of frames otherwise being worked
197     /// with.
get_ref(&self) -> &T198     pub fn get_ref(&self) -> &T {
199         &self.socket
200     }
201 
202     /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
203     ///
204     /// # Note
205     ///
206     /// Care should be taken to not tamper with the underlying stream of data
207     /// coming in as it may corrupt the stream of frames otherwise being worked
208     /// with.
get_mut(&mut self) -> &mut T209     pub fn get_mut(&mut self) -> &mut T {
210         &mut self.socket
211     }
212 
213     /// Returns a reference to the underlying codec wrapped by
214     /// `Framed`.
215     ///
216     /// Note that care should be taken to not tamper with the underlying codec
217     /// as it may corrupt the stream of frames otherwise being worked with.
codec(&self) -> &C218     pub fn codec(&self) -> &C {
219         &self.codec
220     }
221 
222     /// Returns a mutable reference to the underlying codec wrapped by
223     /// `UdpFramed`.
224     ///
225     /// Note that care should be taken to not tamper with the underlying codec
226     /// as it may corrupt the stream of frames otherwise being worked with.
codec_mut(&mut self) -> &mut C227     pub fn codec_mut(&mut self) -> &mut C {
228         &mut self.codec
229     }
230 
231     /// Returns a reference to the read buffer.
read_buffer(&self) -> &BytesMut232     pub fn read_buffer(&self) -> &BytesMut {
233         &self.rd
234     }
235 
236     /// Returns a mutable reference to the read buffer.
read_buffer_mut(&mut self) -> &mut BytesMut237     pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
238         &mut self.rd
239     }
240 
241     /// Consumes the `Framed`, returning its underlying I/O stream.
into_inner(self) -> T242     pub fn into_inner(self) -> T {
243         self.socket
244     }
245 }
246