1import re
2import struct
3from datetime import datetime, timedelta
4from io import BytesIO
5
6from .compat import timezone, xrange, byte_as_integer, unpack_float16
7from .types import CBORTag, undefined, break_marker, CBORSimpleValue, FrozenDict
8
9timestamp_re = re.compile(r'^(\d{4})-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)'
10                          r'(?:\.(\d+))?(?:Z|([+-]\d\d):(\d\d))$')
11
12
13class CBORDecodeError(Exception):
14    """Raised when an error occurs deserializing a CBOR datastream."""
15
16
17def decode_uint(decoder, subtype, shareable_index=None, allow_indefinite=False):
18    # Major tag 0
19    if subtype < 24:
20        return subtype
21    elif subtype == 24:
22        return struct.unpack('>B', decoder.read(1))[0]
23    elif subtype == 25:
24        return struct.unpack('>H', decoder.read(2))[0]
25    elif subtype == 26:
26        return struct.unpack('>L', decoder.read(4))[0]
27    elif subtype == 27:
28        return struct.unpack('>Q', decoder.read(8))[0]
29    elif subtype == 31 and allow_indefinite:
30        return None
31    else:
32        raise CBORDecodeError('unknown unsigned integer subtype 0x%x' % subtype)
33
34
35def decode_negint(decoder, subtype, shareable_index=None):
36    # Major tag 1
37    uint = decode_uint(decoder, subtype)
38    return -uint - 1
39
40
41def decode_bytestring(decoder, subtype, shareable_index=None):
42    # Major tag 2
43    length = decode_uint(decoder, subtype, allow_indefinite=True)
44    if length is None:
45        # Indefinite length
46        buf = []
47        while True:
48            initial_byte = byte_as_integer(decoder.read(1))
49            if initial_byte == 255:
50                return b''.join(buf)
51            else:
52                length = decode_uint(decoder, initial_byte & 31)
53                value = decoder.read(length)
54                buf.append(value)
55    else:
56        return decoder.read(length)
57
58
59def decode_string(decoder, subtype, shareable_index=None):
60    # Major tag 3
61    return decode_bytestring(decoder, subtype).decode('utf-8')
62
63
64def decode_array(decoder, subtype, shareable_index=None):
65    # Major tag 4
66    items = []
67    decoder.set_shareable(shareable_index, items)
68    length = decode_uint(decoder, subtype, allow_indefinite=True)
69    if length is None:
70        # Indefinite length
71        while True:
72            value = decoder.decode()
73            if value is break_marker:
74                break
75            else:
76                items.append(value)
77    else:
78        for _ in xrange(length):
79            item = decoder.decode()
80            items.append(item)
81
82    if decoder.immutable:
83        return tuple(items)
84    else:
85        return items
86
87
88def decode_map(decoder, subtype, shareable_index=None):
89    # Major tag 5
90    dictionary = {}
91    decoder.set_shareable(shareable_index, dictionary)
92    length = decode_uint(decoder, subtype, allow_indefinite=True)
93    if length is None:
94        # Indefinite length
95        while True:
96            key_flag = decoder.immutable
97            decoder._immutable = True
98            key = decoder.decode()
99            decoder._immutable = key_flag
100            if key is break_marker:
101                break
102            else:
103                value = decoder.decode()
104                dictionary[key] = value
105    else:
106        for _ in xrange(length):
107            key_flag = decoder.immutable
108            decoder._immutable = True
109            key = decoder.decode()
110            decoder._immutable = key_flag
111            value = decoder.decode()
112            dictionary[key] = value
113
114    if decoder.object_hook:
115        return decoder.object_hook(decoder, dictionary)
116    elif decoder.immutable:
117        return FrozenDict(dictionary)
118    else:
119        return dictionary
120
121
122def decode_semantic(decoder, subtype, shareable_index=None):
123    # Major tag 6
124    tagnum = decode_uint(decoder, subtype)
125
126    # Special handling for the "shareable" tag
127    if tagnum == 28:
128        shareable_index = decoder._allocate_shareable()
129        return decoder.decode(shareable_index)
130
131    # Special handling for sets
132    if tagnum == 258:
133        key_flag = decoder.immutable
134        decoder._immutable = True
135        value = decoder.decode()
136        decoder._immutable = key_flag
137    else:
138        value = decoder.decode()
139
140    semantic_decoder = semantic_decoders.get(tagnum)
141    if semantic_decoder:
142        return semantic_decoder(decoder, value, shareable_index)
143
144    tag = CBORTag(tagnum, value)
145    if decoder.tag_hook:
146        return decoder.tag_hook(decoder, tag, shareable_index)
147    else:
148        return tag
149
150
151def decode_special(decoder, subtype, shareable_index=None):
152    # Simple value
153    if subtype < 20:
154        return CBORSimpleValue(subtype)
155
156    # Major tag 7
157    return special_decoders[subtype](decoder)
158
159
160#
161# Semantic decoders (major tag 6)
162#
163
164def decode_datetime_string(decoder, value, shareable_index=None):
165    # Semantic tag 0
166    match = timestamp_re.match(value)
167    if match:
168        year, month, day, hour, minute, second, micro, offset_h, offset_m = match.groups()
169        if offset_h:
170            tz = timezone(timedelta(hours=int(offset_h), minutes=int(offset_m)))
171        else:
172            tz = timezone.utc
173
174        return datetime(int(year), int(month), int(day), int(hour), int(minute), int(second),
175                        int(micro or 0), tz)
176    else:
177        raise CBORDecodeError('invalid datetime string: {}'.format(value))
178
179
180def decode_epoch_datetime(decoder, value, shareable_index=None):
181    # Semantic tag 1
182    return datetime.fromtimestamp(value, timezone.utc)
183
184
185def decode_positive_bignum(decoder, value, shareable_index=None):
186    # Semantic tag 2
187    from binascii import hexlify
188    return int(hexlify(value), 16)
189
190
191def decode_negative_bignum(decoder, value, shareable_index=None):
192    # Semantic tag 3
193    return -decode_positive_bignum(decoder, value) - 1
194
195
196def decode_fraction(decoder, value, shareable_index=None):
197    # Semantic tag 4
198    from decimal import Decimal
199    exp = Decimal(value[0])
200    mantissa = Decimal(value[1])
201    return mantissa * (10 ** exp)
202
203
204def decode_bigfloat(decoder, value, shareable_index=None):
205    # Semantic tag 5
206    from decimal import Decimal
207    exp = Decimal(value[0])
208    mantissa = Decimal(value[1])
209    return mantissa * (2 ** exp)
210
211
212def decode_sharedref(decoder, value, shareable_index=None):
213    # Semantic tag 29
214    try:
215        shared = decoder._shareables[value]
216    except IndexError:
217        raise CBORDecodeError('shared reference %d not found' % value)
218
219    if shared is None:
220        raise CBORDecodeError('shared value %d has not been initialized' % value)
221    else:
222        return shared
223
224
225def decode_rational(decoder, value, shareable_index=None):
226    # Semantic tag 30
227    from fractions import Fraction
228    return Fraction(*value)
229
230
231def decode_regexp(decoder, value, shareable_index=None):
232    # Semantic tag 35
233    return re.compile(value)
234
235
236def decode_mime(decoder, value, shareable_index=None):
237    # Semantic tag 36
238    from email.parser import Parser
239    return Parser().parsestr(value)
240
241
242def decode_uuid(decoder, value, shareable_index=None):
243    # Semantic tag 37
244    from uuid import UUID
245    return UUID(bytes=value)
246
247
248def decode_set(decoder, value, shareable_index=None):
249    # Semantic tag 258
250    if decoder.immutable:
251        return frozenset(value)
252    else:
253        return set(value)
254
255
256#
257# Special decoders (major tag 7)
258#
259
260def decode_simple_value(decoder, shareable_index=None):
261    return CBORSimpleValue(struct.unpack('>B', decoder.read(1))[0])
262
263
264def decode_float16(decoder, shareable_index=None):
265    payload = decoder.read(2)
266    try:
267        value = struct.unpack('>e', payload)[0]
268    except struct.error:
269        value = unpack_float16(payload)
270    return value
271
272
273def decode_float32(decoder, shareable_index=None):
274    return struct.unpack('>f', decoder.read(4))[0]
275
276
277def decode_float64(decoder, shareable_index=None):
278    return struct.unpack('>d', decoder.read(8))[0]
279
280
281major_decoders = {
282    0: decode_uint,
283    1: decode_negint,
284    2: decode_bytestring,
285    3: decode_string,
286    4: decode_array,
287    5: decode_map,
288    6: decode_semantic,
289    7: decode_special
290}
291
292special_decoders = {
293    20: lambda self: False,
294    21: lambda self: True,
295    22: lambda self: None,
296    23: lambda self: undefined,
297    24: decode_simple_value,
298    25: decode_float16,
299    26: decode_float32,
300    27: decode_float64,
301    31: lambda self: break_marker
302}
303
304semantic_decoders = {
305    0: decode_datetime_string,
306    1: decode_epoch_datetime,
307    2: decode_positive_bignum,
308    3: decode_negative_bignum,
309    4: decode_fraction,
310    5: decode_bigfloat,
311    29: decode_sharedref,
312    30: decode_rational,
313    35: decode_regexp,
314    36: decode_mime,
315    37: decode_uuid,
316    258: decode_set
317}
318
319
320class CBORDecoder(object):
321    """
322    Deserializes a CBOR encoded byte stream.
323
324    :param tag_hook: Callable that takes 3 arguments: the decoder instance, the
325        :class:`~cbor2.types.CBORTag` and the shareable index for the resulting object, if any.
326        This callback is called for any tags for which there is no built-in decoder.
327        The return value is substituted for the CBORTag object in the deserialized output.
328    :param object_hook: Callable that takes 2 arguments: the decoder instance and the dictionary.
329        This callback is called for each deserialized :class:`dict` object.
330        The return value is substituted for the dict in the deserialized output.
331    """
332
333    __slots__ = ('fp', 'tag_hook', 'object_hook', '_shareables', '_immutable')
334
335    def __init__(self, fp, tag_hook=None, object_hook=None):
336        self.fp = fp
337        self.tag_hook = tag_hook
338        self.object_hook = object_hook
339        self._shareables = []
340        self._immutable = False
341
342    @property
343    def immutable(self):
344        """
345        Used by decoders to check if the calling context requires an immutable type.
346        Object_hook or tag_hook should raise an exception if this flag is set unless
347        the result can be safely used as a dict key.
348        """
349        return self._immutable
350
351    def _allocate_shareable(self):
352        self._shareables.append(None)
353        return len(self._shareables) - 1
354
355    def set_shareable(self, index, value):
356        """
357        Set the shareable value for the last encountered shared value marker, if any.
358
359        If the given index is ``None``, nothing is done.
360
361        :param index: the value of the ``shared_index`` argument to the decoder
362        :param value: the shared value
363
364        """
365        if index is not None:
366            self._shareables[index] = value
367
368    def read(self, amount):
369        """
370        Read bytes from the data stream.
371
372        :param int amount: the number of bytes to read
373
374        """
375        data = self.fp.read(amount)
376        if len(data) < amount:
377            raise CBORDecodeError('premature end of stream (expected to read {} bytes, got {} '
378                                  'instead)'.format(amount, len(data)))
379
380        return data
381
382    def decode(self, shareable_index=None):
383        """
384        Decode the next value from the stream.
385
386        :raises CBORDecodeError: if there is any problem decoding the stream
387
388        """
389        try:
390            initial_byte = byte_as_integer(self.fp.read(1))
391            major_type = initial_byte >> 5
392            subtype = initial_byte & 31
393        except Exception as e:
394            raise CBORDecodeError('error reading major type at index {}: {}'
395                                  .format(self.fp.tell(), e))
396
397        decoder = major_decoders[major_type]
398        try:
399            return decoder(self, subtype, shareable_index)
400        except CBORDecodeError:
401            raise
402        except Exception as e:
403            raise CBORDecodeError('error decoding value at index {}: {}'.format(self.fp.tell(), e))
404
405    def decode_from_bytes(self, buf):
406        """
407        Wrap the given bytestring as a file and call :meth:`decode` with it as the argument.
408
409        This method was intended to be used from the ``tag_hook`` hook when an object needs to be
410        decoded separately from the rest but while still taking advantage of the shared value
411        registry.
412
413        """
414        old_fp = self.fp
415        self.fp = BytesIO(buf)
416        retval = self.decode()
417        self.fp = old_fp
418        return retval
419
420
421def loads(payload, **kwargs):
422    """
423    Deserialize an object from a bytestring.
424
425    :param bytes payload: the bytestring to serialize
426    :param kwargs: keyword arguments passed to :class:`~.CBORDecoder`
427    :return: the deserialized object
428
429    """
430    fp = BytesIO(payload)
431    return CBORDecoder(fp, **kwargs).decode()
432
433
434def load(fp, **kwargs):
435    """
436    Deserialize an object from an open file.
437
438    :param fp: the input file (any file-like object)
439    :param kwargs: keyword arguments passed to :class:`~.CBORDecoder`
440    :return: the deserialized object
441
442    """
443    return CBORDecoder(fp, **kwargs).decode()
444