1 use std::convert::TryFrom;
2 
3 use enumflags2::{bitflags, BitFlags};
4 use once_cell::sync::OnceCell;
5 use serde::{Deserialize, Serialize};
6 use serde_repr::{Deserialize_repr, Serialize_repr};
7 
8 use static_assertions::assert_impl_all;
9 use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
10 use zvariant::{EncodingContext, ObjectPath, Signature, Type};
11 
12 use crate::{Error, MessageField, MessageFieldCode, MessageFields};
13 
14 pub(crate) const PRIMARY_HEADER_SIZE: usize = 12;
15 pub(crate) const MIN_MESSAGE_SIZE: usize = PRIMARY_HEADER_SIZE + 4;
16 
17 /// D-Bus code for endianness.
18 #[repr(u8)]
19 #[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Serialize_repr, Type)]
20 pub enum EndianSig {
21     /// The D-Bus message is in big-endian (network) byte order.
22     Big = b'B',
23 
24     /// The D-Bus message is in little-endian byte order.
25     Little = b'l',
26 }
27 
28 assert_impl_all!(EndianSig: Send, Sync, Unpin);
29 
30 // Such a shame I've to do this manually
31 impl TryFrom<u8> for EndianSig {
32     type Error = Error;
33 
try_from(val: u8) -> Result<EndianSig, Error>34     fn try_from(val: u8) -> Result<EndianSig, Error> {
35         match val {
36             b'B' => Ok(EndianSig::Big),
37             b'l' => Ok(EndianSig::Little),
38             _ => Err(Error::IncorrectEndian),
39         }
40     }
41 }
42 
43 #[cfg(target_endian = "big")]
44 /// Signature of the target's native endian.
45 pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Big;
46 #[cfg(target_endian = "little")]
47 /// Signature of the target's native endian.
48 pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Little;
49 
50 /// Message header representing the D-Bus type of the message.
51 #[repr(u8)]
52 #[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Serialize_repr, Type)]
53 pub enum MessageType {
54     /// Invalid message type. All unknown types on received messages are treated as invalid.
55     Invalid = 0,
56     /// Method call. This message type may prompt a reply (and typically does).
57     MethodCall = 1,
58     /// A reply to a method call.
59     MethodReturn = 2,
60     /// An error in response to a method call.
61     Error = 3,
62     /// Signal emission.
63     Signal = 4,
64 }
65 
66 assert_impl_all!(MessageType: Send, Sync, Unpin);
67 
68 // Such a shame I've to do this manually
69 impl From<u8> for MessageType {
from(val: u8) -> MessageType70     fn from(val: u8) -> MessageType {
71         match val {
72             1 => MessageType::MethodCall,
73             2 => MessageType::MethodReturn,
74             3 => MessageType::Error,
75             4 => MessageType::Signal,
76             _ => MessageType::Invalid,
77         }
78     }
79 }
80 
81 /// Pre-defined flags that can be passed in Message header.
82 #[bitflags]
83 #[repr(u8)]
84 #[derive(Debug, Copy, Clone, PartialEq, Type)]
85 pub enum MessageFlags {
86     /// This message does not expect method return replies or error replies, even if it is of a type
87     /// that can have a reply; the reply should be omitted.
88     ///
89     /// Note that `MessageType::MethodCall` is the only message type currently defined in the
90     /// specification that can expect a reply, so the presence or absence of this flag in the other
91     /// three message types that are currently documented is meaningless: replies to those message
92     /// types should not be sent, whether this flag is present or not.
93     NoReplyExpected = 0x1,
94     /// The bus must not launch an owner for the destination name in response to this message.
95     NoAutoStart = 0x2,
96     /// This flag may be set on a method call message to inform the receiving side that the caller
97     /// is prepared to wait for interactive authorization, which might take a considerable time to
98     /// complete. For instance, if this flag is set, it would be appropriate to query the user for
99     /// passwords or confirmation via Polkit or a similar framework.
100     AllowInteractiveAuth = 0x4,
101 }
102 
103 assert_impl_all!(MessageFlags: Send, Sync, Unpin);
104 
105 #[derive(Clone, Debug)]
106 struct SerialNum(OnceCell<u32>);
107 
108 // FIXME: Can use `zvariant::Type` macro after `zvariant` provides a blanket implementation for
109 // `OnceCell<T>`.
110 impl zvariant::Type for SerialNum {
signature() -> Signature<'static>111     fn signature() -> Signature<'static> {
112         u32::signature()
113     }
114 }
115 
116 // Unfortunately Serde doesn't provide a blanket impl. for `Cell<T>` so we have to implement manually.
117 //
118 // https://github.com/serde-rs/serde/issues/1952
119 impl Serialize for SerialNum {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer,120     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121     where
122         S: serde::Serializer,
123     {
124         // `Message` serializes the PrimaryHeader at construct time before the user has the
125         // time to tweak it and set a correct serial_num. We should probably avoid this but
126         // for now, let's silently use a default serialized value.
127         self.0
128             .get()
129             .cloned()
130             .unwrap_or_default()
131             .serialize(serializer)
132     }
133 }
134 
135 impl<'de> Deserialize<'de> for SerialNum {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: serde::Deserializer<'de>,136     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
137     where
138         D: serde::Deserializer<'de>,
139     {
140         Ok(SerialNum(OnceCell::from(u32::deserialize(deserializer)?)))
141     }
142 }
143 
144 /// The primary message header, which is present in all D-Bus messages.
145 ///
146 /// This header contains all the essential information about a message, regardless of its type.
147 #[derive(Clone, Debug, Serialize, Deserialize, Type)]
148 pub struct MessagePrimaryHeader {
149     endian_sig: EndianSig,
150     msg_type: MessageType,
151     flags: BitFlags<MessageFlags>,
152     protocol_version: u8,
153     body_len: u32,
154     serial_num: SerialNum,
155 }
156 
157 assert_impl_all!(MessagePrimaryHeader: Send, Sync, Unpin);
158 
159 impl MessagePrimaryHeader {
160     /// Create a new `MessagePrimaryHeader` instance.
new(msg_type: MessageType, body_len: u32) -> Self161     pub fn new(msg_type: MessageType, body_len: u32) -> Self {
162         Self {
163             endian_sig: NATIVE_ENDIAN_SIG,
164             msg_type,
165             flags: BitFlags::empty(),
166             protocol_version: 1,
167             body_len,
168             serial_num: SerialNum(OnceCell::new()),
169         }
170     }
171 
read(buf: &[u8]) -> Result<(MessagePrimaryHeader, u32), Error>172     pub(crate) fn read(buf: &[u8]) -> Result<(MessagePrimaryHeader, u32), Error> {
173         let ctx = EncodingContext::<byteorder::NativeEndian>::new_dbus(0);
174         let primary_header = zvariant::from_slice(buf, ctx)?;
175         let fields_len = zvariant::from_slice(&buf[PRIMARY_HEADER_SIZE..], ctx)?;
176         Ok((primary_header, fields_len))
177     }
178 
179     /// D-Bus code for bytorder encoding of the message.
endian_sig(&self) -> EndianSig180     pub fn endian_sig(&self) -> EndianSig {
181         self.endian_sig
182     }
183 
184     /// Set the D-Bus code for bytorder encoding of the message.
set_endian_sig(&mut self, sig: EndianSig)185     pub fn set_endian_sig(&mut self, sig: EndianSig) {
186         self.endian_sig = sig;
187     }
188 
189     /// The message type.
msg_type(&self) -> MessageType190     pub fn msg_type(&self) -> MessageType {
191         self.msg_type
192     }
193 
194     /// Set the message type.
set_msg_type(&mut self, msg_type: MessageType)195     pub fn set_msg_type(&mut self, msg_type: MessageType) {
196         self.msg_type = msg_type;
197     }
198 
199     /// The message flags.
flags(&self) -> BitFlags<MessageFlags>200     pub fn flags(&self) -> BitFlags<MessageFlags> {
201         self.flags
202     }
203 
204     /// Set the message flags.
set_flags(&mut self, flags: BitFlags<MessageFlags>)205     pub fn set_flags(&mut self, flags: BitFlags<MessageFlags>) {
206         self.flags = flags;
207     }
208 
209     /// The major version of the protocol the message is compliant to.
210     ///
211     /// Currently only `1` is valid.
protocol_version(&self) -> u8212     pub fn protocol_version(&self) -> u8 {
213         self.protocol_version
214     }
215 
216     /// Set the major version of the protocol the message is compliant to.
217     ///
218     /// Currently only `1` is valid.
set_protocol_version(&mut self, version: u8)219     pub fn set_protocol_version(&mut self, version: u8) {
220         self.protocol_version = version;
221     }
222 
223     /// The byte length of the message body.
body_len(&self) -> u32224     pub fn body_len(&self) -> u32 {
225         self.body_len
226     }
227 
228     /// Set the byte length of the message body.
set_body_len(&mut self, len: u32)229     pub fn set_body_len(&mut self, len: u32) {
230         self.body_len = len;
231     }
232 
233     /// The serial number of the message (if set).
234     ///
235     /// This is used to match a reply to a method call.
236     ///
237     /// **Note:** There is no setter provided for this in the public API since this is set by the
238     /// [`Connection`](struct.Connection.html) the message is sent over.
serial_num(&self) -> Option<&u32>239     pub fn serial_num(&self) -> Option<&u32> {
240         self.serial_num.0.get()
241     }
242 
serial_num_or_init<F>(&mut self, f: F) -> &u32 where F: FnOnce() -> u32,243     pub(crate) fn serial_num_or_init<F>(&mut self, f: F) -> &u32
244     where
245         F: FnOnce() -> u32,
246     {
247         self.serial_num.0.get_or_init(f)
248     }
249 }
250 
251 /// The message header, containing all the metadata about the message.
252 ///
253 /// This includes both the [`MessagePrimaryHeader`] and [`MessageFields`].
254 ///
255 /// [`MessagePrimaryHeader`]: struct.MessagePrimaryHeader.html
256 /// [`MessageFields`]: struct.MessageFields.html
257 #[derive(Debug, Clone, Serialize, Deserialize, Type)]
258 pub struct MessageHeader<'m> {
259     primary: MessagePrimaryHeader,
260     #[serde(borrow)]
261     fields: MessageFields<'m>,
262     end: ((),), // To ensure header end on 8-byte boundary
263 }
264 
265 assert_impl_all!(MessageHeader<'_>: Send, Sync, Unpin);
266 
267 macro_rules! get_field {
268     ($self:ident, $kind:ident) => {
269         get_field!($self, $kind, (|v| v))
270     };
271     ($self:ident, $kind:ident, $closure:tt) => {
272         #[allow(clippy::redundant_closure_call)]
273         match $self.fields().get_field(MessageFieldCode::$kind) {
274             Some(MessageField::$kind(value)) => Ok(Some($closure(value))),
275             Some(_) => Err(Error::InvalidField),
276             None => Ok(None),
277         }
278     };
279 }
280 
281 macro_rules! get_field_u32 {
282     ($self:ident, $kind:ident) => {
283         get_field!($self, $kind, (|v: &u32| *v))
284     };
285 }
286 
287 impl<'m> MessageHeader<'m> {
288     /// Create a new `MessageHeader` instance.
new(primary: MessagePrimaryHeader, fields: MessageFields<'m>) -> Self289     pub fn new(primary: MessagePrimaryHeader, fields: MessageFields<'m>) -> Self {
290         Self {
291             primary,
292             fields,
293             end: ((),),
294         }
295     }
296 
297     /// Get a reference to the primary header.
primary(&self) -> &MessagePrimaryHeader298     pub fn primary(&self) -> &MessagePrimaryHeader {
299         &self.primary
300     }
301 
302     /// Get a mutable reference to the primary header.
primary_mut(&mut self) -> &mut MessagePrimaryHeader303     pub fn primary_mut(&mut self) -> &mut MessagePrimaryHeader {
304         &mut self.primary
305     }
306 
307     /// Get the primary header, consuming `self`.
into_primary(self) -> MessagePrimaryHeader308     pub fn into_primary(self) -> MessagePrimaryHeader {
309         self.primary
310     }
311 
312     /// Get a reference to the message fields.
fields<'s>(&'s self) -> &'s MessageFields<'m>313     pub fn fields<'s>(&'s self) -> &'s MessageFields<'m> {
314         &self.fields
315     }
316 
317     /// Get a mutable reference to the message fields.
fields_mut<'s>(&'s mut self) -> &'s mut MessageFields<'m>318     pub fn fields_mut<'s>(&'s mut self) -> &'s mut MessageFields<'m> {
319         &mut self.fields
320     }
321 
322     /// Get the message fields, consuming `self`.
into_fields(self) -> MessageFields<'m>323     pub fn into_fields(self) -> MessageFields<'m> {
324         self.fields
325     }
326 
327     /// The message type
message_type(&self) -> Result<MessageType, Error>328     pub fn message_type(&self) -> Result<MessageType, Error> {
329         Ok(self.primary().msg_type())
330     }
331 
332     /// The object to send a call to, or the object a signal is emitted from.
path<'s>(&'s self) -> Result<Option<&ObjectPath<'m>>, Error>333     pub fn path<'s>(&'s self) -> Result<Option<&ObjectPath<'m>>, Error> {
334         get_field!(self, Path)
335     }
336 
337     /// The interface to invoke a method call on, or that a signal is emitted from.
interface<'s>(&'s self) -> Result<Option<&InterfaceName<'m>>, Error>338     pub fn interface<'s>(&'s self) -> Result<Option<&InterfaceName<'m>>, Error> {
339         get_field!(self, Interface)
340     }
341 
342     /// The member, either the method name or signal name.
member<'s>(&'s self) -> Result<Option<&MemberName<'m>>, Error>343     pub fn member<'s>(&'s self) -> Result<Option<&MemberName<'m>>, Error> {
344         get_field!(self, Member)
345     }
346 
347     /// The name of the error that occurred, for errors.
error_name<'s>(&'s self) -> Result<Option<&ErrorName<'m>>, Error>348     pub fn error_name<'s>(&'s self) -> Result<Option<&ErrorName<'m>>, Error> {
349         get_field!(self, ErrorName)
350     }
351 
352     /// The serial number of the message this message is a reply to.
reply_serial(&self) -> Result<Option<u32>, Error>353     pub fn reply_serial(&self) -> Result<Option<u32>, Error> {
354         get_field_u32!(self, ReplySerial)
355     }
356 
357     /// The name of the connection this message is intended for.
destination<'s>(&'s self) -> Result<Option<&BusName<'m>>, Error>358     pub fn destination<'s>(&'s self) -> Result<Option<&BusName<'m>>, Error> {
359         get_field!(self, Destination)
360     }
361 
362     /// Unique name of the sending connection.
sender<'s>(&'s self) -> Result<Option<&UniqueName<'m>>, Error>363     pub fn sender<'s>(&'s self) -> Result<Option<&UniqueName<'m>>, Error> {
364         get_field!(self, Sender)
365     }
366 
367     /// The signature of the message body.
signature(&self) -> Result<Option<&Signature<'m>>, Error>368     pub fn signature(&self) -> Result<Option<&Signature<'m>>, Error> {
369         get_field!(self, Signature)
370     }
371 
372     /// The number of Unix file descriptors that accompany the message.
unix_fds(&self) -> Result<Option<u32>, Error>373     pub fn unix_fds(&self) -> Result<Option<u32>, Error> {
374         get_field_u32!(self, UnixFDs)
375     }
376 }
377 
378 #[cfg(test)]
379 mod tests {
380     use crate::{MessageField, MessageFields, MessageHeader, MessagePrimaryHeader, MessageType};
381 
382     use std::{
383         convert::{TryFrom, TryInto},
384         error::Error,
385         result::Result,
386     };
387     use test_log::test;
388     use zbus_names::{InterfaceName, MemberName};
389     use zvariant::{ObjectPath, Signature};
390 
391     #[test]
header() -> Result<(), Box<dyn Error>>392     fn header() -> Result<(), Box<dyn Error>> {
393         let path = ObjectPath::try_from("/some/path")?;
394         let iface = InterfaceName::try_from("some.interface")?;
395         let member = MemberName::try_from("Member")?;
396         let mut f = MessageFields::new();
397         f.add(MessageField::Path(path.clone()));
398         f.add(MessageField::Interface(iface.clone()));
399         f.add(MessageField::Member(member.clone()));
400         f.add(MessageField::Sender(":1.84".try_into()?));
401         let h = MessageHeader::new(MessagePrimaryHeader::new(MessageType::Signal, 77), f);
402 
403         assert_eq!(h.message_type()?, MessageType::Signal);
404         assert_eq!(h.path()?, Some(&path));
405         assert_eq!(h.interface()?, Some(&iface));
406         assert_eq!(h.member()?, Some(&member));
407         assert_eq!(h.error_name()?, None);
408         assert_eq!(h.destination()?, None);
409         assert_eq!(h.reply_serial()?, None);
410         assert_eq!(h.sender()?.unwrap(), ":1.84");
411         assert_eq!(h.signature()?, None);
412         assert_eq!(h.unix_fds()?, None);
413 
414         let mut f = MessageFields::new();
415         f.add(MessageField::ErrorName("org.zbus.Error".try_into()?));
416         f.add(MessageField::Destination(":1.11".try_into()?));
417         f.add(MessageField::ReplySerial(88));
418         f.add(MessageField::Signature(Signature::from_str_unchecked(
419             "say",
420         )));
421         f.add(MessageField::UnixFDs(12));
422         let h = MessageHeader::new(MessagePrimaryHeader::new(MessageType::MethodReturn, 77), f);
423 
424         assert_eq!(h.message_type()?, MessageType::MethodReturn);
425         assert_eq!(h.path()?, None);
426         assert_eq!(h.interface()?, None);
427         assert_eq!(h.member()?, None);
428         assert_eq!(h.error_name()?.unwrap(), "org.zbus.Error");
429         assert_eq!(h.destination()?.unwrap(), ":1.11");
430         assert_eq!(h.reply_serial()?, Some(88));
431         assert_eq!(h.sender()?, None);
432         assert_eq!(h.signature()?, Some(&Signature::from_str_unchecked("say")));
433         assert_eq!(h.unix_fds()?, Some(12));
434 
435         Ok(())
436     }
437 }
438