1 //! Helper for implementing `RequestConnection::extension_information()`.
2 
3 use std::collections::{hash_map::Entry as HashMapEntry, HashMap};
4 
5 use crate::connection::{RequestConnection, SequenceNumber};
6 use crate::cookie::Cookie;
7 use crate::errors::{ConnectionError, ReplyError};
8 use crate::protocol::xproto::{ConnectionExt, QueryExtensionReply};
9 use crate::x11_utils::{ExtInfoProvider, ExtensionInformation};
10 
11 /// Helper for implementing `RequestConnection::extension_information()`.
12 ///
13 /// This helps with implementing `RequestConnection`. Most likely, you do not need this in your own
14 /// code, unless you really want to implement your own X11 connection.
15 #[derive(Debug, Default)]
16 pub struct ExtensionManager(HashMap<&'static str, CheckState>);
17 
18 #[derive(Debug)]
19 enum CheckState {
20     Prefetched(SequenceNumber),
21     Present(ExtensionInformation),
22     Missing,
23     Error,
24 }
25 
26 impl ExtensionManager {
27     /// If the extension has not prefetched yet, sends a `QueryExtension`
28     /// requests, adds a field to the hash map and returns a reference to it.
prefetch_extension_information_aux<C: RequestConnection>( &mut self, conn: &C, extension_name: &'static str, ) -> Result<&mut CheckState, ConnectionError>29     fn prefetch_extension_information_aux<C: RequestConnection>(
30         &mut self,
31         conn: &C,
32         extension_name: &'static str,
33     ) -> Result<&mut CheckState, ConnectionError> {
34         match self.0.entry(extension_name) {
35             // Extension already checked, return the cached value
36             HashMapEntry::Occupied(entry) => Ok(entry.into_mut()),
37             HashMapEntry::Vacant(entry) => {
38                 let cookie = conn.query_extension(extension_name.as_bytes())?;
39                 Ok(entry.insert(CheckState::Prefetched(cookie.into_sequence_number())))
40             }
41         }
42     }
43 
44     /// Prefetchs an extension sending a `QueryExtension` without waiting for
45     /// the reply.
prefetch_extension_information<C: RequestConnection>( &mut self, conn: &C, extension_name: &'static str, ) -> Result<(), ConnectionError>46     pub fn prefetch_extension_information<C: RequestConnection>(
47         &mut self,
48         conn: &C,
49         extension_name: &'static str,
50     ) -> Result<(), ConnectionError> {
51         // We are not interested on the reference to the entry.
52         let _ = self.prefetch_extension_information_aux(conn, extension_name)?;
53         Ok(())
54     }
55 
56     /// An implementation of `RequestConnection::extension_information()`.
57     ///
58     /// The given connection is used for sending a `QueryExtension` request if needed.
extension_information<C: RequestConnection>( &mut self, conn: &C, extension_name: &'static str, ) -> Result<Option<ExtensionInformation>, ConnectionError>59     pub fn extension_information<C: RequestConnection>(
60         &mut self,
61         conn: &C,
62         extension_name: &'static str,
63     ) -> Result<Option<ExtensionInformation>, ConnectionError> {
64         let entry = self.prefetch_extension_information_aux(conn, extension_name)?;
65         match entry {
66             CheckState::Prefetched(sequence_number) => {
67                 match Cookie::<C, QueryExtensionReply>::new(conn, *sequence_number).reply() {
68                     Err(err) => {
69                         *entry = CheckState::Error;
70                         match err {
71                             ReplyError::ConnectionError(e) => Err(e),
72                             // The X11 protocol specification does not specify any error
73                             // for the QueryExtension request, so this should not happen.
74                             ReplyError::X11Error(_) => Err(ConnectionError::UnknownError),
75                         }
76                     }
77                     Ok(info) => {
78                         if info.present {
79                             let info = ExtensionInformation {
80                                 major_opcode: info.major_opcode,
81                                 first_event: info.first_event,
82                                 first_error: info.first_error,
83                             };
84                             *entry = CheckState::Present(info);
85                             Ok(Some(info))
86                         } else {
87                             *entry = CheckState::Missing;
88                             Ok(None)
89                         }
90                     }
91                 }
92             }
93             CheckState::Present(info) => Ok(Some(*info)),
94             CheckState::Missing => Ok(None),
95             CheckState::Error => Err(ConnectionError::UnknownError),
96         }
97     }
98 }
99 
100 impl ExtInfoProvider for ExtensionManager {
get_from_major_opcode(&self, major_opcode: u8) -> Option<(&str, ExtensionInformation)>101     fn get_from_major_opcode(&self, major_opcode: u8) -> Option<(&str, ExtensionInformation)> {
102         self.0
103             .iter()
104             .filter_map(|(name, state)| {
105                 if let CheckState::Present(info) = state {
106                     Some((*name, *info))
107                 } else {
108                     None
109                 }
110             })
111             .find(|(_, info)| info.major_opcode == major_opcode)
112     }
113 
get_from_event_code(&self, event_code: u8) -> Option<(&str, ExtensionInformation)>114     fn get_from_event_code(&self, event_code: u8) -> Option<(&str, ExtensionInformation)> {
115         self.0
116             .iter()
117             .filter_map(|(name, state)| {
118                 if let CheckState::Present(info) = state {
119                     if info.first_event <= event_code {
120                         Some((*name, *info))
121                     } else {
122                         None
123                     }
124                 } else {
125                     None
126                 }
127             })
128             .max_by_key(|(_, info)| info.first_event)
129     }
130 
get_from_error_code(&self, error_code: u8) -> Option<(&str, ExtensionInformation)>131     fn get_from_error_code(&self, error_code: u8) -> Option<(&str, ExtensionInformation)> {
132         self.0
133             .iter()
134             .filter_map(|(name, state)| {
135                 if let CheckState::Present(info) = state {
136                     if info.first_error <= error_code {
137                         Some((*name, *info))
138                     } else {
139                         None
140                     }
141                 } else {
142                     None
143                 }
144             })
145             .max_by_key(|(_, info)| info.first_error)
146     }
147 }
148 
149 #[cfg(test)]
150 mod test {
151     use std::cell::RefCell;
152     use std::io::IoSlice;
153 
154     use crate::connection::{
155         BufWithFds, DiscardMode, ReplyOrError, RequestConnection, RequestKind, SequenceNumber,
156     };
157     use crate::cookie::{Cookie, CookieWithFds, VoidCookie};
158     use crate::errors::{ConnectionError, ParseError};
159     use crate::utils::RawFdContainer;
160     use crate::x11_utils::{ExtInfoProvider, ExtensionInformation, TryParse, TryParseFd};
161 
162     use super::{CheckState, ExtensionManager};
163 
164     struct FakeConnection(RefCell<SequenceNumber>);
165 
166     impl RequestConnection for FakeConnection {
167         type Buf = Vec<u8>;
168 
send_request_with_reply<R>( &self, _bufs: &[IoSlice<'_>], _fds: Vec<RawFdContainer>, ) -> Result<Cookie<'_, Self, R>, ConnectionError> where R: TryParse,169         fn send_request_with_reply<R>(
170             &self,
171             _bufs: &[IoSlice<'_>],
172             _fds: Vec<RawFdContainer>,
173         ) -> Result<Cookie<'_, Self, R>, ConnectionError>
174         where
175             R: TryParse,
176         {
177             Ok(Cookie::new(self, 1))
178         }
179 
send_request_with_reply_with_fds<R>( &self, _bufs: &[IoSlice<'_>], _fds: Vec<RawFdContainer>, ) -> Result<CookieWithFds<'_, Self, R>, ConnectionError> where R: TryParseFd,180         fn send_request_with_reply_with_fds<R>(
181             &self,
182             _bufs: &[IoSlice<'_>],
183             _fds: Vec<RawFdContainer>,
184         ) -> Result<CookieWithFds<'_, Self, R>, ConnectionError>
185         where
186             R: TryParseFd,
187         {
188             unimplemented!()
189         }
190 
send_request_without_reply( &self, _bufs: &[IoSlice<'_>], _fds: Vec<RawFdContainer>, ) -> Result<VoidCookie<'_, Self>, ConnectionError>191         fn send_request_without_reply(
192             &self,
193             _bufs: &[IoSlice<'_>],
194             _fds: Vec<RawFdContainer>,
195         ) -> Result<VoidCookie<'_, Self>, ConnectionError> {
196             unimplemented!()
197         }
198 
discard_reply(&self, _sequence: SequenceNumber, _kind: RequestKind, _mode: DiscardMode)199         fn discard_reply(&self, _sequence: SequenceNumber, _kind: RequestKind, _mode: DiscardMode) {
200             unimplemented!()
201         }
202 
prefetch_extension_information( &self, _extension_name: &'static str, ) -> Result<(), ConnectionError>203         fn prefetch_extension_information(
204             &self,
205             _extension_name: &'static str,
206         ) -> Result<(), ConnectionError> {
207             unimplemented!();
208         }
209 
extension_information( &self, _extension_name: &'static str, ) -> Result<Option<ExtensionInformation>, ConnectionError>210         fn extension_information(
211             &self,
212             _extension_name: &'static str,
213         ) -> Result<Option<ExtensionInformation>, ConnectionError> {
214             unimplemented!()
215         }
216 
wait_for_reply_or_raw_error( &self, sequence: SequenceNumber, ) -> Result<ReplyOrError<Vec<u8>>, ConnectionError>217         fn wait_for_reply_or_raw_error(
218             &self,
219             sequence: SequenceNumber,
220         ) -> Result<ReplyOrError<Vec<u8>>, ConnectionError> {
221             // Code should only ask once for the reply to a request. Check that this is the case
222             // (by requiring monotonically increasing sequence numbers here).
223             let mut last = self.0.borrow_mut();
224             assert!(
225                 *last < sequence,
226                 "Last sequence number that was awaited was {}, but now {}",
227                 *last,
228                 sequence
229             );
230             *last = sequence;
231             // Then return an error, because that's what the #[test] below needs.
232             Err(ConnectionError::UnknownError)
233         }
234 
wait_for_reply( &self, _sequence: SequenceNumber, ) -> Result<Option<Vec<u8>>, ConnectionError>235         fn wait_for_reply(
236             &self,
237             _sequence: SequenceNumber,
238         ) -> Result<Option<Vec<u8>>, ConnectionError> {
239             unimplemented!()
240         }
241 
wait_for_reply_with_fds_raw( &self, _sequence: SequenceNumber, ) -> Result<ReplyOrError<BufWithFds<Vec<u8>>, Vec<u8>>, ConnectionError>242         fn wait_for_reply_with_fds_raw(
243             &self,
244             _sequence: SequenceNumber,
245         ) -> Result<ReplyOrError<BufWithFds<Vec<u8>>, Vec<u8>>, ConnectionError> {
246             unimplemented!()
247         }
248 
check_for_raw_error( &self, _sequence: SequenceNumber, ) -> Result<Option<Vec<u8>>, ConnectionError>249         fn check_for_raw_error(
250             &self,
251             _sequence: SequenceNumber,
252         ) -> Result<Option<Vec<u8>>, ConnectionError> {
253             unimplemented!()
254         }
255 
maximum_request_bytes(&self) -> usize256         fn maximum_request_bytes(&self) -> usize {
257             0
258         }
259 
prefetch_maximum_request_bytes(&self)260         fn prefetch_maximum_request_bytes(&self) {
261             unimplemented!()
262         }
263 
parse_error(&self, _error: &[u8]) -> Result<crate::x11_utils::X11Error, ParseError>264         fn parse_error(&self, _error: &[u8]) -> Result<crate::x11_utils::X11Error, ParseError> {
265             unimplemented!()
266         }
267 
parse_event(&self, _event: &[u8]) -> Result<crate::protocol::Event, ParseError>268         fn parse_event(&self, _event: &[u8]) -> Result<crate::protocol::Event, ParseError> {
269             unimplemented!()
270         }
271     }
272 
273     #[test]
test_double_await()274     fn test_double_await() {
275         let conn = FakeConnection(RefCell::new(0));
276         let mut ext_info = ExtensionManager::default();
277 
278         // Ask for an extension info. FakeConnection will return an error.
279         match ext_info.extension_information(&conn, "whatever") {
280             Err(ConnectionError::UnknownError) => {}
281             r => panic!("Unexpected result: {:?}", r),
282         }
283 
284         // Ask again for the extension information. ExtensionInformation should not try to get the
285         // reply again, because that would just hang. Once upon a time, this caused a hang.
286         match ext_info.extension_information(&conn, "whatever") {
287             Err(ConnectionError::UnknownError) => {}
288             r => panic!("Unexpected result: {:?}", r),
289         }
290     }
291 
292     #[test]
test_info_provider()293     fn test_info_provider() {
294         let info = ExtensionInformation {
295             major_opcode: 4,
296             first_event: 5,
297             first_error: 6,
298         };
299 
300         let mut ext_info = ExtensionManager::default();
301         let _ = ext_info.0.insert("prefetched", CheckState::Prefetched(42));
302         let _ = ext_info.0.insert("present", CheckState::Present(info));
303         let _ = ext_info.0.insert("missing", CheckState::Missing);
304         let _ = ext_info.0.insert("error", CheckState::Error);
305 
306         assert_eq!(ext_info.get_from_major_opcode(4), Some(("present", info)));
307         assert_eq!(ext_info.get_from_event_code(5), Some(("present", info)));
308         assert_eq!(ext_info.get_from_error_code(6), Some(("present", info)));
309     }
310 }
311