1#! /usr/bin/env python3
2
3"""
4This script should be called *manually* when we want to upgrade SSLError
5`library` and `reason` mnemnonics to a more recent OpenSSL version.
6
7It takes two arguments:
8- the path to the OpenSSL source tree (e.g. git checkout)
9- the path to the header file to be generated Modules/_ssl_data_{version}.h
10- error codes are version specific
11"""
12
13import argparse
14import datetime
15import operator
16import os
17import re
18import sys
19
20
21parser = argparse.ArgumentParser(
22    description="Generate ssl_data.h from OpenSSL sources"
23)
24parser.add_argument("srcdir", help="OpenSSL source directory")
25parser.add_argument(
26    "output", nargs="?", type=argparse.FileType("w"), default=sys.stdout
27)
28
29
30def _file_search(fname, pat):
31    with open(fname, encoding="utf-8") as f:
32        for line in f:
33            match = pat.search(line)
34            if match is not None:
35                yield match
36
37
38def parse_err_h(args):
39    """Parse err codes, e.g. ERR_LIB_X509: 11"""
40    pat = re.compile(r"#\s*define\W+ERR_LIB_(\w+)\s+(\d+)")
41    lib2errnum = {}
42    for match in _file_search(args.err_h, pat):
43        libname, num = match.groups()
44        lib2errnum[libname] = int(num)
45
46    return lib2errnum
47
48
49def parse_openssl_error_text(args):
50    """Parse error reasons, X509_R_AKID_MISMATCH"""
51    # ignore backslash line continuation for now
52    pat = re.compile(r"^((\w+?)_R_(\w+)):(\d+):")
53    for match in _file_search(args.errtxt, pat):
54        reason, libname, errname, num = match.groups()
55        if "_F_" in reason:
56            # ignore function codes
57            continue
58        num = int(num)
59        yield reason, libname, errname, num
60
61
62def parse_extra_reasons(args):
63    """Parse extra reasons from openssl.ec"""
64    pat = re.compile(r"^R\s+((\w+)_R_(\w+))\s+(\d+)")
65    for match in _file_search(args.errcodes, pat):
66        reason, libname, errname, num = match.groups()
67        num = int(num)
68        yield reason, libname, errname, num
69
70
71def gen_library_codes(args):
72    """Generate table short libname to numeric code"""
73    yield "static struct py_ssl_library_code library_codes[] = {"
74    for libname in sorted(args.lib2errnum):
75        yield f"#ifdef ERR_LIB_{libname}"
76        yield f'    {{"{libname}", ERR_LIB_{libname}}},'
77        yield "#endif"
78    yield "    { NULL }"
79    yield "};"
80    yield ""
81
82
83def gen_error_codes(args):
84    """Generate error code table for error reasons"""
85    yield "static struct py_ssl_error_code error_codes[] = {"
86    for reason, libname, errname, num in args.reasons:
87        yield f"  #ifdef {reason}"
88        yield f'    {{"{errname}", ERR_LIB_{libname}, {reason}}},'
89        yield "  #else"
90        yield f'    {{"{errname}", {args.lib2errnum[libname]}, {num}}},'
91        yield "  #endif"
92
93    yield "    { NULL }"
94    yield "};"
95    yield ""
96
97
98def main():
99    args = parser.parse_args()
100
101    args.err_h = os.path.join(args.srcdir, "include", "openssl", "err.h")
102    if not os.path.isfile(args.err_h):
103        # Fall back to infile for OpenSSL 3.0.0
104        args.err_h += ".in"
105    args.errcodes = os.path.join(args.srcdir, "crypto", "err", "openssl.ec")
106    args.errtxt = os.path.join(args.srcdir, "crypto", "err", "openssl.txt")
107
108    if not os.path.isfile(args.errtxt):
109        parser.error(f"File {args.errtxt} not found in srcdir\n.")
110
111    # {X509: 11, ...}
112    args.lib2errnum = parse_err_h(args)
113
114    # [('X509_R_AKID_MISMATCH', 'X509', 'AKID_MISMATCH', 110), ...]
115    reasons = []
116    reasons.extend(parse_openssl_error_text(args))
117    reasons.extend(parse_extra_reasons(args))
118    # sort by libname, numeric error code
119    args.reasons = sorted(reasons, key=operator.itemgetter(0, 3))
120
121    lines = [
122        "/* File generated by Tools/ssl/make_ssl_data.py */"
123        f"/* Generated on {datetime.datetime.utcnow().isoformat()} */"
124    ]
125    lines.extend(gen_library_codes(args))
126    lines.append("")
127    lines.extend(gen_error_codes(args))
128
129    for line in lines:
130        args.output.write(line + "\n")
131
132
133if __name__ == "__main__":
134    main()
135