1"""Classes and functions for managing compressors."""
2
3import io
4import zlib
5from distutils.version import LooseVersion
6
7try:
8    from threading import RLock
9except ImportError:
10    from dummy_threading import RLock
11
12try:
13    import bz2
14except ImportError:
15    bz2 = None
16
17try:
18    import lz4
19    from lz4.frame import LZ4FrameFile
20except ImportError:
21    lz4 = None
22
23try:
24    import lzma
25except ImportError:
26    lzma = None
27
28
29LZ4_NOT_INSTALLED_ERROR = ('LZ4 is not installed. Install it with pip: '
30                           'https://python-lz4.readthedocs.io/')
31
32# Registered compressors
33_COMPRESSORS = {}
34
35# Magic numbers of supported compression file formats.
36_ZFILE_PREFIX = b'ZF'  # used with pickle files created before 0.9.3.
37_ZLIB_PREFIX = b'\x78'
38_GZIP_PREFIX = b'\x1f\x8b'
39_BZ2_PREFIX = b'BZ'
40_XZ_PREFIX = b'\xfd\x37\x7a\x58\x5a'
41_LZMA_PREFIX = b'\x5d\x00'
42_LZ4_PREFIX = b'\x04\x22\x4D\x18'
43
44
45def register_compressor(compressor_name, compressor,
46                        force=False):
47    """Register a new compressor.
48
49    Parameters
50    -----------
51    compressor_name: str.
52        The name of the compressor.
53    compressor: CompressorWrapper
54        An instance of a 'CompressorWrapper'.
55    """
56    global _COMPRESSORS
57    if not isinstance(compressor_name, str):
58        raise ValueError("Compressor name should be a string, "
59                         "'{}' given.".format(compressor_name))
60
61    if not isinstance(compressor, CompressorWrapper):
62        raise ValueError("Compressor should implement the CompressorWrapper "
63                         "interface, '{}' given.".format(compressor))
64
65    if (compressor.fileobj_factory is not None and
66            (not hasattr(compressor.fileobj_factory, 'read') or
67             not hasattr(compressor.fileobj_factory, 'write') or
68             not hasattr(compressor.fileobj_factory, 'seek') or
69             not hasattr(compressor.fileobj_factory, 'tell'))):
70        raise ValueError("Compressor 'fileobj_factory' attribute should "
71                         "implement the file object interface, '{}' given."
72                         .format(compressor.fileobj_factory))
73
74    if compressor_name in _COMPRESSORS and not force:
75        raise ValueError("Compressor '{}' already registered."
76                         .format(compressor_name))
77
78    _COMPRESSORS[compressor_name] = compressor
79
80
81class CompressorWrapper():
82    """A wrapper around a compressor file object.
83
84    Attributes
85    ----------
86    obj: a file-like object
87        The object must implement the buffer interface and will be used
88        internally to compress/decompress the data.
89    prefix: bytestring
90        A bytestring corresponding to the magic number that identifies the
91        file format associated to the compressor.
92    extention: str
93        The file extension used to automatically select this compressor during
94        a dump to a file.
95    """
96
97    def __init__(self, obj, prefix=b'', extension=''):
98        self.fileobj_factory = obj
99        self.prefix = prefix
100        self.extension = extension
101
102    def compressor_file(self, fileobj, compresslevel=None):
103        """Returns an instance of a compressor file object."""
104        if compresslevel is None:
105            return self.fileobj_factory(fileobj, 'wb')
106        else:
107            return self.fileobj_factory(fileobj, 'wb',
108                                        compresslevel=compresslevel)
109
110    def decompressor_file(self, fileobj):
111        """Returns an instance of a decompressor file object."""
112        return self.fileobj_factory(fileobj, 'rb')
113
114
115class BZ2CompressorWrapper(CompressorWrapper):
116
117    prefix = _BZ2_PREFIX
118    extension = '.bz2'
119
120    def __init__(self):
121        if bz2 is not None:
122            self.fileobj_factory = bz2.BZ2File
123        else:
124            self.fileobj_factory = None
125
126    def _check_versions(self):
127        if bz2 is None:
128            raise ValueError('bz2 module is not compiled on your python '
129                             'standard library.')
130
131    def compressor_file(self, fileobj, compresslevel=None):
132        """Returns an instance of a compressor file object."""
133        self._check_versions()
134        if compresslevel is None:
135            return self.fileobj_factory(fileobj, 'wb')
136        else:
137            return self.fileobj_factory(fileobj, 'wb',
138                                        compresslevel=compresslevel)
139
140    def decompressor_file(self, fileobj):
141        """Returns an instance of a decompressor file object."""
142        self._check_versions()
143        fileobj = self.fileobj_factory(fileobj, 'rb')
144        return fileobj
145
146
147class LZMACompressorWrapper(CompressorWrapper):
148
149    prefix = _LZMA_PREFIX
150    extension = '.lzma'
151    _lzma_format_name = 'FORMAT_ALONE'
152
153    def __init__(self):
154        if lzma is not None:
155            self.fileobj_factory = lzma.LZMAFile
156            self._lzma_format = getattr(lzma, self._lzma_format_name)
157        else:
158            self.fileobj_factory = None
159
160    def _check_versions(self):
161        if lzma is None:
162            raise ValueError('lzma module is not compiled on your python '
163                             'standard library.')
164
165    def compressor_file(self, fileobj, compresslevel=None):
166        """Returns an instance of a compressor file object."""
167        if compresslevel is None:
168            return self.fileobj_factory(fileobj, 'wb',
169                                        format=self._lzma_format)
170        else:
171            return self.fileobj_factory(fileobj, 'wb',
172                                        format=self._lzma_format,
173                                        preset=compresslevel)
174
175    def decompressor_file(self, fileobj):
176        """Returns an instance of a decompressor file object."""
177        return lzma.LZMAFile(fileobj, 'rb')
178
179
180class XZCompressorWrapper(LZMACompressorWrapper):
181
182    prefix = _XZ_PREFIX
183    extension = '.xz'
184    _lzma_format_name = 'FORMAT_XZ'
185
186
187class LZ4CompressorWrapper(CompressorWrapper):
188
189    prefix = _LZ4_PREFIX
190    extension = '.lz4'
191
192    def __init__(self):
193        if lz4 is not None:
194            self.fileobj_factory = LZ4FrameFile
195        else:
196            self.fileobj_factory = None
197
198    def _check_versions(self):
199        if lz4 is None:
200            raise ValueError(LZ4_NOT_INSTALLED_ERROR)
201        lz4_version = lz4.__version__
202        if lz4_version.startswith("v"):
203            lz4_version = lz4_version[1:]
204        if LooseVersion(lz4_version) < LooseVersion('0.19'):
205            raise ValueError(LZ4_NOT_INSTALLED_ERROR)
206
207    def compressor_file(self, fileobj, compresslevel=None):
208        """Returns an instance of a compressor file object."""
209        self._check_versions()
210        if compresslevel is None:
211            return self.fileobj_factory(fileobj, 'wb')
212        else:
213            return self.fileobj_factory(fileobj, 'wb',
214                                        compression_level=compresslevel)
215
216    def decompressor_file(self, fileobj):
217        """Returns an instance of a decompressor file object."""
218        self._check_versions()
219        return self.fileobj_factory(fileobj, 'rb')
220
221
222###############################################################################
223#  base file compression/decompression object definition
224_MODE_CLOSED = 0
225_MODE_READ = 1
226_MODE_READ_EOF = 2
227_MODE_WRITE = 3
228_BUFFER_SIZE = 8192
229
230
231class BinaryZlibFile(io.BufferedIOBase):
232    """A file object providing transparent zlib (de)compression.
233
234    TODO python2_drop: is it still needed since we dropped Python 2 support A
235    BinaryZlibFile can act as a wrapper for an existing file object, or refer
236    directly to a named file on disk.
237
238    Note that BinaryZlibFile provides only a *binary* file interface: data read
239    is returned as bytes, and data to be written should be given as bytes.
240
241    This object is an adaptation of the BZ2File object and is compatible with
242    versions of python >= 2.7.
243
244    If filename is a str or bytes object, it gives the name
245    of the file to be opened. Otherwise, it should be a file object,
246    which will be used to read or write the compressed data.
247
248    mode can be 'rb' for reading (default) or 'wb' for (over)writing
249
250    If mode is 'wb', compresslevel can be a number between 1
251    and 9 specifying the level of compression: 1 produces the least
252    compression, and 9 produces the most compression. 3 is the default.
253    """
254
255    wbits = zlib.MAX_WBITS
256
257    def __init__(self, filename, mode="rb", compresslevel=3):
258        # This lock must be recursive, so that BufferedIOBase's
259        # readline(), readlines() and writelines() don't deadlock.
260        self._lock = RLock()
261        self._fp = None
262        self._closefp = False
263        self._mode = _MODE_CLOSED
264        self._pos = 0
265        self._size = -1
266        self.compresslevel = compresslevel
267
268        if not isinstance(compresslevel, int) or not (1 <= compresslevel <= 9):
269            raise ValueError("'compresslevel' must be an integer "
270                             "between 1 and 9. You provided 'compresslevel={}'"
271                             .format(compresslevel))
272
273        if mode == "rb":
274            self._mode = _MODE_READ
275            self._decompressor = zlib.decompressobj(self.wbits)
276            self._buffer = b""
277            self._buffer_offset = 0
278        elif mode == "wb":
279            self._mode = _MODE_WRITE
280            self._compressor = zlib.compressobj(self.compresslevel,
281                                                zlib.DEFLATED, self.wbits,
282                                                zlib.DEF_MEM_LEVEL, 0)
283        else:
284            raise ValueError("Invalid mode: %r" % (mode,))
285
286        if isinstance(filename, str):
287            self._fp = io.open(filename, mode)
288            self._closefp = True
289        elif hasattr(filename, "read") or hasattr(filename, "write"):
290            self._fp = filename
291        else:
292            raise TypeError("filename must be a str or bytes object, "
293                            "or a file")
294
295    def close(self):
296        """Flush and close the file.
297
298        May be called more than once without error. Once the file is
299        closed, any other operation on it will raise a ValueError.
300        """
301        with self._lock:
302            if self._mode == _MODE_CLOSED:
303                return
304            try:
305                if self._mode in (_MODE_READ, _MODE_READ_EOF):
306                    self._decompressor = None
307                elif self._mode == _MODE_WRITE:
308                    self._fp.write(self._compressor.flush())
309                    self._compressor = None
310            finally:
311                try:
312                    if self._closefp:
313                        self._fp.close()
314                finally:
315                    self._fp = None
316                    self._closefp = False
317                    self._mode = _MODE_CLOSED
318                    self._buffer = b""
319                    self._buffer_offset = 0
320
321    @property
322    def closed(self):
323        """True if this file is closed."""
324        return self._mode == _MODE_CLOSED
325
326    def fileno(self):
327        """Return the file descriptor for the underlying file."""
328        self._check_not_closed()
329        return self._fp.fileno()
330
331    def seekable(self):
332        """Return whether the file supports seeking."""
333        return self.readable() and self._fp.seekable()
334
335    def readable(self):
336        """Return whether the file was opened for reading."""
337        self._check_not_closed()
338        return self._mode in (_MODE_READ, _MODE_READ_EOF)
339
340    def writable(self):
341        """Return whether the file was opened for writing."""
342        self._check_not_closed()
343        return self._mode == _MODE_WRITE
344
345    # Mode-checking helper functions.
346
347    def _check_not_closed(self):
348        if self.closed:
349            fname = getattr(self._fp, 'name', None)
350            msg = "I/O operation on closed file"
351            if fname is not None:
352                msg += " {}".format(fname)
353            msg += "."
354            raise ValueError(msg)
355
356    def _check_can_read(self):
357        if self._mode not in (_MODE_READ, _MODE_READ_EOF):
358            self._check_not_closed()
359            raise io.UnsupportedOperation("File not open for reading")
360
361    def _check_can_write(self):
362        if self._mode != _MODE_WRITE:
363            self._check_not_closed()
364            raise io.UnsupportedOperation("File not open for writing")
365
366    def _check_can_seek(self):
367        if self._mode not in (_MODE_READ, _MODE_READ_EOF):
368            self._check_not_closed()
369            raise io.UnsupportedOperation("Seeking is only supported "
370                                          "on files open for reading")
371        if not self._fp.seekable():
372            raise io.UnsupportedOperation("The underlying file object "
373                                          "does not support seeking")
374
375    # Fill the readahead buffer if it is empty. Returns False on EOF.
376    def _fill_buffer(self):
377        if self._mode == _MODE_READ_EOF:
378            return False
379        # Depending on the input data, our call to the decompressor may not
380        # return any data. In this case, try again after reading another block.
381        while self._buffer_offset == len(self._buffer):
382            try:
383                rawblock = (self._decompressor.unused_data or
384                            self._fp.read(_BUFFER_SIZE))
385                if not rawblock:
386                    raise EOFError
387            except EOFError:
388                # End-of-stream marker and end of file. We're good.
389                self._mode = _MODE_READ_EOF
390                self._size = self._pos
391                return False
392            else:
393                self._buffer = self._decompressor.decompress(rawblock)
394            self._buffer_offset = 0
395        return True
396
397    # Read data until EOF.
398    # If return_data is false, consume the data without returning it.
399    def _read_all(self, return_data=True):
400        # The loop assumes that _buffer_offset is 0. Ensure that this is true.
401        self._buffer = self._buffer[self._buffer_offset:]
402        self._buffer_offset = 0
403
404        blocks = []
405        while self._fill_buffer():
406            if return_data:
407                blocks.append(self._buffer)
408            self._pos += len(self._buffer)
409            self._buffer = b""
410        if return_data:
411            return b"".join(blocks)
412
413    # Read a block of up to n bytes.
414    # If return_data is false, consume the data without returning it.
415    def _read_block(self, n_bytes, return_data=True):
416        # If we have enough data buffered, return immediately.
417        end = self._buffer_offset + n_bytes
418        if end <= len(self._buffer):
419            data = self._buffer[self._buffer_offset: end]
420            self._buffer_offset = end
421            self._pos += len(data)
422            return data if return_data else None
423
424        # The loop assumes that _buffer_offset is 0. Ensure that this is true.
425        self._buffer = self._buffer[self._buffer_offset:]
426        self._buffer_offset = 0
427
428        blocks = []
429        while n_bytes > 0 and self._fill_buffer():
430            if n_bytes < len(self._buffer):
431                data = self._buffer[:n_bytes]
432                self._buffer_offset = n_bytes
433            else:
434                data = self._buffer
435                self._buffer = b""
436            if return_data:
437                blocks.append(data)
438            self._pos += len(data)
439            n_bytes -= len(data)
440        if return_data:
441            return b"".join(blocks)
442
443    def read(self, size=-1):
444        """Read up to size uncompressed bytes from the file.
445
446        If size is negative or omitted, read until EOF is reached.
447        Returns b'' if the file is already at EOF.
448        """
449        with self._lock:
450            self._check_can_read()
451            if size == 0:
452                return b""
453            elif size < 0:
454                return self._read_all()
455            else:
456                return self._read_block(size)
457
458    def readinto(self, b):
459        """Read up to len(b) bytes into b.
460
461        Returns the number of bytes read (0 for EOF).
462        """
463        with self._lock:
464            return io.BufferedIOBase.readinto(self, b)
465
466    def write(self, data):
467        """Write a byte string to the file.
468
469        Returns the number of uncompressed bytes written, which is
470        always len(data). Note that due to buffering, the file on disk
471        may not reflect the data written until close() is called.
472        """
473        with self._lock:
474            self._check_can_write()
475            # Convert data type if called by io.BufferedWriter.
476            if isinstance(data, memoryview):
477                data = data.tobytes()
478
479            compressed = self._compressor.compress(data)
480            self._fp.write(compressed)
481            self._pos += len(data)
482            return len(data)
483
484    # Rewind the file to the beginning of the data stream.
485    def _rewind(self):
486        self._fp.seek(0, 0)
487        self._mode = _MODE_READ
488        self._pos = 0
489        self._decompressor = zlib.decompressobj(self.wbits)
490        self._buffer = b""
491        self._buffer_offset = 0
492
493    def seek(self, offset, whence=0):
494        """Change the file position.
495
496        The new position is specified by offset, relative to the
497        position indicated by whence. Values for whence are:
498
499            0: start of stream (default); offset must not be negative
500            1: current stream position
501            2: end of stream; offset must not be positive
502
503        Returns the new file position.
504
505        Note that seeking is emulated, so depending on the parameters,
506        this operation may be extremely slow.
507        """
508        with self._lock:
509            self._check_can_seek()
510
511            # Recalculate offset as an absolute file position.
512            if whence == 0:
513                pass
514            elif whence == 1:
515                offset = self._pos + offset
516            elif whence == 2:
517                # Seeking relative to EOF - we need to know the file's size.
518                if self._size < 0:
519                    self._read_all(return_data=False)
520                offset = self._size + offset
521            else:
522                raise ValueError("Invalid value for whence: %s" % (whence,))
523
524            # Make it so that offset is the number of bytes to skip forward.
525            if offset < self._pos:
526                self._rewind()
527            else:
528                offset -= self._pos
529
530            # Read and discard data until we reach the desired position.
531            self._read_block(offset, return_data=False)
532
533            return self._pos
534
535    def tell(self):
536        """Return the current file position."""
537        with self._lock:
538            self._check_not_closed()
539            return self._pos
540
541
542class ZlibCompressorWrapper(CompressorWrapper):
543
544    def __init__(self):
545        CompressorWrapper.__init__(self, obj=BinaryZlibFile,
546                                   prefix=_ZLIB_PREFIX, extension='.z')
547
548
549class BinaryGzipFile(BinaryZlibFile):
550    """A file object providing transparent gzip (de)compression.
551
552    If filename is a str or bytes object, it gives the name
553    of the file to be opened. Otherwise, it should be a file object,
554    which will be used to read or write the compressed data.
555
556    mode can be 'rb' for reading (default) or 'wb' for (over)writing
557
558    If mode is 'wb', compresslevel can be a number between 1
559    and 9 specifying the level of compression: 1 produces the least
560    compression, and 9 produces the most compression. 3 is the default.
561    """
562
563    wbits = 31  # zlib compressor/decompressor wbits value for gzip format.
564
565
566class GzipCompressorWrapper(CompressorWrapper):
567
568    def __init__(self):
569        CompressorWrapper.__init__(self, obj=BinaryGzipFile,
570                                   prefix=_GZIP_PREFIX, extension='.gz')
571