1# Copyright (c) 2009,2016,2019 MetPy Developers.
2# Distributed under the terms of the BSD 3-Clause License.
3# SPDX-License-Identifier: BSD-3-Clause
4"""A collection of general purpose tools for reading files."""
5
6import bz2
7from collections import namedtuple
8import contextlib
9import gzip
10from io import BytesIO
11import logging
12from struct import Struct
13import zlib
14
15import numpy as np
16
17log = logging.getLogger(__name__)
18
19
20def open_as_needed(filename, mode='rb'):
21    """Return a file-object given either a filename or an object.
22
23    Handles opening with the right class based on the file extension.
24
25    """
26    # Handle file-like objects
27    if hasattr(filename, 'read'):
28        # See if the file object is really gzipped or bzipped.
29        lead = filename.read(4)
30
31        # If we can seek, seek back to start, otherwise read all the data into an
32        # in-memory file-like object.
33        if hasattr(filename, 'seek'):
34            filename.seek(0)
35        else:
36            filename = BytesIO(lead + filename.read())
37
38        # If the leading bytes match one of the signatures, pass into the appropriate class.
39        with contextlib.suppress(AttributeError):
40            lead = lead.encode('ascii')
41        if lead.startswith(b'\x1f\x8b'):
42            filename = gzip.GzipFile(fileobj=filename)
43        elif lead.startswith(b'BZh'):
44            filename = bz2.BZ2File(filename)
45
46        return filename
47
48    # This will convert pathlib.Path instances to strings
49    filename = str(filename)
50
51    if filename.endswith('.bz2'):
52        return bz2.BZ2File(filename, mode)
53    elif filename.endswith('.gz'):
54        return gzip.GzipFile(filename, mode)
55    else:
56        kwargs = {'errors': 'surrogateescape'} if mode != 'rb' else {}
57        return open(filename, mode, **kwargs)  # noqa: SIM115
58
59
60class NamedStruct(Struct):
61    """Parse bytes using :class:`Struct` but provide named fields."""
62
63    def __init__(self, info, prefmt='', tuple_name=None):
64        """Initialize the NamedStruct."""
65        if tuple_name is None:
66            tuple_name = 'NamedStruct'
67        names, fmts = zip(*info)
68        self.converters = {}
69        conv_off = 0
70        for ind, i in enumerate(info):
71            if len(i) > 2:
72                self.converters[ind - conv_off] = i[-1]
73            elif not i[0]:  # Skip items with no name
74                conv_off += 1
75        self._tuple = namedtuple(tuple_name, ' '.join(n for n in names if n))
76        super().__init__(prefmt + ''.join(f for f in fmts if f))
77
78    def _create(self, items):
79        if self.converters:
80            items = list(items)
81            for ind, conv in self.converters.items():
82                items[ind] = conv(items[ind])
83            if len(items) < len(self._tuple._fields):
84                items.extend([None] * (len(self._tuple._fields) - len(items)))
85        return self.make_tuple(*items)
86
87    def make_tuple(self, *args, **kwargs):
88        """Construct the underlying tuple from values."""
89        return self._tuple(*args, **kwargs)
90
91    def unpack(self, s):
92        """Parse bytes and return a namedtuple."""
93        return self._create(super().unpack(s))
94
95    def unpack_from(self, buff, offset=0):
96        """Read bytes from a buffer and return as a namedtuple."""
97        return self._create(super().unpack_from(buff, offset))
98
99    def unpack_file(self, fobj):
100        """Unpack the next bytes from a file object."""
101        return self.unpack(fobj.read(self.size))
102
103    def pack(self, **kwargs):
104        """Pack the arguments into bytes using the structure."""
105        t = self.make_tuple(**kwargs)
106        return super().pack(*t)
107
108
109# This works around times when we have more than 255 items and can't use
110# NamedStruct. This is a CPython limit for arguments.
111class DictStruct(Struct):
112    """Parse bytes using :class:`Struct` but provide named fields using dictionary access."""
113
114    def __init__(self, info, prefmt=''):
115        """Initialize the DictStruct."""
116        names, formats = zip(*info)
117
118        # Remove empty names
119        self._names = [n for n in names if n]
120
121        super().__init__(prefmt + ''.join(f for f in formats if f))
122
123    def _create(self, items):
124        return dict(zip(self._names, items))
125
126    def unpack(self, s):
127        """Parse bytes and return a dict."""
128        return self._create(super().unpack(s))
129
130    def unpack_from(self, buff, offset=0):
131        """Unpack the next bytes from a file object."""
132        return self._create(super().unpack_from(buff, offset))
133
134
135class Enum:
136    """Map values to specific strings."""
137
138    def __init__(self, *args, **kwargs):
139        """Initialize the mapping."""
140        # Assign values for args in order starting at 0
141        self.val_map = {ind: a for ind, a in enumerate(args)}
142
143        # Invert the kwargs dict so that we can map from value to name
144        self.val_map.update(zip(kwargs.values(), kwargs.keys()))
145
146    def __call__(self, val):
147        """Map an integer to the string representation."""
148        return self.val_map.get(val, f'Unknown ({val})')
149
150
151class Bits:
152    """Breaks an integer into a specified number of True/False bits."""
153
154    def __init__(self, num_bits):
155        """Initialize the number of bits."""
156        self._bits = range(num_bits)
157
158    def __call__(self, val):
159        """Convert the integer to the list of True/False values."""
160        return [bool((val >> i) & 0x1) for i in self._bits]
161
162
163class BitField:
164    """Convert an integer to a string for each bit."""
165
166    def __init__(self, *names):
167        """Initialize the list of named bits."""
168        self._names = names
169
170    def __call__(self, val):
171        """Return a list with a string for each True bit in the integer."""
172        if not val:
173            return None
174
175        bits = []
176        for n in self._names:
177            if val & 0x1:
178                bits.append(n)
179            val >>= 1
180            if not val:
181                break
182
183        # Return whole list if empty or multiple items, otherwise just single item
184        return bits[0] if len(bits) == 1 else bits
185
186
187class Array:
188    """Use a Struct as a callable to unpack a bunch of bytes as a list."""
189
190    def __init__(self, fmt):
191        """Initialize the Struct unpacker."""
192        self._struct = Struct(fmt)
193
194    def __call__(self, buf):
195        """Perform the actual unpacking."""
196        return list(self._struct.unpack(buf))
197
198
199class IOBuffer:
200    """Holds bytes from a buffer to simplify parsing and random access."""
201
202    def __init__(self, source):
203        """Initialize the IOBuffer with the source data."""
204        self._data = bytearray(source)
205        self.reset()
206
207    @classmethod
208    def fromfile(cls, fobj):
209        """Initialize the IOBuffer with the contents of the file object."""
210        return cls(fobj.read())
211
212    def reset(self):
213        """Reset buffer back to initial state."""
214        self._offset = 0
215        self.clear_marks()
216
217    def set_mark(self):
218        """Mark the current location and return its id so that the buffer can return later."""
219        self._bookmarks.append(self._offset)
220        return len(self._bookmarks) - 1
221
222    def jump_to(self, mark, offset=0):
223        """Jump to a previously set mark."""
224        self._offset = self._bookmarks[mark] + offset
225
226    def offset_from(self, mark):
227        """Calculate the current offset relative to a marked location."""
228        return self._offset - self._bookmarks[mark]
229
230    def clear_marks(self):
231        """Clear all marked locations."""
232        self._bookmarks = []
233
234    def splice(self, mark, newdata):
235        """Replace the data after the marked location with the specified data."""
236        self.jump_to(mark)
237        self._data = self._data[:self._offset] + bytearray(newdata)
238
239    def read_struct(self, struct_class):
240        """Parse and return a structure from the current buffer offset."""
241        struct = struct_class.unpack_from(memoryview(self._data), self._offset)
242        self.skip(struct_class.size)
243        return struct
244
245    def read_func(self, func, num_bytes=None):
246        """Parse data from the current buffer offset using a function."""
247        # only advance if func succeeds
248        res = func(self.get_next(num_bytes))
249        self.skip(num_bytes)
250        return res
251
252    def read_ascii(self, num_bytes=None):
253        """Return the specified bytes as ascii-formatted text."""
254        return self.read(num_bytes).decode('ascii')
255
256    def read_binary(self, num, item_type='B'):
257        """Parse the current buffer offset as the specified code."""
258        if 'B' in item_type:
259            return self.read(num)
260
261        if item_type[0] in ('@', '=', '<', '>', '!'):
262            order = item_type[0]
263            item_type = item_type[1:]
264        else:
265            order = '@'
266
267        return list(self.read_struct(Struct(order + f'{int(num):d}' + item_type)))
268
269    def read_int(self, size, endian, signed):
270        """Parse the current buffer offset as the specified integer code."""
271        return int.from_bytes(self.read(size), endian, signed=signed)
272
273    def read_array(self, count, dtype):
274        """Read an array of values from the buffer."""
275        ret = np.frombuffer(self._data, offset=self._offset, dtype=dtype, count=count)
276        self.skip(ret.nbytes)
277        return ret
278
279    def read(self, num_bytes=None):
280        """Read and return the specified bytes from the buffer."""
281        res = self.get_next(num_bytes)
282        self.skip(len(res))
283        return res
284
285    def get_next(self, num_bytes=None):
286        """Get the next bytes in the buffer without modifying the offset."""
287        if num_bytes is None:
288            return self._data[self._offset:]
289        else:
290            return self._data[self._offset:self._offset + num_bytes]
291
292    def skip(self, num_bytes):
293        """Jump the ahead the specified bytes in the buffer."""
294        if num_bytes is None:
295            self._offset = len(self._data)
296        else:
297            self._offset += num_bytes
298
299    def check_remains(self, num_bytes):
300        """Check that the number of bytes specified remains in the buffer."""
301        return len(self._data[self._offset:]) == num_bytes
302
303    def truncate(self, num_bytes):
304        """Remove the specified number of bytes from the end of the buffer."""
305        self._data = self._data[:-num_bytes]
306
307    def at_end(self):
308        """Return whether the buffer has reached the end of data."""
309        return self._offset >= len(self._data)
310
311    def __getitem__(self, item):
312        """Return the data at the specified location."""
313        return self._data[item]
314
315    def __str__(self):
316        """Return a string representation of the IOBuffer."""
317        return f'Size: {len(self._data)} Offset: {self._offset}'
318
319    def __len__(self):
320        """Return the amount of data in the buffer."""
321        return len(self._data)
322
323
324def zlib_decompress_all_frames(data):
325    """Decompress all frames of zlib-compressed bytes.
326
327    Repeatedly tries to decompress `data` until all data are decompressed, or decompression
328    fails. This will skip over bytes that are not compressed with zlib.
329
330    Parameters
331    ----------
332    data : bytearray or bytes
333        Binary data compressed using zlib.
334
335    Returns
336    -------
337        bytearray
338            All decompressed bytes
339
340    """
341    frames = bytearray()
342    data = bytes(data)
343    while data:
344        decomp = zlib.decompressobj()
345        try:
346            frames += decomp.decompress(data)
347            data = decomp.unused_data
348            log.debug('Decompressed zlib frame. %d bytes remain.', len(data))
349        except zlib.error:
350            log.debug('Remaining %d bytes are not zlib compressed.', len(data))
351            frames.extend(data)
352            break
353    return frames
354
355
356def bits_to_code(val):
357    """Convert the number of bits to the proper code for unpacking."""
358    if val == 8:
359        return 'B'
360    elif val == 16:
361        return 'H'
362    else:
363        log.warning('Unsupported bit size: %s. Returning "B"', val)
364        return 'B'
365