1 use ffi;
2 use libc;
3 use super::Connection;
4 
5 use std::mem;
6 use std::sync::{Mutex, RwLock};
7 use std::os::unix::io::{RawFd, AsRawFd};
8 use std::os::raw::{c_void, c_uint};
9 
10 /// A file descriptor to watch for incoming events (for async I/O).
11 ///
12 /// # Example
13 /// ```
14 /// extern crate libc;
15 /// extern crate dbus;
16 /// fn main() {
17 ///     use dbus::{Connection, BusType, WatchEvent};
18 ///     let c = Connection::get_private(BusType::Session).unwrap();
19 ///
20 ///     // Get a list of fds to poll for
21 ///     let mut fds: Vec<_> = c.watch_fds().iter().map(|w| w.to_pollfd()).collect();
22 ///
23 ///     // Poll them with a 1 s timeout
24 ///     let r = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::c_ulong, 1000) };
25 ///     assert!(r >= 0);
26 ///
27 ///     // And handle incoming events
28 ///     for pfd in fds.iter().filter(|pfd| pfd.revents != 0) {
29 ///         for item in c.watch_handle(pfd.fd, WatchEvent::from_revents(pfd.revents)) {
30 ///             // Handle item
31 ///             println!("Received ConnectionItem: {:?}", item);
32 ///         }
33 ///     }
34 /// }
35 /// ```
36 
37 #[repr(C)]
38 #[derive(Debug, PartialEq, Copy, Clone)]
39 /// The enum is here for backwards compatibility mostly.
40 ///
41 /// It should really be bitflags instead.
42 pub enum WatchEvent {
43     /// The fd is readable
44     Readable = ffi::DBUS_WATCH_READABLE as isize,
45     /// The fd is writable
46     Writable = ffi::DBUS_WATCH_WRITABLE as isize,
47     /// An error occured on the fd
48     Error = ffi::DBUS_WATCH_ERROR as isize,
49     /// The fd received a hangup.
50     Hangup = ffi::DBUS_WATCH_HANGUP as isize,
51 }
52 
53 impl WatchEvent {
54     /// After running poll, this transforms the revents into a parameter you can send into `Connection::watch_handle`
from_revents(revents: libc::c_short) -> c_uint55     pub fn from_revents(revents: libc::c_short) -> c_uint {
56         0 +
57         if (revents & libc::POLLIN) != 0 { WatchEvent::Readable as c_uint } else { 0 } +
58         if (revents & libc::POLLOUT) != 0 { WatchEvent::Writable as c_uint } else { 0 } +
59         if (revents & libc::POLLERR) != 0 { WatchEvent::Error as c_uint } else { 0 } +
60         if (revents & libc::POLLHUP) != 0 { WatchEvent::Hangup as c_uint } else { 0 }
61     }
62 }
63 
64 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
65 /// A file descriptor, and an indication whether it should be read from, written to, or both.
66 pub struct Watch {
67     fd: RawFd,
68     read: bool,
69     write: bool,
70 }
71 
72 impl Watch {
73     /// Get the RawFd this Watch is for
fd(&self) -> RawFd74     pub fn fd(&self) -> RawFd { self.fd }
75     /// Add POLLIN to events to listen for
readable(&self) -> bool76     pub fn readable(&self) -> bool { self.read }
77     /// Add POLLOUT to events to listen for
writable(&self) -> bool78     pub fn writable(&self) -> bool { self.write }
79     /// Returns the current watch as a libc::pollfd, to use with libc::poll
to_pollfd(&self) -> libc::pollfd80     pub fn to_pollfd(&self) -> libc::pollfd {
81         libc::pollfd { fd: self.fd, revents: 0, events: libc::POLLERR + libc::POLLHUP +
82             if self.readable() { libc::POLLIN } else { 0 } +
83             if self.writable() { libc::POLLOUT } else { 0 },
84         }
85     }
86 /*
87     pub (crate) unsafe fn from_raw(watch: *mut ffi::DBusWatch) -> Self {
88         let mut w = Watch { fd: ffi::dbus_watch_get_unix_fd(watch), read: false, write: false};
89         let enabled = ffi::dbus_watch_get_enabled(watch) != 0;
90         if enabled {
91             let flags = ffi::dbus_watch_get_flags(watch);
92             w.read = (flags & WatchEvent::Readable as c_uint) != 0;
93             w.write = (flags & WatchEvent::Writable as c_uint) != 0;
94         }
95         w
96     }
97 */
98 }
99 
100 impl AsRawFd for Watch {
as_raw_fd(&self) -> RawFd101     fn as_raw_fd(&self) -> RawFd { self.fd }
102 }
103 
104 /// Note - internal struct, not to be used outside API. Moving it outside its box will break things.
105 pub struct WatchList {
106     watches: RwLock<Vec<*mut ffi::DBusWatch>>,
107     enabled_fds: Mutex<Vec<Watch>>,
108     on_update: Mutex<Box<Fn(Watch) + Send>>,
109 }
110 
111 impl WatchList {
new(c: &Connection, on_update: Box<Fn(Watch) + Send>) -> Box<WatchList>112     pub fn new(c: &Connection, on_update: Box<Fn(Watch) + Send>) -> Box<WatchList> {
113         let w = Box::new(WatchList { on_update: Mutex::new(on_update), watches: RwLock::new(vec!()), enabled_fds: Mutex::new(vec!()) });
114         if unsafe { ffi::dbus_connection_set_watch_functions(super::connection::conn_handle(c),
115             Some(add_watch_cb), Some(remove_watch_cb), Some(toggled_watch_cb), &*w as *const _ as *mut _, None) } == 0 {
116             panic!("dbus_connection_set_watch_functions failed");
117         }
118         w
119     }
120 
set_on_update(&self, on_update: Box<Fn(Watch) + Send>)121     pub fn set_on_update(&self, on_update: Box<Fn(Watch) + Send>) { *self.on_update.lock().unwrap() = on_update; }
122 
watch_handle(&self, fd: RawFd, flags: c_uint)123     pub fn watch_handle(&self, fd: RawFd, flags: c_uint) {
124         // println!("watch_handle {} flags {}", fd, flags);
125         for &q in self.watches.read().unwrap().iter() {
126             let w = self.get_watch(q);
127             if w.fd != fd { continue };
128             if unsafe { ffi::dbus_watch_handle(q, flags) } == 0 {
129                 panic!("dbus_watch_handle failed");
130             }
131             self.update(q);
132         };
133     }
134 
get_enabled_fds(&self) -> Vec<Watch>135     pub fn get_enabled_fds(&self) -> Vec<Watch> {
136         self.enabled_fds.lock().unwrap().clone()
137     }
138 
get_watch(&self, watch: *mut ffi::DBusWatch) -> Watch139     fn get_watch(&self, watch: *mut ffi::DBusWatch) -> Watch {
140         let mut w = Watch { fd: unsafe { ffi::dbus_watch_get_unix_fd(watch) }, read: false, write: false};
141         let enabled = self.watches.read().unwrap().contains(&watch) && unsafe { ffi::dbus_watch_get_enabled(watch) != 0 };
142         let flags = unsafe { ffi::dbus_watch_get_flags(watch) };
143         if enabled {
144             w.read = (flags & WatchEvent::Readable as c_uint) != 0;
145             w.write = (flags & WatchEvent::Writable as c_uint) != 0;
146         }
147         // println!("Get watch fd {:?} ptr {:?} enabled {:?} flags {:?}", w, watch, enabled, flags);
148         w
149     }
150 
update(&self, watch: *mut ffi::DBusWatch)151     fn update(&self, watch: *mut ffi::DBusWatch) {
152         let mut w = self.get_watch(watch);
153 
154         for &q in self.watches.read().unwrap().iter() {
155             if q == watch { continue };
156             let ww = self.get_watch(q);
157             if ww.fd != w.fd { continue };
158             w.read |= ww.read;
159             w.write |= ww.write;
160         }
161         // println!("Updated sum: {:?}", w);
162 
163         {
164             let mut fdarr = self.enabled_fds.lock().unwrap();
165 
166             if w.write || w.read {
167                 if fdarr.contains(&w) { return; } // Nothing changed
168             }
169             else if !fdarr.iter().any(|q| w.fd == q.fd) { return; } // Nothing changed
170 
171             fdarr.retain(|f| f.fd != w.fd);
172             if w.write || w.read { fdarr.push(w) };
173         }
174         let func = self.on_update.lock().unwrap();
175         (*func)(w);
176     }
177 }
178 
add_watch_cb(watch: *mut ffi::DBusWatch, data: *mut c_void) -> u32179 extern "C" fn add_watch_cb(watch: *mut ffi::DBusWatch, data: *mut c_void) -> u32 {
180     let wlist: &WatchList = unsafe { mem::transmute(data) };
181     // println!("Add watch {:?}", watch);
182     wlist.watches.write().unwrap().push(watch);
183     wlist.update(watch);
184     1
185 }
186 
remove_watch_cb(watch: *mut ffi::DBusWatch, data: *mut c_void)187 extern "C" fn remove_watch_cb(watch: *mut ffi::DBusWatch, data: *mut c_void) {
188     let wlist: &WatchList = unsafe { mem::transmute(data) };
189     // println!("Removed watch {:?}", watch);
190     wlist.watches.write().unwrap().retain(|w| *w != watch);
191     wlist.update(watch);
192 }
193 
toggled_watch_cb(watch: *mut ffi::DBusWatch, data: *mut c_void)194 extern "C" fn toggled_watch_cb(watch: *mut ffi::DBusWatch, data: *mut c_void) {
195     let wlist: &WatchList = unsafe { mem::transmute(data) };
196     // println!("Toggled watch {:?}", watch);
197     wlist.update(watch);
198 }
199 
200 #[cfg(test)]
201 mod test {
202     use libc;
203     use super::super::{Connection, Message, BusType, WatchEvent, ConnectionItem, MessageType};
204 
205     #[test]
async()206     fn async() {
207         let c = Connection::get_private(BusType::Session).unwrap();
208         c.register_object_path("/test").unwrap();
209         let m = Message::new_method_call(&c.unique_name(), "/test", "com.example.asynctest", "AsyncTest").unwrap();
210         let serial = c.send(m).unwrap();
211         println!("Async: sent serial {}", serial);
212 
213         let mut fds: Vec<_> = c.watch_fds().iter().map(|w| w.to_pollfd()).collect();
214         let mut new_fds = None;
215         let mut i = 0;
216         let mut success = false;
217         while !success {
218             i += 1;
219             if let Some(q) = new_fds { fds = q; new_fds = None };
220 
221             for f in fds.iter_mut() { f.revents = 0 };
222 
223             assert!(unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, 1000) } > 0);
224 
225             for f in fds.iter().filter(|pfd| pfd.revents != 0) {
226                 let m = WatchEvent::from_revents(f.revents);
227                 println!("Async: fd {}, revents {} -> {}", f.fd, f.revents, m);
228                 assert!(f.revents & libc::POLLIN != 0 || f.revents & libc::POLLOUT != 0);
229 
230                 for e in c.watch_handle(f.fd, m) {
231                     println!("Async: got {:?}", e);
232                     match e {
233                         ConnectionItem::MethodCall(m) => {
234                             assert_eq!(m.headers(), (MessageType::MethodCall, Some("/test".to_string()),
235                                 Some("com.example.asynctest".into()), Some("AsyncTest".to_string())));
236                             let mut mr = Message::new_method_return(&m).unwrap();
237                             mr.append_items(&["Goodies".into()]);
238                             c.send(mr).unwrap();
239                         }
240                         ConnectionItem::MethodReturn(m) => {
241                             assert_eq!(m.headers().0, MessageType::MethodReturn);
242                             assert_eq!(m.get_reply_serial().unwrap(), serial);
243                             let i = m.get_items();
244                             let s: &str = i[0].inner().unwrap();
245                             assert_eq!(s, "Goodies");
246                             success = true;
247                         }
248                         _ => (),
249                     }
250                 }
251                 if i > 100 { panic!() };
252             }
253         }
254     }
255 }
256