1 // Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com> 2 // 3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or 4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or 5 // http://opensource.org/licenses/MIT>, at your option. This file may not be 6 // copied, modified, or distributed except according to those terms. 7 8 //! `DnsMultiplexer` and associated types implement the state machines for sending DNS messages while using the underlying streams. 9 10 use std::borrow::Borrow; 11 use std::collections::hash_map::Entry; 12 use std::collections::HashMap; 13 use std::fmt::{self, Display}; 14 use std::pin::Pin; 15 use std::sync::Arc; 16 use std::task::{Context, Poll}; 17 use std::time::{Duration, SystemTime, UNIX_EPOCH}; 18 19 use futures::channel::oneshot; 20 use futures::stream::{Stream, StreamExt}; 21 use futures::{ready, Future, FutureExt}; 22 use log::{debug, warn}; 23 use rand; 24 use rand::distributions::{Distribution, Standard}; 25 use smallvec::SmallVec; 26 use tokio::{self, time::Delay}; 27 28 use crate::error::*; 29 use crate::op::{Message, MessageFinalizer, OpCode}; 30 use crate::xfer::{ 31 ignore_send, DnsClientStream, DnsRequest, DnsRequestOptions, DnsRequestSender, DnsResponse, 32 SerialMessage, 33 }; 34 use crate::DnsStreamHandle; 35 36 const QOS_MAX_RECEIVE_MSGS: usize = 100; // max number of messages to receive from the UDP socket 37 38 struct ActiveRequest { 39 // the completion is the channel for a response to the original request 40 completion: oneshot::Sender<Result<DnsResponse, ProtoError>>, 41 request_id: u16, 42 request_options: DnsRequestOptions, 43 // most requests pass a single Message response directly through to the completion 44 // this small vec will have no allocations, unless the requests is a DNS-SD request 45 // expecting more than one response 46 // TODO: change the completion above to a Stream, and don't hold messages... 47 responses: SmallVec<[Message; 1]>, 48 timeout: Delay, 49 } 50 51 impl ActiveRequest { new( completion: oneshot::Sender<Result<DnsResponse, ProtoError>>, request_id: u16, request_options: DnsRequestOptions, timeout: Delay, ) -> Self52 fn new( 53 completion: oneshot::Sender<Result<DnsResponse, ProtoError>>, 54 request_id: u16, 55 request_options: DnsRequestOptions, 56 timeout: Delay, 57 ) -> Self { 58 ActiveRequest { 59 completion, 60 request_id, 61 request_options, 62 // request, 63 responses: SmallVec::new(), 64 timeout, 65 } 66 } 67 68 /// polls the timeout and converts the error poll_timeout(&mut self, cx: &mut Context) -> Poll<()>69 fn poll_timeout(&mut self, cx: &mut Context) -> Poll<()> { 70 self.timeout.poll_unpin(cx) 71 } 72 73 /// Returns true of the other side canceled the request is_canceled(&self) -> bool74 fn is_canceled(&self) -> bool { 75 self.completion.is_canceled() 76 } 77 78 /// Adds the response to the request such that it can be later sent to the client add_response(&mut self, message: Message)79 fn add_response(&mut self, message: Message) { 80 self.responses.push(message); 81 } 82 83 /// the request id of the message that was sent request_id(&self) -> u1684 fn request_id(&self) -> u16 { 85 self.request_id 86 } 87 88 /// the request options from the message that was sent request_options(&self) -> &DnsRequestOptions89 fn request_options(&self) -> &DnsRequestOptions { 90 &self.request_options 91 } 92 93 /// Sends an error complete_with_error(self, error: ProtoError)94 fn complete_with_error(self, error: ProtoError) { 95 ignore_send(self.completion.send(Err(error))); 96 } 97 98 /// sends any registered responses to the requestor 99 /// 100 /// Any error sending will be logged and ignored. This must only be called after associating a response, 101 /// otherwise an error will always be returned. complete(self)102 fn complete(self) { 103 if self.responses.is_empty() { 104 self.complete_with_error("no responses received, should have timedout".into()); 105 } else { 106 ignore_send(self.completion.send(Ok(self.responses.into()))); 107 } 108 } 109 } 110 111 /// A DNS Client implemented over futures-rs. 112 /// 113 /// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol 114 /// implementations. This should be used for underlying protocols that do not natively support 115 /// multiplexed sessions. 116 #[must_use = "futures do nothing unless polled"] 117 pub struct DnsMultiplexer<S, MF, D = Box<dyn DnsStreamHandle>> 118 where 119 D: Send + 'static, 120 S: DnsClientStream + 'static, 121 MF: MessageFinalizer, 122 { 123 stream: S, 124 timeout_duration: Duration, 125 stream_handle: D, 126 active_requests: HashMap<u16, ActiveRequest>, 127 signer: Option<Arc<MF>>, 128 is_shutdown: bool, 129 } 130 131 impl<S, MF> DnsMultiplexer<S, MF, Box<dyn DnsStreamHandle>> 132 where 133 S: DnsClientStream + Unpin + 'static, 134 MF: MessageFinalizer, 135 { 136 /// Spawns a new DnsMultiplexer Stream. This uses a default timeout of 5 seconds for all requests. 137 /// 138 /// # Arguments 139 /// 140 /// * `stream` - A stream of bytes that can be used to send/receive DNS messages 141 /// (see TcpClientStream or UdpClientStream) 142 /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received. 143 /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed 144 #[allow(clippy::new_ret_no_self)] new<F>( stream: F, stream_handle: Box<dyn DnsStreamHandle>, signer: Option<Arc<MF>>, ) -> DnsMultiplexerConnect<F, S, MF> where F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,145 pub fn new<F>( 146 stream: F, 147 stream_handle: Box<dyn DnsStreamHandle>, 148 signer: Option<Arc<MF>>, 149 ) -> DnsMultiplexerConnect<F, S, MF> 150 where 151 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static, 152 { 153 Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer) 154 } 155 156 /// Spawns a new DnsMultiplexer Stream. 157 /// 158 /// # Arguments 159 /// 160 /// * `stream` - A stream of bytes that can be used to send/receive DNS messages 161 /// (see TcpClientStream or UdpClientStream) 162 /// * `timeout_duration` - All requests may fail due to lack of response, this is the time to 163 /// wait for a response before canceling the request. 164 /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received. 165 /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed with_timeout<F>( stream: F, stream_handle: Box<dyn DnsStreamHandle>, timeout_duration: Duration, signer: Option<Arc<MF>>, ) -> DnsMultiplexerConnect<F, S, MF> where F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,166 pub fn with_timeout<F>( 167 stream: F, 168 stream_handle: Box<dyn DnsStreamHandle>, 169 timeout_duration: Duration, 170 signer: Option<Arc<MF>>, 171 ) -> DnsMultiplexerConnect<F, S, MF> 172 where 173 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static, 174 { 175 DnsMultiplexerConnect { 176 stream, 177 stream_handle: Some(stream_handle), 178 timeout_duration, 179 signer, 180 } 181 } 182 183 /// loop over active_requests and remove cancelled requests 184 /// this should free up space if we already had 4096 active requests drop_cancelled(&mut self, cx: &mut Context)185 fn drop_cancelled(&mut self, cx: &mut Context) { 186 let mut canceled = HashMap::<u16, ProtoError>::new(); 187 for (&id, ref mut active_req) in &mut self.active_requests { 188 if active_req.is_canceled() { 189 canceled.insert(id, ProtoError::from("requestor canceled")); 190 } 191 192 // check for timeouts... 193 match active_req.poll_timeout(cx) { 194 Poll::Ready(()) => { 195 debug!("request timed out: {}", id); 196 canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout)); 197 } 198 Poll::Pending => (), 199 } 200 } 201 202 // drop all the canceled requests 203 for (id, error) in canceled { 204 if let Some(active_request) = self.active_requests.remove(&id) { 205 if active_request.responses.is_empty() { 206 // complete the request, it's failed... 207 active_request.complete_with_error(error); 208 } else { 209 // this is a timeout waiting for multiple responses... 210 active_request.complete(); 211 } 212 } 213 } 214 } 215 216 /// creates random query_id, validates against all active queries next_random_query_id(&self, cx: &mut Context) -> Poll<u16>217 fn next_random_query_id(&self, cx: &mut Context) -> Poll<u16> { 218 let mut rand = rand::thread_rng(); 219 220 for _ in 0..100 { 221 let id: u16 = Standard.sample(&mut rand); // the range is [0 ... u16::max] 222 223 if !self.active_requests.contains_key(&id) { 224 return Poll::Ready(id); 225 } 226 } 227 228 cx.waker().wake_by_ref(); 229 Poll::Pending 230 } 231 232 /// Closes all outstanding completes with a closed stream error stream_closed_close_all(&mut self)233 fn stream_closed_close_all(&mut self) { 234 if !self.active_requests.is_empty() { 235 warn!( 236 "stream closed before response received: {}", 237 self.stream.name_server_addr() 238 ); 239 } 240 241 let error = ProtoError::from("stream closed before response received"); 242 243 for (_, active_request) in self.active_requests.drain() { 244 if active_request.responses.is_empty() { 245 // complete the request, it's failed... 246 active_request.complete_with_error(error.clone()); 247 } else { 248 // this is a timeout waiting for multiple responses... 249 active_request.complete(); 250 } 251 } 252 } 253 } 254 255 /// A wrapper for a future DnsExchange connection 256 #[must_use = "futures do nothing unless polled"] 257 pub struct DnsMultiplexerConnect<F, S, MF> 258 where 259 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static, 260 S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin, 261 MF: MessageFinalizer + Send + Sync + 'static, 262 { 263 stream: F, 264 stream_handle: Option<Box<dyn DnsStreamHandle>>, 265 timeout_duration: Duration, 266 signer: Option<Arc<MF>>, 267 } 268 269 impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF> 270 where 271 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static, 272 S: DnsClientStream + Unpin + 'static, 273 MF: MessageFinalizer + Send + Sync + 'static, 274 { 275 type Output = Result<DnsMultiplexer<S, MF, Box<dyn DnsStreamHandle>>, ProtoError>; 276 poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>277 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { 278 let stream: S = ready!(self.stream.poll_unpin(cx))?; 279 280 Poll::Ready(Ok(DnsMultiplexer { 281 stream, 282 timeout_duration: self.timeout_duration, 283 stream_handle: self 284 .stream_handle 285 .take() 286 .expect("must not poll after complete"), 287 active_requests: HashMap::new(), 288 signer: self.signer.clone(), 289 is_shutdown: false, 290 })) 291 } 292 } 293 294 impl<S, MF> Display for DnsMultiplexer<S, MF> 295 where 296 S: DnsClientStream + 'static, 297 MF: MessageFinalizer + Send + Sync + 'static, 298 { fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error>299 fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { 300 write!(formatter, "{}", self.stream) 301 } 302 } 303 304 impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF> 305 where 306 S: DnsClientStream + Unpin + 'static, 307 MF: MessageFinalizer + Send + Sync + 'static, 308 { 309 type DnsResponseFuture = DnsMultiplexerSerialResponse; 310 send_message(&mut self, request: DnsRequest, cx: &mut Context) -> Self::DnsResponseFuture311 fn send_message(&mut self, request: DnsRequest, cx: &mut Context) -> Self::DnsResponseFuture { 312 if self.is_shutdown { 313 panic!("can not send messages after stream is shutdown") 314 } 315 316 // TODO: handle the pending case with future::poll_fn 317 // get next query_id 318 let query_id: u16 = match self.next_random_query_id(cx) { 319 Poll::Ready(id) => id, 320 Poll::Pending => { 321 return DnsMultiplexerSerialResponseInner::Err(Some(ProtoError::from( 322 "id space exhausted, consider filing an issue", 323 ))) 324 .into() 325 } 326 }; 327 328 let (mut request, request_options) = request.unwrap(); 329 request.set_id(query_id); 330 331 let now = match SystemTime::now() 332 .duration_since(UNIX_EPOCH) 333 .map_err(|_| ProtoErrorKind::Message("Current time is before the Unix epoch.").into()) 334 { 335 Ok(now) => now.as_secs(), 336 Err(err) => return DnsMultiplexerSerialResponseInner::Err(Some(err)).into(), 337 }; 338 339 // TODO: truncates u64 to u32, error on overflow? 340 let now = now as u32; 341 342 // update messages need to be signed. 343 if let OpCode::Update = request.op_code() { 344 if let Some(ref signer) = self.signer { 345 if let Err(e) = request.finalize::<MF>(signer.borrow(), now) { 346 debug!("could not sign message: {}", e); 347 return DnsMultiplexerSerialResponseInner::Err(Some(e)).into(); 348 } 349 } 350 } 351 352 // store a Timeout for this message before sending 353 let timeout = tokio::time::delay_for(self.timeout_duration); 354 355 let (complete, receiver) = oneshot::channel(); 356 357 // send the message 358 let active_request = ActiveRequest::new(complete, request.id(), request_options, timeout); 359 360 match request.to_vec() { 361 Ok(buffer) => { 362 debug!("sending message id: {}", active_request.request_id()); 363 let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr()); 364 365 // add to the map -after- the client send b/c we don't want to put it in the map if 366 // we ended up returning an error from the send. 367 match self.stream_handle.send(serial_message) { 368 Ok(()) => self 369 .active_requests 370 .insert(active_request.request_id(), active_request), 371 Err(err) => return DnsMultiplexerSerialResponseInner::Err(Some(err)).into(), 372 }; 373 } 374 Err(e) => { 375 debug!( 376 "error message id: {} error: {}", 377 active_request.request_id(), 378 e 379 ); 380 // complete with the error, don't add to the map of active requests 381 return DnsMultiplexerSerialResponseInner::Err(Some(e)).into(); 382 } 383 } 384 385 DnsMultiplexerSerialResponseInner::Completion(receiver).into() 386 } 387 error_response(error: ProtoError) -> Self::DnsResponseFuture388 fn error_response(error: ProtoError) -> Self::DnsResponseFuture { 389 DnsMultiplexerSerialResponseInner::Err(Some(error)).into() 390 } 391 shutdown(&mut self)392 fn shutdown(&mut self) { 393 self.is_shutdown = true; 394 } 395 is_shutdown(&self) -> bool396 fn is_shutdown(&self) -> bool { 397 self.is_shutdown 398 } 399 } 400 401 impl<S, MF> Stream for DnsMultiplexer<S, MF> 402 where 403 S: DnsClientStream + Unpin + 'static, 404 MF: MessageFinalizer + Send + Sync + 'static, 405 { 406 type Item = Result<(), ProtoError>; 407 poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>408 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { 409 // Always drop the cancelled queries first 410 self.drop_cancelled(cx); 411 412 if self.is_shutdown && self.active_requests.is_empty() { 413 debug!("stream is done: {}", self); 414 return Poll::Ready(None); 415 } 416 417 // Collect all inbound requests, max 100 at a time for QoS 418 // by having a max we will guarantee that the client can't be DOSed in this loop 419 // TODO: make the QoS configurable 420 let mut messages_received = 0; 421 for i in 0..QOS_MAX_RECEIVE_MSGS { 422 match self.stream.poll_next_unpin(cx)? { 423 Poll::Ready(Some(buffer)) => { 424 messages_received = i; 425 426 // deserialize or log decode_error 427 match buffer.to_message() { 428 Ok(message) => match self.active_requests.entry(message.id()) { 429 Entry::Occupied(mut request_entry) => { 430 // first add the response to the active_requests responses 431 let complete = { 432 let active_request = request_entry.get_mut(); 433 active_request.add_response(message); 434 435 // determine if this is complete 436 !active_request.request_options().expects_multiple_responses 437 }; 438 439 // now check if the request is complete 440 if complete { 441 let active_request = request_entry.remove(); 442 active_request.complete(); 443 } 444 } 445 Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()), 446 }, 447 // TODO: return src address for diagnostics 448 Err(e) => debug!("error decoding message: {}", e), 449 } 450 } 451 Poll::Ready(None) => { 452 debug!("io_stream closed by other side: {}", self.stream); 453 self.stream_closed_close_all(); 454 return Poll::Ready(None); 455 } 456 Poll::Pending => break, 457 } 458 } 459 460 // If still active, then if the qos (for _ in 0..100 loop) limit 461 // was hit then "yield". This'll make sure that the future is 462 // woken up immediately on the next turn of the event loop. 463 if messages_received == QOS_MAX_RECEIVE_MSGS { 464 // FIXME: this was a task::current().notify(); is this right? 465 cx.waker().wake_by_ref(); 466 } 467 468 // Finally, return not ready to keep the 'driver task' alive. 469 Poll::Pending 470 } 471 } 472 473 /// A future that resolves into a DnsResponse 474 #[must_use = "futures do nothing unless polled"] 475 pub struct DnsMultiplexerSerialResponse(DnsMultiplexerSerialResponseInner); 476 477 impl DnsMultiplexerSerialResponse { 478 /// Returns a new future with the oneshot completion completion(complete: oneshot::Receiver<ProtoResult<DnsResponse>>) -> Self479 pub fn completion(complete: oneshot::Receiver<ProtoResult<DnsResponse>>) -> Self { 480 DnsMultiplexerSerialResponseInner::Completion(complete).into() 481 } 482 } 483 484 impl Future for DnsMultiplexerSerialResponse { 485 type Output = Result<DnsResponse, ProtoError>; 486 poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>487 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { 488 self.0.poll_unpin(cx) 489 } 490 } 491 492 impl From<DnsMultiplexerSerialResponseInner> for DnsMultiplexerSerialResponse { from(inner: DnsMultiplexerSerialResponseInner) -> Self493 fn from(inner: DnsMultiplexerSerialResponseInner) -> Self { 494 DnsMultiplexerSerialResponse(inner) 495 } 496 } 497 498 enum DnsMultiplexerSerialResponseInner { 499 Completion(oneshot::Receiver<ProtoResult<DnsResponse>>), 500 Err(Option<ProtoError>), 501 } 502 503 impl Future for DnsMultiplexerSerialResponseInner { 504 type Output = Result<DnsResponse, ProtoError>; 505 poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>506 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { 507 match *self { 508 // The inner type of the completion might have been an error 509 // we need to unwrap that, and translate to be the Future's error 510 DnsMultiplexerSerialResponseInner::Completion(ref mut complete) => { 511 complete.poll_unpin(cx).map(|r| { 512 r.map_err(|_| ProtoError::from("the completion was canceled")) 513 .and_then(|r| r) 514 }) 515 } 516 DnsMultiplexerSerialResponseInner::Err(ref mut err) => { 517 Poll::Ready(Err(err.take().expect("cannot poll after complete"))) 518 } 519 } 520 } 521 } 522