1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this file,
3# You can obtain one at http://mozilla.org/MPL/2.0/.
4
5import abc
6import errno
7import os
8import platform
9import socket
10import threading
11import time
12import traceback
13import urlparse
14
15import mozprocess
16
17
18__all__ = ["SeleniumServer", "ChromeDriverServer",
19           "GeckoDriverServer", "WebDriverServer"]
20
21
22class WebDriverServer(object):
23    __metaclass__ = abc.ABCMeta
24
25    default_base_path = "/"
26    _used_ports = set()
27
28    def __init__(self, logger, binary, host="127.0.0.1", port=None,
29                 base_path="", env=None):
30        self.logger = logger
31        self.binary = binary
32        self.host = host
33        if base_path == "":
34            self.base_path = self.default_base_path
35        else:
36            self.base_path = base_path
37        self.env = os.environ.copy() if env is None else env
38
39        self._port = port
40        self._cmd = None
41        self._proc = None
42
43    @abc.abstractmethod
44    def make_command(self):
45        """Returns the full command for starting the server process as a list."""
46
47    def start(self, block=True):
48        try:
49            self._run(block)
50        except KeyboardInterrupt:
51            self.stop()
52
53    def _run(self, block):
54        self._cmd = self.make_command()
55        self._proc = mozprocess.ProcessHandler(
56            self._cmd,
57            processOutputLine=self.on_output,
58            env=self.env,
59            storeOutput=False)
60
61        try:
62            self._proc.run()
63        except OSError as e:
64            if e.errno == errno.ENOENT:
65                raise IOError(
66                    "WebDriver HTTP server executable not found: %s" % self.binary)
67            raise
68
69        self.logger.debug(
70            "Waiting for server to become accessible: %s" % self.url)
71        try:
72            wait_for_service((self.host, self.port))
73        except:
74            self.logger.error(
75                "WebDriver HTTP server was not accessible "
76                "within the timeout:\n%s" % traceback.format_exc())
77            if self._proc.poll():
78                self.logger.error("Webdriver server process exited with code %i" %
79                                  self._proc.returncode)
80            raise
81
82        if block:
83            self._proc.wait()
84
85    def stop(self):
86        if self.is_alive:
87            return self._proc.kill()
88        return not self.is_alive
89
90    @property
91    def is_alive(self):
92        return (self._proc is not None and
93                self._proc.proc is not None and
94                self._proc.poll() is None)
95
96    def on_output(self, line):
97        self.logger.process_output(self.pid,
98                                   line.decode("utf8", "replace"),
99                                   command=" ".join(self._cmd))
100
101    @property
102    def pid(self):
103        if self._proc is not None:
104            return self._proc.pid
105
106    @property
107    def url(self):
108        return "http://%s:%i%s" % (self.host, self.port, self.base_path)
109
110    @property
111    def port(self):
112        if self._port is None:
113            self._port = self._find_next_free_port()
114        return self._port
115
116    @staticmethod
117    def _find_next_free_port():
118        port = get_free_port(4444, exclude=WebDriverServer._used_ports)
119        WebDriverServer._used_ports.add(port)
120        return port
121
122
123class SeleniumServer(WebDriverServer):
124    default_base_path = "/wd/hub"
125
126    def make_command(self):
127        return ["java", "-jar", self.binary, "-port", str(self.port)]
128
129
130class ChromeDriverServer(WebDriverServer):
131    default_base_path = "/wd/hub"
132
133    def __init__(self, logger, binary="chromedriver", port=None,
134                 base_path=""):
135        WebDriverServer.__init__(
136            self, logger, binary, port=port, base_path=base_path)
137
138    def make_command(self):
139        return [self.binary,
140                cmd_arg("port", str(self.port)),
141                cmd_arg("url-base", self.base_path) if self.base_path else ""]
142
143
144class GeckoDriverServer(WebDriverServer):
145    def __init__(self, logger, marionette_port=2828, binary="wires",
146                 host="127.0.0.1", port=None):
147        env = os.environ.copy()
148        env["RUST_BACKTRACE"] = "1"
149        WebDriverServer.__init__(self, logger, binary, host=host, port=port, env=env)
150        self.marionette_port = marionette_port
151
152    def make_command(self):
153        return [self.binary,
154                "--connect-existing",
155                "--marionette-port", str(self.marionette_port),
156                "--host", self.host,
157                "--port", str(self.port)]
158
159
160def cmd_arg(name, value=None):
161    prefix = "-" if platform.system() == "Windows" else "--"
162    rv = prefix + name
163    if value is not None:
164        rv += "=" + value
165    return rv
166
167
168def get_free_port(start_port, exclude=None):
169    """Get the first port number after start_port (inclusive) that is
170    not currently bound.
171
172    :param start_port: Integer port number at which to start testing.
173    :param exclude: Set of port numbers to skip"""
174    port = start_port
175    while True:
176        if exclude and port in exclude:
177            port += 1
178            continue
179        s = socket.socket()
180        try:
181            s.bind(("127.0.0.1", port))
182        except socket.error:
183            port += 1
184        else:
185            return port
186        finally:
187            s.close()
188
189
190def wait_for_service(addr, timeout=15):
191    """Waits until network service given as a tuple of (host, port) becomes
192    available or the `timeout` duration is reached, at which point
193    ``socket.error`` is raised."""
194    end = time.time() + timeout
195    while end > time.time():
196        so = socket.socket()
197        try:
198            so.connect(addr)
199        except socket.timeout:
200            pass
201        except socket.error as e:
202            if e[0] != errno.ECONNREFUSED:
203                raise
204        else:
205            return True
206        finally:
207            so.close()
208        time.sleep(0.5)
209    raise socket.error("Service is unavailable: %s:%i" % addr)
210