1 //! Wayland socket manipulation
2 
3 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
4 
5 use nix::sys::socket;
6 use nix::sys::uio;
7 use nix::Result as NixResult;
8 
9 use wire::{ArgumentType, Message, MessageParseError, MessageWriteError};
10 
11 /// Maximum number of FD that can be sent in a single socket message
12 pub const MAX_FDS_OUT: usize = 28;
13 /// Maximum number of bytes that can be sent in a single socket message
14 pub const MAX_BYTES_OUT: usize = 4096;
15 
16 /*
17  * Socket
18  */
19 
20 /// A wayland socket
21 pub struct Socket {
22     fd: RawFd,
23 }
24 
25 impl Socket {
26     /// Send a single message to the socket
27     ///
28     /// A single socket message can contain several wayland messages
29     ///
30     /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
31     /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
32     /// end may lose some data.
send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()>33     pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()> {
34         let iov = [uio::IoVec::from_slice(bytes)];
35         if fds.len() > 0 {
36             let cmsgs = [socket::ControlMessage::ScmRights(fds)];
37             socket::sendmsg(self.fd, &iov, &cmsgs, socket::MsgFlags::MSG_DONTWAIT, None)?;
38         } else {
39             socket::sendmsg(self.fd, &iov, &[], socket::MsgFlags::MSG_DONTWAIT, None)?;
40         };
41         Ok(())
42     }
43 
44     /// Receive a single message from the socket
45     ///
46     /// Return the number of bytes received and the number of Fds received.
47     ///
48     /// Errors with `WouldBlock` is no message is available.
49     ///
50     /// A single socket message can contain several wayland messages.
51     ///
52     /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
53     /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
54     /// be lost.
rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)>55     pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)> {
56         let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
57         let iov = [uio::IoVec::from_mut_slice(buffer)];
58 
59         let msg = socket::recvmsg(self.fd, &iov[..], Some(&mut cmsg), socket::MsgFlags::MSG_DONTWAIT)?;
60 
61         let mut fd_count = 0;
62         let received_fds = msg.cmsgs().flat_map(|cmsg| {
63             match cmsg {
64                 socket::ControlMessageOwned::ScmRights(s) => s,
65                 _ => Vec::new(),
66             }
67         });
68         for (fd, place) in received_fds.zip(fds.iter_mut()) {
69             fd_count += 1;
70             *place = fd;
71         }
72         Ok((msg.bytes, fd_count))
73     }
74 }
75 
76 impl FromRawFd for Socket {
from_raw_fd(fd: RawFd) -> Socket77     unsafe fn from_raw_fd(fd: RawFd) -> Socket {
78         Socket { fd }
79     }
80 }
81 
82 impl AsRawFd for Socket {
as_raw_fd(&self) -> RawFd83     fn as_raw_fd(&self) -> RawFd {
84         self.fd
85     }
86 }
87 
88 impl IntoRawFd for Socket {
into_raw_fd(self) -> RawFd89     fn into_raw_fd(self) -> RawFd {
90         self.fd
91     }
92 }
93 
94 impl Drop for Socket {
drop(&mut self)95     fn drop(&mut self) {
96         let _ = ::nix::unistd::close(self.fd);
97     }
98 }
99 
100 /*
101  * BufferedSocket
102  */
103 
104 /// An adapter around a raw Socket that directly handles buffering and
105 /// conversion from/to wayland messages
106 pub struct BufferedSocket {
107     socket: Socket,
108     in_data: Buffer<u32>,
109     in_fds: Buffer<RawFd>,
110     out_data: Buffer<u32>,
111     out_fds: Buffer<RawFd>,
112 }
113 
114 impl BufferedSocket {
115     /// Wrap a Socket into a Buffered Socket
new(socket: Socket) -> BufferedSocket116     pub fn new(socket: Socket) -> BufferedSocket {
117         BufferedSocket {
118             socket: socket,
119             in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), // Incoming buffers are twice as big in order to be
120             in_fds: Buffer::new(2 * MAX_FDS_OUT),        // able to store leftover data if needed
121             out_data: Buffer::new(MAX_BYTES_OUT / 4),
122             out_fds: Buffer::new(MAX_FDS_OUT),
123         }
124     }
125 
126     /// Get direct access to the underlying socket
get_socket(&mut self) -> &mut Socket127     pub fn get_socket(&mut self) -> &mut Socket {
128         &mut self.socket
129     }
130 
131     /// Retrieve ownership of the underlying Socket
132     ///
133     /// Any leftover content in the internal buffers will be lost
into_socket(self) -> Socket134     pub fn into_socket(self) -> Socket {
135         self.socket
136     }
137 
138     /// Flush the contents of the outgoing buffer into the socket
flush(&mut self) -> NixResult<()>139     pub fn flush(&mut self) -> NixResult<()> {
140         {
141             let words = self.out_data.get_contents();
142             let bytes = unsafe { ::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4) };
143             let fds = self.out_fds.get_contents();
144             self.socket.send_msg(bytes, fds)?;
145         }
146         self.out_data.clear();
147         self.out_fds.clear();
148         Ok(())
149     }
150 
151     // internal method
152     //
153     // attempts to write a message in the internal out buffers,
154     // returns true if successful
155     //
156     // if false is returned, it means there is not enough space
157     // in the buffer
attempt_write_message(&mut self, msg: &Message) -> NixResult<bool>158     fn attempt_write_message(&mut self, msg: &Message) -> NixResult<bool> {
159         match msg.write_to_buffers(
160             self.out_data.get_writable_storage(),
161             self.out_fds.get_writable_storage(),
162         ) {
163             Ok((bytes_out, fds_out)) => {
164                 self.out_data.advance(bytes_out);
165                 self.out_fds.advance(fds_out);
166                 Ok(true)
167             }
168             Err(MessageWriteError::BufferTooSmall) => Ok(false),
169             Err(MessageWriteError::DupFdFailed(e)) => Err(e),
170         }
171     }
172 
173     /// Write a message to the outgoing buffer
174     ///
175     /// This method may flush the internal buffer if necessary (if it is full).
176     ///
177     /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
178     /// will be returned.
write_message(&mut self, msg: &Message) -> NixResult<()>179     pub fn write_message(&mut self, msg: &Message) -> NixResult<()> {
180         if !self.attempt_write_message(msg)? {
181             // the attempt failed, there is not enough space in the buffer
182             // we need to flush it
183             self.flush()?;
184             if !self.attempt_write_message(msg)? {
185                 // If this fails again, this means the message is too big
186                 // to be transmitted at all
187                 return Err(::nix::Error::Sys(::nix::errno::Errno::E2BIG));
188             }
189         }
190         Ok(())
191     }
192 
193     /// Try to fill the incoming buffers of this socket, to prepare
194     /// a new round of parsing.
fill_incoming_buffers(&mut self) -> NixResult<()>195     pub fn fill_incoming_buffers(&mut self) -> NixResult<()> {
196         // clear the buffers if they have no content
197         if !self.in_data.has_content() {
198             self.in_data.clear();
199         }
200         if !self.in_fds.has_content() {
201             self.in_fds.clear();
202         }
203         // receive a message
204         let (in_bytes, in_fds) = {
205             let words = self.in_data.get_writable_storage();
206             let bytes =
207                 unsafe { ::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4) };
208             let fds = self.in_fds.get_writable_storage();
209             self.socket.rcv_msg(bytes, fds)?
210         };
211         if in_bytes == 0 {
212             // the other end of the socket was closed
213             return Err(::nix::Error::Sys(::nix::errno::Errno::EPIPE));
214         }
215         // advance the storage
216         self.in_data
217             .advance(in_bytes / 4 + if in_bytes % 4 > 0 { 1 } else { 0 });
218         self.in_fds.advance(in_fds);
219         Ok(())
220     }
221 
222     /// Read and deserialize a single message from the incoming buffers socket
223     ///
224     /// This method requires one closure that given an object id and an opcode,
225     /// must provide the signature of the associated request/event, in the form of
226     /// a `&'static [ArgumentType]`. If it returns `None`, meaning that
227     /// the couple object/opcode does not exist, an error will be returned.
228     ///
229     /// There are 3 possibilities of return value:
230     ///
231     /// - `Ok(Ok(msg))`: no error occurred, this is the message
232     /// - `Ok(Err(e))`: either a malformed message was encountered or we need more data,
233     ///    in the latter case you need to try calling `fill_incoming_buffers()`.
234     /// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
235     ///   (this can be a "wouldblock" error, which just means that no message is available
236     ///   to read)
read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError> where F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,237     pub fn read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError>
238     where
239         F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
240     {
241         let (msg, read_data, read_fd) = {
242             let data = self.in_data.get_contents();
243             let fds = self.in_fds.get_contents();
244             if data.len() < 2 {
245                 return Err(MessageParseError::MissingData);
246             }
247             let object_id = data[0];
248             let opcode = (data[1] & 0x0000FFFF) as u16;
249             if let Some(sig) = signature(object_id, opcode) {
250                 match Message::from_raw(data, sig, fds) {
251                     Ok((msg, rest_data, rest_fds)) => {
252                         (msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
253                     }
254                     // TODO: gracefully handle wayland messages split across unix messages ?
255                     Err(e) => return Err(e),
256                 }
257             } else {
258                 // no signature found ?
259                 return Err(MessageParseError::Malformed);
260             }
261         };
262 
263         self.in_data.offset(read_data);
264         self.in_fds.offset(read_fd);
265 
266         Ok(msg)
267     }
268 
269     /// Read and deserialize messages from the socket
270     ///
271     /// This method requires two closures:
272     ///
273     /// - The first one, given an object id and an opcode, must provide
274     ///   the signature of the associated request/event, in the form of
275     ///   a `&'static [ArgumentType]`. If it returns `None`, meaning that
276     ///   the couple object/opcode does not exist, the parsing will be
277     ///   prematurely interrupted and this method will return a
278     ///   `MessageParseError::Malformed` error.
279     /// - The second closure is charged to process the parsed message. If it
280     ///   returns `false`, the iteration will be prematurely stopped.
281     ///
282     /// In both cases of early stopping, the remaining unused data will be left
283     /// in the buffers, and will start to be processed at the next call of this
284     /// method.
285     ///
286     /// There are 3 possibilities of return value:
287     ///
288     /// - `Ok(Ok(n))`: no error occurred, `n` messages where processed
289     /// - `Ok(Err(MessageParseError::Malformed))`: a malformed message was encountered
290     ///   (this is a protocol error and is supposed to be fatal to the connection).
291     /// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
292     ///   (this can be a "wouldblock" error, which just means that no message is available
293     ///   to read)
read_messages<F1, F2>( &mut self, mut signature: F1, mut callback: F2, ) -> NixResult<Result<usize, MessageParseError>> where F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>, F2: FnMut(Message) -> bool,294     pub fn read_messages<F1, F2>(
295         &mut self,
296         mut signature: F1,
297         mut callback: F2,
298     ) -> NixResult<Result<usize, MessageParseError>>
299     where
300         F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
301         F2: FnMut(Message) -> bool,
302     {
303         // message parsing
304         let mut dispatched = 0;
305 
306         loop {
307             let mut err = None;
308             // first parse any leftover messages
309             loop {
310                 match self.read_one_message(&mut signature) {
311                     Ok(msg) => {
312                         let keep_going = callback(msg);
313                         dispatched += 1;
314                         if !keep_going {
315                             break;
316                         }
317                     }
318                     Err(e) => {
319                         err = Some(e);
320                         break;
321                     }
322                 }
323             }
324 
325             // copy back any leftover content to the front of the buffer
326             self.in_data.move_to_front();
327             self.in_fds.move_to_front();
328 
329             if let Some(MessageParseError::Malformed) = err {
330                 // early stop here
331                 return Ok(Err(MessageParseError::Malformed));
332             }
333 
334             if err.is_none() && self.in_data.has_content() {
335                 // we stopped reading without error while there is content? That means
336                 // the user requested an early stopping
337                 return Ok(Ok(dispatched));
338             }
339 
340             // now, try to get more data
341             match self.fill_incoming_buffers() {
342                 Ok(()) => (),
343                 Err(e @ ::nix::Error::Sys(::nix::errno::Errno::EAGAIN)) => {
344                     // stop looping, returning Ok() or EAGAIN depending on whether messages
345                     // were dispatched
346                     if dispatched == 0 {
347                         return Err(e);
348                     } else {
349                         break;
350                     }
351                 }
352                 Err(e) => return Err(e),
353             }
354         }
355 
356         Ok(Ok(dispatched))
357     }
358 }
359 
360 /*
361  * Buffer
362  */
363 
364 struct Buffer<T: Copy> {
365     storage: Vec<T>,
366     occupied: usize,
367     offset: usize,
368 }
369 
370 impl<T: Copy + Default> Buffer<T> {
new(size: usize) -> Buffer<T>371     fn new(size: usize) -> Buffer<T> {
372         Buffer {
373             storage: vec![T::default(); size],
374             occupied: 0,
375             offset: 0,
376         }
377     }
378 
379     /// Check if this buffer has content to read
has_content(&self) -> bool380     fn has_content(&self) -> bool {
381         self.occupied > self.offset
382     }
383 
384     /// Advance the internal counter of occupied space
advance(&mut self, bytes: usize)385     fn advance(&mut self, bytes: usize) {
386         self.occupied += bytes;
387     }
388 
389     /// Advance the read offset of current occupied space
offset(&mut self, bytes: usize)390     fn offset(&mut self, bytes: usize) {
391         self.offset += bytes;
392     }
393 
394     /// Clears the contents of the buffer
395     ///
396     /// This only sets the counter of occupied space back to zero,
397     /// allowing previous content to be overwritten.
clear(&mut self)398     fn clear(&mut self) {
399         self.occupied = 0;
400         self.offset = 0;
401     }
402 
403     /// Get the current contents of the occupied space of the buffer
get_contents(&self) -> &[T]404     fn get_contents(&self) -> &[T] {
405         &self.storage[(self.offset)..(self.occupied)]
406     }
407 
408     /// Get mutable access to the unoccupied space of the buffer
get_writable_storage(&mut self) -> &mut [T]409     fn get_writable_storage(&mut self) -> &mut [T] {
410         &mut self.storage[(self.occupied)..]
411     }
412 
413     /// Move the unread contents of the buffer to the front, to ensure
414     /// maximal write space availability
move_to_front(&mut self)415     fn move_to_front(&mut self) {
416         unsafe {
417             ::std::ptr::copy(
418                 &self.storage[self.offset] as *const T,
419                 &mut self.storage[0] as *mut T,
420                 self.occupied - self.offset,
421             );
422         }
423         self.occupied -= self.offset;
424         self.offset = 0;
425     }
426 }
427 
428 #[cfg(test)]
429 mod tests {
430     use super::*;
431     use wire::{Argument, ArgumentType, Message};
432 
433     use std::ffi::CString;
434 
same_file(a: RawFd, b: RawFd) -> bool435     fn same_file(a: RawFd, b: RawFd) -> bool {
436         let stat1 = ::nix::sys::stat::fstat(a).unwrap();
437         let stat2 = ::nix::sys::stat::fstat(b).unwrap();
438         stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
439     }
440 
441     // check if two messages are equal
442     //
443     // if arguments contain FDs, check that the fd point to
444     // the same file, rather than are the same number.
assert_eq_msgs(msg1: &Message, msg2: &Message)445     fn assert_eq_msgs(msg1: &Message, msg2: &Message) {
446         assert_eq!(msg1.sender_id, msg2.sender_id);
447         assert_eq!(msg1.opcode, msg2.opcode);
448         assert_eq!(msg1.args.len(), msg2.args.len());
449         for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
450             if let (&Argument::Fd(fd1), &Argument::Fd(fd2)) = (arg1, arg2) {
451                 assert!(same_file(fd1, fd2));
452             } else {
453                 assert_eq!(arg1, arg2);
454             }
455         }
456     }
457 
458     #[test]
write_read_cycle()459     fn write_read_cycle() {
460         let msg = Message {
461             sender_id: 42,
462             opcode: 7,
463             args: vec![
464                 Argument::Uint(3),
465                 Argument::Fixed(-89),
466                 Argument::Str(CString::new(&b"I like trains!"[..]).unwrap()),
467                 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]),
468                 Argument::Object(88),
469                 Argument::NewId(56),
470                 Argument::Int(-25),
471             ],
472         };
473 
474         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
475         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
476         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
477 
478         client.write_message(&msg).unwrap();
479         client.flush().unwrap();
480 
481         static SIGNATURE: &'static [ArgumentType] = &[
482             ArgumentType::Uint,
483             ArgumentType::Fixed,
484             ArgumentType::Str,
485             ArgumentType::Array,
486             ArgumentType::Object,
487             ArgumentType::NewId,
488             ArgumentType::Int,
489         ];
490 
491         let ret = server
492             .read_messages(
493                 |sender_id, opcode| {
494                     if sender_id == 42 && opcode == 7 {
495                         Some(SIGNATURE)
496                     } else {
497                         None
498                     }
499                 },
500                 |message| {
501                     assert_eq_msgs(&message, &msg);
502                     true
503                 },
504             )
505             .unwrap()
506             .unwrap();
507 
508         assert_eq!(ret, 1);
509     }
510 
511     #[test]
write_read_cycle_fd()512     fn write_read_cycle_fd() {
513         let msg = Message {
514             sender_id: 42,
515             opcode: 7,
516             args: vec![
517                 Argument::Fd(1), // stdin
518                 Argument::Fd(0), // stdout
519             ],
520         };
521 
522         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
523         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
524         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
525 
526         client.write_message(&msg).unwrap();
527         client.flush().unwrap();
528 
529         static SIGNATURE: &'static [ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
530 
531         let ret = server
532             .read_messages(
533                 |sender_id, opcode| {
534                     if sender_id == 42 && opcode == 7 {
535                         Some(SIGNATURE)
536                     } else {
537                         None
538                     }
539                 },
540                 |message| {
541                     assert_eq_msgs(&message, &msg);
542                     true
543                 },
544             )
545             .unwrap()
546             .unwrap();
547 
548         assert_eq!(ret, 1);
549     }
550 
551     #[test]
write_read_cycle_multiple()552     fn write_read_cycle_multiple() {
553         let messages = [
554             Message {
555                 sender_id: 42,
556                 opcode: 0,
557                 args: vec![
558                     Argument::Int(42),
559                     Argument::Str(CString::new(&b"I like trains"[..]).unwrap()),
560                 ],
561             },
562             Message {
563                 sender_id: 42,
564                 opcode: 1,
565                 args: vec![
566                     Argument::Fd(1), // stdin
567                     Argument::Fd(0), // stdout
568                 ],
569             },
570             Message {
571                 sender_id: 42,
572                 opcode: 2,
573                 args: vec![
574                     Argument::Uint(3),
575                     Argument::Fd(2), // stderr
576                 ],
577             },
578         ];
579 
580         static SIGNATURES: &'static [&'static [ArgumentType]] = &[
581             &[ArgumentType::Int, ArgumentType::Str],
582             &[ArgumentType::Fd, ArgumentType::Fd],
583             &[ArgumentType::Uint, ArgumentType::Fd],
584         ];
585 
586         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
587         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
588         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
589 
590         for msg in &messages {
591             client.write_message(msg).unwrap();
592         }
593         client.flush().unwrap();
594 
595         let mut recv_msgs = Vec::new();
596         let ret = server
597             .read_messages(
598                 |sender_id, opcode| {
599                     if sender_id == 42 {
600                         Some(SIGNATURES[opcode as usize])
601                     } else {
602                         None
603                     }
604                 },
605                 |message| {
606                     recv_msgs.push(message);
607                     true
608                 },
609             )
610             .unwrap()
611             .unwrap();
612 
613         assert_eq!(ret, 3);
614         assert_eq!(recv_msgs.len(), 3);
615         for (msg1, msg2) in messages.iter().zip(recv_msgs.iter()) {
616             assert_eq_msgs(msg1, msg2);
617         }
618     }
619 
620     #[test]
parse_with_string_len_multiple_of_4()621     fn parse_with_string_len_multiple_of_4() {
622         let msg = Message {
623             sender_id: 2,
624             opcode: 0,
625             args: vec![
626                 Argument::Uint(18),
627                 Argument::Str(CString::new(&b"wl_shell"[..]).unwrap()),
628                 Argument::Uint(1),
629             ],
630         };
631 
632         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
633         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
634         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
635 
636         client.write_message(&msg).unwrap();
637         client.flush().unwrap();
638 
639         static SIGNATURE: &'static [ArgumentType] =
640             &[ArgumentType::Uint, ArgumentType::Str, ArgumentType::Uint];
641 
642         let ret = server
643             .read_messages(
644                 |sender_id, opcode| {
645                     if sender_id == 2 && opcode == 0 {
646                         Some(SIGNATURE)
647                     } else {
648                         None
649                     }
650                 },
651                 |message| {
652                     assert_eq_msgs(&message, &msg);
653                     true
654                 },
655             )
656             .unwrap()
657             .unwrap();
658 
659         assert_eq!(ret, 1);
660     }
661 }
662