1import errno
2import os
3import selectors
4import signal
5import socket
6import struct
7import sys
8import threading
9
10from . import connection
11from . import process
12from .context import reduction
13from . import semaphore_tracker
14from . import spawn
15from . import util
16
17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
18           'set_forkserver_preload']
19
20#
21#
22#
23
24MAXFDS_TO_SEND = 256
25UNSIGNED_STRUCT = struct.Struct('Q')     # large enough for pid_t
26
27#
28# Forkserver class
29#
30
31class ForkServer(object):
32
33    def __init__(self):
34        self._forkserver_address = None
35        self._forkserver_alive_fd = None
36        self._forkserver_pid = None
37        self._inherited_fds = None
38        self._lock = threading.Lock()
39        self._preload_modules = ['__main__']
40
41    def set_forkserver_preload(self, modules_names):
42        '''Set list of module names to try to load in forkserver process.'''
43        if not all(type(mod) is str for mod in self._preload_modules):
44            raise TypeError('module_names must be a list of strings')
45        self._preload_modules = modules_names
46
47    def get_inherited_fds(self):
48        '''Return list of fds inherited from parent process.
49
50        This returns None if the current process was not started by fork
51        server.
52        '''
53        return self._inherited_fds
54
55    def connect_to_new_process(self, fds):
56        '''Request forkserver to create a child process.
57
58        Returns a pair of fds (status_r, data_w).  The calling process can read
59        the child process's pid and (eventually) its returncode from status_r.
60        The calling process should write to data_w the pickled preparation and
61        process data.
62        '''
63        self.ensure_running()
64        if len(fds) + 4 >= MAXFDS_TO_SEND:
65            raise ValueError('too many fds')
66        with socket.socket(socket.AF_UNIX) as client:
67            client.connect(self._forkserver_address)
68            parent_r, child_w = os.pipe()
69            child_r, parent_w = os.pipe()
70            allfds = [child_r, child_w, self._forkserver_alive_fd,
71                      semaphore_tracker.getfd()]
72            allfds += fds
73            try:
74                reduction.sendfds(client, allfds)
75                return parent_r, parent_w
76            except:
77                os.close(parent_r)
78                os.close(parent_w)
79                raise
80            finally:
81                os.close(child_r)
82                os.close(child_w)
83
84    def ensure_running(self):
85        '''Make sure that a fork server is running.
86
87        This can be called from any process.  Note that usually a child
88        process will just reuse the forkserver started by its parent, so
89        ensure_running() will do nothing.
90        '''
91        with self._lock:
92            semaphore_tracker.ensure_running()
93            if self._forkserver_pid is not None:
94                # forkserver was launched before, is it still running?
95                pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
96                if not pid:
97                    # still alive
98                    return
99                # dead, launch it again
100                os.close(self._forkserver_alive_fd)
101                self._forkserver_address = None
102                self._forkserver_alive_fd = None
103                self._forkserver_pid = None
104
105            cmd = ('from multiprocess.forkserver import main; ' +
106                   'main(%d, %d, %r, **%r)')
107
108            if self._preload_modules:
109                desired_keys = {'main_path', 'sys_path'}
110                data = spawn.get_preparation_data('ignore')
111                data = dict((x,y) for (x,y) in data.items()
112                            if x in desired_keys)
113            else:
114                data = {}
115
116            with socket.socket(socket.AF_UNIX) as listener:
117                address = connection.arbitrary_address('AF_UNIX')
118                listener.bind(address)
119                os.chmod(address, 0o600)
120                listener.listen()
121
122                # all client processes own the write end of the "alive" pipe;
123                # when they all terminate the read end becomes ready.
124                alive_r, alive_w = os.pipe()
125                try:
126                    fds_to_pass = [listener.fileno(), alive_r]
127                    cmd %= (listener.fileno(), alive_r, self._preload_modules,
128                            data)
129                    exe = spawn.get_executable()
130                    args = [exe] + util._args_from_interpreter_flags()
131                    args += ['-c', cmd]
132                    pid = util.spawnv_passfds(exe, args, fds_to_pass)
133                except:
134                    os.close(alive_w)
135                    raise
136                finally:
137                    os.close(alive_r)
138                self._forkserver_address = address
139                self._forkserver_alive_fd = alive_w
140                self._forkserver_pid = pid
141
142#
143#
144#
145
146def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
147    '''Run forkserver.'''
148    if preload:
149        if '__main__' in preload and main_path is not None:
150            process.current_process()._inheriting = True
151            try:
152                spawn.import_main_path(main_path)
153            finally:
154                del process.current_process()._inheriting
155        for modname in preload:
156            try:
157                __import__(modname)
158            except ImportError:
159                pass
160
161    util._close_stdin()
162
163    handlers = {
164        # no need to reap zombie processes;
165        signal.SIGCHLD: signal.SIG_IGN,
166        # protect the process from ^C
167        signal.SIGINT: signal.SIG_IGN,
168        }
169    old_handlers = {sig: signal.signal(sig, val)
170                    for (sig, val) in handlers.items()}
171
172    with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
173         selectors.DefaultSelector() as selector:
174        _forkserver._forkserver_address = listener.getsockname()
175
176        selector.register(listener, selectors.EVENT_READ)
177        selector.register(alive_r, selectors.EVENT_READ)
178
179        while True:
180            try:
181                while True:
182                    rfds = [key.fileobj for (key, events) in selector.select()]
183                    if rfds:
184                        break
185
186                if alive_r in rfds:
187                    # EOF because no more client processes left
188                    assert os.read(alive_r, 1) == b''
189                    raise SystemExit
190
191                assert listener in rfds
192                with listener.accept()[0] as s:
193                    code = 1
194                    if os.fork() == 0:
195                        try:
196                            _serve_one(s, listener, alive_r, old_handlers)
197                        except Exception:
198                            sys.excepthook(*sys.exc_info())
199                            sys.stderr.flush()
200                        finally:
201                            os._exit(code)
202
203            except OSError as e:
204                if e.errno != errno.ECONNABORTED:
205                    raise
206
207def _serve_one(s, listener, alive_r, handlers):
208    # close unnecessary stuff and reset signal handlers
209    listener.close()
210    os.close(alive_r)
211    for sig, val in handlers.items():
212        signal.signal(sig, val)
213
214    # receive fds from parent process
215    fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
216    s.close()
217    assert len(fds) <= MAXFDS_TO_SEND
218    (child_r, child_w, _forkserver._forkserver_alive_fd,
219     stfd, *_forkserver._inherited_fds) = fds
220    semaphore_tracker._semaphore_tracker._fd = stfd
221
222    # send pid to client processes
223    write_unsigned(child_w, os.getpid())
224
225    # reseed random number generator
226    if 'random' in sys.modules:
227        import random
228        random.seed()
229
230    # run process object received over pipe
231    code = spawn._main(child_r)
232
233    # write the exit code to the pipe
234    write_unsigned(child_w, code)
235
236#
237# Read and write unsigned numbers
238#
239
240def read_unsigned(fd):
241    data = b''
242    length = UNSIGNED_STRUCT.size
243    while len(data) < length:
244        s = os.read(fd, length - len(data))
245        if not s:
246            raise EOFError('unexpected EOF')
247        data += s
248    return UNSIGNED_STRUCT.unpack(data)[0]
249
250def write_unsigned(fd, n):
251    msg = UNSIGNED_STRUCT.pack(n)
252    while msg:
253        nbytes = os.write(fd, msg)
254        if nbytes == 0:
255            raise RuntimeError('should not get here')
256        msg = msg[nbytes:]
257
258#
259#
260#
261
262_forkserver = ForkServer()
263ensure_running = _forkserver.ensure_running
264get_inherited_fds = _forkserver.get_inherited_fds
265connect_to_new_process = _forkserver.connect_to_new_process
266set_forkserver_preload = _forkserver.set_forkserver_preload
267