1import inspect
2import logging
3import re
4import os
5import shutil
6import stat
7import subprocess
8import sys
9import time
10from datetime import datetime, timedelta
11from string import Template
12from typing import List, Optional
13
14from configparser import ConfigParser, ExtendedInterpolation
15from urllib.parse import urlparse
16
17from .certs import Credentials, HttpdTestCA, CertificateSpec
18from .log import HttpdErrorLog
19from .nghttp import Nghttp
20from .result import ExecResult
21
22
23log = logging.getLogger(__name__)
24
25
26class Dummy:
27    pass
28
29
30class HttpdTestSetup:
31
32    # the modules we want to load
33    MODULES = [
34        "log_config",
35        "logio",
36        "unixd",
37        "version",
38        "authn_core",
39        "authz_host",
40        "authz_groupfile",
41        "authz_user",
42        "authz_core",
43        "access_compat",
44        "auth_basic",
45        "cache",
46        "cache_disk",
47        "cache_socache",
48        "socache_shmcb",
49        "dumpio",
50        "reqtimeout",
51        "filter",
52        "mime",
53        "env",
54        "headers",
55        "setenvif",
56        "slotmem_shm",
57        "status",
58        "dir",
59        "alias",
60        "rewrite",
61        "deflate",
62        "proxy",
63        "proxy_http",
64    ]
65
66    def __init__(self, env: 'HttpdTestEnv'):
67        self.env = env
68        self._source_dirs = [os.path.dirname(inspect.getfile(HttpdTestSetup))]
69        self._modules = HttpdTestSetup.MODULES.copy()
70
71    def add_source_dir(self, source_dir):
72        self._source_dirs.append(source_dir)
73
74    def add_modules(self, modules: List[str]):
75        self._modules.extend(modules)
76
77    def make(self):
78        self._make_dirs()
79        self._make_conf()
80        if self.env.mpm_module is not None:
81            self.add_modules([self.env.mpm_module])
82        if self.env.ssl_module is not None:
83            self.add_modules([self.env.ssl_module])
84        self._make_modules_conf()
85        self._make_htdocs()
86        self.env.clear_curl_headerfiles()
87
88    def _make_dirs(self):
89        if os.path.exists(self.env.gen_dir):
90            shutil.rmtree(self.env.gen_dir)
91        os.makedirs(self.env.gen_dir)
92        if not os.path.exists(self.env.server_logs_dir):
93            os.makedirs(self.env.server_logs_dir)
94
95    def _make_conf(self):
96        # remove anything from another run/test suite
97        conf_dest_dir = os.path.join(self.env.server_dir, 'conf')
98        if os.path.isdir(conf_dest_dir):
99            shutil.rmtree(conf_dest_dir)
100        for d in self._source_dirs:
101            conf_src_dir = os.path.join(d, 'conf')
102            if os.path.isdir(conf_src_dir):
103                if not os.path.exists(conf_dest_dir):
104                    os.makedirs(conf_dest_dir)
105                for name in os.listdir(conf_src_dir):
106                    src_path = os.path.join(conf_src_dir, name)
107                    m = re.match(r'(.+).template', name)
108                    if m:
109                        self._make_template(src_path, os.path.join(conf_dest_dir, m.group(1)))
110                    elif os.path.isfile(src_path):
111                        shutil.copy(src_path, os.path.join(conf_dest_dir, name))
112
113    def _make_template(self, src, dest):
114        var_map = dict()
115        for name, value in HttpdTestEnv.__dict__.items():
116            if isinstance(value, property):
117                var_map[name] = value.fget(self.env)
118        t = Template(''.join(open(src).readlines()))
119        with open(dest, 'w') as fd:
120            fd.write(t.substitute(var_map))
121
122    def _make_modules_conf(self):
123        loaded = set()
124        modules_conf = os.path.join(self.env.server_dir, 'conf/modules.conf')
125        with open(modules_conf, 'w') as fd:
126            # issue load directives for all modules we want that are shared
127            missing_mods = list()
128            for m in self._modules:
129                match = re.match(r'^mod_(.+)$', m)
130                if match:
131                    m = match.group(1)
132                if m in loaded:
133                    continue
134                mod_path = os.path.join(self.env.libexec_dir, f"mod_{m}.so")
135                if os.path.isfile(mod_path):
136                    fd.write(f"LoadModule {m}_module   \"{mod_path}\"\n")
137                elif m in self.env.static_modules:
138                    fd.write(f"#built static: LoadModule {m}_module   \"{mod_path}\"\n")
139                else:
140                    missing_mods.append(m)
141                loaded.add(m)
142        if len(missing_mods) > 0:
143            raise Exception(f"Unable to find modules: {missing_mods} "
144                            f"DSOs: {self.env.dso_modules}")
145
146    def _make_htdocs(self):
147        if not os.path.exists(self.env.server_docs_dir):
148            os.makedirs(self.env.server_docs_dir)
149        dest_dir = os.path.join(self.env.server_dir, 'htdocs')
150        # remove anything from another run/test suite
151        if os.path.isdir(dest_dir):
152            shutil.rmtree(dest_dir)
153        for d in self._source_dirs:
154            srcdocs = os.path.join(d, 'htdocs')
155            if os.path.isdir(srcdocs):
156                shutil.copytree(srcdocs, dest_dir, dirs_exist_ok=True)
157        # make all contained .py scripts executable
158        for dirpath, _dirnames, filenames in os.walk(dest_dir):
159            for fname in filenames:
160                if re.match(r'.+\.py', fname):
161                    py_file = os.path.join(dirpath, fname)
162                    st = os.stat(py_file)
163                    os.chmod(py_file, st.st_mode | stat.S_IEXEC)
164
165
166class HttpdTestEnv:
167
168    @classmethod
169    def get_ssl_module(cls):
170        return os.environ['SSL'] if 'SSL' in os.environ else 'mod_ssl'
171
172    def __init__(self, pytestconfig=None):
173        self._our_dir = os.path.dirname(inspect.getfile(Dummy))
174        self.config = ConfigParser(interpolation=ExtendedInterpolation())
175        self.config.read(os.path.join(self._our_dir, 'config.ini'))
176
177        self._bin_dir = self.config.get('global', 'bindir')
178        self._apxs = self.config.get('global', 'apxs')
179        self._prefix = self.config.get('global', 'prefix')
180        self._apachectl = self.config.get('global', 'apachectl')
181        self._libexec_dir = self.get_apxs_var('LIBEXECDIR')
182
183        self._curl = self.config.get('global', 'curl_bin')
184        self._nghttp = self.config.get('global', 'nghttp')
185        if self._nghttp is None:
186            self._nghttp = 'nghttp'
187        self._h2load = self.config.get('global', 'h2load')
188        if self._h2load is None:
189            self._h2load = 'h2load'
190
191        self._http_port = int(self.config.get('test', 'http_port'))
192        self._https_port = int(self.config.get('test', 'https_port'))
193        self._proxy_port = int(self.config.get('test', 'proxy_port'))
194        self._http_tld = self.config.get('test', 'http_tld')
195        self._test_dir = self.config.get('test', 'test_dir')
196        self._gen_dir = self.config.get('test', 'gen_dir')
197        self._server_dir = os.path.join(self._gen_dir, 'apache')
198        self._server_conf_dir = os.path.join(self._server_dir, "conf")
199        self._server_docs_dir = os.path.join(self._server_dir, "htdocs")
200        self._server_logs_dir = os.path.join(self.server_dir, "logs")
201        self._server_access_log = os.path.join(self._server_logs_dir, "access_log")
202        self._error_log = HttpdErrorLog(os.path.join(self._server_logs_dir, "error_log"))
203        self._apachectl_stderr = None
204
205        self._dso_modules = self.config.get('httpd', 'dso_modules').split(' ')
206        self._static_modules = self.config.get('httpd', 'static_modules').split(' ')
207        self._mpm_module = f"mpm_{os.environ['MPM']}" if 'MPM' in os.environ else 'mpm_event'
208        self._ssl_module = self.get_ssl_module()
209        if len(self._ssl_module.strip()) == 0:
210            self._ssl_module = None
211
212        self._httpd_addr = "127.0.0.1"
213        self._http_base = f"http://{self._httpd_addr}:{self.http_port}"
214        self._https_base = f"https://{self._httpd_addr}:{self.https_port}"
215
216        self._verbosity = pytestconfig.option.verbose if pytestconfig is not None else 0
217        self._test_conf = os.path.join(self._server_conf_dir, "test.conf")
218        self._httpd_base_conf = []
219        self._httpd_log_modules = []
220        self._log_interesting = None
221        self._setup = None
222
223        self._ca = None
224        self._cert_specs = [CertificateSpec(domains=[
225            f"test1.{self._http_tld}",
226            f"test2.{self._http_tld}",
227            f"test3.{self._http_tld}",
228            f"cgi.{self._http_tld}",
229        ], key_type='rsa4096')]
230
231        self._verify_certs = False
232        self._curl_headerfiles_n = 0
233
234    def add_httpd_conf(self, lines: List[str]):
235        self._httpd_base_conf.extend(lines)
236
237    def add_httpd_log_modules(self, modules: List[str]):
238        self._httpd_log_modules.extend(modules)
239
240    def issue_certs(self):
241        if self._ca is None:
242            self._ca = HttpdTestCA.create_root(name=self.http_tld,
243                                               store_dir=os.path.join(self.server_dir, 'ca'),
244                                               key_type="rsa4096")
245        self._ca.issue_certs(self._cert_specs)
246
247    def setup_httpd(self, setup: HttpdTestSetup = None):
248        """Create the server environment with config, htdocs and certificates"""
249        self._setup = setup if setup is not None else HttpdTestSetup(env=self)
250        self._setup.make()
251        self.issue_certs()
252        if self._httpd_log_modules:
253            if self._verbosity >= 2:
254                log_level = "trace2"
255            elif self._verbosity >= 1:
256                log_level = "debug"
257            else:
258                log_level = "info"
259            self._log_interesting = "LogLevel"
260            for name in self._httpd_log_modules:
261                self._log_interesting += f" {name}:{log_level}"
262
263    @property
264    def apxs(self) -> str:
265        return self._apxs
266
267    @property
268    def verbosity(self) -> int:
269        return self._verbosity
270
271    @property
272    def prefix(self) -> str:
273        return self._prefix
274
275    @property
276    def mpm_module(self) -> str:
277        return self._mpm_module
278
279    @property
280    def ssl_module(self) -> str:
281        return self._ssl_module
282
283    @property
284    def http_addr(self) -> str:
285        return self._httpd_addr
286
287    @property
288    def http_port(self) -> int:
289        return self._http_port
290
291    @property
292    def https_port(self) -> int:
293        return self._https_port
294
295    @property
296    def proxy_port(self) -> int:
297        return self._proxy_port
298
299    @property
300    def http_tld(self) -> str:
301        return self._http_tld
302
303    @property
304    def http_base_url(self) -> str:
305        return self._http_base
306
307    @property
308    def https_base_url(self) -> str:
309        return self._https_base
310
311    @property
312    def bin_dir(self) -> str:
313        return self._bin_dir
314
315    @property
316    def gen_dir(self) -> str:
317        return self._gen_dir
318
319    @property
320    def test_dir(self) -> str:
321        return self._test_dir
322
323    @property
324    def server_dir(self) -> str:
325        return self._server_dir
326
327    @property
328    def server_logs_dir(self) -> str:
329        return self._server_logs_dir
330
331    @property
332    def libexec_dir(self) -> str:
333        return self._libexec_dir
334
335    @property
336    def dso_modules(self) -> List[str]:
337        return self._dso_modules
338
339    @property
340    def static_modules(self) -> List[str]:
341        return self._static_modules
342
343    @property
344    def server_conf_dir(self) -> str:
345        return self._server_conf_dir
346
347    @property
348    def server_docs_dir(self) -> str:
349        return self._server_docs_dir
350
351    @property
352    def httpd_error_log(self) -> HttpdErrorLog:
353        return self._error_log
354
355    def htdocs_src(self, path):
356        return os.path.join(self._our_dir, 'htdocs', path)
357
358    @property
359    def h2load(self) -> str:
360        return self._h2load
361
362    @property
363    def ca(self) -> Credentials:
364        return self._ca
365
366    @property
367    def apachectl_stderr(self):
368        return self._apachectl_stderr
369
370    def add_cert_specs(self, specs: List[CertificateSpec]):
371        self._cert_specs.extend(specs)
372
373    def get_credentials_for_name(self, dns_name) -> List['Credentials']:
374        for spec in [s for s in self._cert_specs if s.domains is not None]:
375            if dns_name in spec.domains:
376                return self.ca.get_credentials_for_name(spec.domains[0])
377        return []
378
379    def _versiontuple(self, v):
380        return tuple(map(int, v.split('.')))
381
382    def httpd_is_at_least(self, minv):
383        hv = self._versiontuple(self.get_httpd_version())
384        return hv >= self._versiontuple(minv)
385
386    def has_h2load(self):
387        return self._h2load != ""
388
389    def h2load_is_at_least(self, minv):
390        if not self.has_h2load():
391            return False
392        p = subprocess.run([self._h2load, '--version'], capture_output=True, text=True)
393        if p.returncode != 0:
394            return False
395        s = p.stdout.strip()
396        m = re.match(r'h2load nghttp2/(\S+)', s)
397        if m:
398            hv = self._versiontuple(m.group(1))
399            return hv >= self._versiontuple(minv)
400        return False
401
402    def has_nghttp(self):
403        return self._nghttp != ""
404
405    def has_nghttp_get_assets(self):
406        if not self.has_nghttp():
407            return False
408        args = [self._nghttp, "-a"]
409        p = subprocess.run(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
410        rv = p.returncode
411        if rv != 0:
412            return False
413        return p.stderr == ""
414
415    def get_apxs_var(self, name: str) -> str:
416        p = subprocess.run([self._apxs, "-q", name], capture_output=True, text=True)
417        if p.returncode != 0:
418            return ""
419        return p.stdout.strip()
420
421    def get_httpd_version(self) -> str:
422        return self.get_apxs_var("HTTPD_VERSION")
423
424    def mkpath(self, path):
425        if not os.path.exists(path):
426            return os.makedirs(path)
427
428    def run(self, args, intext=None, debug_log=True):
429        if debug_log:
430            log.debug(f"run: {args}")
431        start = datetime.now()
432        p = subprocess.run(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE,
433                           input=intext.encode() if intext else None)
434        return ExecResult(args=args, exit_code=p.returncode,
435                          stdout=p.stdout, stderr=p.stderr,
436                          duration=datetime.now() - start)
437
438    def mkurl(self, scheme, hostname, path='/'):
439        port = self.https_port if scheme == 'https' else self.http_port
440        return f"{scheme}://{hostname}.{self.http_tld}:{port}{path}"
441
442    def install_test_conf(self, lines: List[str]):
443        with open(self._test_conf, 'w') as fd:
444            fd.write('\n'.join(self._httpd_base_conf))
445            fd.write('\n')
446            if self._verbosity >= 2:
447                fd.write(f"LogLevel core:trace5 {self.mpm_module}:trace5\n")
448            if self._log_interesting:
449                fd.write(self._log_interesting)
450            fd.write('\n\n')
451            fd.write('\n'.join(lines))
452            fd.write('\n')
453
454    def is_live(self, url: str = None, timeout: timedelta = None):
455        if url is None:
456            url = self._http_base
457        if timeout is None:
458            timeout = timedelta(seconds=5)
459        try_until = datetime.now() + timeout
460        last_err = ""
461        while datetime.now() < try_until:
462            # noinspection PyBroadException
463            try:
464                r = self.curl_get(url, insecure=True)
465                if r.exit_code == 0:
466                    return True
467                time.sleep(.1)
468            except ConnectionRefusedError:
469                log.debug("connection refused")
470                time.sleep(.1)
471            except:
472                if last_err != str(sys.exc_info()[0]):
473                    last_err = str(sys.exc_info()[0])
474                    log.debug("Unexpected error: %s", last_err)
475                time.sleep(.1)
476        log.debug(f"Unable to contact server after {timeout}")
477        return False
478
479    def is_dead(self, url: str = None, timeout: timedelta = None):
480        if url is None:
481            url = self._http_base
482        if timeout is None:
483            timeout = timedelta(seconds=5)
484        try_until = datetime.now() + timeout
485        last_err = None
486        while datetime.now() < try_until:
487            # noinspection PyBroadException
488            try:
489                r = self.curl_get(url)
490                if r.exit_code != 0:
491                    return True
492                time.sleep(.1)
493            except ConnectionRefusedError:
494                log.debug("connection refused")
495                return True
496            except:
497                if last_err != str(sys.exc_info()[0]):
498                    last_err = str(sys.exc_info()[0])
499                    log.debug("Unexpected error: %s", last_err)
500                time.sleep(.1)
501        log.debug(f"Server still responding after {timeout}")
502        return False
503
504    def _run_apachectl(self, cmd) -> ExecResult:
505        conf_file = 'stop.conf' if cmd == 'stop' else 'httpd.conf'
506        args = [self._apachectl,
507                "-d", self.server_dir,
508                "-f", os.path.join(self._server_dir, f'conf/{conf_file}'),
509                "-k", cmd]
510        r = self.run(args)
511        self._apachectl_stderr = r.stderr
512        if r.exit_code != 0:
513            log.warning(f"failed: {r}")
514        return r
515
516    def apache_reload(self):
517        r = self._run_apachectl("graceful")
518        if r.exit_code == 0:
519            timeout = timedelta(seconds=10)
520            return 0 if self.is_live(self._http_base, timeout=timeout) else -1
521        return r.exit_code
522
523    def apache_restart(self):
524        self.apache_stop()
525        r = self._run_apachectl("start")
526        if r.exit_code == 0:
527            timeout = timedelta(seconds=10)
528            return 0 if self.is_live(self._http_base, timeout=timeout) else -1
529        return r.exit_code
530
531    def apache_stop(self):
532        r = self._run_apachectl("stop")
533        if r.exit_code == 0:
534            timeout = timedelta(seconds=10)
535            return 0 if self.is_dead(self._http_base, timeout=timeout) else -1
536        return r
537
538    def apache_graceful_stop(self):
539        log.debug("stop apache")
540        self._run_apachectl("graceful-stop")
541        return 0 if self.is_dead() else -1
542
543    def apache_fail(self):
544        log.debug("expect apache fail")
545        self._run_apachectl("stop")
546        rv = self._run_apachectl("start")
547        if rv == 0:
548            rv = 0 if self.is_dead() else -1
549        else:
550            rv = 0
551        return rv
552
553    def apache_access_log_clear(self):
554        if os.path.isfile(self._server_access_log):
555            os.remove(self._server_access_log)
556
557    def get_ca_pem_file(self, hostname: str) -> Optional[str]:
558        if len(self.get_credentials_for_name(hostname)) > 0:
559            return self.ca.cert_file
560        return None
561
562    def clear_curl_headerfiles(self):
563        for fname in os.listdir(path=self.gen_dir):
564            if re.match(r'curl\.headers\.\d+', fname):
565                os.remove(os.path.join(self.gen_dir, fname))
566        self._curl_headerfiles_n = 0
567
568    def curl_complete_args(self, urls, timeout=None, options=None,
569                           insecure=False, force_resolve=True):
570        if not isinstance(urls, list):
571            urls = [urls]
572        u = urlparse(urls[0])
573        assert u.hostname, f"hostname not in url: {urls[0]}"
574        headerfile = f"{self.gen_dir}/curl.headers.{self._curl_headerfiles_n}"
575        self._curl_headerfiles_n += 1
576
577        args = [
578            self._curl, "-s", "--path-as-is", "-D", headerfile,
579        ]
580        if u.scheme == 'http':
581            pass
582        elif insecure:
583            args.append('--insecure')
584        elif options and "--cacert" in options:
585            pass
586        else:
587            ca_pem = self.get_ca_pem_file(u.hostname)
588            if ca_pem:
589                args.extend(["--cacert", ca_pem])
590
591        if force_resolve and u.hostname != 'localhost' \
592                and u.hostname != self._httpd_addr \
593                and not re.match(r'^(\d+|\[|:).*', u.hostname):
594            assert u.port, f"port not in url: {urls[0]}"
595            args.extend(["--resolve", f"{u.hostname}:{u.port}:{self._httpd_addr}"])
596        if timeout is not None and int(timeout) > 0:
597            args.extend(["--connect-timeout", str(int(timeout))])
598        if options:
599            args.extend(options)
600        args += urls
601        return args, headerfile
602
603    def curl_parse_headerfile(self, headerfile: str, r: ExecResult = None) -> ExecResult:
604        lines = open(headerfile).readlines()
605        exp_stat = True
606        if r is None:
607            r = ExecResult(args=[], exit_code=0, stdout=b'', stderr=b'')
608        header = {}
609        for line in lines:
610            if exp_stat:
611                log.debug("reading 1st response line: %s", line)
612                m = re.match(r'^(\S+) (\d+) (.*)$', line)
613                assert m
614                r.add_response({
615                    "protocol": m.group(1),
616                    "status": int(m.group(2)),
617                    "description": m.group(3),
618                    "body": r.outraw
619                })
620                exp_stat = False
621                header = {}
622            elif re.match(r'^$', line):
623                exp_stat = True
624            else:
625                log.debug("reading header line: %s", line)
626                m = re.match(r'^([^:]+):\s*(.*)$', line)
627                assert m
628                header[m.group(1).lower()] = m.group(2)
629        if r.response:
630            r.response["header"] = header
631        return r
632
633    def curl_raw(self, urls, timeout=10, options=None, insecure=False,
634                 force_resolve=True):
635        xopt = ['-vvvv']
636        if options:
637            xopt.extend(options)
638        args, headerfile = self.curl_complete_args(
639            urls=urls, timeout=timeout, options=options, insecure=insecure,
640            force_resolve=force_resolve)
641        r = self.run(args)
642        if r.exit_code == 0:
643            self.curl_parse_headerfile(headerfile, r=r)
644            if r.json:
645                r.response["json"] = r.json
646        os.remove(headerfile)
647        return r
648
649    def curl_get(self, url, insecure=False, options=None):
650        return self.curl_raw([url], insecure=insecure, options=options)
651
652    def curl_upload(self, url, fpath, timeout=5, options=None):
653        if not options:
654            options = []
655        options.extend([
656            "--form", ("file=@%s" % fpath)
657        ])
658        return self.curl_raw(urls=[url], timeout=timeout, options=options)
659
660    def curl_post_data(self, url, data="", timeout=5, options=None):
661        if not options:
662            options = []
663        options.extend(["--data", "%s" % data])
664        return self.curl_raw(url, timeout, options)
665
666    def curl_post_value(self, url, key, value, timeout=5, options=None):
667        if not options:
668            options = []
669        options.extend(["--form", "{0}={1}".format(key, value)])
670        return self.curl_raw(url, timeout, options)
671
672    def curl_protocol_version(self, url, timeout=5, options=None):
673        if not options:
674            options = []
675        options.extend(["-w", "%{http_version}\n", "-o", "/dev/null"])
676        r = self.curl_raw(url, timeout=timeout, options=options)
677        if r.exit_code == 0 and r.response:
678            return r.response["body"].decode('utf-8').rstrip()
679        return -1
680
681    def nghttp(self):
682        return Nghttp(self._nghttp, connect_addr=self._httpd_addr, tmp_dir=self.gen_dir)
683
684    def h2load_status(self, run: ExecResult):
685        stats = {}
686        m = re.search(
687            r'requests: (\d+) total, (\d+) started, (\d+) done, (\d+) succeeded'
688            r', (\d+) failed, (\d+) errored, (\d+) timeout', run.stdout)
689        if m:
690            stats["requests"] = {
691                "total": int(m.group(1)),
692                "started": int(m.group(2)),
693                "done": int(m.group(3)),
694                "succeeded": int(m.group(4))
695            }
696            m = re.search(r'status codes: (\d+) 2xx, (\d+) 3xx, (\d+) 4xx, (\d+) 5xx',
697                          run.stdout)
698            if m:
699                stats["status"] = {
700                    "2xx": int(m.group(1)),
701                    "3xx": int(m.group(2)),
702                    "4xx": int(m.group(3)),
703                    "5xx": int(m.group(4))
704                }
705            run.add_results({"h2load": stats})
706        return run
707