1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import six
8
9from cryptography.utils import int_from_bytes, int_to_bytes
10
11
12# This module contains a lightweight DER encoder and decoder. See X.690 for the
13# specification. This module intentionally does not implement the more complex
14# BER encoding, only DER.
15#
16# Note this implementation treats an element's constructed bit as part of the
17# tag. This is fine for DER, where the bit is always computable from the type.
18
19
20CONSTRUCTED = 0x20
21CONTEXT_SPECIFIC = 0x80
22
23INTEGER = 0x02
24BIT_STRING = 0x03
25OCTET_STRING = 0x04
26NULL = 0x05
27OBJECT_IDENTIFIER = 0x06
28SEQUENCE = 0x10 | CONSTRUCTED
29SET = 0x11 | CONSTRUCTED
30PRINTABLE_STRING = 0x13
31UTC_TIME = 0x17
32GENERALIZED_TIME = 0x18
33
34
35class DERReader(object):
36    def __init__(self, data):
37        self.data = memoryview(data)
38
39    def __enter__(self):
40        return self
41
42    def __exit__(self, exc_type, exc_value, tb):
43        if exc_value is None:
44            self.check_empty()
45
46    def is_empty(self):
47        return len(self.data) == 0
48
49    def check_empty(self):
50        if not self.is_empty():
51            raise ValueError("Invalid DER input: trailing data")
52
53    def read_byte(self):
54        if len(self.data) < 1:
55            raise ValueError("Invalid DER input: insufficient data")
56        ret = six.indexbytes(self.data, 0)
57        self.data = self.data[1:]
58        return ret
59
60    def read_bytes(self, n):
61        if len(self.data) < n:
62            raise ValueError("Invalid DER input: insufficient data")
63        ret = self.data[:n]
64        self.data = self.data[n:]
65        return ret
66
67    def read_any_element(self):
68        tag = self.read_byte()
69        # Tag numbers 31 or higher are stored in multiple bytes. No supported
70        # ASN.1 types use such tags, so reject these.
71        if tag & 0x1F == 0x1F:
72            raise ValueError("Invalid DER input: unexpected high tag number")
73        length_byte = self.read_byte()
74        if length_byte & 0x80 == 0:
75            # If the high bit is clear, the first length byte is the length.
76            length = length_byte
77        else:
78            # If the high bit is set, the first length byte encodes the length
79            # of the length.
80            length_byte &= 0x7F
81            if length_byte == 0:
82                raise ValueError(
83                    "Invalid DER input: indefinite length form is not allowed "
84                    "in DER"
85                )
86            length = 0
87            for i in range(length_byte):
88                length <<= 8
89                length |= self.read_byte()
90                if length == 0:
91                    raise ValueError(
92                        "Invalid DER input: length was not minimally-encoded"
93                    )
94            if length < 0x80:
95                # If the length could have been encoded in short form, it must
96                # not use long form.
97                raise ValueError(
98                    "Invalid DER input: length was not minimally-encoded"
99                )
100        body = self.read_bytes(length)
101        return tag, DERReader(body)
102
103    def read_element(self, expected_tag):
104        tag, body = self.read_any_element()
105        if tag != expected_tag:
106            raise ValueError("Invalid DER input: unexpected tag")
107        return body
108
109    def read_single_element(self, expected_tag):
110        with self:
111            return self.read_element(expected_tag)
112
113    def read_optional_element(self, expected_tag):
114        if len(self.data) > 0 and six.indexbytes(self.data, 0) == expected_tag:
115            return self.read_element(expected_tag)
116        return None
117
118    def as_integer(self):
119        if len(self.data) == 0:
120            raise ValueError("Invalid DER input: empty integer contents")
121        first = six.indexbytes(self.data, 0)
122        if first & 0x80 == 0x80:
123            raise ValueError("Negative DER integers are not supported")
124        # The first 9 bits must not all be zero or all be ones. Otherwise, the
125        # encoding should have been one byte shorter.
126        if len(self.data) > 1:
127            second = six.indexbytes(self.data, 1)
128            if first == 0 and second & 0x80 == 0:
129                raise ValueError(
130                    "Invalid DER input: integer not minimally-encoded"
131                )
132        return int_from_bytes(self.data, "big")
133
134
135def encode_der_integer(x):
136    if not isinstance(x, six.integer_types):
137        raise ValueError("Value must be an integer")
138    if x < 0:
139        raise ValueError("Negative integers are not supported")
140    n = x.bit_length() // 8 + 1
141    return int_to_bytes(x, n)
142
143
144def encode_der(tag, *children):
145    length = 0
146    for child in children:
147        length += len(child)
148    chunks = [six.int2byte(tag)]
149    if length < 0x80:
150        chunks.append(six.int2byte(length))
151    else:
152        length_bytes = int_to_bytes(length)
153        chunks.append(six.int2byte(0x80 | len(length_bytes)))
154        chunks.append(length_bytes)
155    chunks.extend(children)
156    return b"".join(chunks)
157