1# Copyright: (c) 2020, Jordan Borean (@jborean93) <jborean93@gmail.com>
2# MIT License (see LICENSE or https://opensource.org/licenses/MIT)
3
4import collections
5import datetime
6import enum
7import struct
8import typing
9
10from spnego._text import to_bytes, to_text
11
12ASN1Value = collections.namedtuple('ASN1Value', ['tag_class', 'constructed', 'tag_number', 'b_data'])
13"""A representation of an ASN.1 TLV as a Python object.
14
15Defines the ASN.1 Type Length Value (TLV) values as separate objects for easier parsing. This is returned by
16:method:`unpack_asn1`.
17
18Attributes:
19    tag_class (TagClass): The tag class of the TLV.
20    constructed (bool): Whether the value is constructed or 0, 1, or more element encodings (True) or not (False).
21    tag_number (Union[TypeTagNumber, int]): The tag number of the value, can be a TypeTagNumber if the tag_class
22        is `universal` otherwise it's an explicit tag number value.
23    b_data (bytes): The raw byes of the TLV value.
24"""
25
26
27class TagClass(enum.IntEnum):
28    universal = 0
29    application = 1
30    context_specific = 2
31    private = 3
32
33    @classmethod
34    def native_labels(cls) -> typing.Dict["TagClass", str]:
35        return {
36            TagClass.universal: 'Universal',
37            TagClass.application: 'Application',
38            TagClass.context_specific: 'Context-specific',
39            TagClass.private: 'Private',
40        }
41
42
43class TypeTagNumber(enum.IntEnum):
44    end_of_content = 0
45    boolean = 1
46    integer = 2
47    bit_string = 3
48    octet_string = 4
49    null = 5
50    object_identifier = 6
51    object_descriptor = 7
52    external = 8
53    real = 9
54    enumerated = 10
55    embedded_pdv = 11
56    utf8_string = 12
57    relative_oid = 13
58    time = 14
59    reserved = 15
60    sequence = 16
61    sequence_of = 16
62    set = 17
63    set_of = 17
64    numeric_string = 18
65    printable_string = 19
66    t61_string = 20
67    videotex_string = 21
68    ia5_string = 22
69    utc_time = 23
70    generalized_time = 24
71    graphic_string = 25
72    visible_string = 26
73    general_string = 27
74    universal_string = 28
75    character_string = 29
76    bmp_string = 30
77    date = 31
78    time_of_day = 32
79    date_time = 33
80    duration = 34
81    oid_iri = 35
82    relative_oid_iri = 36
83
84    @classmethod
85    def native_labels(cls) -> typing.Dict[int, str]:
86        return {
87            TypeTagNumber.end_of_content: 'End-of-Content (EOC)',
88            TypeTagNumber.boolean: 'BOOLEAN',
89            TypeTagNumber.integer: 'INTEGER',
90            TypeTagNumber.bit_string: 'BIT STRING',
91            TypeTagNumber.octet_string: 'OCTET STRING',
92            TypeTagNumber.null: 'NULL',
93            TypeTagNumber.object_identifier: 'OBJECT IDENTIFIER',
94            TypeTagNumber.object_descriptor: 'Object Descriptor',
95            TypeTagNumber.external: 'EXTERNAL',
96            TypeTagNumber.real: 'REAL (float)',
97            TypeTagNumber.enumerated: 'ENUMERATED',
98            TypeTagNumber.embedded_pdv: 'EMBEDDED PDV',
99            TypeTagNumber.utf8_string: 'UTF8String',
100            TypeTagNumber.relative_oid: 'RELATIVE-OID',
101            TypeTagNumber.time: 'TIME',
102            TypeTagNumber.reserved: 'RESERVED',
103            TypeTagNumber.sequence: 'SEQUENCE or SEQUENCE OF',
104            TypeTagNumber.set: 'SET or SET OF',
105            TypeTagNumber.numeric_string: 'NumericString',
106            TypeTagNumber.printable_string: 'PrintableString',
107            TypeTagNumber.t61_string: 'T61String',
108            TypeTagNumber.videotex_string: 'VideotexString',
109            TypeTagNumber.ia5_string: 'IA5String',
110            TypeTagNumber.utc_time: 'UTCTime',
111            TypeTagNumber.generalized_time: 'GeneralizedTime',
112            TypeTagNumber.graphic_string: 'GraphicString',
113            TypeTagNumber.visible_string: 'VisibleString',
114            TypeTagNumber.general_string: 'GeneralString',
115            TypeTagNumber.universal_string: 'UniversalString',
116            TypeTagNumber.character_string: 'CHARACTER',
117            TypeTagNumber.bmp_string: 'BMPString',
118            TypeTagNumber.date: 'DATE',
119            TypeTagNumber.time_of_day: 'TIME-OF-DAY',
120            TypeTagNumber.date_time: 'DATE-TIME',
121            TypeTagNumber.duration: 'DURATION',
122            TypeTagNumber.oid_iri: 'OID-IRI',
123            TypeTagNumber.relative_oid_iri: 'RELATIVE-OID-IRI',
124        }
125
126
127def extract_asn1_tlv(
128    tlv: typing.Union[bytes, ASN1Value],
129    tag_class: TagClass,
130    tag_number: typing.Union[int, TypeTagNumber],
131) -> bytes:
132    """ Extract the bytes and validates the existing tag of an ASN.1 value. """
133    if isinstance(tlv, ASN1Value):
134        if tag_class == TagClass.universal:
135            label_name = TypeTagNumber.native_labels().get(tag_number, 'Unknown tag type')
136            msg = "Invalid ASN.1 %s tags, actual tag class %s and tag number %s" \
137                  % (label_name, tlv.tag_class, tlv.tag_number)
138
139        else:
140            msg = "Invalid ASN.1 tags, actual tag %s and number %s, expecting class %s and number %s" \
141                  % (tlv.tag_class, tlv.tag_number, tag_class, tag_number)
142
143        if tlv.tag_class != tag_class or tlv.tag_number != tag_number:
144            raise ValueError(msg)
145
146        return tlv.b_data
147
148    return tlv
149
150
151def get_sequence_value(
152    sequence: typing.Dict[int, ASN1Value],
153    tag: int,
154    structure_name: str,
155    field_name: typing.Optional[str] = None,
156    unpack_func: typing.Optional[typing.Callable[[typing.Union[bytes, ASN1Value]], typing.Any]] = None,
157) -> typing.Any:
158    """ Gets an optional tag entry in a tagged sequence will a further unpacking of the value. """
159    if tag not in sequence:
160        return
161
162    if not unpack_func:
163        return sequence[tag]
164
165    try:
166        return unpack_func(sequence[tag])
167    except ValueError as e:
168        where = '%s in %s' % (field_name, structure_name) if field_name else structure_name
169        raise ValueError("Failed unpacking %s: %s" % (where, str(e))) from e
170
171
172def pack_asn1(
173    tag_class: TagClass,
174    constructed: bool,
175    tag_number: typing.Union[TypeTagNumber, int],
176    b_data: bytes,
177) -> bytes:
178    """Pack the ASN.1 value into the ASN.1 bytes.
179
180    Will pack the raw bytes into an ASN.1 Type Length Value (TLV) value. A TLV is in the form:
181
182    | Identifier Octet(s) | Length Octet(s) | Data Octet(s) |
183
184    Args:
185        tag_class: The tag class of the data.
186        constructed: Whether the data is constructed (True), i.e. contains 0, 1, or more element encodings, or is
187            primitive (False).
188        tag_number: The type tag number if tag_class is universal else the explicit tag number of the TLV.
189        b_data: The encoded value to pack into the ASN.1 TLV.
190
191    Returns:
192        bytes: The ASN.1 value as raw bytes.
193    """
194    b_asn1_data = bytearray()
195
196    # ASN.1 Identifier octet is
197    #
198    # |             Octet 1             |  |              Octet 2              |
199    # | 8 | 7 |  6  | 5 | 4 | 3 | 2 | 1 |  |   8   | 7 | 6 | 5 | 4 | 3 | 2 | 1 |
200    # | Class | P/C | Tag Number (0-30) |  | More  | Tag number                |
201    #
202    # If Tag Number is >= 31 the first 5 bits are 1 and the 2nd octet is used to encode the length.
203    if tag_class < 0 or tag_class > 3:
204        raise ValueError("tag_class must be between 0 and 3")
205
206    identifier_octets = tag_class << 6
207    identifier_octets |= ((1 if constructed else 0) << 5)
208
209    if tag_number < 31:
210        identifier_octets |= tag_number
211        b_asn1_data.append(identifier_octets)
212    else:
213        # Set the first 5 bits of the first octet to 1 and encode the tag number in subsequent octets.
214        identifier_octets |= 31
215        b_asn1_data.append(identifier_octets)
216        b_asn1_data.extend(_pack_asn1_octet_number(tag_number))
217
218    # ASN.1 Length octet for DER encoding is always in the definite form. This form packs the lengths in the following
219    # octet structure:
220    #
221    # |                       Octet 1                       |  |            Octet n            |
222    # |     8     |  7  |  6  |  5  |  4  |  3  |  2  |  1  |  | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 |
223    # | Long form | Short = length, Long = num octets       |  | Big endian length for long    |
224    #
225    # Basically if the length < 127 it's encoded in the first octet, otherwise the first octet 7 bits indicates how
226    # many subsequent octets were used to encode the length.
227    length = len(b_data)
228    if length < 128:
229        b_asn1_data.append(length)
230    else:
231        length_octets = bytearray()
232        while length:
233            length_octets.append(length & 0b11111111)
234            length >>= 8
235
236        # Reverse the octets so the higher octets are first, add the initial length octet with the MSB set and add them
237        # all to the main ASN.1 byte array.
238        length_octets.reverse()
239        b_asn1_data.append(len(length_octets) | 0b10000000)
240        b_asn1_data.extend(length_octets)
241
242    return bytes(b_asn1_data) + b_data
243
244
245def pack_asn1_bit_string(
246    value: bytes,
247    tag: bool = True,
248) -> bytes:
249    # First octet is the number of unused bits in the last octet from the LSB.
250    b_data = b"\x00" + value
251    if tag:
252        b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.bit_string, b_data)
253
254    return b_data
255
256
257def pack_asn1_enumerated(
258    value: int,
259    tag: bool = True,
260) -> bytes:
261    """ Packs an int into an ASN.1 ENUMERATED byte value with optional universal tagging. """
262    b_data = pack_asn1_integer(value, tag=False)
263    if tag:
264        b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.enumerated, b_data)
265
266    return b_data
267
268
269def pack_asn1_general_string(
270    value: typing.Union[str, bytes],
271    tag: bool = True,
272    encoding: str = 'ascii',
273) -> bytes:
274    """ Packs an string value into an ASN.1 GeneralString byte value with optional universal tagging. """
275    b_data = to_bytes(value, encoding=encoding)
276    if tag:
277        b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.general_string, b_data)
278
279    return b_data
280
281
282def pack_asn1_integer(
283    value: int,
284    tag: bool = True,
285) -> bytes:
286    """ Packs an int value into an ASN.1 INTEGER byte value with optional universal tagging. """
287    # Thanks to https://github.com/andrivet/python-asn1 for help with the negative value logic.
288    is_negative = False
289    limit = 0x7f
290    if value < 0:
291        value = -value
292        is_negative = True
293        limit = 0x80
294
295    b_int = bytearray()
296    while value > limit:
297        val = value & 0xFF
298
299        if is_negative:
300            val = 0xFF - val
301
302        b_int.append(val)
303        value >>= 8
304
305    b_int.append(((0xFF - value) if is_negative else value) & 0xFF)
306
307    if is_negative:
308        for idx, val in enumerate(b_int):
309            if val < 0xFF:
310                b_int[idx] += 1
311                break
312
313            b_int[idx] = 0
314
315    if is_negative and b_int[-1] == 0x7F:  # Two's complement corner case
316        b_int.append(0xFF)
317
318    b_int.reverse()
319
320    b_value = bytes(b_int)
321    if tag:
322        b_value = pack_asn1(TagClass.universal, False, TypeTagNumber.integer, b_value)
323
324    return b_value
325
326
327def pack_asn1_object_identifier(
328    oid: str,
329    tag: bool = True,
330) -> bytes:
331    """ Packs an str value into an ASN.1 OBJECT IDENTIFIER byte value with optional universal tagging. """
332    b_oid = bytearray()
333    oid_split = [int(i) for i in oid.split('.')]
334
335    if len(oid_split) < 2:
336        raise ValueError("An OID must have 2 or more elements split by '.'")
337
338    # The first byte of the OID is the first 2 elements (x.y) as (x * 40) + y
339    b_oid.append((oid_split[0] * 40) + oid_split[1])
340
341    for val in oid_split[2:]:
342        b_oid.extend(_pack_asn1_octet_number(val))
343
344    b_value = bytes(b_oid)
345    if tag:
346        b_value = pack_asn1(TagClass.universal, False, TypeTagNumber.object_identifier, b_value)
347
348    return b_value
349
350
351def pack_asn1_octet_string(
352    b_data: bytes,
353    tag: bool = True,
354) -> bytes:
355    """ Packs an bytes value into an ASN.1 OCTET STRING byte value with optional universal tagging. """
356    if tag:
357        b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.octet_string, b_data)
358
359    return b_data
360
361
362def pack_asn1_sequence(
363    sequence: typing.List[bytes],
364    tag: bool = True,
365) -> bytes:
366    """ Packs a list of encoded bytes into an ASN.1 SEQUENCE byte value with optional universal tagging. """
367    b_data = b"".join(sequence)
368    if tag:
369        b_data = pack_asn1(TagClass.universal, True, TypeTagNumber.sequence, b_data)
370
371    return b_data
372
373
374def _pack_asn1_octet_number(num: int) -> bytes:
375    """ Packs an int number into an ASN.1 integer value that spans multiple octets. """
376    num_octets = bytearray()
377
378    while num:
379        # Get the 7 bit value of the number.
380        octet_value = num & 0b01111111
381
382        # Set the MSB if this isn't the first octet we are processing (overall last octet)
383        if len(num_octets):
384            octet_value |= 0b10000000
385
386        num_octets.append(octet_value)
387
388        # Shift the number by 7 bits as we've just processed them.
389        num >>= 7
390
391    # Finally we reverse the order so the higher octets are first.
392    num_octets.reverse()
393
394    return num_octets
395
396
397def unpack_asn1(b_data: bytes) -> typing.Tuple[ASN1Value, bytes]:
398    """Unpacks an ASN.1 TLV into each element.
399
400    Unpacks the raw ASN.1 value into a `ASN1Value` tuple and returns the remaining bytes that are not part of the
401    ASN.1 TLV.
402
403    Args:
404        b_data: The raw bytes to unpack as an ASN.1 TLV.
405
406    Returns:
407        ASN1Value: The ASN.1 value that is unpacked from the raw bytes passed in.
408        bytes: Any remaining bytes that are not part of the ASN1Value.
409    """
410    octet1 = struct.unpack("B", b_data[:1])[0]
411    tag_class = TagClass((octet1 & 0b11000000) >> 6)
412    constructed = bool(octet1 & 0b00100000)
413    tag_number = octet1 & 0b00011111
414
415    length_offset = 1
416    if tag_number == 31:
417        tag_number, octet_count = _unpack_asn1_octet_number(b_data[1:])
418        length_offset += octet_count
419
420    if tag_class == TagClass.universal:
421        tag_number = TypeTagNumber(tag_number)
422
423    b_data = b_data[length_offset:]
424
425    length = struct.unpack("B", b_data[:1])[0]
426    length_octets = 1
427
428    if length & 0b10000000:
429        # If the MSB is set then the length octet just contains the number of octets that encodes the actual length.
430        length_octets += length & 0b01111111
431        length = 0
432
433        for idx in range(1, length_octets):
434            octet_val = struct.unpack("B", b_data[idx:idx + 1])[0]
435            length += octet_val << (8 * (length_octets - 1 - idx))
436
437    value = ASN1Value(tag_class=tag_class, constructed=constructed, tag_number=tag_number,
438                      b_data=b_data[length_octets:length_octets + length])
439
440    return value, b_data[length_octets + length:]
441
442
443def unpack_asn1_bit_string(value: typing.Union[ASN1Value, bytes]) -> bytes:
444    """ Unpacks an ASN.1 BIT STRING value. """
445    b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.bit_string)
446
447    # First octet is the number of unused bits in the last octet from the LSB.
448    unused_bits = struct.unpack("B", b_data[:1])[0]
449    last_octet = struct.unpack("B", b_data[-2:-1])[0]
450    last_octet = (last_octet >> unused_bits) << unused_bits
451
452    return b_data[1:-1] + struct.pack("B", last_octet)
453
454
455def unpack_asn1_boolean(value: typing.Union[ASN1Value, bytes]) -> bool:
456    """ Unpacks an ASN.1 BOOLEAN value. """
457    b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.boolean)
458
459    return b_data != b"\x00"
460
461
462def unpack_asn1_enumerated(value: typing.Union[ASN1Value, bytes]) -> int:
463    """ Unpacks an ASN.1 ENUMERATED value. """
464    b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.enumerated)
465
466    return unpack_asn1_integer(b_data)
467
468
469def unpack_asn1_general_string(value: typing.Union[ASN1Value, bytes]) -> bytes:
470    """ Unpacks an ASN.1 GeneralString value. """
471    return extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.general_string)
472
473
474def unpack_asn1_generalized_time(value: typing.Union[ASN1Value, bytes]) -> datetime.datetime:
475    """ Unpacks an ASN.1 GeneralizedTime value. """
476    data = to_text(extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.generalized_time))
477
478    # While ASN.1 can have a timezone encoded, KerberosTime is the only thing we use and it is always in UTC with the
479    # Z prefix. We strip out the Z because Python 2 doesn't support the %z identifier and add the UTC tz to the object.
480    # https://www.rfc-editor.org/rfc/rfc4120#section-5.2.3
481    if data.endswith('Z'):
482        data = data[:-1]
483
484    err = None
485    for datetime_format in ['%Y%m%d%H%M%S.%f', '%Y%m%d%H%M%S']:
486        try:
487            dt = datetime.datetime.strptime(data, datetime_format)
488            return dt.replace(tzinfo=datetime.timezone.utc)
489        except ValueError as e:
490            err = e
491
492    else:
493        raise err  # type: ignore
494
495
496def unpack_asn1_integer(value: typing.Union[ASN1Value, bytes]) -> int:
497    """ Unpacks an ASN.1 INTEGER value. """
498    b_int = bytearray(extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.integer))
499
500    is_negative = b_int[0] & 0b10000000
501    if is_negative:
502        # Get the two's compliment.
503        for i in range(len(b_int)):
504            b_int[i] = 0xFF - b_int[i]
505
506        for i in range(len(b_int) - 1, -1, -1):
507            if b_int[i] == 0xFF:
508                b_int[i - 1] += 1
509                b_int[i] = 0
510                break
511
512            else:
513                b_int[i] += 1
514                break
515
516    int_value = 0
517    for val in b_int:
518        int_value = (int_value << 8) | val
519
520    if is_negative:
521        int_value *= -1
522
523    return int_value
524
525
526def unpack_asn1_object_identifier(value: typing.Union[ASN1Value, bytes]) -> str:
527    """ Unpacks an ASN.1 OBJECT IDENTIFIER value. """
528    b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.object_identifier)
529
530    first_element = struct.unpack("B", b_data[:1])[0]
531    second_element = first_element % 40
532    ids = [(first_element - second_element) // 40, second_element]
533
534    idx = 1
535    while idx != len(b_data):
536        oid, octet_len = _unpack_asn1_octet_number(b_data[idx:])
537        ids.append(oid)
538        idx += octet_len
539
540    return ".".join([str(i) for i in ids])
541
542
543def unpack_asn1_octet_string(value: typing.Union[ASN1Value, bytes]) -> bytes:
544    """ Unpacks an ASN.1 OCTET STRING value. """
545    return extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.octet_string)
546
547
548def unpack_asn1_sequence(value: typing.Union[ASN1Value, bytes]) -> typing.List[ASN1Value]:
549    """ Unpacks an ASN.1 SEQUENCE value. """
550    b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.sequence)
551
552    values = []
553    while b_data:
554        v, b_data = unpack_asn1(b_data)
555        values.append(v)
556
557    return values
558
559
560def unpack_asn1_tagged_sequence(value: typing.Union[ASN1Value, bytes]) -> typing.Dict[int, ASN1Value]:
561    """ Unpacks an ASN.1 SEQUENCE value as a dictionary. """
562    return dict([(e.tag_number, unpack_asn1(e.b_data)[0]) for e in unpack_asn1_sequence(value)])
563
564
565def _unpack_asn1_octet_number(b_data: bytes) -> typing.Tuple[int, int]:
566    """ Unpacks an ASN.1 INTEGER value that can span across multiple octets. """
567    i = 0
568    idx = 0
569    while True:
570        element = struct.unpack("B", b_data[idx:idx + 1])[0]
571        idx += 1
572
573        i = (i << 7) + (element & 0b01111111)
574        if not element & 0b10000000:
575            break
576
577    return i, idx  # int value and the number of octets used.
578