1 use crate::error::TLSError;
2 #[cfg(feature = "logging")]
3 use crate::log::warn;
4 use crate::msgs::enums::{ContentType, HandshakeType};
5 use crate::msgs::message::{Message, MessagePayload};
6 
7 /// For a Message $m, and a HandshakePayload enum member $payload_type,
8 /// return Ok(payload) if $m is both a handshake message and one that
9 /// has the given $payload_type.  If not, return Err(TLSError) quoting
10 /// $handshake_type as the expected handshake type.
11 macro_rules! require_handshake_msg(
12   ( $m:expr, $handshake_type:path, $payload_type:path ) => (
13     match $m.payload {
14         MessagePayload::Handshake(ref hsp) => match hsp.payload {
15             $payload_type(ref hm) => Ok(hm),
16             _ => Err(TLSError::InappropriateHandshakeMessage {
17                      expect_types: vec![ $handshake_type ],
18                      got_type: hsp.typ})
19         }
20         _ => Err(TLSError::InappropriateMessage {
21                  expect_types: vec![ ContentType::Handshake ],
22                  got_type: $m.typ})
23     }
24   )
25 );
26 
27 /// Like require_handshake_msg, but moves the payload out of $m.
28 macro_rules! require_handshake_msg_mut(
29   ( $m:expr, $handshake_type:path, $payload_type:path ) => (
30     match $m.payload {
31         MessagePayload::Handshake(hsp) => match hsp.payload {
32             $payload_type(hm) => Ok(hm),
33             _ => Err(TLSError::InappropriateHandshakeMessage {
34                      expect_types: vec![ $handshake_type ],
35                      got_type: hsp.typ})
36         }
37         _ => Err(TLSError::InappropriateMessage {
38                  expect_types: vec![ ContentType::Handshake ],
39                  got_type: $m.typ})
40     }
41   )
42 );
43 
44 /// Validate the message `m`: return an error if:
45 ///
46 /// - the type of m does not appear in `content_types`.
47 /// - if m is a handshake message, the handshake message type does
48 ///   not appear in `handshake_types`.
check_message( m: &Message, content_types: &[ContentType], handshake_types: &[HandshakeType], ) -> Result<(), TLSError>49 pub fn check_message(
50     m: &Message,
51     content_types: &[ContentType],
52     handshake_types: &[HandshakeType],
53 ) -> Result<(), TLSError> {
54     if !content_types.contains(&m.typ) {
55         warn!(
56             "Received a {:?} message while expecting {:?}",
57             m.typ, content_types
58         );
59         return Err(TLSError::InappropriateMessage {
60             expect_types: content_types.to_vec(),
61             got_type: m.typ,
62         });
63     }
64 
65     if let MessagePayload::Handshake(ref hsp) = m.payload {
66         if !handshake_types.is_empty() && !handshake_types.contains(&hsp.typ) {
67             warn!(
68                 "Received a {:?} handshake message while expecting {:?}",
69                 hsp.typ, handshake_types
70             );
71             return Err(TLSError::InappropriateHandshakeMessage {
72                 expect_types: handshake_types.to_vec(),
73                 got_type: hsp.typ,
74             });
75         }
76     }
77 
78     Ok(())
79 }
80