1 //! Wayland socket manipulation
2 
3 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
4 
5 use nix::sys::socket;
6 use nix::sys::uio;
7 use nix::Result as NixResult;
8 
9 use wire::{ArgumentType, Message, MessageParseError, MessageWriteError};
10 
11 /// Maximum number of FD that can be sent in a single socket message
12 pub const MAX_FDS_OUT: usize = 28;
13 /// Maximum number of bytes that can be sent in a single socket message
14 pub const MAX_BYTES_OUT: usize = 4096;
15 
16 /*
17  * Socket
18  */
19 
20 /// A wayland socket
21 pub struct Socket {
22     fd: RawFd,
23 }
24 
25 impl Socket {
26     /// Send a single message to the socket
27     ///
28     /// A single socket message can contain several wayland messages
29     ///
30     /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
31     /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
32     /// end may lose some data.
send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()>33     pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()> {
34         let iov = [uio::IoVec::from_slice(bytes)];
35         if !fds.is_empty() {
36             let cmsgs = [socket::ControlMessage::ScmRights(fds)];
37             socket::sendmsg(self.fd, &iov, &cmsgs, socket::MsgFlags::MSG_DONTWAIT, None)?;
38         } else {
39             socket::sendmsg(self.fd, &iov, &[], socket::MsgFlags::MSG_DONTWAIT, None)?;
40         };
41         Ok(())
42     }
43 
44     /// Receive a single message from the socket
45     ///
46     /// Return the number of bytes received and the number of Fds received.
47     ///
48     /// Errors with `WouldBlock` is no message is available.
49     ///
50     /// A single socket message can contain several wayland messages.
51     ///
52     /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
53     /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
54     /// be lost.
rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)>55     pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)> {
56         let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
57         let iov = [uio::IoVec::from_mut_slice(buffer)];
58 
59         let msg = socket::recvmsg(self.fd, &iov[..], Some(&mut cmsg), socket::MsgFlags::MSG_DONTWAIT)?;
60 
61         let mut fd_count = 0;
62         let received_fds = msg.cmsgs().flat_map(|cmsg| {
63             match cmsg {
64                 socket::ControlMessageOwned::ScmRights(s) => s,
65                 _ => Vec::new(),
66             }
67         });
68         for (fd, place) in received_fds.zip(fds.iter_mut()) {
69             fd_count += 1;
70             *place = fd;
71         }
72         Ok((msg.bytes, fd_count))
73     }
74 }
75 
76 impl FromRawFd for Socket {
from_raw_fd(fd: RawFd) -> Socket77     unsafe fn from_raw_fd(fd: RawFd) -> Socket {
78         Socket { fd }
79     }
80 }
81 
82 impl AsRawFd for Socket {
as_raw_fd(&self) -> RawFd83     fn as_raw_fd(&self) -> RawFd {
84         self.fd
85     }
86 }
87 
88 impl IntoRawFd for Socket {
into_raw_fd(self) -> RawFd89     fn into_raw_fd(self) -> RawFd {
90         self.fd
91     }
92 }
93 
94 impl Drop for Socket {
drop(&mut self)95     fn drop(&mut self) {
96         let _ = ::nix::unistd::close(self.fd);
97     }
98 }
99 
100 /*
101  * BufferedSocket
102  */
103 
104 /// An adapter around a raw Socket that directly handles buffering and
105 /// conversion from/to wayland messages
106 pub struct BufferedSocket {
107     socket: Socket,
108     in_data: Buffer<u32>,
109     in_fds: Buffer<RawFd>,
110     out_data: Buffer<u32>,
111     out_fds: Buffer<RawFd>,
112 }
113 
114 impl BufferedSocket {
115     /// Wrap a Socket into a Buffered Socket
new(socket: Socket) -> BufferedSocket116     pub fn new(socket: Socket) -> BufferedSocket {
117         BufferedSocket {
118             socket,
119             in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), // Incoming buffers are twice as big in order to be
120             in_fds: Buffer::new(2 * MAX_FDS_OUT),        // able to store leftover data if needed
121             out_data: Buffer::new(MAX_BYTES_OUT / 4),
122             out_fds: Buffer::new(MAX_FDS_OUT),
123         }
124     }
125 
126     /// Get direct access to the underlying socket
get_socket(&mut self) -> &mut Socket127     pub fn get_socket(&mut self) -> &mut Socket {
128         &mut self.socket
129     }
130 
131     /// Retrieve ownership of the underlying Socket
132     ///
133     /// Any leftover content in the internal buffers will be lost
into_socket(self) -> Socket134     pub fn into_socket(self) -> Socket {
135         self.socket
136     }
137 
138     /// Flush the contents of the outgoing buffer into the socket
flush(&mut self) -> NixResult<()>139     pub fn flush(&mut self) -> NixResult<()> {
140         {
141             let words = self.out_data.get_contents();
142             let bytes = unsafe { ::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4) };
143             let fds = self.out_fds.get_contents();
144             self.socket.send_msg(bytes, fds)?;
145             for &fd in fds {
146                 // once the fds are sent, we can close them
147                 let _ = ::nix::unistd::close(fd);
148             }
149         }
150         self.out_data.clear();
151         self.out_fds.clear();
152         Ok(())
153     }
154 
155     // internal method
156     //
157     // attempts to write a message in the internal out buffers,
158     // returns true if successful
159     //
160     // if false is returned, it means there is not enough space
161     // in the buffer
attempt_write_message(&mut self, msg: &Message) -> NixResult<bool>162     fn attempt_write_message(&mut self, msg: &Message) -> NixResult<bool> {
163         match msg.write_to_buffers(
164             self.out_data.get_writable_storage(),
165             self.out_fds.get_writable_storage(),
166         ) {
167             Ok((bytes_out, fds_out)) => {
168                 self.out_data.advance(bytes_out);
169                 self.out_fds.advance(fds_out);
170                 Ok(true)
171             }
172             Err(MessageWriteError::BufferTooSmall) => Ok(false),
173             Err(MessageWriteError::DupFdFailed(e)) => Err(e),
174         }
175     }
176 
177     /// Write a message to the outgoing buffer
178     ///
179     /// This method may flush the internal buffer if necessary (if it is full).
180     ///
181     /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
182     /// will be returned.
write_message(&mut self, msg: &Message) -> NixResult<()>183     pub fn write_message(&mut self, msg: &Message) -> NixResult<()> {
184         if !self.attempt_write_message(msg)? {
185             // the attempt failed, there is not enough space in the buffer
186             // we need to flush it
187             self.flush()?;
188             if !self.attempt_write_message(msg)? {
189                 // If this fails again, this means the message is too big
190                 // to be transmitted at all
191                 return Err(::nix::Error::Sys(::nix::errno::Errno::E2BIG));
192             }
193         }
194         Ok(())
195     }
196 
197     /// Try to fill the incoming buffers of this socket, to prepare
198     /// a new round of parsing.
fill_incoming_buffers(&mut self) -> NixResult<()>199     pub fn fill_incoming_buffers(&mut self) -> NixResult<()> {
200         // clear the buffers if they have no content
201         if !self.in_data.has_content() {
202             self.in_data.clear();
203         }
204         if !self.in_fds.has_content() {
205             self.in_fds.clear();
206         }
207         // receive a message
208         let (in_bytes, in_fds) = {
209             let words = self.in_data.get_writable_storage();
210             let bytes =
211                 unsafe { ::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4) };
212             let fds = self.in_fds.get_writable_storage();
213             self.socket.rcv_msg(bytes, fds)?
214         };
215         if in_bytes == 0 {
216             // the other end of the socket was closed
217             return Err(::nix::Error::Sys(::nix::errno::Errno::EPIPE));
218         }
219         // advance the storage
220         self.in_data
221             .advance(in_bytes / 4 + if in_bytes % 4 > 0 { 1 } else { 0 });
222         self.in_fds.advance(in_fds);
223         Ok(())
224     }
225 
226     /// Read and deserialize a single message from the incoming buffers socket
227     ///
228     /// This method requires one closure that given an object id and an opcode,
229     /// must provide the signature of the associated request/event, in the form of
230     /// a `&'static [ArgumentType]`. If it returns `None`, meaning that
231     /// the couple object/opcode does not exist, an error will be returned.
232     ///
233     /// There are 3 possibilities of return value:
234     ///
235     /// - `Ok(Ok(msg))`: no error occurred, this is the message
236     /// - `Ok(Err(e))`: either a malformed message was encountered or we need more data,
237     ///    in the latter case you need to try calling `fill_incoming_buffers()`.
238     /// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
239     ///   (this can be a "wouldblock" error, which just means that no message is available
240     ///   to read)
read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError> where F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,241     pub fn read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError>
242     where
243         F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
244     {
245         let (msg, read_data, read_fd) = {
246             let data = self.in_data.get_contents();
247             let fds = self.in_fds.get_contents();
248             if data.len() < 2 {
249                 return Err(MessageParseError::MissingData);
250             }
251             let object_id = data[0];
252             let opcode = (data[1] & 0x0000_FFFF) as u16;
253             if let Some(sig) = signature(object_id, opcode) {
254                 match Message::from_raw(data, sig, fds) {
255                     Ok((msg, rest_data, rest_fds)) => {
256                         (msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
257                     }
258                     // TODO: gracefully handle wayland messages split across unix messages ?
259                     Err(e) => return Err(e),
260                 }
261             } else {
262                 // no signature found ?
263                 return Err(MessageParseError::Malformed);
264             }
265         };
266 
267         self.in_data.offset(read_data);
268         self.in_fds.offset(read_fd);
269 
270         Ok(msg)
271     }
272 
273     /// Read and deserialize messages from the socket
274     ///
275     /// This method requires two closures:
276     ///
277     /// - The first one, given an object id and an opcode, must provide
278     ///   the signature of the associated request/event, in the form of
279     ///   a `&'static [ArgumentType]`. If it returns `None`, meaning that
280     ///   the couple object/opcode does not exist, the parsing will be
281     ///   prematurely interrupted and this method will return a
282     ///   `MessageParseError::Malformed` error.
283     /// - The second closure is charged to process the parsed message. If it
284     ///   returns `false`, the iteration will be prematurely stopped.
285     ///
286     /// In both cases of early stopping, the remaining unused data will be left
287     /// in the buffers, and will start to be processed at the next call of this
288     /// method.
289     ///
290     /// There are 3 possibilities of return value:
291     ///
292     /// - `Ok(Ok(n))`: no error occurred, `n` messages where processed
293     /// - `Ok(Err(MessageParseError::Malformed))`: a malformed message was encountered
294     ///   (this is a protocol error and is supposed to be fatal to the connection).
295     /// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
296     ///   (this can be a "wouldblock" error, which just means that no message is available
297     ///   to read)
read_messages<F1, F2>( &mut self, mut signature: F1, mut callback: F2, ) -> NixResult<Result<usize, MessageParseError>> where F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>, F2: FnMut(Message) -> bool,298     pub fn read_messages<F1, F2>(
299         &mut self,
300         mut signature: F1,
301         mut callback: F2,
302     ) -> NixResult<Result<usize, MessageParseError>>
303     where
304         F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
305         F2: FnMut(Message) -> bool,
306     {
307         // message parsing
308         let mut dispatched = 0;
309 
310         loop {
311             let mut err = None;
312             // first parse any leftover messages
313             loop {
314                 match self.read_one_message(&mut signature) {
315                     Ok(msg) => {
316                         let keep_going = callback(msg);
317                         dispatched += 1;
318                         if !keep_going {
319                             break;
320                         }
321                     }
322                     Err(e) => {
323                         err = Some(e);
324                         break;
325                     }
326                 }
327             }
328 
329             // copy back any leftover content to the front of the buffer
330             self.in_data.move_to_front();
331             self.in_fds.move_to_front();
332 
333             if let Some(MessageParseError::Malformed) = err {
334                 // early stop here
335                 return Ok(Err(MessageParseError::Malformed));
336             }
337 
338             if err.is_none() && self.in_data.has_content() {
339                 // we stopped reading without error while there is content? That means
340                 // the user requested an early stopping
341                 return Ok(Ok(dispatched));
342             }
343 
344             // now, try to get more data
345             match self.fill_incoming_buffers() {
346                 Ok(()) => (),
347                 Err(e @ ::nix::Error::Sys(::nix::errno::Errno::EAGAIN)) => {
348                     // stop looping, returning Ok() or EAGAIN depending on whether messages
349                     // were dispatched
350                     if dispatched == 0 {
351                         return Err(e);
352                     } else {
353                         break;
354                     }
355                 }
356                 Err(e) => return Err(e),
357             }
358         }
359 
360         Ok(Ok(dispatched))
361     }
362 }
363 
364 /*
365  * Buffer
366  */
367 
368 struct Buffer<T: Copy> {
369     storage: Vec<T>,
370     occupied: usize,
371     offset: usize,
372 }
373 
374 impl<T: Copy + Default> Buffer<T> {
new(size: usize) -> Buffer<T>375     fn new(size: usize) -> Buffer<T> {
376         Buffer {
377             storage: vec![T::default(); size],
378             occupied: 0,
379             offset: 0,
380         }
381     }
382 
383     /// Check if this buffer has content to read
has_content(&self) -> bool384     fn has_content(&self) -> bool {
385         self.occupied > self.offset
386     }
387 
388     /// Advance the internal counter of occupied space
advance(&mut self, bytes: usize)389     fn advance(&mut self, bytes: usize) {
390         self.occupied += bytes;
391     }
392 
393     /// Advance the read offset of current occupied space
offset(&mut self, bytes: usize)394     fn offset(&mut self, bytes: usize) {
395         self.offset += bytes;
396     }
397 
398     /// Clears the contents of the buffer
399     ///
400     /// This only sets the counter of occupied space back to zero,
401     /// allowing previous content to be overwritten.
clear(&mut self)402     fn clear(&mut self) {
403         self.occupied = 0;
404         self.offset = 0;
405     }
406 
407     /// Get the current contents of the occupied space of the buffer
get_contents(&self) -> &[T]408     fn get_contents(&self) -> &[T] {
409         &self.storage[(self.offset)..(self.occupied)]
410     }
411 
412     /// Get mutable access to the unoccupied space of the buffer
get_writable_storage(&mut self) -> &mut [T]413     fn get_writable_storage(&mut self) -> &mut [T] {
414         &mut self.storage[(self.occupied)..]
415     }
416 
417     /// Move the unread contents of the buffer to the front, to ensure
418     /// maximal write space availability
move_to_front(&mut self)419     fn move_to_front(&mut self) {
420         unsafe {
421             ::std::ptr::copy(
422                 &self.storage[self.offset] as *const T,
423                 &mut self.storage[0] as *mut T,
424                 self.occupied - self.offset,
425             );
426         }
427         self.occupied -= self.offset;
428         self.offset = 0;
429     }
430 }
431 
432 #[cfg(test)]
433 mod tests {
434     use super::*;
435     use wire::{Argument, ArgumentType, Message};
436 
437     use std::ffi::CString;
438 
same_file(a: RawFd, b: RawFd) -> bool439     fn same_file(a: RawFd, b: RawFd) -> bool {
440         let stat1 = ::nix::sys::stat::fstat(a).unwrap();
441         let stat2 = ::nix::sys::stat::fstat(b).unwrap();
442         stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
443     }
444 
445     // check if two messages are equal
446     //
447     // if arguments contain FDs, check that the fd point to
448     // the same file, rather than are the same number.
assert_eq_msgs(msg1: &Message, msg2: &Message)449     fn assert_eq_msgs(msg1: &Message, msg2: &Message) {
450         assert_eq!(msg1.sender_id, msg2.sender_id);
451         assert_eq!(msg1.opcode, msg2.opcode);
452         assert_eq!(msg1.args.len(), msg2.args.len());
453         for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
454             if let (&Argument::Fd(fd1), &Argument::Fd(fd2)) = (arg1, arg2) {
455                 assert!(same_file(fd1, fd2));
456             } else {
457                 assert_eq!(arg1, arg2);
458             }
459         }
460     }
461 
462     #[test]
write_read_cycle()463     fn write_read_cycle() {
464         let msg = Message {
465             sender_id: 42,
466             opcode: 7,
467             args: vec![
468                 Argument::Uint(3),
469                 Argument::Fixed(-89),
470                 Argument::Str(CString::new(&b"I like trains!"[..]).unwrap()),
471                 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]),
472                 Argument::Object(88),
473                 Argument::NewId(56),
474                 Argument::Int(-25),
475             ],
476         };
477 
478         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
479         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
480         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
481 
482         client.write_message(&msg).unwrap();
483         client.flush().unwrap();
484 
485         static SIGNATURE: &'static [ArgumentType] = &[
486             ArgumentType::Uint,
487             ArgumentType::Fixed,
488             ArgumentType::Str,
489             ArgumentType::Array,
490             ArgumentType::Object,
491             ArgumentType::NewId,
492             ArgumentType::Int,
493         ];
494 
495         let ret = server
496             .read_messages(
497                 |sender_id, opcode| {
498                     if sender_id == 42 && opcode == 7 {
499                         Some(SIGNATURE)
500                     } else {
501                         None
502                     }
503                 },
504                 |message| {
505                     assert_eq_msgs(&message, &msg);
506                     true
507                 },
508             )
509             .unwrap()
510             .unwrap();
511 
512         assert_eq!(ret, 1);
513     }
514 
515     #[test]
write_read_cycle_fd()516     fn write_read_cycle_fd() {
517         let msg = Message {
518             sender_id: 42,
519             opcode: 7,
520             args: vec![
521                 Argument::Fd(1), // stdin
522                 Argument::Fd(0), // stdout
523             ],
524         };
525 
526         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
527         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
528         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
529 
530         client.write_message(&msg).unwrap();
531         client.flush().unwrap();
532 
533         static SIGNATURE: &'static [ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
534 
535         let ret = server
536             .read_messages(
537                 |sender_id, opcode| {
538                     if sender_id == 42 && opcode == 7 {
539                         Some(SIGNATURE)
540                     } else {
541                         None
542                     }
543                 },
544                 |message| {
545                     assert_eq_msgs(&message, &msg);
546                     true
547                 },
548             )
549             .unwrap()
550             .unwrap();
551 
552         assert_eq!(ret, 1);
553     }
554 
555     #[test]
write_read_cycle_multiple()556     fn write_read_cycle_multiple() {
557         let messages = [
558             Message {
559                 sender_id: 42,
560                 opcode: 0,
561                 args: vec![
562                     Argument::Int(42),
563                     Argument::Str(CString::new(&b"I like trains"[..]).unwrap()),
564                 ],
565             },
566             Message {
567                 sender_id: 42,
568                 opcode: 1,
569                 args: vec![
570                     Argument::Fd(1), // stdin
571                     Argument::Fd(0), // stdout
572                 ],
573             },
574             Message {
575                 sender_id: 42,
576                 opcode: 2,
577                 args: vec![
578                     Argument::Uint(3),
579                     Argument::Fd(2), // stderr
580                 ],
581             },
582         ];
583 
584         static SIGNATURES: &'static [&'static [ArgumentType]] = &[
585             &[ArgumentType::Int, ArgumentType::Str],
586             &[ArgumentType::Fd, ArgumentType::Fd],
587             &[ArgumentType::Uint, ArgumentType::Fd],
588         ];
589 
590         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
591         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
592         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
593 
594         for msg in &messages {
595             client.write_message(msg).unwrap();
596         }
597         client.flush().unwrap();
598 
599         let mut recv_msgs = Vec::new();
600         let ret = server
601             .read_messages(
602                 |sender_id, opcode| {
603                     if sender_id == 42 {
604                         Some(SIGNATURES[opcode as usize])
605                     } else {
606                         None
607                     }
608                 },
609                 |message| {
610                     recv_msgs.push(message);
611                     true
612                 },
613             )
614             .unwrap()
615             .unwrap();
616 
617         assert_eq!(ret, 3);
618         assert_eq!(recv_msgs.len(), 3);
619         for (msg1, msg2) in messages.iter().zip(recv_msgs.iter()) {
620             assert_eq_msgs(msg1, msg2);
621         }
622     }
623 
624     #[test]
parse_with_string_len_multiple_of_4()625     fn parse_with_string_len_multiple_of_4() {
626         let msg = Message {
627             sender_id: 2,
628             opcode: 0,
629             args: vec![
630                 Argument::Uint(18),
631                 Argument::Str(CString::new(&b"wl_shell"[..]).unwrap()),
632                 Argument::Uint(1),
633             ],
634         };
635 
636         let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
637         let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
638         let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
639 
640         client.write_message(&msg).unwrap();
641         client.flush().unwrap();
642 
643         static SIGNATURE: &'static [ArgumentType] =
644             &[ArgumentType::Uint, ArgumentType::Str, ArgumentType::Uint];
645 
646         let ret = server
647             .read_messages(
648                 |sender_id, opcode| {
649                     if sender_id == 2 && opcode == 0 {
650                         Some(SIGNATURE)
651                     } else {
652                         None
653                     }
654                 },
655                 |message| {
656                     assert_eq_msgs(&message, &msg);
657                     true
658                 },
659             )
660             .unwrap()
661             .unwrap();
662 
663         assert_eq!(ret, 1);
664     }
665 }
666