1 /**
2  * Copyright (c) 2016-present, Gregory Szorc
3  * All rights reserved.
4  *
5  * This software may be modified and distributed under the terms
6  * of the BSD license. See the LICENSE file for details.
7  */
8 
9 #include "python-zstandard.h"
10 
11 extern PyObject *ZstdError;
12 
13 /**
14  * Ensure the ZSTD_DCtx on a decompressor is initiated and ready for a new
15  * operation.
16  */
ensure_dctx(ZstdDecompressor * decompressor,int loadDict)17 int ensure_dctx(ZstdDecompressor *decompressor, int loadDict) {
18     size_t zresult;
19 
20     ZSTD_DCtx_reset(decompressor->dctx, ZSTD_reset_session_only);
21 
22     if (decompressor->maxWindowSize) {
23         zresult = ZSTD_DCtx_setMaxWindowSize(decompressor->dctx,
24                                              decompressor->maxWindowSize);
25         if (ZSTD_isError(zresult)) {
26             PyErr_Format(ZstdError, "unable to set max window size: %s",
27                          ZSTD_getErrorName(zresult));
28             return 1;
29         }
30     }
31 
32     zresult = ZSTD_DCtx_setParameter(decompressor->dctx, ZSTD_d_format,
33                                      decompressor->format);
34     if (ZSTD_isError(zresult)) {
35         PyErr_Format(ZstdError, "unable to set decoding format: %s",
36                      ZSTD_getErrorName(zresult));
37         return 1;
38     }
39 
40     if (loadDict && decompressor->dict) {
41         if (ensure_ddict(decompressor->dict)) {
42             return 1;
43         }
44 
45         zresult =
46             ZSTD_DCtx_refDDict(decompressor->dctx, decompressor->dict->ddict);
47         if (ZSTD_isError(zresult)) {
48             PyErr_Format(ZstdError,
49                          "unable to reference prepared dictionary: %s",
50                          ZSTD_getErrorName(zresult));
51             return 1;
52         }
53     }
54 
55     return 0;
56 }
57 
Decompressor_init(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)58 static int Decompressor_init(ZstdDecompressor *self, PyObject *args,
59                              PyObject *kwargs) {
60     static char *kwlist[] = {"dict_data", "max_window_size", "format", NULL};
61 
62     PyObject *dict = NULL;
63     Py_ssize_t maxWindowSize = 0;
64     ZSTD_format_e format = ZSTD_f_zstd1;
65 
66     self->dctx = NULL;
67     self->dict = NULL;
68 
69     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OnI:ZstdDecompressor",
70                                      kwlist, &dict, &maxWindowSize, &format)) {
71         return -1;
72     }
73 
74     if (dict) {
75         if (dict == Py_None) {
76             dict = NULL;
77         }
78         else if (!PyObject_IsInstance(dict,
79                                       (PyObject *)&ZstdCompressionDictType)) {
80             PyErr_Format(PyExc_TypeError,
81                          "dict_data must be zstd.ZstdCompressionDict");
82             return -1;
83         }
84     }
85 
86     self->dctx = ZSTD_createDCtx();
87     if (!self->dctx) {
88         PyErr_NoMemory();
89         goto except;
90     }
91 
92     self->maxWindowSize = maxWindowSize;
93     self->format = format;
94 
95     if (dict) {
96         self->dict = (ZstdCompressionDict *)dict;
97         Py_INCREF(dict);
98     }
99 
100     if (ensure_dctx(self, 1)) {
101         goto except;
102     }
103 
104     return 0;
105 
106 except:
107     Py_CLEAR(self->dict);
108 
109     if (self->dctx) {
110         ZSTD_freeDCtx(self->dctx);
111         self->dctx = NULL;
112     }
113 
114     return -1;
115 }
116 
Decompressor_dealloc(ZstdDecompressor * self)117 static void Decompressor_dealloc(ZstdDecompressor *self) {
118     Py_CLEAR(self->dict);
119 
120     if (self->dctx) {
121         ZSTD_freeDCtx(self->dctx);
122         self->dctx = NULL;
123     }
124 
125     PyObject_Del(self);
126 }
127 
Decompressor_memory_size(ZstdDecompressor * self)128 static PyObject *Decompressor_memory_size(ZstdDecompressor *self) {
129     if (self->dctx) {
130         return PyLong_FromSize_t(ZSTD_sizeof_DCtx(self->dctx));
131     }
132     else {
133         PyErr_SetString(
134             ZstdError,
135             "no decompressor context found; this should never happen");
136         return NULL;
137     }
138 }
139 
Decompressor_copy_stream(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)140 static PyObject *Decompressor_copy_stream(ZstdDecompressor *self,
141                                           PyObject *args, PyObject *kwargs) {
142     static char *kwlist[] = {"ifh", "ofh", "read_size", "write_size", NULL};
143 
144     PyObject *source;
145     PyObject *dest;
146     size_t inSize = ZSTD_DStreamInSize();
147     size_t outSize = ZSTD_DStreamOutSize();
148     ZSTD_inBuffer input;
149     ZSTD_outBuffer output;
150     Py_ssize_t totalRead = 0;
151     Py_ssize_t totalWrite = 0;
152     char *readBuffer;
153     Py_ssize_t readSize;
154     PyObject *readResult = NULL;
155     PyObject *res = NULL;
156     size_t zresult = 0;
157     PyObject *writeResult;
158     PyObject *totalReadPy;
159     PyObject *totalWritePy;
160 
161     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|kk:copy_stream", kwlist,
162                                      &source, &dest, &inSize, &outSize)) {
163         return NULL;
164     }
165 
166     if (!PyObject_HasAttrString(source, "read")) {
167         PyErr_SetString(PyExc_ValueError,
168                         "first argument must have a read() method");
169         return NULL;
170     }
171 
172     if (!PyObject_HasAttrString(dest, "write")) {
173         PyErr_SetString(PyExc_ValueError,
174                         "second argument must have a write() method");
175         return NULL;
176     }
177 
178     /* Prevent free on uninitialized memory in finally. */
179     output.dst = NULL;
180 
181     if (ensure_dctx(self, 1)) {
182         res = NULL;
183         goto finally;
184     }
185 
186     output.dst = PyMem_Malloc(outSize);
187     if (!output.dst) {
188         PyErr_NoMemory();
189         res = NULL;
190         goto finally;
191     }
192     output.size = outSize;
193     output.pos = 0;
194 
195     /* Read source stream until EOF */
196     while (1) {
197         readResult = PyObject_CallMethod(source, "read", "n", inSize);
198         if (!readResult) {
199             goto finally;
200         }
201 
202         PyBytes_AsStringAndSize(readResult, &readBuffer, &readSize);
203 
204         /* If no data was read, we're at EOF. */
205         if (0 == readSize) {
206             break;
207         }
208 
209         totalRead += readSize;
210 
211         /* Send data to decompressor */
212         input.src = readBuffer;
213         input.size = readSize;
214         input.pos = 0;
215 
216         while (input.pos < input.size) {
217             Py_BEGIN_ALLOW_THREADS zresult =
218                 ZSTD_decompressStream(self->dctx, &output, &input);
219             Py_END_ALLOW_THREADS
220 
221                 if (ZSTD_isError(zresult)) {
222                 PyErr_Format(ZstdError, "zstd decompressor error: %s",
223                              ZSTD_getErrorName(zresult));
224                 res = NULL;
225                 goto finally;
226             }
227 
228             if (output.pos) {
229                 writeResult = PyObject_CallMethod(dest, "write", "y#",
230                                                   output.dst, output.pos);
231                 if (NULL == writeResult) {
232                     res = NULL;
233                     goto finally;
234                 }
235 
236                 Py_XDECREF(writeResult);
237                 totalWrite += output.pos;
238                 output.pos = 0;
239             }
240         }
241 
242         Py_CLEAR(readResult);
243     }
244 
245     /* Source stream is exhausted. Finish up. */
246 
247     totalReadPy = PyLong_FromSsize_t(totalRead);
248     totalWritePy = PyLong_FromSsize_t(totalWrite);
249     res = PyTuple_Pack(2, totalReadPy, totalWritePy);
250     Py_DECREF(totalReadPy);
251     Py_DECREF(totalWritePy);
252 
253 finally:
254     if (output.dst) {
255         PyMem_Free(output.dst);
256     }
257 
258     Py_XDECREF(readResult);
259 
260     return res;
261 }
262 
Decompressor_decompress(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)263 PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args,
264                                   PyObject *kwargs) {
265     static char *kwlist[] = {"data", "max_output_size", NULL};
266 
267     Py_buffer source;
268     Py_ssize_t maxOutputSize = 0;
269     unsigned long long decompressedSize;
270     size_t destCapacity;
271     PyObject *result = NULL;
272     size_t zresult;
273     ZSTD_outBuffer outBuffer;
274     ZSTD_inBuffer inBuffer;
275 
276     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|n:decompress", kwlist,
277                                      &source, &maxOutputSize)) {
278         return NULL;
279     }
280 
281     if (ensure_dctx(self, 1)) {
282         goto finally;
283     }
284 
285     decompressedSize = ZSTD_getFrameContentSize(source.buf, source.len);
286 
287     if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) {
288         PyErr_SetString(ZstdError,
289                         "error determining content size from frame header");
290         goto finally;
291     }
292     /* Special case of empty frame. */
293     else if (0 == decompressedSize) {
294         result = PyBytes_FromStringAndSize("", 0);
295         goto finally;
296     }
297     /* Missing content size in frame header. */
298     if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) {
299         if (0 == maxOutputSize) {
300             PyErr_SetString(ZstdError,
301                             "could not determine content size in frame header");
302             goto finally;
303         }
304 
305         result = PyBytes_FromStringAndSize(NULL, maxOutputSize);
306         destCapacity = maxOutputSize;
307         decompressedSize = 0;
308     }
309     /* Size is recorded in frame header. */
310     else {
311         assert(SIZE_MAX >= PY_SSIZE_T_MAX);
312         if (decompressedSize > PY_SSIZE_T_MAX) {
313             PyErr_SetString(
314                 ZstdError, "frame is too large to decompress on this platform");
315             goto finally;
316         }
317 
318         result = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)decompressedSize);
319         destCapacity = (size_t)decompressedSize;
320     }
321 
322     if (!result) {
323         goto finally;
324     }
325 
326     outBuffer.dst = PyBytes_AsString(result);
327     outBuffer.size = destCapacity;
328     outBuffer.pos = 0;
329 
330     inBuffer.src = source.buf;
331     inBuffer.size = source.len;
332     inBuffer.pos = 0;
333 
334     Py_BEGIN_ALLOW_THREADS zresult =
335         ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
336     Py_END_ALLOW_THREADS
337 
338         if (ZSTD_isError(zresult)) {
339         PyErr_Format(ZstdError, "decompression error: %s",
340                      ZSTD_getErrorName(zresult));
341         Py_CLEAR(result);
342         goto finally;
343     }
344     else if (zresult) {
345         PyErr_Format(ZstdError,
346                      "decompression error: did not decompress full frame");
347         Py_CLEAR(result);
348         goto finally;
349     }
350     else if (decompressedSize && outBuffer.pos != decompressedSize) {
351         PyErr_Format(
352             ZstdError,
353             "decompression error: decompressed %zu bytes; expected %llu",
354             zresult, decompressedSize);
355         Py_CLEAR(result);
356         goto finally;
357     }
358     else if (outBuffer.pos < destCapacity) {
359         if (safe_pybytes_resize(&result, outBuffer.pos)) {
360             Py_CLEAR(result);
361             goto finally;
362         }
363     }
364 
365 finally:
366     PyBuffer_Release(&source);
367     return result;
368 }
369 
Decompressor_decompressobj(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)370 static ZstdDecompressionObj *Decompressor_decompressobj(ZstdDecompressor *self,
371                                                         PyObject *args,
372                                                         PyObject *kwargs) {
373     static char *kwlist[] = {"write_size", NULL};
374 
375     ZstdDecompressionObj *result = NULL;
376     size_t outSize = ZSTD_DStreamOutSize();
377 
378     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|k:decompressobj", kwlist,
379                                      &outSize)) {
380         return NULL;
381     }
382 
383     if (!outSize) {
384         PyErr_SetString(PyExc_ValueError, "write_size must be positive");
385         return NULL;
386     }
387 
388     result = (ZstdDecompressionObj *)PyObject_CallObject(
389         (PyObject *)&ZstdDecompressionObjType, NULL);
390     if (!result) {
391         return NULL;
392     }
393 
394     if (ensure_dctx(self, 1)) {
395         Py_DECREF(result);
396         return NULL;
397     }
398 
399     result->decompressor = self;
400     Py_INCREF(result->decompressor);
401     result->outSize = outSize;
402 
403     return result;
404 }
405 
406 static ZstdDecompressorIterator *
Decompressor_read_to_iter(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)407 Decompressor_read_to_iter(ZstdDecompressor *self, PyObject *args,
408                           PyObject *kwargs) {
409     static char *kwlist[] = {"reader", "read_size", "write_size", "skip_bytes",
410                              NULL};
411 
412     PyObject *reader;
413     size_t inSize = ZSTD_DStreamInSize();
414     size_t outSize = ZSTD_DStreamOutSize();
415     ZstdDecompressorIterator *result;
416     size_t skipBytes = 0;
417 
418     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kkk:read_to_iter", kwlist,
419                                      &reader, &inSize, &outSize, &skipBytes)) {
420         return NULL;
421     }
422 
423     if (skipBytes >= inSize) {
424         PyErr_SetString(PyExc_ValueError,
425                         "skip_bytes must be smaller than read_size");
426         return NULL;
427     }
428 
429     result = (ZstdDecompressorIterator *)PyObject_CallObject(
430         (PyObject *)&ZstdDecompressorIteratorType, NULL);
431     if (!result) {
432         return NULL;
433     }
434 
435     if (PyObject_HasAttrString(reader, "read")) {
436         result->reader = reader;
437         Py_INCREF(result->reader);
438     }
439     else if (1 == PyObject_CheckBuffer(reader)) {
440         /* Object claims it is a buffer. Try to get a handle to it. */
441         if (0 != PyObject_GetBuffer(reader, &result->buffer, PyBUF_CONTIG_RO)) {
442             goto except;
443         }
444     }
445     else {
446         PyErr_SetString(PyExc_ValueError,
447                         "must pass an object with a read() method or conforms "
448                         "to buffer protocol");
449         goto except;
450     }
451 
452     result->decompressor = self;
453     Py_INCREF(result->decompressor);
454 
455     result->inSize = inSize;
456     result->outSize = outSize;
457     result->skipBytes = skipBytes;
458 
459     if (ensure_dctx(self, 1)) {
460         goto except;
461     }
462 
463     result->input.src = PyMem_Malloc(inSize);
464     if (!result->input.src) {
465         PyErr_NoMemory();
466         goto except;
467     }
468 
469     goto finally;
470 
471 except:
472     Py_CLEAR(result);
473 
474 finally:
475 
476     return result;
477 }
478 
479 static ZstdDecompressionReader *
Decompressor_stream_reader(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)480 Decompressor_stream_reader(ZstdDecompressor *self, PyObject *args,
481                            PyObject *kwargs) {
482     static char *kwlist[] = {"source", "read_size", "read_across_frames",
483                              "closefd", NULL};
484 
485     PyObject *source;
486     size_t readSize = ZSTD_DStreamInSize();
487     PyObject *readAcrossFrames = NULL;
488     PyObject *closefd = NULL;
489     ZstdDecompressionReader *result;
490 
491     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kOO:stream_reader",
492                                      kwlist, &source, &readSize,
493                                      &readAcrossFrames, &closefd)) {
494         return NULL;
495     }
496 
497     if (ensure_dctx(self, 1)) {
498         return NULL;
499     }
500 
501     result = (ZstdDecompressionReader *)PyObject_CallObject(
502         (PyObject *)&ZstdDecompressionReaderType, NULL);
503     if (NULL == result) {
504         return NULL;
505     }
506 
507     result->entered = 0;
508     result->closed = 0;
509 
510     if (PyObject_HasAttrString(source, "read")) {
511         result->reader = source;
512         Py_INCREF(source);
513         result->readSize = readSize;
514     }
515     else if (1 == PyObject_CheckBuffer(source)) {
516         if (0 != PyObject_GetBuffer(source, &result->buffer, PyBUF_CONTIG_RO)) {
517             Py_CLEAR(result);
518             return NULL;
519         }
520     }
521     else {
522         PyErr_SetString(PyExc_TypeError,
523                         "must pass an object with a read() method or that "
524                         "conforms to the buffer protocol");
525         Py_CLEAR(result);
526         return NULL;
527     }
528 
529     result->decompressor = self;
530     Py_INCREF(self);
531     result->readAcrossFrames =
532         readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0;
533     result->closefd = closefd ? PyObject_IsTrue(closefd) : 1;
534 
535     return result;
536 }
537 
538 static ZstdDecompressionWriter *
Decompressor_stream_writer(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)539 Decompressor_stream_writer(ZstdDecompressor *self, PyObject *args,
540                            PyObject *kwargs) {
541     static char *kwlist[] = {"writer", "write_size", "write_return_read",
542                              "closefd", NULL};
543 
544     PyObject *writer;
545     size_t outSize = ZSTD_DStreamOutSize();
546     PyObject *writeReturnRead = NULL;
547     PyObject *closefd = NULL;
548     ZstdDecompressionWriter *result;
549 
550     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kOO:stream_writer",
551                                      kwlist, &writer, &outSize,
552                                      &writeReturnRead, &closefd)) {
553         return NULL;
554     }
555 
556     if (!PyObject_HasAttrString(writer, "write")) {
557         PyErr_SetString(PyExc_ValueError,
558                         "must pass an object with a write() method");
559         return NULL;
560     }
561 
562     if (ensure_dctx(self, 1)) {
563         return NULL;
564     }
565 
566     result = (ZstdDecompressionWriter *)PyObject_CallObject(
567         (PyObject *)&ZstdDecompressionWriterType, NULL);
568     if (!result) {
569         return NULL;
570     }
571 
572     result->entered = 0;
573     result->closing = 0;
574     result->closed = 0;
575 
576     result->decompressor = self;
577     Py_INCREF(result->decompressor);
578 
579     result->writer = writer;
580     Py_INCREF(result->writer);
581 
582     result->outSize = outSize;
583     result->writeReturnRead =
584         writeReturnRead ? PyObject_IsTrue(writeReturnRead) : 1;
585     result->closefd = closefd ? PyObject_IsTrue(closefd) : 1;
586 
587     return result;
588 }
589 
590 static PyObject *
Decompressor_decompress_content_dict_chain(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)591 Decompressor_decompress_content_dict_chain(ZstdDecompressor *self,
592                                            PyObject *args, PyObject *kwargs) {
593     static char *kwlist[] = {"frames", NULL};
594 
595     PyObject *chunks;
596     Py_ssize_t chunksLen;
597     Py_ssize_t chunkIndex;
598     char parity = 0;
599     PyObject *chunk;
600     char *chunkData;
601     Py_ssize_t chunkSize;
602     size_t zresult;
603     ZSTD_frameHeader frameHeader;
604     void *buffer1 = NULL;
605     size_t buffer1Size = 0;
606     size_t buffer1ContentSize = 0;
607     void *buffer2 = NULL;
608     size_t buffer2Size = 0;
609     size_t buffer2ContentSize = 0;
610     void *destBuffer = NULL;
611     PyObject *result = NULL;
612     ZSTD_outBuffer outBuffer;
613     ZSTD_inBuffer inBuffer;
614 
615     if (!PyArg_ParseTupleAndKeywords(args, kwargs,
616                                      "O!:decompress_content_dict_chain", kwlist,
617                                      &PyList_Type, &chunks)) {
618         return NULL;
619     }
620 
621     chunksLen = PyList_Size(chunks);
622     if (!chunksLen) {
623         PyErr_SetString(PyExc_ValueError, "empty input chain");
624         return NULL;
625     }
626 
627     /* The first chunk should not be using a dictionary. We handle it specially.
628      */
629     chunk = PyList_GetItem(chunks, 0);
630     if (!PyBytes_Check(chunk)) {
631         PyErr_SetString(PyExc_ValueError, "chunk 0 must be bytes");
632         return NULL;
633     }
634 
635     /* We require that all chunks be zstd frames and that they have content size
636      * set. */
637     PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize);
638     zresult = ZSTD_getFrameHeader(&frameHeader, (void *)chunkData, chunkSize);
639     if (ZSTD_isError(zresult)) {
640         PyErr_SetString(PyExc_ValueError, "chunk 0 is not a valid zstd frame");
641         return NULL;
642     }
643     else if (zresult) {
644         PyErr_SetString(PyExc_ValueError,
645                         "chunk 0 is too small to contain a zstd frame");
646         return NULL;
647     }
648 
649     if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) {
650         PyErr_SetString(PyExc_ValueError,
651                         "chunk 0 missing content size in frame");
652         return NULL;
653     }
654 
655     assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize);
656 
657     /* We check against PY_SSIZE_T_MAX here because we ultimately cast the
658      * result to a Python object and it's length can be no greater than
659      * Py_ssize_t. In theory, we could have an intermediate frame that is
660      * larger. But a) why would this API be used for frames that large b)
661      * it isn't worth the complexity to support. */
662     assert(SIZE_MAX >= PY_SSIZE_T_MAX);
663     if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) {
664         PyErr_SetString(PyExc_ValueError,
665                         "chunk 0 is too large to decompress on this platform");
666         return NULL;
667     }
668 
669     if (ensure_dctx(self, 0)) {
670         goto finally;
671     }
672 
673     buffer1Size = (size_t)frameHeader.frameContentSize;
674     buffer1 = PyMem_Malloc(buffer1Size);
675     if (!buffer1) {
676         goto finally;
677     }
678 
679     outBuffer.dst = buffer1;
680     outBuffer.size = buffer1Size;
681     outBuffer.pos = 0;
682 
683     inBuffer.src = chunkData;
684     inBuffer.size = chunkSize;
685     inBuffer.pos = 0;
686 
687     Py_BEGIN_ALLOW_THREADS zresult =
688         ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
689     Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
690         PyErr_Format(ZstdError, "could not decompress chunk 0: %s",
691                      ZSTD_getErrorName(zresult));
692         goto finally;
693     }
694     else if (zresult) {
695         PyErr_Format(ZstdError, "chunk 0 did not decompress full frame");
696         goto finally;
697     }
698 
699     buffer1ContentSize = outBuffer.pos;
700 
701     /* Special case of a simple chain. */
702     if (1 == chunksLen) {
703         result = PyBytes_FromStringAndSize(buffer1, buffer1Size);
704         goto finally;
705     }
706 
707     /* This should ideally look at next chunk. But this is slightly simpler. */
708     buffer2Size = (size_t)frameHeader.frameContentSize;
709     buffer2 = PyMem_Malloc(buffer2Size);
710     if (!buffer2) {
711         goto finally;
712     }
713 
714     /* For each subsequent chunk, use the previous fulltext as a content
715        dictionary. Our strategy is to have 2 buffers. One holds the previous
716        fulltext (to be used as a content dictionary) and the other holds the new
717        fulltext. The buffers grow when needed but never decrease in size. This
718        limits the memory allocator overhead.
719     */
720     for (chunkIndex = 1; chunkIndex < chunksLen; chunkIndex++) {
721         chunk = PyList_GetItem(chunks, chunkIndex);
722         if (!PyBytes_Check(chunk)) {
723             PyErr_Format(PyExc_ValueError, "chunk %zd must be bytes",
724                          chunkIndex);
725             goto finally;
726         }
727 
728         PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize);
729         zresult =
730             ZSTD_getFrameHeader(&frameHeader, (void *)chunkData, chunkSize);
731         if (ZSTD_isError(zresult)) {
732             PyErr_Format(PyExc_ValueError,
733                          "chunk %zd is not a valid zstd frame", chunkIndex);
734             goto finally;
735         }
736         else if (zresult) {
737             PyErr_Format(PyExc_ValueError,
738                          "chunk %zd is too small to contain a zstd frame",
739                          chunkIndex);
740             goto finally;
741         }
742 
743         if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) {
744             PyErr_Format(PyExc_ValueError,
745                          "chunk %zd missing content size in frame", chunkIndex);
746             goto finally;
747         }
748 
749         assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize);
750 
751         if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) {
752             PyErr_Format(
753                 PyExc_ValueError,
754                 "chunk %zd is too large to decompress on this platform",
755                 chunkIndex);
756             goto finally;
757         }
758 
759         inBuffer.src = chunkData;
760         inBuffer.size = chunkSize;
761         inBuffer.pos = 0;
762 
763         parity = chunkIndex % 2;
764 
765         /* This could definitely be abstracted to reduce code duplication. */
766         if (parity) {
767             /* Resize destination buffer to hold larger content. */
768             if (buffer2Size < frameHeader.frameContentSize) {
769                 buffer2Size = (size_t)frameHeader.frameContentSize;
770                 destBuffer = PyMem_Realloc(buffer2, buffer2Size);
771                 if (!destBuffer) {
772                     goto finally;
773                 }
774                 buffer2 = destBuffer;
775             }
776 
777             Py_BEGIN_ALLOW_THREADS zresult = ZSTD_DCtx_refPrefix_advanced(
778                 self->dctx, buffer1, buffer1ContentSize, ZSTD_dct_rawContent);
779             Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
780                 PyErr_Format(ZstdError,
781                              "failed to load prefix dictionary at chunk %zd",
782                              chunkIndex);
783                 goto finally;
784             }
785 
786             outBuffer.dst = buffer2;
787             outBuffer.size = buffer2Size;
788             outBuffer.pos = 0;
789 
790             Py_BEGIN_ALLOW_THREADS zresult =
791                 ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
792             Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
793                 PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
794                              chunkIndex, ZSTD_getErrorName(zresult));
795                 goto finally;
796             }
797             else if (zresult) {
798                 PyErr_Format(ZstdError,
799                              "chunk %zd did not decompress full frame",
800                              chunkIndex);
801                 goto finally;
802             }
803 
804             buffer2ContentSize = outBuffer.pos;
805         }
806         else {
807             if (buffer1Size < frameHeader.frameContentSize) {
808                 buffer1Size = (size_t)frameHeader.frameContentSize;
809                 destBuffer = PyMem_Realloc(buffer1, buffer1Size);
810                 if (!destBuffer) {
811                     goto finally;
812                 }
813                 buffer1 = destBuffer;
814             }
815 
816             Py_BEGIN_ALLOW_THREADS zresult = ZSTD_DCtx_refPrefix_advanced(
817                 self->dctx, buffer2, buffer2ContentSize, ZSTD_dct_rawContent);
818             Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
819                 PyErr_Format(ZstdError,
820                              "failed to load prefix dictionary at chunk %zd",
821                              chunkIndex);
822                 goto finally;
823             }
824 
825             outBuffer.dst = buffer1;
826             outBuffer.size = buffer1Size;
827             outBuffer.pos = 0;
828 
829             Py_BEGIN_ALLOW_THREADS zresult =
830                 ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
831             Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) {
832                 PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
833                              chunkIndex, ZSTD_getErrorName(zresult));
834                 goto finally;
835             }
836             else if (zresult) {
837                 PyErr_Format(ZstdError,
838                              "chunk %zd did not decompress full frame",
839                              chunkIndex);
840                 goto finally;
841             }
842 
843             buffer1ContentSize = outBuffer.pos;
844         }
845     }
846 
847     result = PyBytes_FromStringAndSize(parity ? buffer2 : buffer1,
848                                        parity ? buffer2ContentSize
849                                               : buffer1ContentSize);
850 
851 finally:
852     if (buffer2) {
853         PyMem_Free(buffer2);
854     }
855     if (buffer1) {
856         PyMem_Free(buffer1);
857     }
858 
859     return result;
860 }
861 
862 typedef struct {
863     void *sourceData;
864     size_t sourceSize;
865     size_t destSize;
866 } FramePointer;
867 
868 typedef struct {
869     FramePointer *frames;
870     Py_ssize_t framesSize;
871     unsigned long long compressedSize;
872 } FrameSources;
873 
874 typedef struct {
875     void *dest;
876     Py_ssize_t destSize;
877     BufferSegment *segments;
878     Py_ssize_t segmentsSize;
879 } DecompressorDestBuffer;
880 
881 typedef enum {
882     DecompressorWorkerError_none = 0,
883     DecompressorWorkerError_zstd = 1,
884     DecompressorWorkerError_memory = 2,
885     DecompressorWorkerError_sizeMismatch = 3,
886     DecompressorWorkerError_unknownSize = 4,
887 } DecompressorWorkerError;
888 
889 typedef struct {
890     /* Source records and length */
891     FramePointer *framePointers;
892     /* Which records to process. */
893     Py_ssize_t startOffset;
894     Py_ssize_t endOffset;
895     unsigned long long totalSourceSize;
896 
897     /* Compression state and settings. */
898     ZSTD_DCtx *dctx;
899     int requireOutputSizes;
900 
901     /* Output storage. */
902     DecompressorDestBuffer *destBuffers;
903     Py_ssize_t destCount;
904 
905     /* Item that error occurred on. */
906     Py_ssize_t errorOffset;
907     /* If an error occurred. */
908     DecompressorWorkerError error;
909     /* result from zstd decompression operation */
910     size_t zresult;
911 } DecompressorWorkerState;
912 
913 #ifdef HAVE_ZSTD_POOL_APIS
decompress_worker(DecompressorWorkerState * state)914 static void decompress_worker(DecompressorWorkerState *state) {
915     size_t allocationSize;
916     DecompressorDestBuffer *destBuffer;
917     Py_ssize_t frameIndex;
918     Py_ssize_t localOffset = 0;
919     Py_ssize_t currentBufferStartIndex = state->startOffset;
920     Py_ssize_t remainingItems = state->endOffset - state->startOffset + 1;
921     void *tmpBuf;
922     Py_ssize_t destOffset = 0;
923     FramePointer *framePointers = state->framePointers;
924     size_t zresult;
925     unsigned long long totalOutputSize = 0;
926 
927     assert(NULL == state->destBuffers);
928     assert(0 == state->destCount);
929     assert(state->endOffset - state->startOffset >= 0);
930 
931     /* We could get here due to the way work is allocated. Ideally we wouldn't
932        get here. But that would require a bit of a refactor in the caller. */
933     if (state->totalSourceSize > SIZE_MAX) {
934         state->error = DecompressorWorkerError_memory;
935         state->errorOffset = 0;
936         return;
937     }
938 
939     /*
940      * We need to allocate a buffer to hold decompressed data. How we do this
941      * depends on what we know about the output. The following scenarios are
942      * possible:
943      *
944      * 1. All structs defining frames declare the output size.
945      * 2. The decompressed size is embedded within the zstd frame.
946      * 3. The decompressed size is not stored anywhere.
947      *
948      * For now, we only support #1 and #2.
949      */
950 
951     /* Resolve ouput segments. */
952     for (frameIndex = state->startOffset; frameIndex <= state->endOffset;
953          frameIndex++) {
954         FramePointer *fp = &framePointers[frameIndex];
955         unsigned long long decompressedSize;
956 
957         if (0 == fp->destSize) {
958             decompressedSize =
959                 ZSTD_getFrameContentSize(fp->sourceData, fp->sourceSize);
960 
961             if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) {
962                 state->error = DecompressorWorkerError_unknownSize;
963                 state->errorOffset = frameIndex;
964                 return;
965             }
966             else if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) {
967                 if (state->requireOutputSizes) {
968                     state->error = DecompressorWorkerError_unknownSize;
969                     state->errorOffset = frameIndex;
970                     return;
971                 }
972 
973                 /* This will fail the assert for .destSize > 0 below. */
974                 decompressedSize = 0;
975             }
976 
977             if (decompressedSize > SIZE_MAX) {
978                 state->error = DecompressorWorkerError_memory;
979                 state->errorOffset = frameIndex;
980                 return;
981             }
982 
983             fp->destSize = (size_t)decompressedSize;
984         }
985 
986         totalOutputSize += fp->destSize;
987     }
988 
989     state->destBuffers = calloc(1, sizeof(DecompressorDestBuffer));
990     if (NULL == state->destBuffers) {
991         state->error = DecompressorWorkerError_memory;
992         return;
993     }
994 
995     state->destCount = 1;
996 
997     destBuffer = &state->destBuffers[state->destCount - 1];
998 
999     assert(framePointers[state->startOffset].destSize > 0); /* For now. */
1000 
1001     allocationSize = roundpow2((size_t)state->totalSourceSize);
1002 
1003     if (framePointers[state->startOffset].destSize > allocationSize) {
1004         allocationSize = roundpow2(framePointers[state->startOffset].destSize);
1005     }
1006 
1007     destBuffer->dest = malloc(allocationSize);
1008     if (NULL == destBuffer->dest) {
1009         state->error = DecompressorWorkerError_memory;
1010         return;
1011     }
1012 
1013     destBuffer->destSize = allocationSize;
1014 
1015     destBuffer->segments = calloc(remainingItems, sizeof(BufferSegment));
1016     if (NULL == destBuffer->segments) {
1017         /* Caller will free state->dest as part of cleanup. */
1018         state->error = DecompressorWorkerError_memory;
1019         return;
1020     }
1021 
1022     destBuffer->segmentsSize = remainingItems;
1023 
1024     for (frameIndex = state->startOffset; frameIndex <= state->endOffset;
1025          frameIndex++) {
1026         ZSTD_outBuffer outBuffer;
1027         ZSTD_inBuffer inBuffer;
1028         const void *source = framePointers[frameIndex].sourceData;
1029         const size_t sourceSize = framePointers[frameIndex].sourceSize;
1030         void *dest;
1031         const size_t decompressedSize = framePointers[frameIndex].destSize;
1032         size_t destAvailable = destBuffer->destSize - destOffset;
1033 
1034         assert(decompressedSize > 0); /* For now. */
1035 
1036         /*
1037          * Not enough space in current buffer. Finish current before and
1038          * allocate and switch to a new one.
1039          */
1040         if (decompressedSize > destAvailable) {
1041             /*
1042              * Shrinking the destination buffer is optional. But it should be
1043              * cheap, so we just do it.
1044              */
1045             if (destAvailable) {
1046                 tmpBuf = realloc(destBuffer->dest, destOffset);
1047                 if (NULL == tmpBuf) {
1048                     state->error = DecompressorWorkerError_memory;
1049                     return;
1050                 }
1051 
1052                 destBuffer->dest = tmpBuf;
1053                 destBuffer->destSize = destOffset;
1054             }
1055 
1056             /* Truncate segments buffer. */
1057             tmpBuf = realloc(destBuffer->segments,
1058                              (frameIndex - currentBufferStartIndex) *
1059                                  sizeof(BufferSegment));
1060             if (NULL == tmpBuf) {
1061                 state->error = DecompressorWorkerError_memory;
1062                 return;
1063             }
1064 
1065             destBuffer->segments = tmpBuf;
1066             destBuffer->segmentsSize = frameIndex - currentBufferStartIndex;
1067 
1068             /* Grow space for new DestBuffer. */
1069             tmpBuf =
1070                 realloc(state->destBuffers, (state->destCount + 1) *
1071                                                 sizeof(DecompressorDestBuffer));
1072             if (NULL == tmpBuf) {
1073                 state->error = DecompressorWorkerError_memory;
1074                 return;
1075             }
1076 
1077             state->destBuffers = tmpBuf;
1078             state->destCount++;
1079 
1080             destBuffer = &state->destBuffers[state->destCount - 1];
1081 
1082             /* Don't take any chances will non-NULL pointers. */
1083             memset(destBuffer, 0, sizeof(DecompressorDestBuffer));
1084 
1085             allocationSize = roundpow2((size_t)state->totalSourceSize);
1086 
1087             if (decompressedSize > allocationSize) {
1088                 allocationSize = roundpow2(decompressedSize);
1089             }
1090 
1091             destBuffer->dest = malloc(allocationSize);
1092             if (NULL == destBuffer->dest) {
1093                 state->error = DecompressorWorkerError_memory;
1094                 return;
1095             }
1096 
1097             destBuffer->destSize = allocationSize;
1098             destAvailable = allocationSize;
1099             destOffset = 0;
1100             localOffset = 0;
1101 
1102             destBuffer->segments =
1103                 calloc(remainingItems, sizeof(BufferSegment));
1104             if (NULL == destBuffer->segments) {
1105                 state->error = DecompressorWorkerError_memory;
1106                 return;
1107             }
1108 
1109             destBuffer->segmentsSize = remainingItems;
1110             currentBufferStartIndex = frameIndex;
1111         }
1112 
1113         dest = (char *)destBuffer->dest + destOffset;
1114 
1115         outBuffer.dst = dest;
1116         outBuffer.size = decompressedSize;
1117         outBuffer.pos = 0;
1118 
1119         inBuffer.src = source;
1120         inBuffer.size = sourceSize;
1121         inBuffer.pos = 0;
1122 
1123         zresult = ZSTD_decompressStream(state->dctx, &outBuffer, &inBuffer);
1124         if (ZSTD_isError(zresult)) {
1125             state->error = DecompressorWorkerError_zstd;
1126             state->zresult = zresult;
1127             state->errorOffset = frameIndex;
1128             return;
1129         }
1130         else if (zresult || outBuffer.pos != decompressedSize) {
1131             state->error = DecompressorWorkerError_sizeMismatch;
1132             state->zresult = outBuffer.pos;
1133             state->errorOffset = frameIndex;
1134             return;
1135         }
1136 
1137         destBuffer->segments[localOffset].offset = destOffset;
1138         destBuffer->segments[localOffset].length = outBuffer.pos;
1139         destOffset += outBuffer.pos;
1140         localOffset++;
1141         remainingItems--;
1142     }
1143 
1144     if (destBuffer->destSize > destOffset) {
1145         tmpBuf = realloc(destBuffer->dest, destOffset);
1146         if (NULL == tmpBuf) {
1147             state->error = DecompressorWorkerError_memory;
1148             return;
1149         }
1150 
1151         destBuffer->dest = tmpBuf;
1152         destBuffer->destSize = destOffset;
1153     }
1154 }
1155 #endif
1156 
1157 #ifdef HAVE_ZSTD_POOL_APIS
1158 ZstdBufferWithSegmentsCollection *
decompress_from_framesources(ZstdDecompressor * decompressor,FrameSources * frames,Py_ssize_t threadCount)1159 decompress_from_framesources(ZstdDecompressor *decompressor,
1160                              FrameSources *frames, Py_ssize_t threadCount) {
1161     Py_ssize_t i = 0;
1162     int errored = 0;
1163     Py_ssize_t segmentsCount;
1164     ZstdBufferWithSegments *bws = NULL;
1165     PyObject *resultArg = NULL;
1166     Py_ssize_t resultIndex;
1167     ZstdBufferWithSegmentsCollection *result = NULL;
1168     FramePointer *framePointers = frames->frames;
1169     unsigned long long workerBytes = 0;
1170     Py_ssize_t currentThread = 0;
1171     Py_ssize_t workerStartOffset = 0;
1172     POOL_ctx *pool = NULL;
1173     DecompressorWorkerState *workerStates = NULL;
1174     unsigned long long bytesPerWorker;
1175 
1176     /* Caller should normalize 0 and negative values to 1 or larger. */
1177     assert(threadCount >= 1);
1178 
1179     /* More threads than inputs makes no sense under any conditions. */
1180     threadCount =
1181         frames->framesSize < threadCount ? frames->framesSize : threadCount;
1182 
1183     /* TODO lower thread count if input size is too small and threads would just
1184        add overhead. */
1185 
1186     if (decompressor->dict) {
1187         if (ensure_ddict(decompressor->dict)) {
1188             return NULL;
1189         }
1190     }
1191 
1192     /* If threadCount==1, we don't start a thread pool. But we do leverage the
1193        same API for dispatching work. */
1194     workerStates = PyMem_Malloc(threadCount * sizeof(DecompressorWorkerState));
1195     if (NULL == workerStates) {
1196         PyErr_NoMemory();
1197         goto finally;
1198     }
1199 
1200     memset(workerStates, 0, threadCount * sizeof(DecompressorWorkerState));
1201 
1202     if (threadCount > 1) {
1203         pool = POOL_create(threadCount, 1);
1204         if (NULL == pool) {
1205             PyErr_SetString(ZstdError, "could not initialize zstd thread pool");
1206             goto finally;
1207         }
1208     }
1209 
1210     bytesPerWorker = frames->compressedSize / threadCount;
1211 
1212     if (bytesPerWorker > SIZE_MAX) {
1213         PyErr_SetString(ZstdError,
1214                         "too much data per worker for this platform");
1215         goto finally;
1216     }
1217 
1218     for (i = 0; i < threadCount; i++) {
1219         size_t zresult;
1220 
1221         workerStates[i].dctx = ZSTD_createDCtx();
1222         if (NULL == workerStates[i].dctx) {
1223             PyErr_NoMemory();
1224             goto finally;
1225         }
1226 
1227         ZSTD_copyDCtx(workerStates[i].dctx, decompressor->dctx);
1228 
1229         if (decompressor->dict) {
1230             zresult = ZSTD_DCtx_refDDict(workerStates[i].dctx,
1231                                          decompressor->dict->ddict);
1232             if (zresult) {
1233                 PyErr_Format(ZstdError,
1234                              "unable to reference prepared dictionary: %s",
1235                              ZSTD_getErrorName(zresult));
1236                 goto finally;
1237             }
1238         }
1239 
1240         workerStates[i].framePointers = framePointers;
1241         workerStates[i].requireOutputSizes = 1;
1242     }
1243 
1244     Py_BEGIN_ALLOW_THREADS
1245         /* There are many ways to split work among workers.
1246 
1247            For now, we take a simple approach of splitting work so each worker
1248            gets roughly the same number of input bytes. This will result in more
1249            starvation than running N>threadCount jobs. But it avoids
1250            complications around state tracking, which could involve extra
1251            locking.
1252         */
1253         for (i = 0; i < frames->framesSize; i++) {
1254         workerBytes += frames->frames[i].sourceSize;
1255 
1256         /*
1257          * The last worker/thread needs to handle all remaining work. Don't
1258          * trigger it prematurely. Defer to the block outside of the loop.
1259          * (But still process this loop so workerBytes is correct.
1260          */
1261         if (currentThread == threadCount - 1) {
1262             continue;
1263         }
1264 
1265         if (workerBytes >= bytesPerWorker) {
1266             workerStates[currentThread].startOffset = workerStartOffset;
1267             workerStates[currentThread].endOffset = i;
1268             workerStates[currentThread].totalSourceSize = workerBytes;
1269 
1270             if (threadCount > 1) {
1271                 POOL_add(pool, (POOL_function)decompress_worker,
1272                          &workerStates[currentThread]);
1273             }
1274             else {
1275                 decompress_worker(&workerStates[currentThread]);
1276             }
1277             currentThread++;
1278             workerStartOffset = i + 1;
1279             workerBytes = 0;
1280         }
1281     }
1282 
1283     if (workerBytes) {
1284         workerStates[currentThread].startOffset = workerStartOffset;
1285         workerStates[currentThread].endOffset = frames->framesSize - 1;
1286         workerStates[currentThread].totalSourceSize = workerBytes;
1287 
1288         if (threadCount > 1) {
1289             POOL_add(pool, (POOL_function)decompress_worker,
1290                      &workerStates[currentThread]);
1291         }
1292         else {
1293             decompress_worker(&workerStates[currentThread]);
1294         }
1295     }
1296 
1297     if (threadCount > 1) {
1298         POOL_free(pool);
1299         pool = NULL;
1300     }
1301     Py_END_ALLOW_THREADS
1302 
1303         for (i = 0; i < threadCount; i++) {
1304         switch (workerStates[i].error) {
1305         case DecompressorWorkerError_none:
1306             break;
1307 
1308         case DecompressorWorkerError_zstd:
1309             PyErr_Format(ZstdError, "error decompressing item %zd: %s",
1310                          workerStates[i].errorOffset,
1311                          ZSTD_getErrorName(workerStates[i].zresult));
1312             errored = 1;
1313             break;
1314 
1315         case DecompressorWorkerError_memory:
1316             PyErr_NoMemory();
1317             errored = 1;
1318             break;
1319 
1320         case DecompressorWorkerError_sizeMismatch:
1321             PyErr_Format(ZstdError,
1322                          "error decompressing item %zd: decompressed %zu "
1323                          "bytes; expected %zu",
1324                          workerStates[i].errorOffset, workerStates[i].zresult,
1325                          framePointers[workerStates[i].errorOffset].destSize);
1326             errored = 1;
1327             break;
1328 
1329         case DecompressorWorkerError_unknownSize:
1330             PyErr_Format(PyExc_ValueError,
1331                          "could not determine decompressed size of item %zd",
1332                          workerStates[i].errorOffset);
1333             errored = 1;
1334             break;
1335 
1336         default:
1337             PyErr_Format(ZstdError, "unhandled error type: %d; this is a bug",
1338                          workerStates[i].error);
1339             errored = 1;
1340             break;
1341         }
1342 
1343         if (errored) {
1344             break;
1345         }
1346     }
1347 
1348     if (errored) {
1349         goto finally;
1350     }
1351 
1352     segmentsCount = 0;
1353     for (i = 0; i < threadCount; i++) {
1354         segmentsCount += workerStates[i].destCount;
1355     }
1356 
1357     resultArg = PyTuple_New(segmentsCount);
1358     if (NULL == resultArg) {
1359         goto finally;
1360     }
1361 
1362     resultIndex = 0;
1363 
1364     for (i = 0; i < threadCount; i++) {
1365         Py_ssize_t bufferIndex;
1366         DecompressorWorkerState *state = &workerStates[i];
1367 
1368         for (bufferIndex = 0; bufferIndex < state->destCount; bufferIndex++) {
1369             DecompressorDestBuffer *destBuffer =
1370                 &state->destBuffers[bufferIndex];
1371 
1372             bws = BufferWithSegments_FromMemory(
1373                 destBuffer->dest, destBuffer->destSize, destBuffer->segments,
1374                 destBuffer->segmentsSize);
1375             if (NULL == bws) {
1376                 goto finally;
1377             }
1378 
1379             /*
1380              * Memory for buffer and segments was allocated using malloc() in
1381              * worker and the memory is transferred to the BufferWithSegments
1382              * instance. So tell instance to use free() and NULL the reference
1383              * in the state struct so it isn't freed below.
1384              */
1385             bws->useFree = 1;
1386             destBuffer->dest = NULL;
1387             destBuffer->segments = NULL;
1388 
1389             PyTuple_SET_ITEM(resultArg, resultIndex++, (PyObject *)bws);
1390         }
1391     }
1392 
1393     result = (ZstdBufferWithSegmentsCollection *)PyObject_CallObject(
1394         (PyObject *)&ZstdBufferWithSegmentsCollectionType, resultArg);
1395 
1396 finally:
1397     Py_CLEAR(resultArg);
1398 
1399     if (workerStates) {
1400         for (i = 0; i < threadCount; i++) {
1401             Py_ssize_t bufferIndex;
1402             DecompressorWorkerState *state = &workerStates[i];
1403 
1404             if (state->dctx) {
1405                 ZSTD_freeDCtx(state->dctx);
1406             }
1407 
1408             for (bufferIndex = 0; bufferIndex < state->destCount;
1409                  bufferIndex++) {
1410                 if (state->destBuffers) {
1411                     /*
1412                      * Will be NULL if memory transfered to a
1413                      * BufferWithSegments. Otherwise it is left over after an
1414                      * error occurred.
1415                      */
1416                     free(state->destBuffers[bufferIndex].dest);
1417                     free(state->destBuffers[bufferIndex].segments);
1418                 }
1419             }
1420 
1421             free(state->destBuffers);
1422         }
1423 
1424         PyMem_Free(workerStates);
1425     }
1426 
1427     POOL_free(pool);
1428 
1429     return result;
1430 }
1431 #endif
1432 
1433 #ifdef HAVE_ZSTD_POOL_APIS
1434 static ZstdBufferWithSegmentsCollection *
Decompressor_multi_decompress_to_buffer(ZstdDecompressor * self,PyObject * args,PyObject * kwargs)1435 Decompressor_multi_decompress_to_buffer(ZstdDecompressor *self, PyObject *args,
1436                                         PyObject *kwargs) {
1437     static char *kwlist[] = {"frames", "decompressed_sizes", "threads", NULL};
1438 
1439     PyObject *frames;
1440     Py_buffer frameSizes;
1441     int threads = 0;
1442     Py_ssize_t frameCount;
1443     Py_buffer *frameBuffers = NULL;
1444     FramePointer *framePointers = NULL;
1445     unsigned long long *frameSizesP = NULL;
1446     unsigned long long totalInputSize = 0;
1447     FrameSources frameSources;
1448     ZstdBufferWithSegmentsCollection *result = NULL;
1449     Py_ssize_t i;
1450 
1451     memset(&frameSizes, 0, sizeof(frameSizes));
1452 
1453     if (!PyArg_ParseTupleAndKeywords(args, kwargs,
1454                                      "O|y*i:multi_decompress_to_buffer", kwlist,
1455                                      &frames, &frameSizes, &threads)) {
1456         return NULL;
1457     }
1458 
1459     if (frameSizes.buf) {
1460         frameSizesP = (unsigned long long *)frameSizes.buf;
1461     }
1462 
1463     if (threads < 0) {
1464         threads = cpu_count();
1465     }
1466 
1467     if (threads < 2) {
1468         threads = 1;
1469     }
1470 
1471     if (PyObject_TypeCheck(frames, &ZstdBufferWithSegmentsType)) {
1472         ZstdBufferWithSegments *buffer = (ZstdBufferWithSegments *)frames;
1473         frameCount = buffer->segmentCount;
1474 
1475         if (frameSizes.buf &&
1476             frameSizes.len !=
1477                 frameCount * (Py_ssize_t)sizeof(unsigned long long)) {
1478             PyErr_Format(
1479                 PyExc_ValueError,
1480                 "decompressed_sizes size mismatch; expected %zd, got %zd",
1481                 frameCount * sizeof(unsigned long long), frameSizes.len);
1482             goto finally;
1483         }
1484 
1485         framePointers = PyMem_Malloc(frameCount * sizeof(FramePointer));
1486         if (!framePointers) {
1487             PyErr_NoMemory();
1488             goto finally;
1489         }
1490 
1491         for (i = 0; i < frameCount; i++) {
1492             void *sourceData;
1493             unsigned long long sourceSize;
1494             unsigned long long decompressedSize = 0;
1495 
1496             if (buffer->segments[i].offset + buffer->segments[i].length >
1497                 buffer->dataSize) {
1498                 PyErr_Format(PyExc_ValueError,
1499                              "item %zd has offset outside memory area", i);
1500                 goto finally;
1501             }
1502 
1503             sourceData = (char *)buffer->data + buffer->segments[i].offset;
1504             sourceSize = buffer->segments[i].length;
1505             totalInputSize += sourceSize;
1506 
1507             if (frameSizesP) {
1508                 decompressedSize = frameSizesP[i];
1509             }
1510 
1511             if (sourceSize > SIZE_MAX) {
1512                 PyErr_Format(PyExc_ValueError,
1513                              "item %zd is too large for this platform", i);
1514                 goto finally;
1515             }
1516 
1517             if (decompressedSize > SIZE_MAX) {
1518                 PyErr_Format(PyExc_ValueError,
1519                              "decompressed size of item %zd is too large for "
1520                              "this platform",
1521                              i);
1522                 goto finally;
1523             }
1524 
1525             framePointers[i].sourceData = sourceData;
1526             framePointers[i].sourceSize = (size_t)sourceSize;
1527             framePointers[i].destSize = (size_t)decompressedSize;
1528         }
1529     }
1530     else if (PyObject_TypeCheck(frames,
1531                                 &ZstdBufferWithSegmentsCollectionType)) {
1532         Py_ssize_t offset = 0;
1533         ZstdBufferWithSegments *buffer;
1534         ZstdBufferWithSegmentsCollection *collection =
1535             (ZstdBufferWithSegmentsCollection *)frames;
1536 
1537         frameCount = BufferWithSegmentsCollection_length(collection);
1538 
1539         if (frameSizes.buf && frameSizes.len != frameCount) {
1540             PyErr_Format(
1541                 PyExc_ValueError,
1542                 "decompressed_sizes size mismatch; expected %zd; got %zd",
1543                 frameCount * sizeof(unsigned long long), frameSizes.len);
1544             goto finally;
1545         }
1546 
1547         framePointers = PyMem_Malloc(frameCount * sizeof(FramePointer));
1548         if (NULL == framePointers) {
1549             PyErr_NoMemory();
1550             goto finally;
1551         }
1552 
1553         /* Iterate the data structure directly because it is faster. */
1554         for (i = 0; i < collection->bufferCount; i++) {
1555             Py_ssize_t segmentIndex;
1556             buffer = collection->buffers[i];
1557 
1558             for (segmentIndex = 0; segmentIndex < buffer->segmentCount;
1559                  segmentIndex++) {
1560                 unsigned long long decompressedSize =
1561                     frameSizesP ? frameSizesP[offset] : 0;
1562 
1563                 if (buffer->segments[segmentIndex].offset +
1564                         buffer->segments[segmentIndex].length >
1565                     buffer->dataSize) {
1566                     PyErr_Format(PyExc_ValueError,
1567                                  "item %zd has offset outside memory area",
1568                                  offset);
1569                     goto finally;
1570                 }
1571 
1572                 if (buffer->segments[segmentIndex].length > SIZE_MAX) {
1573                     PyErr_Format(
1574                         PyExc_ValueError,
1575                         "item %zd in buffer %zd is too large for this platform",
1576                         segmentIndex, i);
1577                     goto finally;
1578                 }
1579 
1580                 if (decompressedSize > SIZE_MAX) {
1581                     PyErr_Format(PyExc_ValueError,
1582                                  "decompressed size of item %zd in buffer %zd "
1583                                  "is too large for this platform",
1584                                  segmentIndex, i);
1585                     goto finally;
1586                 }
1587 
1588                 totalInputSize += buffer->segments[segmentIndex].length;
1589 
1590                 framePointers[offset].sourceData =
1591                     (char *)buffer->data +
1592                     buffer->segments[segmentIndex].offset;
1593                 framePointers[offset].sourceSize =
1594                     (size_t)buffer->segments[segmentIndex].length;
1595                 framePointers[offset].destSize = (size_t)decompressedSize;
1596 
1597                 offset++;
1598             }
1599         }
1600     }
1601     else if (PyList_Check(frames)) {
1602         frameCount = PyList_GET_SIZE(frames);
1603 
1604         if (frameSizes.buf &&
1605             frameSizes.len !=
1606                 frameCount * (Py_ssize_t)sizeof(unsigned long long)) {
1607             PyErr_Format(
1608                 PyExc_ValueError,
1609                 "decompressed_sizes size mismatch; expected %zd, got %zd",
1610                 frameCount * sizeof(unsigned long long), frameSizes.len);
1611             goto finally;
1612         }
1613 
1614         framePointers = PyMem_Malloc(frameCount * sizeof(FramePointer));
1615         if (!framePointers) {
1616             PyErr_NoMemory();
1617             goto finally;
1618         }
1619 
1620         frameBuffers = PyMem_Malloc(frameCount * sizeof(Py_buffer));
1621         if (NULL == frameBuffers) {
1622             PyErr_NoMemory();
1623             goto finally;
1624         }
1625 
1626         memset(frameBuffers, 0, frameCount * sizeof(Py_buffer));
1627 
1628         /* Do a pass to assemble info about our input buffers and output sizes.
1629          */
1630         for (i = 0; i < frameCount; i++) {
1631             unsigned long long decompressedSize =
1632                 frameSizesP ? frameSizesP[i] : 0;
1633 
1634             if (0 != PyObject_GetBuffer(PyList_GET_ITEM(frames, i),
1635                                         &frameBuffers[i], PyBUF_CONTIG_RO)) {
1636                 PyErr_Clear();
1637                 PyErr_Format(PyExc_TypeError,
1638                              "item %zd not a bytes like object", i);
1639                 goto finally;
1640             }
1641 
1642             if (decompressedSize > SIZE_MAX) {
1643                 PyErr_Format(PyExc_ValueError,
1644                              "decompressed size of item %zd is too large for "
1645                              "this platform",
1646                              i);
1647                 goto finally;
1648             }
1649 
1650             totalInputSize += frameBuffers[i].len;
1651 
1652             framePointers[i].sourceData = frameBuffers[i].buf;
1653             framePointers[i].sourceSize = frameBuffers[i].len;
1654             framePointers[i].destSize = (size_t)decompressedSize;
1655         }
1656     }
1657     else {
1658         PyErr_SetString(PyExc_TypeError,
1659                         "argument must be list or BufferWithSegments");
1660         goto finally;
1661     }
1662 
1663     /* We now have an array with info about our inputs and outputs. Feed it into
1664        our generic decompression function. */
1665     frameSources.frames = framePointers;
1666     frameSources.framesSize = frameCount;
1667     frameSources.compressedSize = totalInputSize;
1668 
1669     result = decompress_from_framesources(self, &frameSources, threads);
1670 
1671 finally:
1672     if (frameSizes.buf) {
1673         PyBuffer_Release(&frameSizes);
1674     }
1675     PyMem_Free(framePointers);
1676 
1677     if (frameBuffers) {
1678         for (i = 0; i < frameCount; i++) {
1679             PyBuffer_Release(&frameBuffers[i]);
1680         }
1681 
1682         PyMem_Free(frameBuffers);
1683     }
1684 
1685     return result;
1686 }
1687 #endif
1688 
1689 static PyMethodDef Decompressor_methods[] = {
1690     {"copy_stream", (PyCFunction)Decompressor_copy_stream,
1691      METH_VARARGS | METH_KEYWORDS, NULL},
1692     {"decompress", (PyCFunction)Decompressor_decompress,
1693      METH_VARARGS | METH_KEYWORDS, NULL},
1694     {"decompressobj", (PyCFunction)Decompressor_decompressobj,
1695      METH_VARARGS | METH_KEYWORDS, NULL},
1696     {"read_to_iter", (PyCFunction)Decompressor_read_to_iter,
1697      METH_VARARGS | METH_KEYWORDS, NULL},
1698     {"stream_reader", (PyCFunction)Decompressor_stream_reader,
1699      METH_VARARGS | METH_KEYWORDS, NULL},
1700     {"stream_writer", (PyCFunction)Decompressor_stream_writer,
1701      METH_VARARGS | METH_KEYWORDS, NULL},
1702     {"decompress_content_dict_chain",
1703      (PyCFunction)Decompressor_decompress_content_dict_chain,
1704      METH_VARARGS | METH_KEYWORDS, NULL},
1705 #ifdef HAVE_ZSTD_POOL_APIS
1706     {"multi_decompress_to_buffer",
1707      (PyCFunction)Decompressor_multi_decompress_to_buffer,
1708      METH_VARARGS | METH_KEYWORDS, NULL},
1709 #endif
1710     {"memory_size", (PyCFunction)Decompressor_memory_size, METH_NOARGS, NULL},
1711     {NULL, NULL}};
1712 
1713 PyTypeObject ZstdDecompressorType = {
1714     PyVarObject_HEAD_INIT(NULL, 0) "zstd.ZstdDecompressor", /* tp_name */
1715     sizeof(ZstdDecompressor),                               /* tp_basicsize */
1716     0,                                                      /* tp_itemsize */
1717     (destructor)Decompressor_dealloc,                       /* tp_dealloc */
1718     0,                                                      /* tp_print */
1719     0,                                                      /* tp_getattr */
1720     0,                                                      /* tp_setattr */
1721     0,                                                      /* tp_compare */
1722     0,                                                      /* tp_repr */
1723     0,                                                      /* tp_as_number */
1724     0,                                                      /* tp_as_sequence */
1725     0,                                                      /* tp_as_mapping */
1726     0,                                                      /* tp_hash */
1727     0,                                                      /* tp_call */
1728     0,                                                      /* tp_str */
1729     0,                                                      /* tp_getattro */
1730     0,                                                      /* tp_setattro */
1731     0,                                                      /* tp_as_buffer */
1732     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,               /* tp_flags */
1733     0,                                                      /* tp_doc */
1734     0,                                                      /* tp_traverse */
1735     0,                                                      /* tp_clear */
1736     0,                                                      /* tp_richcompare */
1737     0,                           /* tp_weaklistoffset */
1738     0,                           /* tp_iter */
1739     0,                           /* tp_iternext */
1740     Decompressor_methods,        /* tp_methods */
1741     0,                           /* tp_members */
1742     0,                           /* tp_getset */
1743     0,                           /* tp_base */
1744     0,                           /* tp_dict */
1745     0,                           /* tp_descr_get */
1746     0,                           /* tp_descr_set */
1747     0,                           /* tp_dictoffset */
1748     (initproc)Decompressor_init, /* tp_init */
1749     0,                           /* tp_alloc */
1750     PyType_GenericNew,           /* tp_new */
1751 };
1752 
decompressor_module_init(PyObject * mod)1753 void decompressor_module_init(PyObject *mod) {
1754     Py_SET_TYPE(&ZstdDecompressorType, &PyType_Type);
1755     if (PyType_Ready(&ZstdDecompressorType) < 0) {
1756         return;
1757     }
1758 
1759     Py_INCREF((PyObject *)&ZstdDecompressorType);
1760     PyModule_AddObject(mod, "ZstdDecompressor",
1761                        (PyObject *)&ZstdDecompressorType);
1762 }
1763