1"""Internal classes used by the gzip, lzma and bz2 modules"""
2
3import io
4
5
6BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE  # Compressed data read chunk size
7
8
9class BaseStream(io.BufferedIOBase):
10    """Mode-checking helper functions."""
11
12    def _check_not_closed(self):
13        if self.closed:
14            raise ValueError("I/O operation on closed file")
15
16    def _check_can_read(self):
17        if not self.readable():
18            raise io.UnsupportedOperation("File not open for reading")
19
20    def _check_can_write(self):
21        if not self.writable():
22            raise io.UnsupportedOperation("File not open for writing")
23
24    def _check_can_seek(self):
25        if not self.readable():
26            raise io.UnsupportedOperation("Seeking is only supported "
27                                          "on files open for reading")
28        if not self.seekable():
29            raise io.UnsupportedOperation("The underlying file object "
30                                          "does not support seeking")
31
32
33class DecompressReader(io.RawIOBase):
34    """Adapts the decompressor API to a RawIOBase reader API"""
35
36    def readable(self):
37        return True
38
39    def __init__(self, fp, decomp_factory, trailing_error=(), **decomp_args):
40        self._fp = fp
41        self._eof = False
42        self._pos = 0  # Current offset in decompressed stream
43
44        # Set to size of decompressed stream once it is known, for SEEK_END
45        self._size = -1
46
47        # Save the decompressor factory and arguments.
48        # If the file contains multiple compressed streams, each
49        # stream will need a separate decompressor object. A new decompressor
50        # object is also needed when implementing a backwards seek().
51        self._decomp_factory = decomp_factory
52        self._decomp_args = decomp_args
53        self._decompressor = self._decomp_factory(**self._decomp_args)
54
55        # Exception class to catch from decompressor signifying invalid
56        # trailing data to ignore
57        self._trailing_error = trailing_error
58
59    def close(self):
60        self._decompressor = None
61        return super().close()
62
63    def seekable(self):
64        return self._fp.seekable()
65
66    def readinto(self, b):
67        with memoryview(b) as view, view.cast("B") as byte_view:
68            data = self.read(len(byte_view))
69            byte_view[:len(data)] = data
70        return len(data)
71
72    def read(self, size=-1):
73        if size < 0:
74            return self.readall()
75
76        if not size or self._eof:
77            return b""
78        data = None  # Default if EOF is encountered
79        # Depending on the input data, our call to the decompressor may not
80        # return any data. In this case, try again after reading another block.
81        while True:
82            if self._decompressor.eof:
83                rawblock = (self._decompressor.unused_data or
84                            self._fp.read(BUFFER_SIZE))
85                if not rawblock:
86                    break
87                # Continue to next stream.
88                self._decompressor = self._decomp_factory(
89                    **self._decomp_args)
90                try:
91                    data = self._decompressor.decompress(rawblock, size)
92                except self._trailing_error:
93                    # Trailing data isn't a valid compressed stream; ignore it.
94                    break
95            else:
96                if self._decompressor.needs_input:
97                    rawblock = self._fp.read(BUFFER_SIZE)
98                    if not rawblock:
99                        raise EOFError("Compressed file ended before the "
100                                       "end-of-stream marker was reached")
101                else:
102                    rawblock = b""
103                data = self._decompressor.decompress(rawblock, size)
104            if data:
105                break
106        if not data:
107            self._eof = True
108            self._size = self._pos
109            return b""
110        self._pos += len(data)
111        return data
112
113    # Rewind the file to the beginning of the data stream.
114    def _rewind(self):
115        self._fp.seek(0)
116        self._eof = False
117        self._pos = 0
118        self._decompressor = self._decomp_factory(**self._decomp_args)
119
120    def seek(self, offset, whence=io.SEEK_SET):
121        # Recalculate offset as an absolute file position.
122        if whence == io.SEEK_SET:
123            pass
124        elif whence == io.SEEK_CUR:
125            offset = self._pos + offset
126        elif whence == io.SEEK_END:
127            # Seeking relative to EOF - we need to know the file's size.
128            if self._size < 0:
129                while self.read(io.DEFAULT_BUFFER_SIZE):
130                    pass
131            offset = self._size + offset
132        else:
133            raise ValueError("Invalid value for whence: {}".format(whence))
134
135        # Make it so that offset is the number of bytes to skip forward.
136        if offset < self._pos:
137            self._rewind()
138        else:
139            offset -= self._pos
140
141        # Read and discard data until we reach the desired position.
142        while offset > 0:
143            data = self.read(min(io.DEFAULT_BUFFER_SIZE, offset))
144            if not data:
145                break
146            offset -= len(data)
147
148        return self._pos
149
150    def tell(self):
151        """Return the current file position."""
152        return self._pos
153