1# Copyright: (c) 2020, Jordan Borean (@jborean93) <jborean93@gmail.com> 2# MIT License (see LICENSE or https://opensource.org/licenses/MIT) 3 4import collections 5import datetime 6import enum 7import struct 8import typing 9 10from spnego._text import to_bytes, to_text 11 12ASN1Value = collections.namedtuple('ASN1Value', ['tag_class', 'constructed', 'tag_number', 'b_data']) 13"""A representation of an ASN.1 TLV as a Python object. 14 15Defines the ASN.1 Type Length Value (TLV) values as separate objects for easier parsing. This is returned by 16:method:`unpack_asn1`. 17 18Attributes: 19 tag_class (TagClass): The tag class of the TLV. 20 constructed (bool): Whether the value is constructed or 0, 1, or more element encodings (True) or not (False). 21 tag_number (Union[TypeTagNumber, int]): The tag number of the value, can be a TypeTagNumber if the tag_class 22 is `universal` otherwise it's an explicit tag number value. 23 b_data (bytes): The raw byes of the TLV value. 24""" 25 26 27class TagClass(enum.IntEnum): 28 universal = 0 29 application = 1 30 context_specific = 2 31 private = 3 32 33 @classmethod 34 def native_labels(cls) -> typing.Dict["TagClass", str]: 35 return { 36 TagClass.universal: 'Universal', 37 TagClass.application: 'Application', 38 TagClass.context_specific: 'Context-specific', 39 TagClass.private: 'Private', 40 } 41 42 43class TypeTagNumber(enum.IntEnum): 44 end_of_content = 0 45 boolean = 1 46 integer = 2 47 bit_string = 3 48 octet_string = 4 49 null = 5 50 object_identifier = 6 51 object_descriptor = 7 52 external = 8 53 real = 9 54 enumerated = 10 55 embedded_pdv = 11 56 utf8_string = 12 57 relative_oid = 13 58 time = 14 59 reserved = 15 60 sequence = 16 61 sequence_of = 16 62 set = 17 63 set_of = 17 64 numeric_string = 18 65 printable_string = 19 66 t61_string = 20 67 videotex_string = 21 68 ia5_string = 22 69 utc_time = 23 70 generalized_time = 24 71 graphic_string = 25 72 visible_string = 26 73 general_string = 27 74 universal_string = 28 75 character_string = 29 76 bmp_string = 30 77 date = 31 78 time_of_day = 32 79 date_time = 33 80 duration = 34 81 oid_iri = 35 82 relative_oid_iri = 36 83 84 @classmethod 85 def native_labels(cls) -> typing.Dict[int, str]: 86 return { 87 TypeTagNumber.end_of_content: 'End-of-Content (EOC)', 88 TypeTagNumber.boolean: 'BOOLEAN', 89 TypeTagNumber.integer: 'INTEGER', 90 TypeTagNumber.bit_string: 'BIT STRING', 91 TypeTagNumber.octet_string: 'OCTET STRING', 92 TypeTagNumber.null: 'NULL', 93 TypeTagNumber.object_identifier: 'OBJECT IDENTIFIER', 94 TypeTagNumber.object_descriptor: 'Object Descriptor', 95 TypeTagNumber.external: 'EXTERNAL', 96 TypeTagNumber.real: 'REAL (float)', 97 TypeTagNumber.enumerated: 'ENUMERATED', 98 TypeTagNumber.embedded_pdv: 'EMBEDDED PDV', 99 TypeTagNumber.utf8_string: 'UTF8String', 100 TypeTagNumber.relative_oid: 'RELATIVE-OID', 101 TypeTagNumber.time: 'TIME', 102 TypeTagNumber.reserved: 'RESERVED', 103 TypeTagNumber.sequence: 'SEQUENCE or SEQUENCE OF', 104 TypeTagNumber.set: 'SET or SET OF', 105 TypeTagNumber.numeric_string: 'NumericString', 106 TypeTagNumber.printable_string: 'PrintableString', 107 TypeTagNumber.t61_string: 'T61String', 108 TypeTagNumber.videotex_string: 'VideotexString', 109 TypeTagNumber.ia5_string: 'IA5String', 110 TypeTagNumber.utc_time: 'UTCTime', 111 TypeTagNumber.generalized_time: 'GeneralizedTime', 112 TypeTagNumber.graphic_string: 'GraphicString', 113 TypeTagNumber.visible_string: 'VisibleString', 114 TypeTagNumber.general_string: 'GeneralString', 115 TypeTagNumber.universal_string: 'UniversalString', 116 TypeTagNumber.character_string: 'CHARACTER', 117 TypeTagNumber.bmp_string: 'BMPString', 118 TypeTagNumber.date: 'DATE', 119 TypeTagNumber.time_of_day: 'TIME-OF-DAY', 120 TypeTagNumber.date_time: 'DATE-TIME', 121 TypeTagNumber.duration: 'DURATION', 122 TypeTagNumber.oid_iri: 'OID-IRI', 123 TypeTagNumber.relative_oid_iri: 'RELATIVE-OID-IRI', 124 } 125 126 127def extract_asn1_tlv( 128 tlv: typing.Union[bytes, ASN1Value], 129 tag_class: TagClass, 130 tag_number: typing.Union[int, TypeTagNumber], 131) -> bytes: 132 """ Extract the bytes and validates the existing tag of an ASN.1 value. """ 133 if isinstance(tlv, ASN1Value): 134 if tag_class == TagClass.universal: 135 label_name = TypeTagNumber.native_labels().get(tag_number, 'Unknown tag type') 136 msg = "Invalid ASN.1 %s tags, actual tag class %s and tag number %s" \ 137 % (label_name, tlv.tag_class, tlv.tag_number) 138 139 else: 140 msg = "Invalid ASN.1 tags, actual tag %s and number %s, expecting class %s and number %s" \ 141 % (tlv.tag_class, tlv.tag_number, tag_class, tag_number) 142 143 if tlv.tag_class != tag_class or tlv.tag_number != tag_number: 144 raise ValueError(msg) 145 146 return tlv.b_data 147 148 return tlv 149 150 151def get_sequence_value( 152 sequence: typing.Dict[int, ASN1Value], 153 tag: int, 154 structure_name: str, 155 field_name: typing.Optional[str] = None, 156 unpack_func: typing.Optional[typing.Callable[[typing.Union[bytes, ASN1Value]], typing.Any]] = None, 157) -> typing.Any: 158 """ Gets an optional tag entry in a tagged sequence will a further unpacking of the value. """ 159 if tag not in sequence: 160 return 161 162 if not unpack_func: 163 return sequence[tag] 164 165 try: 166 return unpack_func(sequence[tag]) 167 except ValueError as e: 168 where = '%s in %s' % (field_name, structure_name) if field_name else structure_name 169 raise ValueError("Failed unpacking %s: %s" % (where, str(e))) from e 170 171 172def pack_asn1( 173 tag_class: TagClass, 174 constructed: bool, 175 tag_number: typing.Union[TypeTagNumber, int], 176 b_data: bytes, 177) -> bytes: 178 """Pack the ASN.1 value into the ASN.1 bytes. 179 180 Will pack the raw bytes into an ASN.1 Type Length Value (TLV) value. A TLV is in the form: 181 182 | Identifier Octet(s) | Length Octet(s) | Data Octet(s) | 183 184 Args: 185 tag_class: The tag class of the data. 186 constructed: Whether the data is constructed (True), i.e. contains 0, 1, or more element encodings, or is 187 primitive (False). 188 tag_number: The type tag number if tag_class is universal else the explicit tag number of the TLV. 189 b_data: The encoded value to pack into the ASN.1 TLV. 190 191 Returns: 192 bytes: The ASN.1 value as raw bytes. 193 """ 194 b_asn1_data = bytearray() 195 196 # ASN.1 Identifier octet is 197 # 198 # | Octet 1 | | Octet 2 | 199 # | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 200 # | Class | P/C | Tag Number (0-30) | | More | Tag number | 201 # 202 # If Tag Number is >= 31 the first 5 bits are 1 and the 2nd octet is used to encode the length. 203 if tag_class < 0 or tag_class > 3: 204 raise ValueError("tag_class must be between 0 and 3") 205 206 identifier_octets = tag_class << 6 207 identifier_octets |= ((1 if constructed else 0) << 5) 208 209 if tag_number < 31: 210 identifier_octets |= tag_number 211 b_asn1_data.append(identifier_octets) 212 else: 213 # Set the first 5 bits of the first octet to 1 and encode the tag number in subsequent octets. 214 identifier_octets |= 31 215 b_asn1_data.append(identifier_octets) 216 b_asn1_data.extend(_pack_asn1_octet_number(tag_number)) 217 218 # ASN.1 Length octet for DER encoding is always in the definite form. This form packs the lengths in the following 219 # octet structure: 220 # 221 # | Octet 1 | | Octet n | 222 # | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 223 # | Long form | Short = length, Long = num octets | | Big endian length for long | 224 # 225 # Basically if the length < 127 it's encoded in the first octet, otherwise the first octet 7 bits indicates how 226 # many subsequent octets were used to encode the length. 227 length = len(b_data) 228 if length < 128: 229 b_asn1_data.append(length) 230 else: 231 length_octets = bytearray() 232 while length: 233 length_octets.append(length & 0b11111111) 234 length >>= 8 235 236 # Reverse the octets so the higher octets are first, add the initial length octet with the MSB set and add them 237 # all to the main ASN.1 byte array. 238 length_octets.reverse() 239 b_asn1_data.append(len(length_octets) | 0b10000000) 240 b_asn1_data.extend(length_octets) 241 242 return bytes(b_asn1_data) + b_data 243 244 245def pack_asn1_bit_string( 246 value: bytes, 247 tag: bool = True, 248) -> bytes: 249 # First octet is the number of unused bits in the last octet from the LSB. 250 b_data = b"\x00" + value 251 if tag: 252 b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.bit_string, b_data) 253 254 return b_data 255 256 257def pack_asn1_enumerated( 258 value: int, 259 tag: bool = True, 260) -> bytes: 261 """ Packs an int into an ASN.1 ENUMERATED byte value with optional universal tagging. """ 262 b_data = pack_asn1_integer(value, tag=False) 263 if tag: 264 b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.enumerated, b_data) 265 266 return b_data 267 268 269def pack_asn1_general_string( 270 value: typing.Union[str, bytes], 271 tag: bool = True, 272 encoding: str = 'ascii', 273) -> bytes: 274 """ Packs an string value into an ASN.1 GeneralString byte value with optional universal tagging. """ 275 b_data = to_bytes(value, encoding=encoding) 276 if tag: 277 b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.general_string, b_data) 278 279 return b_data 280 281 282def pack_asn1_integer( 283 value: int, 284 tag: bool = True, 285) -> bytes: 286 """ Packs an int value into an ASN.1 INTEGER byte value with optional universal tagging. """ 287 # Thanks to https://github.com/andrivet/python-asn1 for help with the negative value logic. 288 is_negative = False 289 limit = 0x7f 290 if value < 0: 291 value = -value 292 is_negative = True 293 limit = 0x80 294 295 b_int = bytearray() 296 while value > limit: 297 val = value & 0xFF 298 299 if is_negative: 300 val = 0xFF - val 301 302 b_int.append(val) 303 value >>= 8 304 305 b_int.append(((0xFF - value) if is_negative else value) & 0xFF) 306 307 if is_negative: 308 for idx, val in enumerate(b_int): 309 if val < 0xFF: 310 b_int[idx] += 1 311 break 312 313 b_int[idx] = 0 314 315 if is_negative and b_int[-1] == 0x7F: # Two's complement corner case 316 b_int.append(0xFF) 317 318 b_int.reverse() 319 320 b_value = bytes(b_int) 321 if tag: 322 b_value = pack_asn1(TagClass.universal, False, TypeTagNumber.integer, b_value) 323 324 return b_value 325 326 327def pack_asn1_object_identifier( 328 oid: str, 329 tag: bool = True, 330) -> bytes: 331 """ Packs an str value into an ASN.1 OBJECT IDENTIFIER byte value with optional universal tagging. """ 332 b_oid = bytearray() 333 oid_split = [int(i) for i in oid.split('.')] 334 335 if len(oid_split) < 2: 336 raise ValueError("An OID must have 2 or more elements split by '.'") 337 338 # The first byte of the OID is the first 2 elements (x.y) as (x * 40) + y 339 b_oid.append((oid_split[0] * 40) + oid_split[1]) 340 341 for val in oid_split[2:]: 342 b_oid.extend(_pack_asn1_octet_number(val)) 343 344 b_value = bytes(b_oid) 345 if tag: 346 b_value = pack_asn1(TagClass.universal, False, TypeTagNumber.object_identifier, b_value) 347 348 return b_value 349 350 351def pack_asn1_octet_string( 352 b_data: bytes, 353 tag: bool = True, 354) -> bytes: 355 """ Packs an bytes value into an ASN.1 OCTET STRING byte value with optional universal tagging. """ 356 if tag: 357 b_data = pack_asn1(TagClass.universal, False, TypeTagNumber.octet_string, b_data) 358 359 return b_data 360 361 362def pack_asn1_sequence( 363 sequence: typing.List[bytes], 364 tag: bool = True, 365) -> bytes: 366 """ Packs a list of encoded bytes into an ASN.1 SEQUENCE byte value with optional universal tagging. """ 367 b_data = b"".join(sequence) 368 if tag: 369 b_data = pack_asn1(TagClass.universal, True, TypeTagNumber.sequence, b_data) 370 371 return b_data 372 373 374def _pack_asn1_octet_number(num: int) -> bytes: 375 """ Packs an int number into an ASN.1 integer value that spans multiple octets. """ 376 num_octets = bytearray() 377 378 while num: 379 # Get the 7 bit value of the number. 380 octet_value = num & 0b01111111 381 382 # Set the MSB if this isn't the first octet we are processing (overall last octet) 383 if len(num_octets): 384 octet_value |= 0b10000000 385 386 num_octets.append(octet_value) 387 388 # Shift the number by 7 bits as we've just processed them. 389 num >>= 7 390 391 # Finally we reverse the order so the higher octets are first. 392 num_octets.reverse() 393 394 return num_octets 395 396 397def unpack_asn1(b_data: bytes) -> typing.Tuple[ASN1Value, bytes]: 398 """Unpacks an ASN.1 TLV into each element. 399 400 Unpacks the raw ASN.1 value into a `ASN1Value` tuple and returns the remaining bytes that are not part of the 401 ASN.1 TLV. 402 403 Args: 404 b_data: The raw bytes to unpack as an ASN.1 TLV. 405 406 Returns: 407 ASN1Value: The ASN.1 value that is unpacked from the raw bytes passed in. 408 bytes: Any remaining bytes that are not part of the ASN1Value. 409 """ 410 octet1 = struct.unpack("B", b_data[:1])[0] 411 tag_class = TagClass((octet1 & 0b11000000) >> 6) 412 constructed = bool(octet1 & 0b00100000) 413 tag_number = octet1 & 0b00011111 414 415 length_offset = 1 416 if tag_number == 31: 417 tag_number, octet_count = _unpack_asn1_octet_number(b_data[1:]) 418 length_offset += octet_count 419 420 if tag_class == TagClass.universal: 421 tag_number = TypeTagNumber(tag_number) 422 423 b_data = b_data[length_offset:] 424 425 length = struct.unpack("B", b_data[:1])[0] 426 length_octets = 1 427 428 if length & 0b10000000: 429 # If the MSB is set then the length octet just contains the number of octets that encodes the actual length. 430 length_octets += length & 0b01111111 431 length = 0 432 433 for idx in range(1, length_octets): 434 octet_val = struct.unpack("B", b_data[idx:idx + 1])[0] 435 length += octet_val << (8 * (length_octets - 1 - idx)) 436 437 value = ASN1Value(tag_class=tag_class, constructed=constructed, tag_number=tag_number, 438 b_data=b_data[length_octets:length_octets + length]) 439 440 return value, b_data[length_octets + length:] 441 442 443def unpack_asn1_bit_string(value: typing.Union[ASN1Value, bytes]) -> bytes: 444 """ Unpacks an ASN.1 BIT STRING value. """ 445 b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.bit_string) 446 447 # First octet is the number of unused bits in the last octet from the LSB. 448 unused_bits = struct.unpack("B", b_data[:1])[0] 449 last_octet = struct.unpack("B", b_data[-2:-1])[0] 450 last_octet = (last_octet >> unused_bits) << unused_bits 451 452 return b_data[1:-1] + struct.pack("B", last_octet) 453 454 455def unpack_asn1_boolean(value: typing.Union[ASN1Value, bytes]) -> bool: 456 """ Unpacks an ASN.1 BOOLEAN value. """ 457 b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.boolean) 458 459 return b_data != b"\x00" 460 461 462def unpack_asn1_enumerated(value: typing.Union[ASN1Value, bytes]) -> int: 463 """ Unpacks an ASN.1 ENUMERATED value. """ 464 b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.enumerated) 465 466 return unpack_asn1_integer(b_data) 467 468 469def unpack_asn1_general_string(value: typing.Union[ASN1Value, bytes]) -> bytes: 470 """ Unpacks an ASN.1 GeneralString value. """ 471 return extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.general_string) 472 473 474def unpack_asn1_generalized_time(value: typing.Union[ASN1Value, bytes]) -> datetime.datetime: 475 """ Unpacks an ASN.1 GeneralizedTime value. """ 476 data = to_text(extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.generalized_time)) 477 478 # While ASN.1 can have a timezone encoded, KerberosTime is the only thing we use and it is always in UTC with the 479 # Z prefix. We strip out the Z because Python 2 doesn't support the %z identifier and add the UTC tz to the object. 480 # https://www.rfc-editor.org/rfc/rfc4120#section-5.2.3 481 if data.endswith('Z'): 482 data = data[:-1] 483 484 err = None 485 for datetime_format in ['%Y%m%d%H%M%S.%f', '%Y%m%d%H%M%S']: 486 try: 487 dt = datetime.datetime.strptime(data, datetime_format) 488 return dt.replace(tzinfo=datetime.timezone.utc) 489 except ValueError as e: 490 err = e 491 492 else: 493 raise err # type: ignore 494 495 496def unpack_asn1_integer(value: typing.Union[ASN1Value, bytes]) -> int: 497 """ Unpacks an ASN.1 INTEGER value. """ 498 b_int = bytearray(extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.integer)) 499 500 is_negative = b_int[0] & 0b10000000 501 if is_negative: 502 # Get the two's compliment. 503 for i in range(len(b_int)): 504 b_int[i] = 0xFF - b_int[i] 505 506 for i in range(len(b_int) - 1, -1, -1): 507 if b_int[i] == 0xFF: 508 b_int[i - 1] += 1 509 b_int[i] = 0 510 break 511 512 else: 513 b_int[i] += 1 514 break 515 516 int_value = 0 517 for val in b_int: 518 int_value = (int_value << 8) | val 519 520 if is_negative: 521 int_value *= -1 522 523 return int_value 524 525 526def unpack_asn1_object_identifier(value: typing.Union[ASN1Value, bytes]) -> str: 527 """ Unpacks an ASN.1 OBJECT IDENTIFIER value. """ 528 b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.object_identifier) 529 530 first_element = struct.unpack("B", b_data[:1])[0] 531 second_element = first_element % 40 532 ids = [(first_element - second_element) // 40, second_element] 533 534 idx = 1 535 while idx != len(b_data): 536 oid, octet_len = _unpack_asn1_octet_number(b_data[idx:]) 537 ids.append(oid) 538 idx += octet_len 539 540 return ".".join([str(i) for i in ids]) 541 542 543def unpack_asn1_octet_string(value: typing.Union[ASN1Value, bytes]) -> bytes: 544 """ Unpacks an ASN.1 OCTET STRING value. """ 545 return extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.octet_string) 546 547 548def unpack_asn1_sequence(value: typing.Union[ASN1Value, bytes]) -> typing.List[ASN1Value]: 549 """ Unpacks an ASN.1 SEQUENCE value. """ 550 b_data = extract_asn1_tlv(value, TagClass.universal, TypeTagNumber.sequence) 551 552 values = [] 553 while b_data: 554 v, b_data = unpack_asn1(b_data) 555 values.append(v) 556 557 return values 558 559 560def unpack_asn1_tagged_sequence(value: typing.Union[ASN1Value, bytes]) -> typing.Dict[int, ASN1Value]: 561 """ Unpacks an ASN.1 SEQUENCE value as a dictionary. """ 562 return dict([(e.tag_number, unpack_asn1(e.b_data)[0]) for e in unpack_asn1_sequence(value)]) 563 564 565def _unpack_asn1_octet_number(b_data: bytes) -> typing.Tuple[int, int]: 566 """ Unpacks an ASN.1 INTEGER value that can span across multiple octets. """ 567 i = 0 568 idx = 0 569 while True: 570 element = struct.unpack("B", b_data[idx:idx + 1])[0] 571 idx += 1 572 573 i = (i << 7) + (element & 0b01111111) 574 if not element & 0b10000000: 575 break 576 577 return i, idx # int value and the number of octets used. 578