1#!/usr/local/bin/python3.8
2#
3# Electrum - lightweight Bitcoin client
4# Copyright (C) 2014 Thomas Voegtlin
5#
6# Permission is hereby granted, free of charge, to any person
7# obtaining a copy of this software and associated documentation files
8# (the "Software"), to deal in the Software without restriction,
9# including without limitation the rights to use, copy, modify, merge,
10# publish, distribute, sublicense, and/or sell copies of the Software,
11# and to permit persons to whom the Software is furnished to do so,
12# subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be
15# included in all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24# SOFTWARE.
25
26import hashlib
27import time
28from datetime import datetime
29
30from . import util
31from .util import profiler, bh2u
32from .logging import get_logger
33
34
35_logger = get_logger(__name__)
36
37
38# algo OIDs
39ALGO_RSA_SHA1 = '1.2.840.113549.1.1.5'
40ALGO_RSA_SHA256 = '1.2.840.113549.1.1.11'
41ALGO_RSA_SHA384 = '1.2.840.113549.1.1.12'
42ALGO_RSA_SHA512 = '1.2.840.113549.1.1.13'
43ALGO_ECDSA_SHA256 = '1.2.840.10045.4.3.2'
44
45# prefixes, see http://stackoverflow.com/questions/3713774/c-sharp-how-to-calculate-asn-1-der-encoding-of-a-particular-hash-algorithm
46PREFIX_RSA_SHA256 = bytearray(
47    [0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20])
48PREFIX_RSA_SHA384 = bytearray(
49    [0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30])
50PREFIX_RSA_SHA512 = bytearray(
51    [0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40])
52
53# types used in ASN1 structured data
54ASN1_TYPES = {
55    'BOOLEAN'          : 0x01,
56    'INTEGER'          : 0x02,
57    'BIT STRING'       : 0x03,
58    'OCTET STRING'     : 0x04,
59    'NULL'             : 0x05,
60    'OBJECT IDENTIFIER': 0x06,
61    'SEQUENCE'         : 0x70,
62    'SET'              : 0x71,
63    'PrintableString'  : 0x13,
64    'IA5String'        : 0x16,
65    'UTCTime'          : 0x17,
66    'GeneralizedTime'  : 0x18,
67    'ENUMERATED'       : 0x0A,
68    'UTF8String'       : 0x0C,
69}
70
71
72class CertificateError(Exception):
73    pass
74
75
76# helper functions
77def bitstr_to_bytestr(s):
78    if s[0] != 0x00:
79        raise TypeError('no padding')
80    return s[1:]
81
82
83def bytestr_to_int(s):
84    i = 0
85    for char in s:
86        i <<= 8
87        i |= char
88    return i
89
90
91def decode_OID(s):
92    r = []
93    r.append(s[0] // 40)
94    r.append(s[0] % 40)
95    k = 0
96    for i in s[1:]:
97        if i < 128:
98            r.append(i + 128 * k)
99            k = 0
100        else:
101            k = (i - 128) + 128 * k
102    return '.'.join(map(str, r))
103
104
105def encode_OID(oid):
106    x = [int(i) for i in oid.split('.')]
107    s = chr(x[0] * 40 + x[1])
108    for i in x[2:]:
109        ss = chr(i % 128)
110        while i > 128:
111            i //= 128
112            ss = chr(128 + i % 128) + ss
113        s += ss
114    return s
115
116
117class ASN1_Node(bytes):
118    def get_node(self, ix):
119        # return index of first byte, first content byte and last byte.
120        first = self[ix + 1]
121        if (first & 0x80) == 0:
122            length = first
123            ixf = ix + 2
124            ixl = ixf + length - 1
125        else:
126            lengthbytes = first & 0x7F
127            length = bytestr_to_int(self[ix + 2:ix + 2 + lengthbytes])
128            ixf = ix + 2 + lengthbytes
129            ixl = ixf + length - 1
130        return ix, ixf, ixl
131
132    def root(self):
133        return self.get_node(0)
134
135    def next_node(self, node):
136        ixs, ixf, ixl = node
137        return self.get_node(ixl + 1)
138
139    def first_child(self, node):
140        ixs, ixf, ixl = node
141        if self[ixs] & 0x20 != 0x20:
142            raise TypeError('Can only open constructed types.', hex(self[ixs]))
143        return self.get_node(ixf)
144
145    def is_child_of(node1, node2):
146        ixs, ixf, ixl = node1
147        jxs, jxf, jxl = node2
148        return ((ixf <= jxs) and (jxl <= ixl)) or ((jxf <= ixs) and (ixl <= jxl))
149
150    def get_all(self, node):
151        # return type + length + value
152        ixs, ixf, ixl = node
153        return self[ixs:ixl + 1]
154
155    def get_value_of_type(self, node, asn1_type):
156        # verify type byte and return content
157        ixs, ixf, ixl = node
158        if ASN1_TYPES[asn1_type] != self[ixs]:
159            raise TypeError('Wrong type:', hex(self[ixs]), hex(ASN1_TYPES[asn1_type]))
160        return self[ixf:ixl + 1]
161
162    def get_value(self, node):
163        ixs, ixf, ixl = node
164        return self[ixf:ixl + 1]
165
166    def get_children(self, node):
167        nodes = []
168        ii = self.first_child(node)
169        nodes.append(ii)
170        while ii[2] < node[2]:
171            ii = self.next_node(ii)
172            nodes.append(ii)
173        return nodes
174
175    def get_sequence(self):
176        return list(map(lambda j: self.get_value(j), self.get_children(self.root())))
177
178    def get_dict(self, node):
179        p = {}
180        for ii in self.get_children(node):
181            for iii in self.get_children(ii):
182                iiii = self.first_child(iii)
183                oid = decode_OID(self.get_value_of_type(iiii, 'OBJECT IDENTIFIER'))
184                iiii = self.next_node(iiii)
185                value = self.get_value(iiii)
186                p[oid] = value
187        return p
188
189    def decode_time(self, ii):
190        GENERALIZED_TIMESTAMP_FMT = '%Y%m%d%H%M%SZ'
191        UTCTIME_TIMESTAMP_FMT = '%y%m%d%H%M%SZ'
192
193        try:
194            return time.strptime(self.get_value_of_type(ii, 'UTCTime').decode('ascii'), UTCTIME_TIMESTAMP_FMT)
195        except TypeError:
196            return time.strptime(self.get_value_of_type(ii, 'GeneralizedTime').decode('ascii'), GENERALIZED_TIMESTAMP_FMT)
197
198class X509(object):
199    def __init__(self, b):
200
201        self.bytes = bytearray(b)
202
203        der = ASN1_Node(b)
204        root = der.root()
205        cert = der.first_child(root)
206        # data for signature
207        self.data = der.get_all(cert)
208
209        # optional version field
210        if der.get_value(cert)[0] == 0xa0:
211            version = der.first_child(cert)
212            serial_number = der.next_node(version)
213        else:
214            serial_number = der.first_child(cert)
215        self.serial_number = bytestr_to_int(der.get_value_of_type(serial_number, 'INTEGER'))
216
217        # signature algorithm
218        sig_algo = der.next_node(serial_number)
219        ii = der.first_child(sig_algo)
220        self.sig_algo = decode_OID(der.get_value_of_type(ii, 'OBJECT IDENTIFIER'))
221
222        # issuer
223        issuer = der.next_node(sig_algo)
224        self.issuer = der.get_dict(issuer)
225
226        # validity
227        validity = der.next_node(issuer)
228        ii = der.first_child(validity)
229        self.notBefore = der.decode_time(ii)
230        ii = der.next_node(ii)
231        self.notAfter = der.decode_time(ii)
232
233        # subject
234        subject = der.next_node(validity)
235        self.subject = der.get_dict(subject)
236        subject_pki = der.next_node(subject)
237        public_key_algo = der.first_child(subject_pki)
238        ii = der.first_child(public_key_algo)
239        self.public_key_algo = decode_OID(der.get_value_of_type(ii, 'OBJECT IDENTIFIER'))
240
241        if self.public_key_algo != '1.2.840.10045.2.1':  # for non EC public key
242            # pubkey modulus and exponent
243            subject_public_key = der.next_node(public_key_algo)
244            spk = der.get_value_of_type(subject_public_key, 'BIT STRING')
245            spk = ASN1_Node(bitstr_to_bytestr(spk))
246            r = spk.root()
247            modulus = spk.first_child(r)
248            exponent = spk.next_node(modulus)
249            rsa_n = spk.get_value_of_type(modulus, 'INTEGER')
250            rsa_e = spk.get_value_of_type(exponent, 'INTEGER')
251            self.modulus = int.from_bytes(rsa_n, byteorder='big', signed=False)
252            self.exponent = int.from_bytes(rsa_e, byteorder='big', signed=False)
253        else:
254            subject_public_key = der.next_node(public_key_algo)
255            spk = der.get_value_of_type(subject_public_key, 'BIT STRING')
256            self.ec_public_key = spk
257
258        # extensions
259        self.CA = False
260        self.AKI = None
261        self.SKI = None
262        i = subject_pki
263        while i[2] < cert[2]:
264            i = der.next_node(i)
265            d = der.get_dict(i)
266            for oid, value in d.items():
267                value = ASN1_Node(value)
268                if oid == '2.5.29.19':
269                    # Basic Constraints
270                    self.CA = bool(value)
271                elif oid == '2.5.29.14':
272                    # Subject Key Identifier
273                    r = value.root()
274                    value = value.get_value_of_type(r, 'OCTET STRING')
275                    self.SKI = bh2u(value)
276                elif oid == '2.5.29.35':
277                    # Authority Key Identifier
278                    self.AKI = bh2u(value.get_sequence()[0])
279                else:
280                    pass
281
282        # cert signature
283        cert_sig_algo = der.next_node(cert)
284        ii = der.first_child(cert_sig_algo)
285        self.cert_sig_algo = decode_OID(der.get_value_of_type(ii, 'OBJECT IDENTIFIER'))
286        cert_sig = der.next_node(cert_sig_algo)
287        self.signature = der.get_value(cert_sig)[1:]
288
289    def get_keyID(self):
290        # http://security.stackexchange.com/questions/72077/validating-an-ssl-certificate-chain-according-to-rfc-5280-am-i-understanding-th
291        return self.SKI if self.SKI else repr(self.subject)
292
293    def get_issuer_keyID(self):
294        return self.AKI if self.AKI else repr(self.issuer)
295
296    def get_common_name(self):
297        return self.subject.get('2.5.4.3', b'unknown').decode()
298
299    def get_signature(self):
300        return self.cert_sig_algo, self.signature, self.data
301
302    def check_ca(self):
303        return self.CA
304
305    def check_date(self):
306        now = time.gmtime()
307        if self.notBefore > now:
308            raise CertificateError('Certificate has not entered its valid date range. (%s)' % self.get_common_name())
309        if self.notAfter <= now:
310            dt = datetime.utcfromtimestamp(time.mktime(self.notAfter))
311            raise CertificateError(f'Certificate ({self.get_common_name()}) has expired (at {dt} UTC).')
312
313    def getFingerprint(self):
314        return hashlib.sha1(self.bytes).digest()
315
316
317@profiler
318def load_certificates(ca_path):
319    from . import pem
320    ca_list = {}
321    ca_keyID = {}
322    # ca_path = '/tmp/tmp.txt'
323    with open(ca_path, 'r', encoding='utf-8') as f:
324        s = f.read()
325    bList = pem.dePemList(s, "CERTIFICATE")
326    for b in bList:
327        try:
328            x = X509(b)
329            x.check_date()
330        except BaseException as e:
331            # with open('/tmp/tmp.txt', 'w') as f:
332            #     f.write(pem.pem(b, 'CERTIFICATE').decode('ascii'))
333            _logger.info(f"cert error: {e}")
334            continue
335
336        fp = x.getFingerprint()
337        ca_list[fp] = x
338        ca_keyID[x.get_keyID()] = fp
339
340    return ca_list, ca_keyID
341
342
343if __name__ == "__main__":
344    import certifi
345
346    ca_path = certifi.where()
347    ca_list, ca_keyID = load_certificates(ca_path)
348