1#!/usr/bin/env python
2
3# This Source Code Form is subject to the terms of the Mozilla Public
4# License, v. 2.0. If a copy of the MPL was not distributed with this
5# file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7"""Spawns necessary HTTP servers for testing Marionette in child
8processes.
9
10"""
11
12from __future__ import absolute_import, print_function
13
14import argparse
15import multiprocessing
16import os
17import sys
18
19from collections import defaultdict
20
21from six import iteritems
22
23from . import httpd
24
25
26__all__ = [
27    "default_doc_root",
28    "iter_proc",
29    "iter_url",
30    "registered_servers",
31    "servers",
32    "start",
33    "where_is",
34]
35here = os.path.abspath(os.path.dirname(__file__))
36
37
38class BlockingChannel(object):
39    def __init__(self, channel):
40        self.chan = channel
41        self.lock = multiprocessing.Lock()
42
43    def call(self, func, args=()):
44        self.send((func, args))
45        return self.recv()
46
47    def send(self, *args):
48        try:
49            self.lock.acquire()
50            self.chan.send(args)
51        finally:
52            self.lock.release()
53
54    def recv(self):
55        try:
56            self.lock.acquire()
57            payload = self.chan.recv()
58            if isinstance(payload, tuple) and len(payload) == 1:
59                return payload[0]
60            return payload
61        except KeyboardInterrupt:
62            return ("stop", ())
63        finally:
64            self.lock.release()
65
66
67class ServerProxy(multiprocessing.Process, BlockingChannel):
68    def __init__(self, channel, init_func, *init_args, **init_kwargs):
69        multiprocessing.Process.__init__(self)
70        BlockingChannel.__init__(self, channel)
71        self.init_func = init_func
72        self.init_args = init_args
73        self.init_kwargs = init_kwargs
74
75    def run(self):
76        try:
77            server = self.init_func(*self.init_args, **self.init_kwargs)
78            server.start()
79            self.send(("ok", ()))
80
81            while True:
82                # ["func", ("arg", ...)]
83                # ["prop", ()]
84                sattr, fargs = self.recv()
85                attr = getattr(server, sattr)
86
87                # apply fargs to attr if it is a function
88                if callable(attr):
89                    rv = attr(*fargs)
90
91                # otherwise attr is a property
92                else:
93                    rv = attr
94
95                self.send(rv)
96
97                if sattr == "stop":
98                    return
99
100        except Exception as e:
101            self.send(("stop", e))
102
103        except KeyboardInterrupt:
104            server.stop()
105
106
107class ServerProc(BlockingChannel):
108    def __init__(self, init_func):
109        self._init_func = init_func
110        self.proc = None
111
112        parent_chan, self.child_chan = multiprocessing.Pipe()
113        BlockingChannel.__init__(self, parent_chan)
114
115    def start(self, doc_root, ssl_config, **kwargs):
116        self.proc = ServerProxy(
117            self.child_chan, self._init_func, doc_root, ssl_config, **kwargs
118        )
119        self.proc.daemon = True
120        self.proc.start()
121
122        res, exc = self.recv()
123        if res == "stop":
124            raise exc
125
126    def get_url(self, url):
127        return self.call("get_url", (url,))
128
129    @property
130    def doc_root(self):
131        return self.call("doc_root", ())
132
133    def stop(self):
134        self.call("stop")
135        if not self.is_alive:
136            return
137        self.proc.join()
138
139    def kill(self):
140        if not self.is_alive:
141            return
142        self.proc.terminate()
143        self.proc.join(0)
144
145    @property
146    def is_alive(self):
147        if self.proc is not None:
148            return self.proc.is_alive()
149        return False
150
151
152def http_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
153    return httpd.FixtureServer(doc_root, url="http://{}:0/".format(host), **kwargs)
154
155
156def https_server(doc_root, ssl_config, host="127.0.0.1", **kwargs):
157    return httpd.FixtureServer(
158        doc_root,
159        url="https://{}:0/".format(host),
160        ssl_key=ssl_config["key_path"],
161        ssl_cert=ssl_config["cert_path"],
162        **kwargs
163    )
164
165
166def start_servers(doc_root, ssl_config, **kwargs):
167    servers = defaultdict()
168    for schema, builder_fn in registered_servers:
169        proc = ServerProc(builder_fn)
170        proc.start(doc_root, ssl_config, **kwargs)
171        servers[schema] = (proc.get_url("/"), proc)
172    return servers
173
174
175def start(doc_root=None, **kwargs):
176    """Start all relevant test servers.
177
178    If no `doc_root` is given the default
179    testing/marionette/harness/marionette_harness/www directory will be used.
180
181    Additional keyword arguments can be given which will be passed on
182    to the individual ``FixtureServer``'s in httpd.py.
183
184    """
185    doc_root = doc_root or default_doc_root
186    ssl_config = {
187        "cert_path": httpd.default_ssl_cert,
188        "key_path": httpd.default_ssl_key,
189    }
190
191    global servers
192    servers = start_servers(doc_root, ssl_config, **kwargs)
193    return servers
194
195
196def where_is(uri, on="http"):
197    """Returns the full URL, including scheme, hostname, and port, for
198    a fixture resource from the server associated with the ``on`` key.
199    It will by default look for the resource in the "http" server.
200
201    """
202    return servers.get(on)[1].get_url(uri)
203
204
205def iter_proc(servers):
206    for _, (_, proc) in iteritems(servers):
207        yield proc
208
209
210def iter_url(servers):
211    for _, (url, _) in iteritems(servers):
212        yield url
213
214
215default_doc_root = os.path.join(os.path.dirname(here), "www")
216registered_servers = [("http", http_server), ("https", https_server)]
217servers = defaultdict()
218
219
220def main(args):
221    global servers
222
223    parser = argparse.ArgumentParser()
224    parser.add_argument(
225        "-r", dest="doc_root", help="Path to document root.  Overrides default."
226    )
227    args = parser.parse_args()
228
229    servers = start(args.doc_root)
230    for url in iter_url(servers):
231        print("{}: listening on {}".format(sys.argv[0], url), file=sys.stderr)
232
233    try:
234        while any(proc.is_alive for proc in iter_proc(servers)):
235            for proc in iter_proc(servers):
236                proc.proc.join(1)
237    except KeyboardInterrupt:
238        for proc in iter_proc(servers):
239            proc.kill()
240
241
242if __name__ == "__main__":
243    main(sys.argv[1:])
244