1import errno
2import os
3import selectors
4import signal
5import socket
6import struct
7import sys
8import threading
9import warnings
10
11from . import connection
12from . import process
13from .context import reduction
14from . import resource_tracker
15from . import spawn
16from . import util
17
18__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
19           'set_forkserver_preload']
20
21#
22#
23#
24
25MAXFDS_TO_SEND = 256
26SIGNED_STRUCT = struct.Struct('q')     # large enough for pid_t
27
28#
29# Forkserver class
30#
31
32class ForkServer(object):
33
34    def __init__(self):
35        self._forkserver_address = None
36        self._forkserver_alive_fd = None
37        self._forkserver_pid = None
38        self._inherited_fds = None
39        self._lock = threading.Lock()
40        self._preload_modules = ['__main__']
41
42    def _stop(self):
43        # Method used by unit tests to stop the server
44        with self._lock:
45            self._stop_unlocked()
46
47    def _stop_unlocked(self):
48        if self._forkserver_pid is None:
49            return
50
51        # close the "alive" file descriptor asks the server to stop
52        os.close(self._forkserver_alive_fd)
53        self._forkserver_alive_fd = None
54
55        os.waitpid(self._forkserver_pid, 0)
56        self._forkserver_pid = None
57
58        if not util.is_abstract_socket_namespace(self._forkserver_address):
59            os.unlink(self._forkserver_address)
60        self._forkserver_address = None
61
62    def set_forkserver_preload(self, modules_names):
63        '''Set list of module names to try to load in forkserver process.'''
64        if not all(type(mod) is str for mod in self._preload_modules):
65            raise TypeError('module_names must be a list of strings')
66        self._preload_modules = modules_names
67
68    def get_inherited_fds(self):
69        '''Return list of fds inherited from parent process.
70
71        This returns None if the current process was not started by fork
72        server.
73        '''
74        return self._inherited_fds
75
76    def connect_to_new_process(self, fds):
77        '''Request forkserver to create a child process.
78
79        Returns a pair of fds (status_r, data_w).  The calling process can read
80        the child process's pid and (eventually) its returncode from status_r.
81        The calling process should write to data_w the pickled preparation and
82        process data.
83        '''
84        self.ensure_running()
85        if len(fds) + 4 >= MAXFDS_TO_SEND:
86            raise ValueError('too many fds')
87        with socket.socket(socket.AF_UNIX) as client:
88            client.connect(self._forkserver_address)
89            parent_r, child_w = os.pipe()
90            child_r, parent_w = os.pipe()
91            allfds = [child_r, child_w, self._forkserver_alive_fd,
92                      resource_tracker.getfd()]
93            allfds += fds
94            try:
95                reduction.sendfds(client, allfds)
96                return parent_r, parent_w
97            except:
98                os.close(parent_r)
99                os.close(parent_w)
100                raise
101            finally:
102                os.close(child_r)
103                os.close(child_w)
104
105    def ensure_running(self):
106        '''Make sure that a fork server is running.
107
108        This can be called from any process.  Note that usually a child
109        process will just reuse the forkserver started by its parent, so
110        ensure_running() will do nothing.
111        '''
112        with self._lock:
113            resource_tracker.ensure_running()
114            if self._forkserver_pid is not None:
115                # forkserver was launched before, is it still running?
116                pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
117                if not pid:
118                    # still alive
119                    return
120                # dead, launch it again
121                os.close(self._forkserver_alive_fd)
122                self._forkserver_address = None
123                self._forkserver_alive_fd = None
124                self._forkserver_pid = None
125
126            cmd = ('from multiprocess.forkserver import main; ' +
127                   'main(%d, %d, %r, **%r)')
128
129            if self._preload_modules:
130                desired_keys = {'main_path', 'sys_path'}
131                data = spawn.get_preparation_data('ignore')
132                data = {x: y for x, y in data.items() if x in desired_keys}
133            else:
134                data = {}
135
136            with socket.socket(socket.AF_UNIX) as listener:
137                address = connection.arbitrary_address('AF_UNIX')
138                listener.bind(address)
139                if not util.is_abstract_socket_namespace(address):
140                    os.chmod(address, 0o600)
141                listener.listen()
142
143                # all client processes own the write end of the "alive" pipe;
144                # when they all terminate the read end becomes ready.
145                alive_r, alive_w = os.pipe()
146                try:
147                    fds_to_pass = [listener.fileno(), alive_r]
148                    cmd %= (listener.fileno(), alive_r, self._preload_modules,
149                            data)
150                    exe = spawn.get_executable()
151                    args = [exe] + util._args_from_interpreter_flags()
152                    args += ['-c', cmd]
153                    pid = util.spawnv_passfds(exe, args, fds_to_pass)
154                except:
155                    os.close(alive_w)
156                    raise
157                finally:
158                    os.close(alive_r)
159                self._forkserver_address = address
160                self._forkserver_alive_fd = alive_w
161                self._forkserver_pid = pid
162
163#
164#
165#
166
167def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
168    '''Run forkserver.'''
169    if preload:
170        if '__main__' in preload and main_path is not None:
171            process.current_process()._inheriting = True
172            try:
173                spawn.import_main_path(main_path)
174            finally:
175                del process.current_process()._inheriting
176        for modname in preload:
177            try:
178                __import__(modname)
179            except ImportError:
180                pass
181
182    util._close_stdin()
183
184    sig_r, sig_w = os.pipe()
185    os.set_blocking(sig_r, False)
186    os.set_blocking(sig_w, False)
187
188    def sigchld_handler(*_unused):
189        # Dummy signal handler, doesn't do anything
190        pass
191
192    handlers = {
193        # unblocking SIGCHLD allows the wakeup fd to notify our event loop
194        signal.SIGCHLD: sigchld_handler,
195        # protect the process from ^C
196        signal.SIGINT: signal.SIG_IGN,
197        }
198    old_handlers = {sig: signal.signal(sig, val)
199                    for (sig, val) in handlers.items()}
200
201    # calling os.write() in the Python signal handler is racy
202    signal.set_wakeup_fd(sig_w)
203
204    # map child pids to client fds
205    pid_to_fd = {}
206
207    with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
208         selectors.DefaultSelector() as selector:
209        _forkserver._forkserver_address = listener.getsockname()
210
211        selector.register(listener, selectors.EVENT_READ)
212        selector.register(alive_r, selectors.EVENT_READ)
213        selector.register(sig_r, selectors.EVENT_READ)
214
215        while True:
216            try:
217                while True:
218                    rfds = [key.fileobj for (key, events) in selector.select()]
219                    if rfds:
220                        break
221
222                if alive_r in rfds:
223                    # EOF because no more client processes left
224                    assert os.read(alive_r, 1) == b'', "Not at EOF?"
225                    raise SystemExit
226
227                if sig_r in rfds:
228                    # Got SIGCHLD
229                    os.read(sig_r, 65536)  # exhaust
230                    while True:
231                        # Scan for child processes
232                        try:
233                            pid, sts = os.waitpid(-1, os.WNOHANG)
234                        except ChildProcessError:
235                            break
236                        if pid == 0:
237                            break
238                        child_w = pid_to_fd.pop(pid, None)
239                        if child_w is not None:
240                            returncode = os.waitstatus_to_exitcode(sts)
241                            # Send exit code to client process
242                            try:
243                                write_signed(child_w, returncode)
244                            except BrokenPipeError:
245                                # client vanished
246                                pass
247                            os.close(child_w)
248                        else:
249                            # This shouldn't happen really
250                            warnings.warn('forkserver: waitpid returned '
251                                          'unexpected pid %d' % pid)
252
253                if listener in rfds:
254                    # Incoming fork request
255                    with listener.accept()[0] as s:
256                        # Receive fds from client
257                        fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
258                        if len(fds) > MAXFDS_TO_SEND:
259                            raise RuntimeError(
260                                "Too many ({0:n}) fds to send".format(
261                                    len(fds)))
262                        child_r, child_w, *fds = fds
263                        s.close()
264                        pid = os.fork()
265                        if pid == 0:
266                            # Child
267                            code = 1
268                            try:
269                                listener.close()
270                                selector.close()
271                                unused_fds = [alive_r, child_w, sig_r, sig_w]
272                                unused_fds.extend(pid_to_fd.values())
273                                code = _serve_one(child_r, fds,
274                                                  unused_fds,
275                                                  old_handlers)
276                            except Exception:
277                                sys.excepthook(*sys.exc_info())
278                                sys.stderr.flush()
279                            finally:
280                                os._exit(code)
281                        else:
282                            # Send pid to client process
283                            try:
284                                write_signed(child_w, pid)
285                            except BrokenPipeError:
286                                # client vanished
287                                pass
288                            pid_to_fd[pid] = child_w
289                            os.close(child_r)
290                            for fd in fds:
291                                os.close(fd)
292
293            except OSError as e:
294                if e.errno != errno.ECONNABORTED:
295                    raise
296
297
298def _serve_one(child_r, fds, unused_fds, handlers):
299    # close unnecessary stuff and reset signal handlers
300    signal.set_wakeup_fd(-1)
301    for sig, val in handlers.items():
302        signal.signal(sig, val)
303    for fd in unused_fds:
304        os.close(fd)
305
306    (_forkserver._forkserver_alive_fd,
307     resource_tracker._resource_tracker._fd,
308     *_forkserver._inherited_fds) = fds
309
310    # Run process object received over pipe
311    parent_sentinel = os.dup(child_r)
312    code = spawn._main(child_r, parent_sentinel)
313
314    return code
315
316
317#
318# Read and write signed numbers
319#
320
321def read_signed(fd):
322    data = b''
323    length = SIGNED_STRUCT.size
324    while len(data) < length:
325        s = os.read(fd, length - len(data))
326        if not s:
327            raise EOFError('unexpected EOF')
328        data += s
329    return SIGNED_STRUCT.unpack(data)[0]
330
331def write_signed(fd, n):
332    msg = SIGNED_STRUCT.pack(n)
333    while msg:
334        nbytes = os.write(fd, msg)
335        if nbytes == 0:
336            raise RuntimeError('should not get here')
337        msg = msg[nbytes:]
338
339#
340#
341#
342
343_forkserver = ForkServer()
344ensure_running = _forkserver.ensure_running
345get_inherited_fds = _forkserver.get_inherited_fds
346connect_to_new_process = _forkserver.connect_to_new_process
347set_forkserver_preload = _forkserver.set_forkserver_preload
348