1import re
2import struct
3from datetime import datetime, timedelta
4from io import BytesIO
5
6from cbor2.compat import timezone, xrange, byte_as_integer
7from cbor2.types import CBORTag, undefined, break_marker, CBORSimpleValue
8
9timestamp_re = re.compile(r'^(\d{4})-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)'
10                          r'(?:\.(\d+))?(?:Z|([+-]\d\d):(\d\d))$')
11
12
13class CBORDecodeError(Exception):
14    """Raised when an error occurs deserializing a CBOR datastream."""
15
16
17def decode_uint(decoder, subtype, shareable_index=None, allow_infinite=False):
18    # Major tag 0
19    if subtype < 24:
20        return subtype
21    elif subtype == 24:
22        return struct.unpack('>B', decoder.read(1))[0]
23    elif subtype == 25:
24        return struct.unpack('>H', decoder.read(2))[0]
25    elif subtype == 26:
26        return struct.unpack('>L', decoder.read(4))[0]
27    elif subtype == 27:
28        return struct.unpack('>Q', decoder.read(8))[0]
29    elif subtype == 31 and allow_infinite:
30        return None
31    else:
32        raise CBORDecodeError('unknown unsigned integer subtype 0x%x' % subtype)
33
34
35def decode_negint(decoder, subtype, shareable_index=None):
36    # Major tag 1
37    uint = decode_uint(decoder, subtype)
38    return -uint - 1
39
40
41def decode_bytestring(decoder, subtype, shareable_index=None):
42    # Major tag 2
43    length = decode_uint(decoder, subtype, allow_infinite=True)
44    if length is None:
45        # Indefinite length
46        buf = bytearray()
47        while True:
48            initial_byte = byte_as_integer(decoder.read(1))
49            if initial_byte == 255:
50                return buf
51            else:
52                length = decode_uint(decoder, initial_byte & 31)
53                value = decoder.read(length)
54                buf.extend(value)
55    else:
56        return decoder.read(length)
57
58
59def decode_string(decoder, subtype, shareable_index=None):
60    # Major tag 3
61    return decode_bytestring(decoder, subtype).decode('utf-8')
62
63
64def decode_array(decoder, subtype, shareable_index=None):
65    # Major tag 4
66    items = []
67    decoder.set_shareable(shareable_index, items)
68    length = decode_uint(decoder, subtype, allow_infinite=True)
69    if length is None:
70        # Indefinite length
71        while True:
72            value = decoder.decode()
73            if value is break_marker:
74                break
75            else:
76                items.append(value)
77    else:
78        for _ in xrange(length):
79            item = decoder.decode()
80            items.append(item)
81
82    return items
83
84
85def decode_map(decoder, subtype, shareable_index=None):
86    # Major tag 5
87    dictionary = {}
88    decoder.set_shareable(shareable_index, dictionary)
89    length = decode_uint(decoder, subtype, allow_infinite=True)
90    if length is None:
91        # Indefinite length
92        while True:
93            key = decoder.decode()
94            if key is break_marker:
95                break
96            else:
97                value = decoder.decode()
98                dictionary[key] = value
99    else:
100        for _ in xrange(length):
101            key = decoder.decode()
102            value = decoder.decode()
103            dictionary[key] = value
104
105    if decoder.object_hook:
106        return decoder.object_hook(decoder, dictionary)
107    else:
108        return dictionary
109
110
111def decode_semantic(decoder, subtype, shareable_index=None):
112    # Major tag 6
113    tagnum = decode_uint(decoder, subtype)
114
115    # Special handling for the "shareable" tag
116    if tagnum == 28:
117        shareable_index = decoder._allocate_shareable()
118        return decoder.decode(shareable_index)
119
120    value = decoder.decode()
121    semantic_decoder = semantic_decoders.get(tagnum)
122    if semantic_decoder:
123        return semantic_decoder(decoder, value, shareable_index)
124
125    tag = CBORTag(tagnum, value)
126    if decoder.tag_hook:
127        return decoder.tag_hook(decoder, tag, shareable_index)
128    else:
129        return tag
130
131
132def decode_special(decoder, subtype, shareable_index=None):
133    # Simple value
134    if subtype < 20:
135        return CBORSimpleValue(subtype)
136
137    # Major tag 7
138    return special_decoders[subtype](decoder)
139
140
141#
142# Semantic decoders (major tag 6)
143#
144
145def decode_datetime_string(decoder, value, shareable_index=None):
146    # Semantic tag 0
147    match = timestamp_re.match(value)
148    if match:
149        year, month, day, hour, minute, second, micro, offset_h, offset_m = match.groups()
150        if offset_h:
151            tz = timezone(timedelta(hours=int(offset_h), minutes=int(offset_m)))
152        else:
153            tz = timezone.utc
154
155        return datetime(int(year), int(month), int(day), int(hour), int(minute), int(second),
156                        int(micro or 0), tz)
157    else:
158        raise CBORDecodeError('invalid datetime string: {}'.format(value))
159
160
161def decode_epoch_datetime(decoder, value, shareable_index=None):
162    # Semantic tag 1
163    return datetime.fromtimestamp(value, timezone.utc)
164
165
166def decode_positive_bignum(decoder, value, shareable_index=None):
167    # Semantic tag 2
168    from binascii import hexlify
169    return int(hexlify(value), 16)
170
171
172def decode_negative_bignum(decoder, value, shareable_index=None):
173    # Semantic tag 3
174    return -decode_positive_bignum(decoder, value) - 1
175
176
177def decode_fraction(decoder, value, shareable_index=None):
178    # Semantic tag 4
179    from decimal import Decimal
180    exp = Decimal(value[0])
181    mantissa = Decimal(value[1])
182    return mantissa * (10 ** exp)
183
184
185def decode_bigfloat(decoder, value, shareable_index=None):
186    # Semantic tag 5
187    from decimal import Decimal
188    exp = Decimal(value[0])
189    mantissa = Decimal(value[1])
190    return mantissa * (2 ** exp)
191
192
193def decode_sharedref(decoder, value, shareable_index=None):
194    # Semantic tag 29
195    try:
196        shared = decoder._shareables[value]
197    except IndexError:
198        raise CBORDecodeError('shared reference %d not found' % value)
199
200    if shared is None:
201        raise CBORDecodeError('shared value %d has not been initialized' % value)
202    else:
203        return shared
204
205
206def decode_rational(decoder, value, shareable_index=None):
207    # Semantic tag 30
208    from fractions import Fraction
209    return Fraction(*value)
210
211
212def decode_regexp(decoder, value, shareable_index=None):
213    # Semantic tag 35
214    return re.compile(value)
215
216
217def decode_mime(decoder, value, shareable_index=None):
218    # Semantic tag 36
219    from email.parser import Parser
220    return Parser().parsestr(value)
221
222
223def decode_uuid(decoder, value, shareable_index=None):
224    # Semantic tag 37
225    from uuid import UUID
226    return UUID(bytes=value)
227
228
229#
230# Special decoders (major tag 7)
231#
232
233def decode_simple_value(decoder, shareable_index=None):
234    return CBORSimpleValue(struct.unpack('>B', decoder.read(1))[0])
235
236
237def decode_float16(decoder, shareable_index=None):
238    # Code adapted from RFC 7049, appendix D
239    from math import ldexp
240
241    def decode_single(single):
242        return struct.unpack("!f", struct.pack("!I", single))[0]
243
244    payload = struct.unpack('>H', decoder.read(2))[0]
245    value = (payload & 0x7fff) << 13 | (payload & 0x8000) << 16
246    if payload & 0x7c00 != 0x7c00:
247        return ldexp(decode_single(value), 112)
248
249    return decode_single(value | 0x7f800000)
250
251
252def decode_float32(decoder, shareable_index=None):
253    return struct.unpack('>f', decoder.read(4))[0]
254
255
256def decode_float64(decoder, shareable_index=None):
257    return struct.unpack('>d', decoder.read(8))[0]
258
259
260major_decoders = {
261    0: decode_uint,
262    1: decode_negint,
263    2: decode_bytestring,
264    3: decode_string,
265    4: decode_array,
266    5: decode_map,
267    6: decode_semantic,
268    7: decode_special
269}
270
271special_decoders = {
272    20: lambda self: False,
273    21: lambda self: True,
274    22: lambda self: None,
275    23: lambda self: undefined,
276    24: decode_simple_value,
277    25: decode_float16,
278    26: decode_float32,
279    27: decode_float64,
280    31: lambda self: break_marker
281}
282
283semantic_decoders = {
284    0: decode_datetime_string,
285    1: decode_epoch_datetime,
286    2: decode_positive_bignum,
287    3: decode_negative_bignum,
288    4: decode_fraction,
289    5: decode_bigfloat,
290    29: decode_sharedref,
291    30: decode_rational,
292    35: decode_regexp,
293    36: decode_mime,
294    37: decode_uuid
295}
296
297
298class CBORDecoder(object):
299    """
300    Deserializes a CBOR encoded byte stream.
301
302    :param tag_hook: Callable that takes 3 arguments: the decoder instance, the
303        :class:`~cbor2.types.CBORTag` and the shareable index for the resulting object, if any.
304        This callback is called for any tags for which there is no built-in decoder.
305        The return value is substituted for the CBORTag object in the deserialized output.
306    :param object_hook: Callable that takes 2 arguments: the decoder instance and the dictionary.
307        This callback is called for each deserialized :class:`dict` object.
308        The return value is substituted for the dict in the deserialized output.
309    """
310
311    __slots__ = ('fp', 'tag_hook', 'object_hook', '_shareables')
312
313    def __init__(self, fp, tag_hook=None, object_hook=None):
314        self.fp = fp
315        self.tag_hook = tag_hook
316        self.object_hook = object_hook
317        self._shareables = []
318
319    def _allocate_shareable(self):
320        self._shareables.append(None)
321        return len(self._shareables) - 1
322
323    def set_shareable(self, index, value):
324        """
325        Set the shareable value for the last encountered shared value marker, if any.
326
327        If the given index is ``None``, nothing is done.
328
329        :param index: the value of the ``shared_index`` argument to the decoder
330        :param value: the shared value
331
332        """
333        if index is not None:
334            self._shareables[index] = value
335
336    def read(self, amount):
337        """
338        Read bytes from the data stream.
339
340        :param int amount: the number of bytes to read
341
342        """
343        data = self.fp.read(amount)
344        if len(data) < amount:
345            raise CBORDecodeError('premature end of stream (expected to read {} bytes, got {} '
346                                  'instead)'.format(amount, len(data)))
347
348        return data
349
350    def decode(self, shareable_index=None):
351        """
352        Decode the next value from the stream.
353
354        :raises CBORDecodeError: if there is any problem decoding the stream
355
356        """
357        try:
358            initial_byte = byte_as_integer(self.fp.read(1))
359            major_type = initial_byte >> 5
360            subtype = initial_byte & 31
361        except Exception as e:
362            raise CBORDecodeError('error reading major type at index {}: {}'
363                                  .format(self.fp.tell(), e))
364
365        decoder = major_decoders[major_type]
366        try:
367            return decoder(self, subtype, shareable_index)
368        except CBORDecodeError:
369            raise
370        except Exception as e:
371            raise CBORDecodeError('error decoding value at index {}: {}'.format(self.fp.tell(), e))
372
373    def decode_from_bytes(self, buf):
374        """
375        Wrap the given bytestring as a file and call :meth:`decode` with it as the argument.
376
377        This method was intended to be used from the ``tag_hook`` hook when an object needs to be
378        decoded separately from the rest but while still taking advantage of the shared value
379        registry.
380
381        """
382        old_fp = self.fp
383        self.fp = BytesIO(buf)
384        retval = self.decode()
385        self.fp = old_fp
386        return retval
387
388
389def loads(payload, **kwargs):
390    """
391    Deserialize an object from a bytestring.
392
393    :param bytes payload: the bytestring to serialize
394    :param kwargs: keyword arguments passed to :class:`~.CBORDecoder`
395    :return: the deserialized object
396
397    """
398    fp = BytesIO(payload)
399    return CBORDecoder(fp, **kwargs).decode()
400
401
402def load(fp, **kwargs):
403    """
404    Deserialize an object from an open file.
405
406    :param fp: the input file (any file-like object)
407    :param kwargs: keyword arguments passed to :class:`~.CBORDecoder`
408    :return: the deserialized object
409
410    """
411    return CBORDecoder(fp, **kwargs).decode()
412