1 // Take a look at the license at the top of the repository in the LICENSE file.
2 
3 use crate::prelude::*;
4 use crate::subclass::prelude::*;
5 use crate::OutputStream;
6 
7 use std::any::Any;
8 use std::io::{Seek, Write};
9 
10 use crate::read_input_stream::std_error_to_gio_error;
11 
12 mod imp {
13     use super::*;
14     use std::cell::RefCell;
15 
16     pub(super) enum Writer {
17         Write(AnyWriter),
18         WriteSeek(AnyWriter),
19     }
20 
21     #[derive(Default)]
22     pub struct WriteOutputStream {
23         pub(super) write: RefCell<Option<Writer>>,
24     }
25 
26     #[glib::object_subclass]
27     impl ObjectSubclass for WriteOutputStream {
28         const NAME: &'static str = "WriteOutputStream";
29         type Type = super::WriteOutputStream;
30         type ParentType = OutputStream;
31         type Interfaces = (crate::Seekable,);
32     }
33 
34     impl ObjectImpl for WriteOutputStream {}
35 
36     impl OutputStreamImpl for WriteOutputStream {
write( &self, _stream: &Self::Type, buffer: &[u8], _cancellable: Option<&crate::Cancellable>, ) -> Result<usize, glib::Error>37         fn write(
38             &self,
39             _stream: &Self::Type,
40             buffer: &[u8],
41             _cancellable: Option<&crate::Cancellable>,
42         ) -> Result<usize, glib::Error> {
43             let mut write = self.write.borrow_mut();
44             let write = match *write {
45                 None => {
46                     return Err(glib::Error::new(
47                         crate::IOErrorEnum::Closed,
48                         "Alwritey closed",
49                     ));
50                 }
51                 Some(Writer::Write(ref mut write)) => write,
52                 Some(Writer::WriteSeek(ref mut write)) => write,
53             };
54 
55             loop {
56                 match std_error_to_gio_error(write.write(buffer)) {
57                     None => continue,
58                     Some(res) => return res,
59                 }
60             }
61         }
62 
close( &self, _stream: &Self::Type, _cancellable: Option<&crate::Cancellable>, ) -> Result<(), glib::Error>63         fn close(
64             &self,
65             _stream: &Self::Type,
66             _cancellable: Option<&crate::Cancellable>,
67         ) -> Result<(), glib::Error> {
68             let _ = self.write.borrow_mut().take();
69             Ok(())
70         }
71 
flush( &self, _stream: &Self::Type, _cancellable: Option<&crate::Cancellable>, ) -> Result<(), glib::Error>72         fn flush(
73             &self,
74             _stream: &Self::Type,
75             _cancellable: Option<&crate::Cancellable>,
76         ) -> Result<(), glib::Error> {
77             let mut write = self.write.borrow_mut();
78             let write = match *write {
79                 None => {
80                     return Err(glib::Error::new(
81                         crate::IOErrorEnum::Closed,
82                         "Alwritey closed",
83                     ));
84                 }
85                 Some(Writer::Write(ref mut write)) => write,
86                 Some(Writer::WriteSeek(ref mut write)) => write,
87             };
88 
89             loop {
90                 match std_error_to_gio_error(write.flush()) {
91                     None => continue,
92                     Some(res) => return res,
93                 }
94             }
95         }
96     }
97 
98     impl SeekableImpl for WriteOutputStream {
tell(&self, _seekable: &Self::Type) -> i6499         fn tell(&self, _seekable: &Self::Type) -> i64 {
100             // XXX: stream_position is not stable yet
101             // let mut write = self.write.borrow_mut();
102             // match *write {
103             //     Some(Writer::WriteSeek(ref mut write)) => {
104             //         write.stream_position().map(|pos| pos as i64).unwrap_or(-1)
105             //     },
106             //     _ => -1,
107             // };
108             -1
109         }
110 
can_seek(&self, _seekable: &Self::Type) -> bool111         fn can_seek(&self, _seekable: &Self::Type) -> bool {
112             let write = self.write.borrow();
113             matches!(*write, Some(Writer::WriteSeek(_)))
114         }
115 
seek( &self, _seekable: &Self::Type, offset: i64, type_: glib::SeekType, _cancellable: Option<&crate::Cancellable>, ) -> Result<(), glib::Error>116         fn seek(
117             &self,
118             _seekable: &Self::Type,
119             offset: i64,
120             type_: glib::SeekType,
121             _cancellable: Option<&crate::Cancellable>,
122         ) -> Result<(), glib::Error> {
123             use std::io::SeekFrom;
124 
125             let mut write = self.write.borrow_mut();
126             match *write {
127                 Some(Writer::WriteSeek(ref mut write)) => {
128                     let pos = match type_ {
129                         glib::SeekType::Cur => SeekFrom::Current(offset),
130                         glib::SeekType::Set => {
131                             if offset < 0 {
132                                 return Err(glib::Error::new(
133                                     crate::IOErrorEnum::InvalidArgument,
134                                     "Invalid Argument",
135                                 ));
136                             } else {
137                                 SeekFrom::Start(offset as u64)
138                             }
139                         }
140                         glib::SeekType::End => SeekFrom::End(offset),
141                         _ => unimplemented!(),
142                     };
143 
144                     loop {
145                         match std_error_to_gio_error(write.seek(pos)) {
146                             None => continue,
147                             Some(res) => return res.map(|_| ()),
148                         }
149                     }
150                 }
151                 _ => Err(glib::Error::new(
152                     crate::IOErrorEnum::NotSupported,
153                     "Truncating not supported",
154                 )),
155             }
156         }
157 
can_truncate(&self, _seekable: &Self::Type) -> bool158         fn can_truncate(&self, _seekable: &Self::Type) -> bool {
159             false
160         }
161 
truncate( &self, _seekable: &Self::Type, _offset: i64, _cancellable: Option<&crate::Cancellable>, ) -> Result<(), glib::Error>162         fn truncate(
163             &self,
164             _seekable: &Self::Type,
165             _offset: i64,
166             _cancellable: Option<&crate::Cancellable>,
167         ) -> Result<(), glib::Error> {
168             Err(glib::Error::new(
169                 crate::IOErrorEnum::NotSupported,
170                 "Truncating not supported",
171             ))
172         }
173     }
174 }
175 
176 glib::wrapper! {
177     pub struct WriteOutputStream(ObjectSubclass<imp::WriteOutputStream>) @extends crate::OutputStream, @implements crate::Seekable;
178 }
179 
180 impl WriteOutputStream {
new<W: Write + Send + Any + 'static>(write: W) -> WriteOutputStream181     pub fn new<W: Write + Send + Any + 'static>(write: W) -> WriteOutputStream {
182         let obj = glib::Object::new(&[]).expect("Failed to create write input stream");
183 
184         let imp = imp::WriteOutputStream::from_instance(&obj);
185         *imp.write.borrow_mut() = Some(imp::Writer::Write(AnyWriter::new(write)));
186         obj
187     }
188 
new_seekable<W: Write + Seek + Send + Any + 'static>(write: W) -> WriteOutputStream189     pub fn new_seekable<W: Write + Seek + Send + Any + 'static>(write: W) -> WriteOutputStream {
190         let obj = glib::Object::new(&[]).expect("Failed to create write input stream");
191 
192         let imp = imp::WriteOutputStream::from_instance(&obj);
193         *imp.write.borrow_mut() = Some(imp::Writer::WriteSeek(AnyWriter::new_seekable(write)));
194         obj
195     }
196 
close_and_take(&self) -> Box<dyn Any + Send + 'static>197     pub fn close_and_take(&self) -> Box<dyn Any + Send + 'static> {
198         let imp = imp::WriteOutputStream::from_instance(self);
199         let inner = imp.write.borrow_mut().take();
200 
201         let ret = match inner {
202             None => {
203                 panic!("Stream already closed or inner taken");
204             }
205             Some(imp::Writer::Write(write)) => write.writer,
206             Some(imp::Writer::WriteSeek(write)) => write.writer,
207         };
208 
209         let _ = self.close(crate::NONE_CANCELLABLE);
210 
211         match ret {
212             AnyOrPanic::Any(w) => w,
213             AnyOrPanic::Panic(p) => std::panic::resume_unwind(p),
214         }
215     }
216 }
217 
218 enum AnyOrPanic {
219     Any(Box<dyn Any + Send + 'static>),
220     Panic(Box<dyn Any + Send + 'static>),
221 }
222 
223 // Helper struct for dynamically dispatching to any kind of Writer and
224 // catching panics along the way
225 struct AnyWriter {
226     writer: AnyOrPanic,
227     write_fn: fn(s: &mut AnyWriter, buffer: &[u8]) -> std::io::Result<usize>,
228     flush_fn: fn(s: &mut AnyWriter) -> std::io::Result<()>,
229     seek_fn: Option<fn(s: &mut AnyWriter, pos: std::io::SeekFrom) -> std::io::Result<u64>>,
230 }
231 
232 impl AnyWriter {
new<W: Write + Any + Send + 'static>(w: W) -> Self233     fn new<W: Write + Any + Send + 'static>(w: W) -> Self {
234         Self {
235             writer: AnyOrPanic::Any(Box::new(w)),
236             write_fn: Self::write_fn::<W>,
237             flush_fn: Self::flush_fn::<W>,
238             seek_fn: None,
239         }
240     }
241 
new_seekable<W: Write + Seek + Any + Send + 'static>(w: W) -> Self242     fn new_seekable<W: Write + Seek + Any + Send + 'static>(w: W) -> Self {
243         Self {
244             writer: AnyOrPanic::Any(Box::new(w)),
245             write_fn: Self::write_fn::<W>,
246             flush_fn: Self::flush_fn::<W>,
247             seek_fn: Some(Self::seek_fn::<W>),
248         }
249     }
250 
write_fn<W: Write + 'static>(s: &mut AnyWriter, buffer: &[u8]) -> std::io::Result<usize>251     fn write_fn<W: Write + 'static>(s: &mut AnyWriter, buffer: &[u8]) -> std::io::Result<usize> {
252         s.with_inner(|w: &mut W| w.write(buffer))
253     }
254 
flush_fn<W: Write + 'static>(s: &mut AnyWriter) -> std::io::Result<()>255     fn flush_fn<W: Write + 'static>(s: &mut AnyWriter) -> std::io::Result<()> {
256         s.with_inner(|w: &mut W| w.flush())
257     }
258 
seek_fn<W: Seek + 'static>( s: &mut AnyWriter, pos: std::io::SeekFrom, ) -> std::io::Result<u64>259     fn seek_fn<W: Seek + 'static>(
260         s: &mut AnyWriter,
261         pos: std::io::SeekFrom,
262     ) -> std::io::Result<u64> {
263         s.with_inner(|w: &mut W| w.seek(pos))
264     }
265 
with_inner<W: 'static, T, F: FnOnce(&mut W) -> std::io::Result<T>>( &mut self, func: F, ) -> std::io::Result<T>266     fn with_inner<W: 'static, T, F: FnOnce(&mut W) -> std::io::Result<T>>(
267         &mut self,
268         func: F,
269     ) -> std::io::Result<T> {
270         match self.writer {
271             AnyOrPanic::Any(ref mut writer) => {
272                 let w = writer.downcast_mut::<W>().unwrap();
273                 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(w))) {
274                     Ok(res) => res,
275                     Err(panic) => {
276                         self.writer = AnyOrPanic::Panic(panic);
277                         Err(std::io::Error::new(std::io::ErrorKind::Other, "Panicked"))
278                     }
279                 }
280             }
281             AnyOrPanic::Panic(_) => Err(std::io::Error::new(
282                 std::io::ErrorKind::Other,
283                 "Panicked before",
284             )),
285         }
286     }
287 
write(&mut self, buffer: &[u8]) -> std::io::Result<usize>288     fn write(&mut self, buffer: &[u8]) -> std::io::Result<usize> {
289         (self.write_fn)(self, buffer)
290     }
291 
flush(&mut self) -> std::io::Result<()>292     fn flush(&mut self) -> std::io::Result<()> {
293         (self.flush_fn)(self)
294     }
295 
seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64>296     fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
297         if let Some(ref seek_fn) = self.seek_fn {
298             seek_fn(self, pos)
299         } else {
300             unreachable!()
301         }
302     }
303 }
304 
305 #[cfg(test)]
306 mod tests {
307     use super::*;
308     use std::io::Cursor;
309 
310     #[test]
test_write()311     fn test_write() {
312         let cursor = Cursor::new(vec![]);
313         let stream = WriteOutputStream::new(cursor);
314 
315         assert_eq!(
316             stream.write(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], crate::NONE_CANCELLABLE),
317             Ok(10)
318         );
319 
320         let inner = stream.close_and_take();
321         assert!(inner.is::<Cursor<Vec<u8>>>());
322         let inner = inner.downcast_ref::<Cursor<Vec<u8>>>().unwrap();
323         assert_eq!(inner.get_ref(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
324     }
325 
326     #[test]
test_write_seek()327     fn test_write_seek() {
328         let cursor = Cursor::new(vec![]);
329         let stream = WriteOutputStream::new_seekable(cursor);
330 
331         assert_eq!(
332             stream.write(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], crate::NONE_CANCELLABLE),
333             Ok(10)
334         );
335 
336         assert!(stream.can_seek());
337         assert_eq!(
338             stream.seek(0, glib::SeekType::Set, crate::NONE_CANCELLABLE),
339             Ok(())
340         );
341 
342         assert_eq!(
343             stream.write(
344                 &[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
345                 crate::NONE_CANCELLABLE
346             ),
347             Ok(10)
348         );
349 
350         let inner = stream.close_and_take();
351         assert!(inner.is::<Cursor<Vec<u8>>>());
352         let inner = inner.downcast_ref::<Cursor<Vec<u8>>>().unwrap();
353         assert_eq!(inner.get_ref(), &[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]);
354     }
355 }
356