1 use crate::codec::{Decoder, Encoder}; 2 3 use futures_core::Stream; 4 use tokio::{io::ReadBuf, net::UdpSocket}; 5 6 use bytes::{BufMut, BytesMut}; 7 use futures_core::ready; 8 use futures_sink::Sink; 9 use std::pin::Pin; 10 use std::task::{Context, Poll}; 11 use std::{ 12 borrow::Borrow, 13 net::{Ipv4Addr, SocketAddr, SocketAddrV4}, 14 }; 15 use std::{io, mem::MaybeUninit}; 16 17 /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using 18 /// the `Encoder` and `Decoder` traits to encode and decode frames. 19 /// 20 /// Raw UDP sockets work with datagrams, but higher-level code usually wants to 21 /// batch these into meaningful chunks, called "frames". This method layers 22 /// framing on top of this socket by using the `Encoder` and `Decoder` traits to 23 /// handle encoding and decoding of messages frames. Note that the incoming and 24 /// outgoing frame types may be distinct. 25 /// 26 /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; 27 /// grouping this into a single object is often useful for layering things which 28 /// require both read and write access to the underlying object. 29 /// 30 /// If you want to work more directly with the streams and sink, consider 31 /// calling [`split`] on the `UdpFramed` returned by this method, which will break 32 /// them into separate objects, allowing them to interact more easily. 33 /// 34 /// [`Stream`]: futures_core::Stream 35 /// [`Sink`]: futures_sink::Sink 36 /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split 37 #[must_use = "sinks do nothing unless polled"] 38 #[derive(Debug)] 39 pub struct UdpFramed<C, T = UdpSocket> { 40 socket: T, 41 codec: C, 42 rd: BytesMut, 43 wr: BytesMut, 44 out_addr: SocketAddr, 45 flushed: bool, 46 is_readable: bool, 47 current_addr: Option<SocketAddr>, 48 } 49 50 const INITIAL_RD_CAPACITY: usize = 64 * 1024; 51 const INITIAL_WR_CAPACITY: usize = 8 * 1024; 52 53 impl<C, T> Unpin for UdpFramed<C, T> {} 54 55 impl<C, T> Stream for UdpFramed<C, T> 56 where 57 T: Borrow<UdpSocket>, 58 C: Decoder, 59 { 60 type Item = Result<(C::Item, SocketAddr), C::Error>; 61 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 63 let pin = self.get_mut(); 64 65 pin.rd.reserve(INITIAL_RD_CAPACITY); 66 67 loop { 68 // Are there still bytes left in the read buffer to decode? 69 if pin.is_readable { 70 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { 71 let current_addr = pin 72 .current_addr 73 .expect("will always be set before this line is called"); 74 75 return Poll::Ready(Some(Ok((frame, current_addr)))); 76 } 77 78 // if this line has been reached then decode has returned `None`. 79 pin.is_readable = false; 80 pin.rd.clear(); 81 } 82 83 // We're out of data. Try and fetch more data to decode 84 let addr = unsafe { 85 // Convert `&mut [MaybeUnit<u8>]` to `&mut [u8]` because we will be 86 // writing to it via `poll_recv_from` and therefore initializing the memory. 87 let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]); 88 let mut read = ReadBuf::uninit(buf); 89 let ptr = read.filled().as_ptr(); 90 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); 91 92 assert_eq!(ptr, read.filled().as_ptr()); 93 let addr = res?; 94 pin.rd.advance_mut(read.filled().len()); 95 addr 96 }; 97 98 pin.current_addr = Some(addr); 99 pin.is_readable = true; 100 } 101 } 102 } 103 104 impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> 105 where 106 T: Borrow<UdpSocket>, 107 C: Encoder<I>, 108 { 109 type Error = C::Error; 110 poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>111 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 112 if !self.flushed { 113 match self.poll_flush(cx)? { 114 Poll::Ready(()) => {} 115 Poll::Pending => return Poll::Pending, 116 } 117 } 118 119 Poll::Ready(Ok(())) 120 } 121 start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error>122 fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { 123 let (frame, out_addr) = item; 124 125 let pin = self.get_mut(); 126 127 pin.codec.encode(frame, &mut pin.wr)?; 128 pin.out_addr = out_addr; 129 pin.flushed = false; 130 131 Ok(()) 132 } 133 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>134 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 135 if self.flushed { 136 return Poll::Ready(Ok(())); 137 } 138 139 let Self { 140 ref socket, 141 ref mut out_addr, 142 ref mut wr, 143 .. 144 } = *self; 145 146 let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; 147 148 let wrote_all = n == self.wr.len(); 149 self.wr.clear(); 150 self.flushed = true; 151 152 let res = if wrote_all { 153 Ok(()) 154 } else { 155 Err(io::Error::new( 156 io::ErrorKind::Other, 157 "failed to write entire datagram to socket", 158 ) 159 .into()) 160 }; 161 162 Poll::Ready(res) 163 } 164 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>165 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 166 ready!(self.poll_flush(cx))?; 167 Poll::Ready(Ok(())) 168 } 169 } 170 171 impl<C, T> UdpFramed<C, T> 172 where 173 T: Borrow<UdpSocket>, 174 { 175 /// Create a new `UdpFramed` backed by the given socket and codec. 176 /// 177 /// See struct level documentation for more details. new(socket: T, codec: C) -> UdpFramed<C, T>178 pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { 179 Self { 180 socket, 181 codec, 182 out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), 183 rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), 184 wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), 185 flushed: true, 186 is_readable: false, 187 current_addr: None, 188 } 189 } 190 191 /// Returns a reference to the underlying I/O stream wrapped by `Framed`. 192 /// 193 /// # Note 194 /// 195 /// Care should be taken to not tamper with the underlying stream of data 196 /// coming in as it may corrupt the stream of frames otherwise being worked 197 /// with. get_ref(&self) -> &T198 pub fn get_ref(&self) -> &T { 199 &self.socket 200 } 201 202 /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. 203 /// 204 /// # Note 205 /// 206 /// Care should be taken to not tamper with the underlying stream of data 207 /// coming in as it may corrupt the stream of frames otherwise being worked 208 /// with. get_mut(&mut self) -> &mut T209 pub fn get_mut(&mut self) -> &mut T { 210 &mut self.socket 211 } 212 213 /// Returns a reference to the underlying codec wrapped by 214 /// `Framed`. 215 /// 216 /// Note that care should be taken to not tamper with the underlying codec 217 /// as it may corrupt the stream of frames otherwise being worked with. codec(&self) -> &C218 pub fn codec(&self) -> &C { 219 &self.codec 220 } 221 222 /// Returns a mutable reference to the underlying codec wrapped by 223 /// `UdpFramed`. 224 /// 225 /// Note that care should be taken to not tamper with the underlying codec 226 /// as it may corrupt the stream of frames otherwise being worked with. codec_mut(&mut self) -> &mut C227 pub fn codec_mut(&mut self) -> &mut C { 228 &mut self.codec 229 } 230 231 /// Returns a reference to the read buffer. read_buffer(&self) -> &BytesMut232 pub fn read_buffer(&self) -> &BytesMut { 233 &self.rd 234 } 235 236 /// Returns a mutable reference to the read buffer. read_buffer_mut(&mut self) -> &mut BytesMut237 pub fn read_buffer_mut(&mut self) -> &mut BytesMut { 238 &mut self.rd 239 } 240 241 /// Consumes the `Framed`, returning its underlying I/O stream. into_inner(self) -> T242 pub fn into_inner(self) -> T { 243 self.socket 244 } 245 } 246