1import functools
2import os
3import random
4import shutil
5import subprocess
6import tempfile
7from datetime import datetime, timedelta
8
9# Amount of time beyond the present to consider certificates "expired." This
10# allows certificates to be proactively re-generated in the "buffer" period
11# prior to their exact expiration time.
12CERT_EXPIRY_BUFFER = dict(hours=6)
13
14
15class OpenSSL(object):
16    def __init__(self, logger, binary, base_path, conf_path, hosts, duration,
17                 base_conf_path=None):
18        """Context manager for interacting with OpenSSL.
19        Creates a config file for the duration of the context.
20
21        :param logger: stdlib logger or python structured logger
22        :param binary: path to openssl binary
23        :param base_path: path to directory for storing certificates
24        :param conf_path: path for configuration file storing configuration data
25        :param hosts: list of hosts to include in configuration (or None if not
26                      generating host certificates)
27        :param duration: Certificate duration in days"""
28
29        self.base_path = base_path
30        self.binary = binary
31        self.conf_path = conf_path
32        self.base_conf_path = base_conf_path
33        self.logger = logger
34        self.proc = None
35        self.cmd = []
36        self.hosts = hosts
37        self.duration = duration
38
39    def __enter__(self):
40        with open(self.conf_path, "w") as f:
41            f.write(get_config(self.base_path, self.hosts, self.duration))
42        return self
43
44    def __exit__(self, *args, **kwargs):
45        os.unlink(self.conf_path)
46
47    def log(self, line):
48        if hasattr(self.logger, "process_output"):
49            self.logger.process_output(self.proc.pid if self.proc is not None else None,
50                                       line.decode("utf8", "replace"),
51                                       command=" ".join(self.cmd))
52        else:
53            self.logger.debug(line)
54
55    def __call__(self, cmd, *args, **kwargs):
56        """Run a command using OpenSSL in the current context.
57
58        :param cmd: The openssl subcommand to run
59        :param *args: Additional arguments to pass to the command
60        """
61        self.cmd = [self.binary, cmd]
62        if cmd != "x509":
63            self.cmd += ["-config", self.conf_path]
64        self.cmd += list(args)
65
66        # Copy the environment and add OPENSSL_CONF if available.
67        env = os.environ.copy()
68        if self.base_conf_path is not None:
69            env["OPENSSL_CONF"] = self.base_conf_path
70
71        self.proc = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
72                                     env=env)
73        stdout, stderr = self.proc.communicate()
74        self.log(stdout)
75        if self.proc.returncode != 0:
76            raise subprocess.CalledProcessError(self.proc.returncode, self.cmd,
77                                                output=stdout)
78
79        self.cmd = []
80        self.proc = None
81        return stdout
82
83
84def make_subject(common_name,
85                 country=None,
86                 state=None,
87                 locality=None,
88                 organization=None,
89                 organization_unit=None):
90    args = [("country", "C"),
91            ("state", "ST"),
92            ("locality", "L"),
93            ("organization", "O"),
94            ("organization_unit", "OU"),
95            ("common_name", "CN")]
96
97    rv = []
98
99    for var, key in args:
100        value = locals()[var]
101        if value is not None:
102            rv.append("/%s=%s" % (key, value.replace("/", "\\/")))
103
104    return "".join(rv)
105
106def make_alt_names(hosts):
107    return ",".join("DNS:%s" % host for host in hosts)
108
109def make_name_constraints(hosts):
110    return ",".join("permitted;DNS:%s" % host for host in hosts)
111
112def get_config(root_dir, hosts, duration=30):
113    if hosts is None:
114        san_line = ""
115        constraints_line = ""
116    else:
117        san_line = "subjectAltName = %s" % make_alt_names(hosts)
118        constraints_line = "nameConstraints = " + make_name_constraints(hosts)
119
120    if os.path.sep == "\\":
121        # This seems to be needed for the Shining Light OpenSSL on
122        # Windows, at least.
123        root_dir = root_dir.replace("\\", "\\\\")
124
125    rv = """[ ca ]
126default_ca = CA_default
127
128[ CA_default ]
129dir = %(root_dir)s
130certs = $dir
131new_certs_dir = $certs
132crl_dir = $dir%(sep)scrl
133database = $dir%(sep)sindex.txt
134private_key = $dir%(sep)scacert.key
135certificate = $dir%(sep)scacert.pem
136serial = $dir%(sep)sserial
137crldir = $dir%(sep)scrl
138crlnumber = $dir%(sep)scrlnumber
139crl = $crldir%(sep)scrl.pem
140RANDFILE = $dir%(sep)sprivate%(sep)s.rand
141x509_extensions = usr_cert
142name_opt        = ca_default
143cert_opt        = ca_default
144default_days = %(duration)d
145default_crl_days = %(duration)d
146default_md = sha256
147preserve = no
148policy = policy_anything
149copy_extensions = copy
150
151[ policy_anything ]
152countryName = optional
153stateOrProvinceName = optional
154localityName = optional
155organizationName = optional
156organizationalUnitName = optional
157commonName = supplied
158emailAddress = optional
159
160[ req ]
161default_bits = 2048
162default_keyfile  = privkey.pem
163distinguished_name = req_distinguished_name
164attributes = req_attributes
165x509_extensions = v3_ca
166
167# Passwords for private keys if not present they will be prompted for
168# input_password = secret
169# output_password = secret
170string_mask = utf8only
171req_extensions = v3_req
172
173[ req_distinguished_name ]
174countryName = Country Name (2 letter code)
175countryName_default = AU
176countryName_min = 2
177countryName_max = 2
178stateOrProvinceName = State or Province Name (full name)
179stateOrProvinceName_default =
180localityName = Locality Name (eg, city)
1810.organizationName = Organization Name
1820.organizationName_default = Web Platform Tests
183organizationalUnitName = Organizational Unit Name (eg, section)
184#organizationalUnitName_default =
185commonName = Common Name (e.g. server FQDN or YOUR name)
186commonName_max = 64
187emailAddress = Email Address
188emailAddress_max = 64
189
190[ req_attributes ]
191
192[ usr_cert ]
193basicConstraints=CA:false
194subjectKeyIdentifier=hash
195authorityKeyIdentifier=keyid,issuer
196
197[ v3_req ]
198basicConstraints = CA:FALSE
199keyUsage = nonRepudiation, digitalSignature, keyEncipherment
200extendedKeyUsage = serverAuth
201%(san_line)s
202
203[ v3_ca ]
204basicConstraints = CA:true
205subjectKeyIdentifier=hash
206authorityKeyIdentifier=keyid:always,issuer:always
207keyUsage = keyCertSign
208%(constraints_line)s
209""" % {"root_dir": root_dir,
210       "san_line": san_line,
211       "duration": duration,
212       "constraints_line": constraints_line,
213       "sep": os.path.sep.replace("\\", "\\\\")}
214
215    return rv
216
217class OpenSSLEnvironment(object):
218    ssl_enabled = True
219
220    def __init__(self, logger, openssl_binary="openssl", base_path=None,
221                 password="web-platform-tests", force_regenerate=False,
222                 duration=30, base_conf_path=None):
223        """SSL environment that creates a local CA and host certificate using OpenSSL.
224
225        By default this will look in base_path for existing certificates that are still
226        valid and only create new certificates if there aren't any. This behaviour can
227        be adjusted using the force_regenerate option.
228
229        :param logger: a stdlib logging compatible logger or mozlog structured logger
230        :param openssl_binary: Path to the OpenSSL binary
231        :param base_path: Path in which certificates will be stored. If None, a temporary
232                          directory will be used and removed when the server shuts down
233        :param password: Password to use
234        :param force_regenerate: Always create a new certificate even if one already exists.
235        """
236        self.logger = logger
237
238        self.temporary = False
239        if base_path is None:
240            base_path = tempfile.mkdtemp()
241            self.temporary = True
242
243        self.base_path = os.path.abspath(base_path)
244        self.password = password
245        self.force_regenerate = force_regenerate
246        self.duration = duration
247        self.base_conf_path = base_conf_path
248
249        self.path = None
250        self.binary = openssl_binary
251        self.openssl = None
252
253        self._ca_cert_path = None
254        self._ca_key_path = None
255        self.host_certificates = {}
256
257    def __enter__(self):
258        if not os.path.exists(self.base_path):
259            os.makedirs(self.base_path)
260
261        path = functools.partial(os.path.join, self.base_path)
262
263        with open(path("index.txt"), "w"):
264            pass
265        with open(path("serial"), "w") as f:
266            serial = "%x" % random.randint(0, 1000000)
267            if len(serial) % 2:
268                serial = "0" + serial
269            f.write(serial)
270
271        self.path = path
272
273        return self
274
275    def __exit__(self, *args, **kwargs):
276        if self.temporary:
277            shutil.rmtree(self.base_path)
278
279    def _config_openssl(self, hosts):
280        conf_path = self.path("openssl.cfg")
281        return OpenSSL(self.logger, self.binary, self.base_path, conf_path, hosts,
282                       self.duration, self.base_conf_path)
283
284    def ca_cert_path(self, hosts):
285        """Get the path to the CA certificate file, generating a
286        new one if needed"""
287        if self._ca_cert_path is None and not self.force_regenerate:
288            self._load_ca_cert()
289        if self._ca_cert_path is None:
290            self._generate_ca(hosts)
291        return self._ca_cert_path
292
293    def _load_ca_cert(self):
294        key_path = self.path("cacert.key")
295        cert_path = self.path("cacert.pem")
296
297        if self.check_key_cert(key_path, cert_path, None):
298            self.logger.info("Using existing CA cert")
299            self._ca_key_path, self._ca_cert_path = key_path, cert_path
300
301    def check_key_cert(self, key_path, cert_path, hosts):
302        """Check that a key and cert file exist and are valid"""
303        if not os.path.exists(key_path) or not os.path.exists(cert_path):
304            return False
305
306        with self._config_openssl(hosts) as openssl:
307            end_date_str = openssl("x509",
308                                   "-noout",
309                                   "-enddate",
310                                   "-in", cert_path).decode("utf8").split("=", 1)[1].strip()
311            # Not sure if this works in other locales
312            end_date = datetime.strptime(end_date_str, "%b %d %H:%M:%S %Y %Z")
313            time_buffer = timedelta(**CERT_EXPIRY_BUFFER)
314            # Because `strptime` does not account for time zone offsets, it is
315            # always in terms of UTC, so the current time should be calculated
316            # accordingly.
317            if end_date < datetime.utcnow() + time_buffer:
318                return False
319
320        #TODO: check the key actually signed the cert.
321        return True
322
323    def _generate_ca(self, hosts):
324        path = self.path
325        self.logger.info("Generating new CA in %s" % self.base_path)
326
327        key_path = path("cacert.key")
328        req_path = path("careq.pem")
329        cert_path = path("cacert.pem")
330
331        with self._config_openssl(hosts) as openssl:
332            openssl("req",
333                    "-batch",
334                    "-new",
335                    "-newkey", "rsa:2048",
336                    "-keyout", key_path,
337                    "-out", req_path,
338                    "-subj", make_subject("web-platform-tests"),
339                    "-passout", "pass:%s" % self.password)
340
341            openssl("ca",
342                    "-batch",
343                    "-create_serial",
344                    "-keyfile", key_path,
345                    "-passin", "pass:%s" % self.password,
346                    "-selfsign",
347                    "-extensions", "v3_ca",
348                    "-notext",
349                    "-in", req_path,
350                    "-out", cert_path)
351
352        os.unlink(req_path)
353
354        self._ca_key_path, self._ca_cert_path = key_path, cert_path
355
356    def host_cert_path(self, hosts):
357        """Get a tuple of (private key path, certificate path) for a host,
358        generating new ones if necessary.
359
360        hosts must be a list of all hosts to appear on the certificate, with
361        the primary hostname first."""
362        hosts = tuple(sorted(hosts, key=lambda x:len(x)))
363        if hosts not in self.host_certificates:
364            if not self.force_regenerate:
365                key_cert = self._load_host_cert(hosts)
366            else:
367                key_cert = None
368            if key_cert is None:
369                key, cert = self._generate_host_cert(hosts)
370            else:
371                key, cert = key_cert
372            self.host_certificates[hosts] = key, cert
373
374        return self.host_certificates[hosts]
375
376    def _load_host_cert(self, hosts):
377        host = hosts[0]
378        key_path = self.path("%s.key" % host)
379        cert_path = self.path("%s.pem" % host)
380
381        # TODO: check that this cert was signed by the CA cert
382        if self.check_key_cert(key_path, cert_path, hosts):
383            self.logger.info("Using existing host cert")
384            return key_path, cert_path
385
386    def _generate_host_cert(self, hosts):
387        host = hosts[0]
388        if not self.force_regenerate:
389            self._load_ca_cert()
390        if self._ca_key_path is None:
391            self._generate_ca(hosts)
392        ca_key_path = self._ca_key_path
393
394        assert os.path.exists(ca_key_path)
395
396        path = self.path
397
398        req_path = path("wpt.req")
399        cert_path = path("%s.pem" % host)
400        key_path = path("%s.key" % host)
401
402        self.logger.info("Generating new host cert")
403
404        with self._config_openssl(hosts) as openssl:
405            openssl("req",
406                    "-batch",
407                    "-newkey", "rsa:2048",
408                    "-keyout", key_path,
409                    "-in", ca_key_path,
410                    "-nodes",
411                    "-out", req_path)
412
413            openssl("ca",
414                    "-batch",
415                    "-in", req_path,
416                    "-passin", "pass:%s" % self.password,
417                    "-subj", make_subject(host),
418                    "-out", cert_path)
419
420        os.unlink(req_path)
421
422        return key_path, cert_path
423