1 // Copyright (c) 2021-present, Gregory Szorc
2 // All rights reserved.
3 //
4 // This software may be modified and distributed under the terms
5 // of the BSD license. See the LICENSE file for details.
6
7 use {
8 crate::exceptions::ZstdError,
9 pyo3::{
10 buffer::PyBuffer,
11 class::{PyBufferProtocol, PySequenceProtocol},
12 exceptions::{PyIndexError, PyTypeError, PyValueError},
13 ffi::Py_buffer,
14 prelude::*,
15 types::{PyBytes, PyTuple},
16 AsPyPointer,
17 },
18 };
19
20 #[repr(C)]
21 #[derive(Clone, Debug)]
22 pub(crate) struct BufferSegment {
23 pub offset: u64,
24 pub length: u64,
25 }
26
27 #[pyclass(module = "zstandard.backend_rust", name = "BufferSegment")]
28 pub struct ZstdBufferSegment {
29 /// The object backing storage. For reference counting.
30 _parent: PyObject,
31 /// PyBuffer into parent object.
32 buffer: PyBuffer<u8>,
33 /// Offset of segment within data.
34 offset: usize,
35 /// Length of segment within data.
36 len: usize,
37 }
38
39 impl ZstdBufferSegment {
as_slice(&self) -> &[u8]40 pub fn as_slice(&self) -> &[u8] {
41 unsafe {
42 std::slice::from_raw_parts(self.buffer.buf_ptr().add(self.offset) as *const _, self.len)
43 }
44 }
45 }
46
47 #[pymethods]
48 impl ZstdBufferSegment {
49 #[getter]
offset(&self) -> usize50 fn offset(&self) -> usize {
51 self.offset
52 }
53
tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes>54 fn tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> {
55 Ok(PyBytes::new(py, self.as_slice()))
56 }
57 }
58
59 #[pyproto]
60 impl PySequenceProtocol for ZstdBufferSegment {
__len__(&self) -> usize61 fn __len__(&self) -> usize {
62 self.len
63 }
64 }
65
66 #[pyproto]
67 impl PyBufferProtocol for ZstdBufferSegment {
bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()>68 fn bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()> {
69 let slice = slf.as_slice();
70
71 if unsafe {
72 pyo3::ffi::PyBuffer_FillInfo(
73 view,
74 slf.as_ptr(),
75 slice.as_ptr() as *mut _,
76 slice.len() as _,
77 1,
78 flags,
79 )
80 } != 0
81 {
82 Err(PyErr::fetch(slf.py()))
83 } else {
84 Ok(())
85 }
86 }
87
88 #[allow(unused_variables)]
bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer)89 fn bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer) {}
90 }
91
92 #[pyclass(module = "zstandard.backend_rust", name = "BufferSegments")]
93 pub struct ZstdBufferSegments {
94 parent: PyObject,
95 }
96
97 #[pyproto]
98 impl PyBufferProtocol for ZstdBufferSegments {
bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()>99 fn bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()> {
100 let py = slf.py();
101
102 let parent: &PyCell<ZstdBufferWithSegments> = slf.parent.extract(py)?;
103
104 if unsafe {
105 pyo3::ffi::PyBuffer_FillInfo(
106 view,
107 slf.as_ptr(),
108 parent.borrow().segments.as_ptr() as *const _ as *mut _,
109 (parent.borrow().segments.len() * std::mem::size_of::<BufferSegment>()) as isize,
110 1,
111 flags,
112 )
113 } != 0
114 {
115 Err(PyErr::fetch(py))
116 } else {
117 Ok(())
118 }
119 }
120
121 #[allow(unused_variables)]
bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer)122 fn bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer) {}
123 }
124
125 #[pyclass(module = "zstandard.backend_rust", name = "BufferWithSegments")]
126 pub struct ZstdBufferWithSegments {
127 source: PyObject,
128 pub(crate) buffer: PyBuffer<u8>,
129 pub(crate) segments: Vec<BufferSegment>,
130 }
131
132 impl ZstdBufferWithSegments {
as_slice(&self) -> &[u8]133 fn as_slice(&self) -> &[u8] {
134 unsafe {
135 std::slice::from_raw_parts(self.buffer.buf_ptr() as *const _, self.buffer.len_bytes())
136 }
137 }
138
get_segment_slice<'p>(&self, _py: Python<'p>, i: usize) -> &'p [u8]139 pub fn get_segment_slice<'p>(&self, _py: Python<'p>, i: usize) -> &'p [u8] {
140 let segment = &self.segments[i];
141
142 unsafe {
143 std::slice::from_raw_parts(
144 self.buffer.buf_ptr().add(segment.offset as usize) as *const _,
145 segment.length as usize,
146 )
147 }
148 }
149 }
150
151 #[pymethods]
152 impl ZstdBufferWithSegments {
153 #[new]
new(py: Python, data: &PyAny, segments: PyBuffer<u8>) -> PyResult<Self>154 pub fn new(py: Python, data: &PyAny, segments: PyBuffer<u8>) -> PyResult<Self> {
155 let data_buffer = PyBuffer::get(data)?;
156
157 if segments.len_bytes() % std::mem::size_of::<BufferSegment>() != 0 {
158 return Err(PyValueError::new_err(format!(
159 "segments array size is not a multiple of {}",
160 std::mem::size_of::<BufferSegment>()
161 )));
162 }
163
164 let segments_slice: &[BufferSegment] = unsafe {
165 std::slice::from_raw_parts(
166 segments.buf_ptr() as *const _,
167 segments.len_bytes() / std::mem::size_of::<BufferSegment>(),
168 )
169 };
170
171 // Make a copy of the segments data. It is cheap to do so and is a
172 // guard against caller changing offsets, which has security implications.
173 let segments = segments_slice.to_vec();
174
175 // Validate segments data, as blindly trusting it could lead to
176 // arbitrary memory access.
177 for segment in &segments {
178 if segment.offset + segment.length > data_buffer.len_bytes() as _ {
179 return Err(PyValueError::new_err(
180 "offset within segments array references memory outside buffer",
181 ));
182 }
183 }
184
185 Ok(Self {
186 source: data.into_py(py),
187 buffer: data_buffer,
188 segments,
189 })
190 }
191
192 #[getter]
size(&self) -> usize193 fn size(&self) -> usize {
194 self.buffer.len_bytes()
195 }
196
segments(slf: PyRef<Self>, py: Python) -> PyResult<ZstdBufferSegments>197 fn segments(slf: PyRef<Self>, py: Python) -> PyResult<ZstdBufferSegments> {
198 Ok(ZstdBufferSegments {
199 // TODO surely there is a better way to cast self to PyObject?
200 parent: unsafe { Py::from_borrowed_ptr(py, slf.as_ptr()) },
201 })
202 }
203
tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes>204 fn tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> {
205 Ok(PyBytes::new(py, self.as_slice()))
206 }
207 }
208
209 #[pyproto]
210 impl PySequenceProtocol for ZstdBufferWithSegments {
__len__(&self) -> usize211 fn __len__(&self) -> usize {
212 self.segments.len()
213 }
214
__getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment>215 fn __getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment> {
216 let py = unsafe { Python::assume_gil_acquired() };
217
218 if key < 0 {
219 return Err(PyIndexError::new_err("offset must be non-negative"));
220 }
221
222 let key = key as usize;
223
224 if key >= self.segments.len() {
225 return Err(PyIndexError::new_err(format!(
226 "offset must be less than {}",
227 self.segments.len()
228 )));
229 }
230
231 let segment = &self.segments[key];
232
233 Ok(ZstdBufferSegment {
234 _parent: self.source.clone_ref(py),
235 buffer: PyBuffer::get(self.source.extract(py)?)?,
236 offset: segment.offset as _,
237 len: segment.length as _,
238 })
239 }
240 }
241
242 #[pyproto]
243 impl PyBufferProtocol for ZstdBufferWithSegments {
bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()>244 fn bf_getbuffer(slf: PyRefMut<Self>, view: *mut Py_buffer, flags: i32) -> PyResult<()> {
245 if unsafe {
246 pyo3::ffi::PyBuffer_FillInfo(
247 view,
248 slf.as_ptr(),
249 slf.buffer.buf_ptr(),
250 slf.buffer.len_bytes() as _,
251 1,
252 flags,
253 )
254 } != 0
255 {
256 Err(PyErr::fetch(slf.py()))
257 } else {
258 Ok(())
259 }
260 }
261
262 #[allow(unused_variables)]
bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer)263 fn bf_releasebuffer(slf: PyRefMut<Self>, view: *mut Py_buffer) {}
264 }
265
266 #[pyclass(
267 module = "zstandard.backend_rust",
268 name = "BufferWithSegmentsCollection"
269 )]
270 pub struct ZstdBufferWithSegmentsCollection {
271 // Py<ZstdBufferWithSegments>.
272 pub(crate) buffers: Vec<PyObject>,
273 first_elements: Vec<usize>,
274 }
275
276 #[pymethods]
277 impl ZstdBufferWithSegmentsCollection {
278 #[new]
279 #[args(py_args = "*")]
new(py: Python, py_args: &PyTuple) -> PyResult<Self>280 pub fn new(py: Python, py_args: &PyTuple) -> PyResult<Self> {
281 if py_args.is_empty() {
282 return Err(PyValueError::new_err("must pass at least 1 argument"));
283 }
284
285 let mut buffers = Vec::with_capacity(py_args.len());
286 let mut first_elements = Vec::with_capacity(py_args.len());
287 let mut offset = 0;
288
289 for item in py_args {
290 let item: &PyCell<ZstdBufferWithSegments> = item.extract().map_err(|_| {
291 PyTypeError::new_err("arguments must be BufferWithSegments instances")
292 })?;
293 let segment = item.borrow();
294
295 if segment.segments.is_empty() || segment.buffer.len_bytes() == 0 {
296 return Err(PyValueError::new_err(
297 "ZstdBufferWithSegments cannot be empty",
298 ));
299 }
300
301 offset += segment.segments.len();
302
303 buffers.push(item.to_object(py));
304 first_elements.push(offset);
305 }
306
307 Ok(Self {
308 buffers,
309 first_elements,
310 })
311 }
312
size(&self, py: Python) -> PyResult<usize>313 fn size(&self, py: Python) -> PyResult<usize> {
314 let mut size = 0;
315
316 for buffer in &self.buffers {
317 let item: &PyCell<ZstdBufferWithSegments> = buffer.extract(py)?;
318
319 for segment in &item.borrow().segments {
320 size += segment.length as usize;
321 }
322 }
323
324 Ok(size)
325 }
326 }
327
328 #[pyproto]
329 impl PySequenceProtocol for ZstdBufferWithSegmentsCollection {
__len__(&self) -> usize330 fn __len__(&self) -> usize {
331 self.first_elements.last().unwrap().clone()
332 }
333
__getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment>334 fn __getitem__(&self, key: isize) -> PyResult<ZstdBufferSegment> {
335 let py = unsafe { Python::assume_gil_acquired() };
336
337 if key < 0 {
338 return Err(PyIndexError::new_err("offset must be non-negative"));
339 }
340
341 let key = key as usize;
342
343 if key >= self.__len__() {
344 return Err(PyIndexError::new_err(format!(
345 "offset must be less than {}",
346 self.__len__()
347 )));
348 }
349
350 let mut offset = 0;
351 for (buffer_index, segment) in self.buffers.iter().enumerate() {
352 if key < self.first_elements[buffer_index] {
353 if buffer_index > 0 {
354 offset = self.first_elements[buffer_index - 1];
355 }
356
357 let item: &PyCell<ZstdBufferWithSegments> = segment.extract(py)?;
358
359 return item.borrow().__getitem__((key - offset) as isize);
360 }
361 }
362
363 Err(ZstdError::new_err(
364 "error resolving segment; this should not happen",
365 ))
366 }
367 }
368
init_module(module: &PyModule) -> PyResult<()>369 pub(crate) fn init_module(module: &PyModule) -> PyResult<()> {
370 module.add_class::<ZstdBufferSegment>()?;
371 module.add_class::<ZstdBufferSegments>()?;
372 module.add_class::<ZstdBufferWithSegments>()?;
373 module.add_class::<ZstdBufferWithSegmentsCollection>()?;
374
375 Ok(())
376 }
377