1import copy
2import inspect
3import json
4import logging
5from configparser import ConfigParser, ExtendedInterpolation
6
7import pytest
8import re
9import os
10import shutil
11import subprocess
12import time
13
14from datetime import datetime, timedelta
15from typing import Dict, Optional
16
17from pyhttpd.certs import CertificateSpec
18from .md_cert_util import MDCertUtil
19from pyhttpd.env import HttpdTestSetup, HttpdTestEnv
20from pyhttpd.result import ExecResult
21
22log = logging.getLogger(__name__)
23
24
25class MDTestSetup(HttpdTestSetup):
26
27    def __init__(self, env: 'MDTestEnv'):
28        super().__init__(env=env)
29        self.mdenv = env
30        self.add_modules(["watchdog", "proxy_connect", "md"])
31
32    def make(self):
33        super().make()
34        if "pebble" == self.mdenv.acme_server:
35            self._make_pebble_conf()
36        self.mdenv.clear_store()
37
38    def _make_pebble_conf(self):
39        our_dir = os.path.dirname(inspect.getfile(MDTestSetup))
40        conf_src_dir = os.path.join(our_dir, 'pebble')
41        conf_dest_dir = os.path.join(self.env.gen_dir, 'pebble')
42        if not os.path.exists(conf_dest_dir):
43            os.makedirs(conf_dest_dir)
44        for name in os.listdir(conf_src_dir):
45            src_path = os.path.join(conf_src_dir, name)
46            m = re.match(r'(.+).template', name)
47            if m:
48                self._make_template(src_path, os.path.join(conf_dest_dir, m.group(1)))
49            elif os.path.isfile(src_path):
50                shutil.copy(src_path, os.path.join(conf_dest_dir, name))
51
52
53class MDTestEnv(HttpdTestEnv):
54
55    MD_S_UNKNOWN = 0
56    MD_S_INCOMPLETE = 1
57    MD_S_COMPLETE = 2
58    MD_S_EXPIRED = 3
59    MD_S_ERROR = 4
60
61    EMPTY_JOUT = {'status': 0, 'output': []}
62
63    DOMAIN_SUFFIX = "%d.org" % time.time()
64    LOG_FMT_TIGHT = '%(levelname)s: %(message)s'
65
66    @classmethod
67    def get_acme_server(cls):
68        return os.environ['ACME'] if 'ACME' in os.environ else "pebble"
69
70    @classmethod
71    def has_acme_server(cls):
72        return cls.get_acme_server() != 'none'
73
74    @classmethod
75    def has_acme_eab(cls):
76        return cls.get_acme_server() == 'pebble'
77
78    @classmethod
79    def is_pebble(cls) -> bool:
80        return cls.get_acme_server() == 'pebble'
81
82    @classmethod
83    def lacks_ocsp(cls):
84        return cls.is_pebble()
85
86    @classmethod
87    def has_a2md(cls):
88        d = os.path.dirname(inspect.getfile(HttpdTestEnv))
89        config = ConfigParser(interpolation=ExtendedInterpolation())
90        config.read(os.path.join(d, 'config.ini'))
91        bin_dir = config.get('global', 'bindir')
92        a2md_bin = os.path.join(bin_dir, 'a2md')
93        return os.path.isfile(a2md_bin)
94
95    def __init__(self, pytestconfig=None):
96        super().__init__(pytestconfig=pytestconfig)
97        self.add_httpd_log_modules(["md"])
98        self._acme_server = self.get_acme_server()
99        self._acme_tos = "accepted"
100        self._acme_ca_pemfile = os.path.join(self.gen_dir, "apache/acme-ca.pem")
101        if "pebble" == self._acme_server:
102            self._acme_url = "https://localhost:14000/dir"
103            self._acme_eab_url = "https://localhost:14001/dir"
104        elif "boulder" == self._acme_server:
105            self._acme_url = "http://localhost:4001/directory"
106            self._acme_eab_url = None
107        else:
108            raise Exception(f"unknown ACME server type: {self._acme_server}")
109        self._acme_server_down = False
110        self._acme_server_ok = False
111
112        self._a2md_bin = os.path.join(self.bin_dir, 'a2md')
113        self._default_domain = f"test1.{self.http_tld}"
114        self._store_dir = "./md"
115        self.set_store_dir_default()
116
117        self.add_cert_specs([
118            CertificateSpec(domains=[f"expired.{self._http_tld}"],
119                            valid_from=timedelta(days=-100),
120                            valid_to=timedelta(days=-10)),
121            CertificateSpec(domains=["localhost"], key_type='rsa2048'),
122        ])
123
124    def setup_httpd(self, setup: HttpdTestSetup = None):
125        super().setup_httpd(setup=MDTestSetup(env=self))
126
127    def set_store_dir_default(self):
128        dirpath = "md"
129        if self.httpd_is_at_least("2.5.0"):
130            dirpath = os.path.join("state", dirpath)
131        self.set_store_dir(dirpath)
132
133    def set_store_dir(self, dirpath):
134        self._store_dir = os.path.join(self.server_dir, dirpath)
135        if self.acme_url:
136            self.a2md_stdargs([self.a2md_bin, "-a", self.acme_url,
137                               "-d", self._store_dir,  "-C", self.acme_ca_pemfile, "-j"])
138            self.a2md_rawargs([self.a2md_bin, "-a", self.acme_url,
139                               "-d", self._store_dir,  "-C", self.acme_ca_pemfile])
140
141    def get_apxs_var(self, name: str) -> str:
142        p = subprocess.run([self._apxs, "-q", name], capture_output=True, text=True)
143        if p.returncode != 0:
144            return ""
145        return p.stdout.strip()
146
147    @property
148    def acme_server(self):
149        return self._acme_server
150
151    @property
152    def acme_url(self):
153        return self._acme_url
154
155    @property
156    def acme_tos(self):
157        return self._acme_tos
158
159    @property
160    def a2md_bin(self):
161        return self._a2md_bin
162
163    @property
164    def acme_ca_pemfile(self):
165        return self._acme_ca_pemfile
166
167    @property
168    def store_dir(self):
169        return self._store_dir
170
171    def get_request_domain(self, request):
172        name = request.node.originalname if request.node.originalname else request.node.name
173        return "%s-%s" % (re.sub(r'[_]', '-', name), MDTestEnv.DOMAIN_SUFFIX)
174
175    def get_method_domain(self, method):
176        return "%s-%s" % (re.sub(r'[_]', '-', method.__name__.lower()), MDTestEnv.DOMAIN_SUFFIX)
177
178    def get_module_domain(self, module):
179        return "%s-%s" % (re.sub(r'[_]', '-', module.__name__.lower()), MDTestEnv.DOMAIN_SUFFIX)
180
181    def get_class_domain(self, c):
182        return "%s-%s" % (re.sub(r'[_]', '-', c.__name__.lower()), MDTestEnv.DOMAIN_SUFFIX)
183
184    # --------- cmd execution ---------
185
186    _a2md_args = []
187    _a2md_args_raw = []
188
189    def a2md_stdargs(self, args):
190        self._a2md_args = [] + args
191
192    def a2md_rawargs(self, args):
193        self._a2md_args_raw = [] + args
194
195    def a2md(self, args, raw=False) -> ExecResult:
196        preargs = self._a2md_args
197        if raw:
198            preargs = self._a2md_args_raw
199        log.debug("running: {0} {1}".format(preargs, args))
200        return self.run(preargs + args)
201
202    def check_acme(self):
203        if self._acme_server_ok:
204            return True
205        if self._acme_server_down:
206            pytest.skip(msg="ACME server not running")
207            return False
208        if self.is_live(self.acme_url, timeout=timedelta(seconds=0.5)):
209            self._acme_server_ok = True
210            return True
211        else:
212            self._acme_server_down = True
213            pytest.fail(msg="ACME server not running", pytrace=False)
214            return False
215
216    def get_ca_pem_file(self, hostname: str) -> Optional[str]:
217        pem_file = super().get_ca_pem_file(hostname)
218        if pem_file is None:
219            pem_file = self.acme_ca_pemfile
220        return pem_file
221
222    # --------- access local store ---------
223
224    def purge_store(self):
225        log.debug("purge store dir: %s" % self._store_dir)
226        assert len(self._store_dir) > 1
227        if os.path.exists(self._store_dir):
228            shutil.rmtree(self._store_dir, ignore_errors=False)
229        os.makedirs(self._store_dir)
230
231    def clear_store(self):
232        log.debug("clear store dir: %s" % self._store_dir)
233        assert len(self._store_dir) > 1
234        if not os.path.exists(self._store_dir):
235            os.makedirs(self._store_dir)
236        for dirpath in ["challenges", "tmp", "archive", "domains", "accounts", "staging", "ocsp"]:
237            shutil.rmtree(os.path.join(self._store_dir, dirpath), ignore_errors=True)
238
239    def clear_ocsp_store(self):
240        assert len(self._store_dir) > 1
241        dirpath = os.path.join(self._store_dir, "ocsp")
242        log.debug("clear ocsp store dir: %s" % dir)
243        if os.path.exists(dirpath):
244            shutil.rmtree(dirpath, ignore_errors=True)
245
246    def authz_save(self, name, content):
247        dirpath = os.path.join(self._store_dir, 'staging', name)
248        os.makedirs(dirpath)
249        open(os.path.join(dirpath, 'authz.json'), "w").write(content)
250
251    def path_store_json(self):
252        return os.path.join(self._store_dir, 'md_store.json')
253
254    def path_account(self, acct):
255        return os.path.join(self._store_dir, 'accounts', acct, 'account.json')
256
257    def path_account_key(self, acct):
258        return os.path.join(self._store_dir, 'accounts', acct, 'account.pem')
259
260    def store_domains(self):
261        return os.path.join(self._store_dir, 'domains')
262
263    def store_archives(self):
264        return os.path.join(self._store_dir, 'archive')
265
266    def store_stagings(self):
267        return os.path.join(self._store_dir, 'staging')
268
269    def store_challenges(self):
270        return os.path.join(self._store_dir, 'challenges')
271
272    def store_domain_file(self, domain, filename):
273        return os.path.join(self.store_domains(), domain, filename)
274
275    def store_archived_file(self, domain, version, filename):
276        return os.path.join(self.store_archives(), "%s.%d" % (domain, version), filename)
277
278    def store_staged_file(self, domain, filename):
279        return os.path.join(self.store_stagings(), domain, filename)
280
281    def path_fallback_cert(self, domain):
282        return os.path.join(self._store_dir, 'domains', domain, 'fallback-pubcert.pem')
283
284    def path_job(self, domain):
285        return os.path.join(self._store_dir, 'staging', domain, 'job.json')
286
287    def replace_store(self, src):
288        shutil.rmtree(self._store_dir, ignore_errors=False)
289        shutil.copytree(src, self._store_dir)
290
291    def list_accounts(self):
292        return os.listdir(os.path.join(self._store_dir, 'accounts'))
293
294    def check_md(self, domain, md=None, state=-1, ca=None, protocol=None, agreement=None, contacts=None):
295        domains = None
296        if isinstance(domain, list):
297            domains = domain
298            domain = domains[0]
299        if md:
300            domain = md
301        path = self.store_domain_file(domain, 'md.json')
302        with open(path) as f:
303            md = json.load(f)
304        assert md
305        if domains:
306            assert md['domains'] == domains
307        if state >= 0:
308            assert md['state'] == state
309        if ca:
310            assert md['ca']['url'] == ca
311        if protocol:
312            assert md['ca']['proto'] == protocol
313        if agreement:
314            assert md['ca']['agreement'] == agreement
315        if contacts:
316            assert md['contacts'] == contacts
317
318    def pkey_fname(self, pkeyspec=None):
319        if pkeyspec and not re.match(r'^rsa( ?\d+)?$', pkeyspec.lower()):
320            return "privkey.{0}.pem".format(pkeyspec.lower())
321        return 'privkey.pem'
322
323    def cert_fname(self, pkeyspec=None):
324        if pkeyspec and not re.match(r'^rsa( ?\d+)?$', pkeyspec.lower()):
325            return "pubcert.{0}.pem".format(pkeyspec.lower())
326        return 'pubcert.pem'
327
328    def check_md_complete(self, domain, pkey=None):
329        md = self.get_md_status(domain)
330        assert md
331        assert 'state' in md, "md is unexpected: {0}".format(md)
332        assert md['state'] is MDTestEnv.MD_S_COMPLETE, f"unexpected state: {md['state']}"
333        pkey_file = self.store_domain_file(domain, self.pkey_fname(pkey))
334        cert_file = self.store_domain_file(domain, self.cert_fname(pkey))
335        r = self.run(['ls', os.path.dirname(pkey_file)])
336        if not os.path.isfile(pkey_file):
337            assert False, f"pkey missing: {pkey_file}: {r.stdout}"
338        if not os.path.isfile(cert_file):
339            assert False, f"cert missing: {cert_file}: {r.stdout}"
340
341    def check_md_credentials(self, domain):
342        if isinstance(domain, list):
343            domains = domain
344            domain = domains[0]
345        else:
346            domains = [domain]
347        # check private key, validate certificate, etc
348        MDCertUtil.validate_privkey(self.store_domain_file(domain, 'privkey.pem'))
349        cert = MDCertUtil(self.store_domain_file(domain, 'pubcert.pem'))
350        cert.validate_cert_matches_priv_key(self.store_domain_file(domain, 'privkey.pem'))
351        # check SANs and CN
352        assert cert.get_cn() == domain
353        # compare lists twice in opposite directions: SAN may not respect ordering
354        san_list = list(cert.get_san_list())
355        assert len(san_list) == len(domains)
356        assert set(san_list).issubset(domains)
357        assert set(domains).issubset(san_list)
358        # check valid dates interval
359        not_before = cert.get_not_before()
360        not_after = cert.get_not_after()
361        assert not_before < datetime.now(not_before.tzinfo)
362        assert not_after > datetime.now(not_after.tzinfo)
363
364    # --------- check utilities ---------
365
366    def check_json_contains(self, actual, expected):
367        # write all expected key:value bindings to a copy of the actual data ...
368        # ... assert it stays unchanged
369        test_json = copy.deepcopy(actual)
370        test_json.update(expected)
371        assert actual == test_json
372
373    def check_file_access(self, path, exp_mask):
374        actual_mask = os.lstat(path).st_mode & 0o777
375        assert oct(actual_mask) == oct(exp_mask)
376
377    def check_dir_empty(self, path):
378        assert os.listdir(path) == []
379
380    def get_http_status(self, domain, path, use_https=True):
381        r = self.get_meta(domain, path, use_https, insecure=True)
382        return r.response['status']
383
384    def get_cert(self, domain, tls=None, ciphers=None):
385        return MDCertUtil.load_server_cert(self._httpd_addr, self.https_port,
386                                           domain, tls=tls, ciphers=ciphers)
387
388    def get_server_cert(self, domain, proto=None, ciphers=None):
389        args = [
390            "openssl", "s_client", "-status",
391            "-connect", "%s:%s" % (self._httpd_addr, self.https_port),
392            "-CAfile", self.acme_ca_pemfile,
393            "-servername", domain,
394            "-showcerts"
395        ]
396        if proto is not None:
397            args.extend(["-{0}".format(proto)])
398        if ciphers is not None:
399            args.extend(["-cipher", ciphers])
400        r = self.run(args)
401        # noinspection PyBroadException
402        try:
403            return MDCertUtil.parse_pem_cert(r.stdout)
404        except:
405            return None
406
407    def verify_cert_key_lenghts(self, domain, pkeys):
408        for p in pkeys:
409            cert = self.get_server_cert(domain, proto="tls1_2", ciphers=p['ciphers'])
410            if 0 == p['keylen']:
411                assert cert is None
412            else:
413                assert cert, "no cert returned for cipher: {0}".format(p['ciphers'])
414                assert cert.get_key_length() == p['keylen'], "key length, expected {0}, got {1}".format(
415                    p['keylen'], cert.get_key_length()
416                )
417
418    def get_meta(self, domain, path, use_https=True, insecure=False):
419        schema = "https" if use_https else "http"
420        port = self.https_port if use_https else self.http_port
421        r = self.curl_get(f"{schema}://{domain}:{port}{path}", insecure=insecure)
422        assert r.exit_code == 0
423        assert r.response
424        assert r.response['header']
425        return r
426
427    def get_content(self, domain, path, use_https=True):
428        schema = "https" if use_https else "http"
429        port = self.https_port if use_https else self.http_port
430        r = self.curl_get(f"{schema}://{domain}:{port}{path}")
431        assert r.exit_code == 0
432        return r.stdout
433
434    def get_json_content(self, domain, path, use_https=True, insecure=False):
435        schema = "https" if use_https else "http"
436        port = self.https_port if use_https else self.http_port
437        url = f"{schema}://{domain}:{port}{path}"
438        r = self.curl_get(url, insecure=insecure)
439        if r.exit_code != 0:
440            log.error(f"curl get on {url} returned {r.exit_code}"
441                      f"\nstdout: {r.stdout}"
442                      f"\nstderr: {r.stderr}")
443        assert r.exit_code == 0, r.stderr
444        return r.json
445
446    def get_certificate_status(self, domain) -> Dict:
447        return self.get_json_content(domain, "/.httpd/certificate-status", insecure=True)
448
449    def get_md_status(self, domain, via_domain=None, use_https=True) -> Dict:
450        if via_domain is None:
451            via_domain = self._default_domain
452        return self.get_json_content(via_domain, f"/md-status/{domain}",
453                                     use_https=use_https)
454
455    def get_server_status(self, query="/", via_domain=None, use_https=True):
456        if via_domain is None:
457            via_domain = self._default_domain
458        return self.get_content(via_domain, "/server-status%s" % query, use_https=use_https)
459
460    def await_completion(self, names, must_renew=False, restart=True, timeout=60,
461                         via_domain=None, use_https=True):
462        try_until = time.time() + timeout
463        renewals = {}
464        names = names.copy()
465        while len(names) > 0:
466            if time.time() >= try_until:
467                return False
468            for name in names:
469                mds = self.get_md_status(name, via_domain=via_domain, use_https=use_https)
470                if mds is None:
471                    log.debug("not managed by md: %s" % name)
472                    return False
473
474                if 'renewal' in mds:
475                    renewal = mds['renewal']
476                    renewals[name] = True
477                    if 'finished' in renewal and renewal['finished'] is True:
478                        if (not must_renew) or (name in renewals):
479                            log.debug(f"domain cert was renewed: {name}")
480                            names.remove(name)
481
482            if len(names) != 0:
483                time.sleep(0.1)
484        if restart:
485            time.sleep(0.1)
486            return self.apache_restart() == 0
487        return True
488
489    def is_renewing(self, name):
490        stat = self.get_certificate_status(name)
491        return 'renewal' in stat
492
493    def await_renewal(self, names, timeout=60):
494        try_until = time.time() + timeout
495        while len(names) > 0:
496            if time.time() >= try_until:
497                return False
498            for name in names:
499                md = self.get_md_status(name)
500                if md is None:
501                    log.debug("not managed by md: %s" % name)
502                    return False
503
504                if 'renewal' in md:
505                    names.remove(name)
506
507            if len(names) != 0:
508                time.sleep(0.1)
509        return True
510
511    def await_error(self, domain, timeout=60, via_domain=None, use_https=True, errors=1):
512        try_until = time.time() + timeout
513        while True:
514            if time.time() >= try_until:
515                return False
516            md = self.get_md_status(domain, via_domain=via_domain, use_https=use_https)
517            if md:
518                if 'state' in md and md['state'] == MDTestEnv.MD_S_ERROR:
519                    return md
520                if 'renewal' in md and 'errors' in md['renewal'] \
521                        and md['renewal']['errors'] >= errors:
522                    return md
523            time.sleep(0.1)
524
525    def await_file(self, fpath, timeout=60):
526        try_until = time.time() + timeout
527        while True:
528            if time.time() >= try_until:
529                return False
530            if os.path.isfile(fpath):
531                return True
532            time.sleep(0.1)
533
534    def check_file_permissions(self, domain):
535        dpath = os.path.join(self.store_dir, 'domains', domain)
536        assert os.path.isdir(dpath)
537        md = json.load(open(os.path.join(dpath, 'md.json')))
538        assert md
539        acct = md['ca']['account']
540        assert acct
541        self.check_file_access(self.path_store_json(), 0o600)
542        # domains
543        self.check_file_access(self.store_domains(), 0o700)
544        self.check_file_access(os.path.join(self.store_domains(), domain), 0o700)
545        self.check_file_access(self.store_domain_file(domain, 'privkey.pem'), 0o600)
546        self.check_file_access(self.store_domain_file(domain, 'pubcert.pem'), 0o600)
547        self.check_file_access(self.store_domain_file(domain, 'md.json'), 0o600)
548        # archive
549        self.check_file_access(self.store_archived_file(domain, 1, 'md.json'), 0o600)
550        # accounts
551        self.check_file_access(os.path.join(self._store_dir, 'accounts'), 0o755)
552        self.check_file_access(os.path.join(self._store_dir, 'accounts', acct), 0o755)
553        self.check_file_access(self.path_account(acct), 0o644)
554        self.check_file_access(self.path_account_key(acct), 0o644)
555        # staging
556        self.check_file_access(self.store_stagings(), 0o755)
557
558    def get_ocsp_status(self, domain, proto=None, cipher=None, ca_file=None):
559        stat = {}
560        args = [
561            "openssl", "s_client", "-status",
562            "-connect", "%s:%s" % (self._httpd_addr, self.https_port),
563            "-CAfile", ca_file if ca_file else self.acme_ca_pemfile,
564            "-servername", domain,
565            "-showcerts"
566        ]
567        if proto is not None:
568            args.extend(["-{0}".format(proto)])
569        if cipher is not None:
570            args.extend(["-cipher", cipher])
571        r = self.run(args, debug_log=False)
572        ocsp_regex = re.compile(r'OCSP response: +([^=\n]+)\n')
573        matches = ocsp_regex.finditer(r.stdout)
574        for m in matches:
575            if m.group(1) != "":
576                stat['ocsp'] = m.group(1)
577        if 'ocsp' not in stat:
578            ocsp_regex = re.compile(r'OCSP Response Status:\s*(.+)')
579            matches = ocsp_regex.finditer(r.stdout)
580            for m in matches:
581                if m.group(1) != "":
582                    stat['ocsp'] = m.group(1)
583        verify_regex = re.compile(r'Verify return code:\s*(.+)')
584        matches = verify_regex.finditer(r.stdout)
585        for m in matches:
586            if m.group(1) != "":
587                stat['verify'] = m.group(1)
588        return stat
589
590    def await_ocsp_status(self, domain, timeout=10, ca_file=None):
591        try_until = time.time() + timeout
592        while True:
593            if time.time() >= try_until:
594                break
595            stat = self.get_ocsp_status(domain, ca_file=ca_file)
596            if 'ocsp' in stat and stat['ocsp'] != "no response sent":
597                return stat
598            time.sleep(0.1)
599        raise TimeoutError(f"ocsp respopnse not available: {domain}")
600
601    def create_self_signed_cert(self, name_list, valid_days, serial=1000, path=None):
602        dirpath = path
603        if not path:
604            dirpath = os.path.join(self.store_domains(), name_list[0])
605        return MDCertUtil.create_self_signed_cert(dirpath, name_list, valid_days, serial)
606