1 // Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2 //
3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 // http://opensource.org/licenses/MIT>, at your option. This file may not be
6 // copied, modified, or distributed except according to those terms.
7 
8 //! `DnsMultiplexer` and associated types implement the state machines for sending DNS messages while using the underlying streams.
9 
10 use std::borrow::Borrow;
11 use std::collections::hash_map::Entry;
12 use std::collections::HashMap;
13 use std::fmt::{self, Display};
14 use std::pin::Pin;
15 use std::sync::Arc;
16 use std::task::{Context, Poll};
17 use std::time::{Duration, SystemTime, UNIX_EPOCH};
18 
19 use futures::channel::oneshot;
20 use futures::stream::{Stream, StreamExt};
21 use futures::{ready, Future, FutureExt};
22 use log::{debug, warn};
23 use rand;
24 use rand::distributions::{Distribution, Standard};
25 use smallvec::SmallVec;
26 use tokio::{self, time::Delay};
27 
28 use crate::error::*;
29 use crate::op::{Message, MessageFinalizer, OpCode};
30 use crate::xfer::{
31     ignore_send, DnsClientStream, DnsRequest, DnsRequestOptions, DnsRequestSender, DnsResponse,
32     SerialMessage,
33 };
34 use crate::DnsStreamHandle;
35 
36 const QOS_MAX_RECEIVE_MSGS: usize = 100; // max number of messages to receive from the UDP socket
37 
38 struct ActiveRequest {
39     // the completion is the channel for a response to the original request
40     completion: oneshot::Sender<Result<DnsResponse, ProtoError>>,
41     request_id: u16,
42     request_options: DnsRequestOptions,
43     // most requests pass a single Message response directly through to the completion
44     //  this small vec will have no allocations, unless the requests is a DNS-SD request
45     //  expecting more than one response
46     // TODO: change the completion above to a Stream, and don't hold messages...
47     responses: SmallVec<[Message; 1]>,
48     timeout: Delay,
49 }
50 
51 impl ActiveRequest {
new( completion: oneshot::Sender<Result<DnsResponse, ProtoError>>, request_id: u16, request_options: DnsRequestOptions, timeout: Delay, ) -> Self52     fn new(
53         completion: oneshot::Sender<Result<DnsResponse, ProtoError>>,
54         request_id: u16,
55         request_options: DnsRequestOptions,
56         timeout: Delay,
57     ) -> Self {
58         ActiveRequest {
59             completion,
60             request_id,
61             request_options,
62             // request,
63             responses: SmallVec::new(),
64             timeout,
65         }
66     }
67 
68     /// polls the timeout and converts the error
poll_timeout(&mut self, cx: &mut Context) -> Poll<()>69     fn poll_timeout(&mut self, cx: &mut Context) -> Poll<()> {
70         self.timeout.poll_unpin(cx)
71     }
72 
73     /// Returns true of the other side canceled the request
is_canceled(&self) -> bool74     fn is_canceled(&self) -> bool {
75         self.completion.is_canceled()
76     }
77 
78     /// Adds the response to the request such that it can be later sent to the client
add_response(&mut self, message: Message)79     fn add_response(&mut self, message: Message) {
80         self.responses.push(message);
81     }
82 
83     /// the request id of the message that was sent
request_id(&self) -> u1684     fn request_id(&self) -> u16 {
85         self.request_id
86     }
87 
88     /// the request options from the message that was sent
request_options(&self) -> &DnsRequestOptions89     fn request_options(&self) -> &DnsRequestOptions {
90         &self.request_options
91     }
92 
93     /// Sends an error
complete_with_error(self, error: ProtoError)94     fn complete_with_error(self, error: ProtoError) {
95         ignore_send(self.completion.send(Err(error)));
96     }
97 
98     /// sends any registered responses to the requestor
99     ///
100     /// Any error sending will be logged and ignored. This must only be called after associating a response,
101     ///   otherwise an error will always be returned.
complete(self)102     fn complete(self) {
103         if self.responses.is_empty() {
104             self.complete_with_error("no responses received, should have timedout".into());
105         } else {
106             ignore_send(self.completion.send(Ok(self.responses.into())));
107         }
108     }
109 }
110 
111 /// A DNS Client implemented over futures-rs.
112 ///
113 /// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
114 ///  implementations. This should be used for underlying protocols that do not natively support
115 ///  multiplexed sessions.
116 #[must_use = "futures do nothing unless polled"]
117 pub struct DnsMultiplexer<S, MF, D = Box<dyn DnsStreamHandle>>
118 where
119     D: Send + 'static,
120     S: DnsClientStream + 'static,
121     MF: MessageFinalizer,
122 {
123     stream: S,
124     timeout_duration: Duration,
125     stream_handle: D,
126     active_requests: HashMap<u16, ActiveRequest>,
127     signer: Option<Arc<MF>>,
128     is_shutdown: bool,
129 }
130 
131 impl<S, MF> DnsMultiplexer<S, MF, Box<dyn DnsStreamHandle>>
132 where
133     S: DnsClientStream + Unpin + 'static,
134     MF: MessageFinalizer,
135 {
136     /// Spawns a new DnsMultiplexer Stream. This uses a default timeout of 5 seconds for all requests.
137     ///
138     /// # Arguments
139     ///
140     /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
141     ///              (see TcpClientStream or UdpClientStream)
142     /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
143     /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
144     #[allow(clippy::new_ret_no_self)]
new<F>( stream: F, stream_handle: Box<dyn DnsStreamHandle>, signer: Option<Arc<MF>>, ) -> DnsMultiplexerConnect<F, S, MF> where F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,145     pub fn new<F>(
146         stream: F,
147         stream_handle: Box<dyn DnsStreamHandle>,
148         signer: Option<Arc<MF>>,
149     ) -> DnsMultiplexerConnect<F, S, MF>
150     where
151         F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
152     {
153         Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
154     }
155 
156     /// Spawns a new DnsMultiplexer Stream.
157     ///
158     /// # Arguments
159     ///
160     /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
161     ///              (see TcpClientStream or UdpClientStream)
162     /// * `timeout_duration` - All requests may fail due to lack of response, this is the time to
163     ///                        wait for a response before canceling the request.
164     /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
165     /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
with_timeout<F>( stream: F, stream_handle: Box<dyn DnsStreamHandle>, timeout_duration: Duration, signer: Option<Arc<MF>>, ) -> DnsMultiplexerConnect<F, S, MF> where F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,166     pub fn with_timeout<F>(
167         stream: F,
168         stream_handle: Box<dyn DnsStreamHandle>,
169         timeout_duration: Duration,
170         signer: Option<Arc<MF>>,
171     ) -> DnsMultiplexerConnect<F, S, MF>
172     where
173         F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
174     {
175         DnsMultiplexerConnect {
176             stream,
177             stream_handle: Some(stream_handle),
178             timeout_duration,
179             signer,
180         }
181     }
182 
183     /// loop over active_requests and remove cancelled requests
184     ///  this should free up space if we already had 4096 active requests
drop_cancelled(&mut self, cx: &mut Context)185     fn drop_cancelled(&mut self, cx: &mut Context) {
186         let mut canceled = HashMap::<u16, ProtoError>::new();
187         for (&id, ref mut active_req) in &mut self.active_requests {
188             if active_req.is_canceled() {
189                 canceled.insert(id, ProtoError::from("requestor canceled"));
190             }
191 
192             // check for timeouts...
193             match active_req.poll_timeout(cx) {
194                 Poll::Ready(()) => {
195                     debug!("request timed out: {}", id);
196                     canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
197                 }
198                 Poll::Pending => (),
199             }
200         }
201 
202         // drop all the canceled requests
203         for (id, error) in canceled {
204             if let Some(active_request) = self.active_requests.remove(&id) {
205                 if active_request.responses.is_empty() {
206                     // complete the request, it's failed...
207                     active_request.complete_with_error(error);
208                 } else {
209                     // this is a timeout waiting for multiple responses...
210                     active_request.complete();
211                 }
212             }
213         }
214     }
215 
216     /// creates random query_id, validates against all active queries
next_random_query_id(&self, cx: &mut Context) -> Poll<u16>217     fn next_random_query_id(&self, cx: &mut Context) -> Poll<u16> {
218         let mut rand = rand::thread_rng();
219 
220         for _ in 0..100 {
221             let id: u16 = Standard.sample(&mut rand); // the range is [0 ... u16::max]
222 
223             if !self.active_requests.contains_key(&id) {
224                 return Poll::Ready(id);
225             }
226         }
227 
228         cx.waker().wake_by_ref();
229         Poll::Pending
230     }
231 
232     /// Closes all outstanding completes with a closed stream error
stream_closed_close_all(&mut self)233     fn stream_closed_close_all(&mut self) {
234         if !self.active_requests.is_empty() {
235             warn!(
236                 "stream closed before response received: {}",
237                 self.stream.name_server_addr()
238             );
239         }
240 
241         let error = ProtoError::from("stream closed before response received");
242 
243         for (_, active_request) in self.active_requests.drain() {
244             if active_request.responses.is_empty() {
245                 // complete the request, it's failed...
246                 active_request.complete_with_error(error.clone());
247             } else {
248                 // this is a timeout waiting for multiple responses...
249                 active_request.complete();
250             }
251         }
252     }
253 }
254 
255 /// A wrapper for a future DnsExchange connection
256 #[must_use = "futures do nothing unless polled"]
257 pub struct DnsMultiplexerConnect<F, S, MF>
258 where
259     F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
260     S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
261     MF: MessageFinalizer + Send + Sync + 'static,
262 {
263     stream: F,
264     stream_handle: Option<Box<dyn DnsStreamHandle>>,
265     timeout_duration: Duration,
266     signer: Option<Arc<MF>>,
267 }
268 
269 impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
270 where
271     F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
272     S: DnsClientStream + Unpin + 'static,
273     MF: MessageFinalizer + Send + Sync + 'static,
274 {
275     type Output = Result<DnsMultiplexer<S, MF, Box<dyn DnsStreamHandle>>, ProtoError>;
276 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>277     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
278         let stream: S = ready!(self.stream.poll_unpin(cx))?;
279 
280         Poll::Ready(Ok(DnsMultiplexer {
281             stream,
282             timeout_duration: self.timeout_duration,
283             stream_handle: self
284                 .stream_handle
285                 .take()
286                 .expect("must not poll after complete"),
287             active_requests: HashMap::new(),
288             signer: self.signer.clone(),
289             is_shutdown: false,
290         }))
291     }
292 }
293 
294 impl<S, MF> Display for DnsMultiplexer<S, MF>
295 where
296     S: DnsClientStream + 'static,
297     MF: MessageFinalizer + Send + Sync + 'static,
298 {
fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error>299     fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
300         write!(formatter, "{}", self.stream)
301     }
302 }
303 
304 impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
305 where
306     S: DnsClientStream + Unpin + 'static,
307     MF: MessageFinalizer + Send + Sync + 'static,
308 {
309     type DnsResponseFuture = DnsMultiplexerSerialResponse;
310 
send_message(&mut self, request: DnsRequest, cx: &mut Context) -> Self::DnsResponseFuture311     fn send_message(&mut self, request: DnsRequest, cx: &mut Context) -> Self::DnsResponseFuture {
312         if self.is_shutdown {
313             panic!("can not send messages after stream is shutdown")
314         }
315 
316         // TODO: handle the pending case with future::poll_fn
317         // get next query_id
318         let query_id: u16 = match self.next_random_query_id(cx) {
319             Poll::Ready(id) => id,
320             Poll::Pending => {
321                 return DnsMultiplexerSerialResponseInner::Err(Some(ProtoError::from(
322                     "id space exhausted, consider filing an issue",
323                 )))
324                 .into()
325             }
326         };
327 
328         let (mut request, request_options) = request.unwrap();
329         request.set_id(query_id);
330 
331         let now = match SystemTime::now()
332             .duration_since(UNIX_EPOCH)
333             .map_err(|_| ProtoErrorKind::Message("Current time is before the Unix epoch.").into())
334         {
335             Ok(now) => now.as_secs(),
336             Err(err) => return DnsMultiplexerSerialResponseInner::Err(Some(err)).into(),
337         };
338 
339         // TODO: truncates u64 to u32, error on overflow?
340         let now = now as u32;
341 
342         // update messages need to be signed.
343         if let OpCode::Update = request.op_code() {
344             if let Some(ref signer) = self.signer {
345                 if let Err(e) = request.finalize::<MF>(signer.borrow(), now) {
346                     debug!("could not sign message: {}", e);
347                     return DnsMultiplexerSerialResponseInner::Err(Some(e)).into();
348                 }
349             }
350         }
351 
352         // store a Timeout for this message before sending
353         let timeout = tokio::time::delay_for(self.timeout_duration);
354 
355         let (complete, receiver) = oneshot::channel();
356 
357         // send the message
358         let active_request = ActiveRequest::new(complete, request.id(), request_options, timeout);
359 
360         match request.to_vec() {
361             Ok(buffer) => {
362                 debug!("sending message id: {}", active_request.request_id());
363                 let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
364 
365                 // add to the map -after- the client send b/c we don't want to put it in the map if
366                 //  we ended up returning an error from the send.
367                 match self.stream_handle.send(serial_message) {
368                     Ok(()) => self
369                         .active_requests
370                         .insert(active_request.request_id(), active_request),
371                     Err(err) => return DnsMultiplexerSerialResponseInner::Err(Some(err)).into(),
372                 };
373             }
374             Err(e) => {
375                 debug!(
376                     "error message id: {} error: {}",
377                     active_request.request_id(),
378                     e
379                 );
380                 // complete with the error, don't add to the map of active requests
381                 return DnsMultiplexerSerialResponseInner::Err(Some(e)).into();
382             }
383         }
384 
385         DnsMultiplexerSerialResponseInner::Completion(receiver).into()
386     }
387 
error_response(error: ProtoError) -> Self::DnsResponseFuture388     fn error_response(error: ProtoError) -> Self::DnsResponseFuture {
389         DnsMultiplexerSerialResponseInner::Err(Some(error)).into()
390     }
391 
shutdown(&mut self)392     fn shutdown(&mut self) {
393         self.is_shutdown = true;
394     }
395 
is_shutdown(&self) -> bool396     fn is_shutdown(&self) -> bool {
397         self.is_shutdown
398     }
399 }
400 
401 impl<S, MF> Stream for DnsMultiplexer<S, MF>
402 where
403     S: DnsClientStream + Unpin + 'static,
404     MF: MessageFinalizer + Send + Sync + 'static,
405 {
406     type Item = Result<(), ProtoError>;
407 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>408     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
409         // Always drop the cancelled queries first
410         self.drop_cancelled(cx);
411 
412         if self.is_shutdown && self.active_requests.is_empty() {
413             debug!("stream is done: {}", self);
414             return Poll::Ready(None);
415         }
416 
417         // Collect all inbound requests, max 100 at a time for QoS
418         //   by having a max we will guarantee that the client can't be DOSed in this loop
419         // TODO: make the QoS configurable
420         let mut messages_received = 0;
421         for i in 0..QOS_MAX_RECEIVE_MSGS {
422             match self.stream.poll_next_unpin(cx)? {
423                 Poll::Ready(Some(buffer)) => {
424                     messages_received = i;
425 
426                     //   deserialize or log decode_error
427                     match buffer.to_message() {
428                         Ok(message) => match self.active_requests.entry(message.id()) {
429                             Entry::Occupied(mut request_entry) => {
430                                 // first add the response to the active_requests responses
431                                 let complete = {
432                                     let active_request = request_entry.get_mut();
433                                     active_request.add_response(message);
434 
435                                     // determine if this is complete
436                                     !active_request.request_options().expects_multiple_responses
437                                 };
438 
439                                 // now check if the request is complete
440                                 if complete {
441                                     let active_request = request_entry.remove();
442                                     active_request.complete();
443                                 }
444                             }
445                             Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
446                         },
447                         // TODO: return src address for diagnostics
448                         Err(e) => debug!("error decoding message: {}", e),
449                     }
450                 }
451                 Poll::Ready(None) => {
452                     debug!("io_stream closed by other side: {}", self.stream);
453                     self.stream_closed_close_all();
454                     return Poll::Ready(None);
455                 }
456                 Poll::Pending => break,
457             }
458         }
459 
460         // If still active, then if the qos (for _ in 0..100 loop) limit
461         // was hit then "yield". This'll make sure that the future is
462         // woken up immediately on the next turn of the event loop.
463         if messages_received == QOS_MAX_RECEIVE_MSGS {
464             // FIXME: this was a task::current().notify(); is this right?
465             cx.waker().wake_by_ref();
466         }
467 
468         // Finally, return not ready to keep the 'driver task' alive.
469         Poll::Pending
470     }
471 }
472 
473 /// A future that resolves into a DnsResponse
474 #[must_use = "futures do nothing unless polled"]
475 pub struct DnsMultiplexerSerialResponse(DnsMultiplexerSerialResponseInner);
476 
477 impl DnsMultiplexerSerialResponse {
478     /// Returns a new future with the oneshot completion
completion(complete: oneshot::Receiver<ProtoResult<DnsResponse>>) -> Self479     pub fn completion(complete: oneshot::Receiver<ProtoResult<DnsResponse>>) -> Self {
480         DnsMultiplexerSerialResponseInner::Completion(complete).into()
481     }
482 }
483 
484 impl Future for DnsMultiplexerSerialResponse {
485     type Output = Result<DnsResponse, ProtoError>;
486 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>487     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
488         self.0.poll_unpin(cx)
489     }
490 }
491 
492 impl From<DnsMultiplexerSerialResponseInner> for DnsMultiplexerSerialResponse {
from(inner: DnsMultiplexerSerialResponseInner) -> Self493     fn from(inner: DnsMultiplexerSerialResponseInner) -> Self {
494         DnsMultiplexerSerialResponse(inner)
495     }
496 }
497 
498 enum DnsMultiplexerSerialResponseInner {
499     Completion(oneshot::Receiver<ProtoResult<DnsResponse>>),
500     Err(Option<ProtoError>),
501 }
502 
503 impl Future for DnsMultiplexerSerialResponseInner {
504     type Output = Result<DnsResponse, ProtoError>;
505 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>506     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
507         match *self {
508             // The inner type of the completion might have been an error
509             //   we need to unwrap that, and translate to be the Future's error
510             DnsMultiplexerSerialResponseInner::Completion(ref mut complete) => {
511                 complete.poll_unpin(cx).map(|r| {
512                     r.map_err(|_| ProtoError::from("the completion was canceled"))
513                         .and_then(|r| r)
514                 })
515             }
516             DnsMultiplexerSerialResponseInner::Err(ref mut err) => {
517                 Poll::Ready(Err(err.take().expect("cannot poll after complete")))
518             }
519         }
520     }
521 }
522