1 use std::sync::Arc;
2 
3 use x11rb::connection::Connection as _;
4 use x11rb::protocol::xproto::{
5     ClientMessageData, ClientMessageEvent, ConnectionExt as _, EventMask, CLIENT_MESSAGE_EVENT,
6 };
7 
8 // Regression test for https://github.com/psychon/x11rb/issues/231
9 #[test]
multithread_test()10 fn multithread_test() {
11     let conn = fake_stream::connect().unwrap();
12     let conn = Arc::new(conn);
13 
14     // Auxiliary thread: send requests and wait for replies
15     let conn1 = conn.clone();
16     let join = std::thread::spawn(move || {
17         // Bug #231 sometimes caused `reply` to hang forever.
18         // Send a huge amount of requests and wait for the reply
19         // to check if it hangs at some point.
20         for i in 1..=1_000_000 {
21             let cookie = conn1.get_input_focus().unwrap();
22             cookie.reply().unwrap();
23 
24             if (i % 50_000) == 0 {
25                 eprintln!("{}", i);
26             }
27         }
28         eprintln!("all replies received successfully");
29 
30         let event = ClientMessageEvent {
31             response_type: CLIENT_MESSAGE_EVENT,
32             format: 32,
33             sequence: 0,
34             window: 0,
35             // Just anything, we don't care
36             type_: 1,
37             data: ClientMessageData::from([0, 0, 0, 0, 0]),
38         };
39 
40         conn1
41             .send_event(false, 0u32, EventMask::NO_EVENT, &event)
42             .unwrap();
43         conn1.flush().unwrap();
44     });
45 
46     // Main thread: wait for events until finished
47     loop {
48         let event = conn.wait_for_raw_event().unwrap();
49         if event[0] == CLIENT_MESSAGE_EVENT {
50             break;
51         }
52     }
53 
54     join.join().unwrap();
55 }
56 
57 /// Implementations of `Read` and `Write` that do enough for the test to work.
58 mod fake_stream {
59     use std::io::{Error, ErrorKind};
60     use std::sync::mpsc::{channel, Receiver, Sender};
61     use std::sync::{Condvar, Mutex};
62 
63     use x11rb::connection::SequenceNumber;
64     use x11rb::errors::ConnectError;
65     use x11rb::protocol::xproto::{
66         ImageOrder, Setup, CLIENT_MESSAGE_EVENT, GET_INPUT_FOCUS_REQUEST, SEND_EVENT_REQUEST,
67     };
68     use x11rb::rust_connection::{PollMode, RustConnection, Stream};
69     use x11rb::utils::RawFdContainer;
70 
71     /// Create a new `RustConnection` connected to a fake stream
connect() -> Result<RustConnection<FakeStream>, ConnectError>72     pub(crate) fn connect() -> Result<RustConnection<FakeStream>, ConnectError> {
73         let setup = Setup {
74             status: 0,
75             protocol_major_version: 0,
76             protocol_minor_version: 0,
77             length: 0,
78             release_number: 0,
79             resource_id_base: 0,
80             resource_id_mask: 0xff,
81             motion_buffer_size: 0,
82             maximum_request_length: 0,
83             image_byte_order: ImageOrder::LSB_FIRST,
84             bitmap_format_bit_order: ImageOrder::LSB_FIRST,
85             bitmap_format_scanline_unit: 0,
86             bitmap_format_scanline_pad: 0,
87             min_keycode: 0,
88             max_keycode: 0,
89             vendor: Vec::new(),
90             pixmap_formats: Vec::new(),
91             roots: Vec::new(),
92         };
93         let stream = fake_stream();
94         RustConnection::for_connected_stream(stream, setup)
95     }
96 
97     /// Get a pair of fake streams that are connected to each other
fake_stream() -> FakeStream98     fn fake_stream() -> FakeStream {
99         let (send, recv) = channel();
100         let pending = Vec::new();
101         FakeStream {
102             inner: Mutex::new(FakeStreamInner {
103                 read: FakeStreamRead { recv, pending },
104                 write: FakeStreamWrite {
105                     send,
106                     seqno: 0,
107                     skip: 0,
108                 },
109             }),
110             condvar: Condvar::new(),
111         }
112     }
113 
114     /// A packet that still needs to be read from FakeStreamRead
115     #[derive(Debug)]
116     enum Packet {
117         GetInputFocusReply(SequenceNumber),
118         Event,
119     }
120 
121     impl Packet {
to_raw(&self) -> Vec<u8>122         fn to_raw(&self) -> Vec<u8> {
123             match self {
124                 Packet::GetInputFocusReply(seqno) => {
125                     let seqno = (*seqno as u16).to_ne_bytes();
126                     let mut reply = vec![0; 32];
127                     reply[0] = 1; // This is a reply
128                     reply[2..4].copy_from_slice(&seqno);
129                     reply
130                 }
131                 Packet::Event => {
132                     let mut reply = vec![0; 32];
133                     reply[0] = CLIENT_MESSAGE_EVENT;
134                     reply
135                 }
136             }
137         }
138     }
139 
140     #[derive(Debug)]
141     pub(crate) struct FakeStream {
142         inner: Mutex<FakeStreamInner>,
143         condvar: Condvar,
144     }
145 
146     #[derive(Debug)]
147     struct FakeStreamInner {
148         read: FakeStreamRead,
149         write: FakeStreamWrite,
150     }
151 
152     #[derive(Debug)]
153     struct FakeStreamRead {
154         recv: Receiver<Packet>,
155         pending: Vec<u8>,
156     }
157 
158     #[derive(Debug)]
159     pub(crate) struct FakeStreamWrite {
160         send: Sender<Packet>,
161         seqno: SequenceNumber,
162         skip: usize,
163     }
164 
165     impl Stream for FakeStream {
poll(&self, mode: PollMode) -> std::io::Result<()>166         fn poll(&self, mode: PollMode) -> std::io::Result<()> {
167             if mode.writable() {
168                 Ok(())
169             } else {
170                 let mut inner = self.inner.lock().unwrap();
171                 loop {
172                     if inner.read.pending.is_empty() {
173                         match inner.read.recv.try_recv() {
174                             Ok(packet) => {
175                                 inner.read.pending.extend(packet.to_raw());
176                                 return Ok(());
177                             }
178                             Err(std::sync::mpsc::TryRecvError::Empty) => {
179                                 inner = self.condvar.wait(inner).unwrap();
180                             }
181                             Err(std::sync::mpsc::TryRecvError::Disconnected) => unreachable!(),
182                         }
183                     } else {
184                         return Ok(());
185                     }
186                 }
187             }
188         }
189 
read( &self, buf: &mut [u8], _fd_storage: &mut Vec<RawFdContainer>, ) -> std::io::Result<usize>190         fn read(
191             &self,
192             buf: &mut [u8],
193             _fd_storage: &mut Vec<RawFdContainer>,
194         ) -> std::io::Result<usize> {
195             let mut inner = self.inner.lock().unwrap();
196             if inner.read.pending.is_empty() {
197                 match inner.read.recv.try_recv() {
198                     Ok(packet) => inner.read.pending.extend(packet.to_raw()),
199                     Err(std::sync::mpsc::TryRecvError::Empty) => {
200                         return Err(Error::new(ErrorKind::WouldBlock, "Would block"));
201                     }
202                     Err(std::sync::mpsc::TryRecvError::Disconnected) => unreachable!(),
203                 }
204             }
205 
206             let len = inner.read.pending.len().min(buf.len());
207             buf[..len].copy_from_slice(&inner.read.pending[..len]);
208             inner.read.pending.drain(..len);
209             Ok(len)
210         }
211 
write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> std::io::Result<usize>212         fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> std::io::Result<usize> {
213             assert!(fds.is_empty());
214 
215             let mut inner = self.inner.lock().unwrap();
216 
217             if inner.write.skip > 0 {
218                 assert_eq!(inner.write.skip, buf.len());
219                 inner.write.skip = 0;
220                 return Ok(buf.len());
221             }
222 
223             inner.write.seqno += 1;
224             match buf[0] {
225                 GET_INPUT_FOCUS_REQUEST => inner
226                     .write
227                     .send
228                     .send(Packet::GetInputFocusReply(inner.write.seqno))
229                     .unwrap(),
230                 SEND_EVENT_REQUEST => inner.write.send.send(Packet::Event).unwrap(),
231                 _ => unimplemented!(),
232             }
233             // Compute how much of the package was not yet received
234             inner.write.skip = usize::from(u16::from_ne_bytes([buf[2], buf[3]])) * 4 - buf.len();
235 
236             self.condvar.notify_all();
237 
238             Ok(buf.len())
239         }
240     }
241 }
242