1 // Copyright (c) 2021-present, Gregory Szorc
2 // All rights reserved.
3 //
4 // This software may be modified and distributed under the terms
5 // of the BSD license. See the LICENSE file for details.
6 
7 use {
8     crate::exceptions::ZstdError,
9     pyo3::{
10         buffer::PyBuffer,
11         class::{PyBufferProtocol, PySequenceProtocol},
12         exceptions::{PyIndexError, PyTypeError, PyValueError},
13         ffi::Py_buffer,
14         prelude::*,
15         types::{PyBytes, PyTuple},
16         AsPyPointer,
17     },
18 };
19 
20 #[repr(C)]
21 #[derive(Clone, Debug)]
22 pub(crate) struct BufferSegment {
23     pub offset: u64,
24     pub length: u64,
25 }
26 
27 #[pyclass(module = "zstandard.backend_rust", name = "BufferSegment")]
28 pub struct ZstdBufferSegment {
29     /// The object backing storage. For reference counting.
30     _parent: PyObject,
31     /// PyBuffer into parent object.
32     buffer: PyBuffer<u8>,
33     /// Offset of segment within data.
34     offset: usize,
35     /// Length of segment within data.
36     len: usize,
37 }
38 
39 impl ZstdBufferSegment {
as_slice(&self) -> &[u8]40     pub fn as_slice(&self) -> &[u8] {
41         unsafe {
42             std::slice::from_raw_parts(self.buffer.buf_ptr().add(self.offset) as *const _, self.len)
43         }
44     }
45 }
46 
47 #[pymethods]
48 impl ZstdBufferSegment {
49     #[getter]
offset(&self) -> usize50     fn offset(&self) -> usize {
51         self.offset
52     }
53 
tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes>54     fn tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> {
55         Ok(PyBytes::new(py, self.as_slice()))
56     }
57 }
58 
59 #[pyproto]
60 impl PySequenceProtocol for ZstdBufferSegment {
__len__(&self) -> usize61     fn __len__(&self) -> usize {
62         self.len
63     }
64 }
65 
66 #[pyproto]
67 impl PyBufferProtocol for ZstdBufferSegment {
bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()>68     fn bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()> {
69         let slice = slf.as_slice();
70 
71         if unsafe {
72             pyo3::ffi::PyBuffer_FillInfo(
73                 view,
74                 slf.as_ptr(),
75                 slice.as_ptr() as *mut _,
76                 slice.len() as _,
77                 1,
78                 flags,
79             )
80         } != 0
81         {
82             Err(PyErr::fetch(slf.py()))
83         } else {
84             Ok(())
85         }
86     }
87 
88     #[allow(unused_variables)]
bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer)89     fn bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer) {}
90 }
91 
92 #[pyclass(module = "zstandard.backend_rust", name = "BufferSegments")]
93 pub struct ZstdBufferSegments {
94     parent: PyObject,
95 }
96 
97 #[pyproto]
98 impl PyBufferProtocol for ZstdBufferSegments {
bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()>99     fn bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()> {
100         let py = slf.py();
101 
102         let parent: &PyCell<ZstdBufferWithSegments> = slf.parent.extract(py)?;
103 
104         if unsafe {
105             pyo3::ffi::PyBuffer_FillInfo(
106                 view,
107                 slf.as_ptr(),
108                 parent.borrow().segments.as_ptr() as *const _ as *mut _,
109                 (parent.borrow().segments.len() * std::mem::size_of::<BufferSegment>()) as isize,
110                 1,
111                 flags,
112             )
113         } != 0
114         {
115             Err(PyErr::fetch(py))
116         } else {
117             Ok(())
118         }
119     }
120 
121     #[allow(unused_variables)]
bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer)122     fn bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer) {}
123 }
124 
125 #[pyclass(module = "zstandard.backend_rust", name = "BufferWithSegments")]
126 pub struct ZstdBufferWithSegments {
127     source: PyObject,
128     pub(crate) buffer: PyBuffer<u8>,
129     pub(crate) segments: Vec<BufferSegment>,
130 }
131 
132 impl ZstdBufferWithSegments {
as_slice(&self) -> &[u8]133     fn as_slice(&self) -> &[u8] {
134         unsafe {
135             std::slice::from_raw_parts(self.buffer.buf_ptr() as *const _, self.buffer.len_bytes())
136         }
137     }
138 
get_segment_slice<'p>(&self, _py: Python<'p>, i: usize) -> &'p [u8]139     pub fn get_segment_slice<'p>(&self, _py: Python<'p>, i: usize) -> &'p [u8] {
140         let segment = &self.segments[i];
141 
142         unsafe {
143             std::slice::from_raw_parts(
144                 self.buffer.buf_ptr().add(segment.offset as usize) as *const _,
145                 segment.length as usize,
146             )
147         }
148     }
149 }
150 
151 #[pymethods]
152 impl ZstdBufferWithSegments {
153     #[new]
new(py: Python, data: &PyAny, segments: PyBuffer<u8>) -> PyResult<Self>154     pub fn new(py: Python, data: &PyAny, segments: PyBuffer<u8>) -> PyResult<Self> {
155         let data_buffer = PyBuffer::get(data)?;
156 
157         if segments.len_bytes() % std::mem::size_of::<BufferSegment>() != 0 {
158             return Err(PyValueError::new_err(format!(
159                 "segments array size is not a multiple of {}",
160                 std::mem::size_of::<BufferSegment>()
161             )));
162         }
163 
164         let segments_slice: &[BufferSegment] = unsafe {
165             std::slice::from_raw_parts(
166                 segments.buf_ptr() as *const _,
167                 segments.len_bytes() / std::mem::size_of::<BufferSegment>(),
168             )
169         };
170 
171         // Make a copy of the segments data. It is cheap to do so and is a
172         // guard against caller changing offsets, which has security implications.
173         let segments = segments_slice.to_vec();
174 
175         // Validate segments data, as blindly trusting it could lead to
176         // arbitrary memory access.
177         for segment in &segments {
178             if segment.offset + segment.length > data_buffer.len_bytes() as _ {
179                 return Err(PyValueError::new_err(
180                     "offset within segments array references memory outside buffer",
181                 ));
182             }
183         }
184 
185         Ok(Self {
186             source: data.into_py(py),
187             buffer: data_buffer,
188             segments,
189         })
190     }
191 
192     #[getter]
size(&self) -> usize193     fn size(&self) -> usize {
194         self.buffer.len_bytes()
195     }
196 
segments(slf: PyRef<Self>, py: Python) -> PyResult<ZstdBufferSegments>197     fn segments(slf: PyRef<Self>, py: Python) -> PyResult<ZstdBufferSegments> {
198         Ok(ZstdBufferSegments {
199             // TODO surely there is a better way to cast self to PyObject?
200             parent: unsafe { Py::from_borrowed_ptr(py, slf.as_ptr()) },
201         })
202     }
203 
tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes>204     fn tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> {
205         Ok(PyBytes::new(py, self.as_slice()))
206     }
207 }
208 
209 #[pyproto]
210 impl PySequenceProtocol for ZstdBufferWithSegments {
__len__(&self) -> usize211     fn __len__(&self) -> usize {
212         self.segments.len()
213     }
214 
__getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment>215     fn __getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment> {
216         let py = unsafe { Python::assume_gil_acquired() };
217 
218         if key < 0 {
219             return Err(PyIndexError::new_err("offset must be non-negative"));
220         }
221 
222         let key = key as usize;
223 
224         if key >= self.segments.len() {
225             return Err(PyIndexError::new_err(format!(
226                 "offset must be less than {}",
227                 self.segments.len()
228             )));
229         }
230 
231         let segment = &self.segments[key];
232 
233         Ok(ZstdBufferSegment {
234             _parent: self.source.clone_ref(py),
235             buffer: PyBuffer::get(self.source.extract(py)?)?,
236             offset: segment.offset as _,
237             len: segment.length as _,
238         })
239     }
240 }
241 
242 #[pyproto]
243 impl PyBufferProtocol for ZstdBufferWithSegments {
bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()>244     fn bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()> {
245         if unsafe {
246             pyo3::ffi::PyBuffer_FillInfo(
247                 view,
248                 slf.as_ptr(),
249                 slf.buffer.buf_ptr(),
250                 slf.buffer.len_bytes() as _,
251                 1,
252                 flags,
253             )
254         } != 0
255         {
256             Err(PyErr::fetch(slf.py()))
257         } else {
258             Ok(())
259         }
260     }
261 
262     #[allow(unused_variables)]
bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer)263     fn bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer) {}
264 }
265 
266 #[pyclass(
267     module = "zstandard.backend_rust",
268     name = "BufferWithSegmentsCollection"
269 )]
270 pub struct ZstdBufferWithSegmentsCollection {
271     // Py<ZstdBufferWithSegments>.
272     pub(crate) buffers: Vec<PyObject>,
273     first_elements: Vec<usize>,
274 }
275 
276 #[pymethods]
277 impl ZstdBufferWithSegmentsCollection {
278     #[new]
279     #[args(py_args = "*")]
new(py: Python, py_args: &PyTuple) -> PyResult<Self>280     pub fn new(py: Python, py_args: &PyTuple) -> PyResult<Self> {
281         if py_args.is_empty() {
282             return Err(PyValueError::new_err("must pass at least 1 argument"));
283         }
284 
285         let mut buffers = Vec::with_capacity(py_args.len());
286         let mut first_elements = Vec::with_capacity(py_args.len());
287         let mut offset = 0;
288 
289         for item in py_args {
290             let item: &PyCell<ZstdBufferWithSegments> = item.extract().map_err(|_| {
291                 PyTypeError::new_err("arguments must be BufferWithSegments instances")
292             })?;
293             let segment = item.borrow();
294 
295             if segment.segments.is_empty() || segment.buffer.len_bytes() == 0 {
296                 return Err(PyValueError::new_err(
297                     "ZstdBufferWithSegments cannot be empty",
298                 ));
299             }
300 
301             offset += segment.segments.len();
302 
303             buffers.push(item.to_object(py));
304             first_elements.push(offset);
305         }
306 
307         Ok(Self {
308             buffers,
309             first_elements,
310         })
311     }
312 
size(&self, py: Python) -> PyResult<usize>313     fn size(&self, py: Python) -> PyResult<usize> {
314         let mut size = 0;
315 
316         for buffer in &self.buffers {
317             let item: &PyCell<ZstdBufferWithSegments> = buffer.extract(py)?;
318 
319             for segment in &item.borrow().segments {
320                 size += segment.length as usize;
321             }
322         }
323 
324         Ok(size)
325     }
326 }
327 
328 #[pyproto]
329 impl PySequenceProtocol for ZstdBufferWithSegmentsCollection {
__len__(&self) -> usize330     fn __len__(&self) -> usize {
331         self.first_elements.last().unwrap().clone()
332     }
333 
__getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment>334     fn __getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment> {
335         let py = unsafe { Python::assume_gil_acquired() };
336 
337         if key < 0 {
338             return Err(PyIndexError::new_err("offset must be non-negative"));
339         }
340 
341         let key = key as usize;
342 
343         if key >= self.__len__() {
344             return Err(PyIndexError::new_err(format!(
345                 "offset must be less than {}",
346                 self.__len__()
347             )));
348         }
349 
350         let mut offset = 0;
351         for (buffer_index, segment) in self.buffers.iter().enumerate() {
352             if key < self.first_elements[buffer_index] {
353                 if buffer_index > 0 {
354                     offset = self.first_elements[buffer_index - 1];
355                 }
356 
357                 let item: &PyCell<ZstdBufferWithSegments> = segment.extract(py)?;
358 
359                 return item.borrow().__getitem__((key - offset) as isize);
360             }
361         }
362 
363         Err(ZstdError::new_err(
364             "error resolving segment; this should not happen",
365         ))
366     }
367 }
368 
init_module(module: &PyModule) -> PyResult<()>369 pub(crate) fn init_module(module: &PyModule) -> PyResult<()> {
370     module.add_class::<ZstdBufferSegment>()?;
371     module.add_class::<ZstdBufferSegments>()?;
372     module.add_class::<ZstdBufferWithSegments>()?;
373     module.add_class::<ZstdBufferWithSegmentsCollection>()?;
374 
375     Ok(())
376 }
377