1import errno
2import os
3import selectors
4import signal
5import socket
6import struct
7import sys
8import threading
9
10from . import connection
11from . import process
12from . import reduction
13from . import semaphore_tracker
14from . import spawn
15from . import util
16
17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
18           'set_forkserver_preload']
19
20#
21#
22#
23
24MAXFDS_TO_SEND = 256
25UNSIGNED_STRUCT = struct.Struct('Q')     # large enough for pid_t
26
27#
28# Forkserver class
29#
30
31class ForkServer(object):
32
33    def __init__(self):
34        self._forkserver_address = None
35        self._forkserver_alive_fd = None
36        self._inherited_fds = None
37        self._lock = threading.Lock()
38        self._preload_modules = ['__main__']
39
40    def set_forkserver_preload(self, modules_names):
41        '''Set list of module names to try to load in forkserver process.'''
42        if not all(type(mod) is str for mod in self._preload_modules):
43            raise TypeError('module_names must be a list of strings')
44        self._preload_modules = modules_names
45
46    def get_inherited_fds(self):
47        '''Return list of fds inherited from parent process.
48
49        This returns None if the current process was not started by fork
50        server.
51        '''
52        return self._inherited_fds
53
54    def connect_to_new_process(self, fds):
55        '''Request forkserver to create a child process.
56
57        Returns a pair of fds (status_r, data_w).  The calling process can read
58        the child process's pid and (eventually) its returncode from status_r.
59        The calling process should write to data_w the pickled preparation and
60        process data.
61        '''
62        self.ensure_running()
63        if len(fds) + 4 >= MAXFDS_TO_SEND:
64            raise ValueError('too many fds')
65        with socket.socket(socket.AF_UNIX) as client:
66            client.connect(self._forkserver_address)
67            parent_r, child_w = os.pipe()
68            child_r, parent_w = os.pipe()
69            allfds = [child_r, child_w, self._forkserver_alive_fd,
70                      semaphore_tracker.getfd()]
71            allfds += fds
72            try:
73                reduction.sendfds(client, allfds)
74                return parent_r, parent_w
75            except:
76                os.close(parent_r)
77                os.close(parent_w)
78                raise
79            finally:
80                os.close(child_r)
81                os.close(child_w)
82
83    def ensure_running(self):
84        '''Make sure that a fork server is running.
85
86        This can be called from any process.  Note that usually a child
87        process will just reuse the forkserver started by its parent, so
88        ensure_running() will do nothing.
89        '''
90        with self._lock:
91            semaphore_tracker.ensure_running()
92            if self._forkserver_alive_fd is not None:
93                return
94
95            cmd = ('from multiprocess.forkserver import main; ' +
96                   'main(%d, %d, %r, **%r)')
97
98            if self._preload_modules:
99                desired_keys = {'main_path', 'sys_path'}
100                data = spawn.get_preparation_data('ignore')
101                data = dict((x,y) for (x,y) in data.items()
102                            if x in desired_keys)
103            else:
104                data = {}
105
106            with socket.socket(socket.AF_UNIX) as listener:
107                address = connection.arbitrary_address('AF_UNIX')
108                listener.bind(address)
109                os.chmod(address, 0o600)
110                listener.listen(100)
111
112                # all client processes own the write end of the "alive" pipe;
113                # when they all terminate the read end becomes ready.
114                alive_r, alive_w = os.pipe()
115                try:
116                    fds_to_pass = [listener.fileno(), alive_r]
117                    cmd %= (listener.fileno(), alive_r, self._preload_modules,
118                            data)
119                    exe = spawn.get_executable()
120                    args = [exe] + util._args_from_interpreter_flags()
121                    args += ['-c', cmd]
122                    pid = util.spawnv_passfds(exe, args, fds_to_pass)
123                except:
124                    os.close(alive_w)
125                    raise
126                finally:
127                    os.close(alive_r)
128                self._forkserver_address = address
129                self._forkserver_alive_fd = alive_w
130
131#
132#
133#
134
135def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
136    '''Run forkserver.'''
137    if preload:
138        if '__main__' in preload and main_path is not None:
139            process.current_process()._inheriting = True
140            try:
141                spawn.import_main_path(main_path)
142            finally:
143                del process.current_process()._inheriting
144        for modname in preload:
145            try:
146                __import__(modname)
147            except ImportError:
148                pass
149
150    # close sys.stdin
151    if sys.stdin is not None:
152        try:
153            sys.stdin.close()
154            sys.stdin = open(os.devnull)
155        except (OSError, ValueError):
156            pass
157
158    # ignoring SIGCHLD means no need to reap zombie processes
159    handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
160    with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
161         selectors.DefaultSelector() as selector:
162        _forkserver._forkserver_address = listener.getsockname()
163
164        selector.register(listener, selectors.EVENT_READ)
165        selector.register(alive_r, selectors.EVENT_READ)
166
167        while True:
168            try:
169                while True:
170                    rfds = [key.fileobj for (key, events) in selector.select()]
171                    if rfds:
172                        break
173
174                if alive_r in rfds:
175                    # EOF because no more client processes left
176                    assert os.read(alive_r, 1) == b''
177                    raise SystemExit
178
179                assert listener in rfds
180                with listener.accept()[0] as s:
181                    code = 1
182                    if os.fork() == 0:
183                        try:
184                            _serve_one(s, listener, alive_r, handler)
185                        except Exception:
186                            sys.excepthook(*sys.exc_info())
187                            sys.stderr.flush()
188                        finally:
189                            os._exit(code)
190
191            except InterruptedError:
192                pass
193            except OSError as e:
194                if e.errno != errno.ECONNABORTED:
195                    raise
196
197def _serve_one(s, listener, alive_r, handler):
198    # close unnecessary stuff and reset SIGCHLD handler
199    listener.close()
200    os.close(alive_r)
201    signal.signal(signal.SIGCHLD, handler)
202
203    # receive fds from parent process
204    fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
205    s.close()
206    assert len(fds) <= MAXFDS_TO_SEND
207    (child_r, child_w, _forkserver._forkserver_alive_fd,
208     stfd, *_forkserver._inherited_fds) = fds
209    semaphore_tracker._semaphore_tracker._fd = stfd
210
211    # send pid to client processes
212    write_unsigned(child_w, os.getpid())
213
214    # reseed random number generator
215    if 'random' in sys.modules:
216        import random
217        random.seed()
218
219    # run process object received over pipe
220    code = spawn._main(child_r)
221
222    # write the exit code to the pipe
223    write_unsigned(child_w, code)
224
225#
226# Read and write unsigned numbers
227#
228
229def read_unsigned(fd):
230    data = b''
231    length = UNSIGNED_STRUCT.size
232    while len(data) < length:
233        while True:
234            try:
235                s = os.read(fd, length - len(data))
236            except InterruptedError:
237                pass
238            else:
239                break
240        if not s:
241            raise EOFError('unexpected EOF')
242        data += s
243    return UNSIGNED_STRUCT.unpack(data)[0]
244
245def write_unsigned(fd, n):
246    msg = UNSIGNED_STRUCT.pack(n)
247    while msg:
248        while True:
249            try:
250                nbytes = os.write(fd, msg)
251            except InterruptedError:
252                pass
253            else:
254                break
255        if nbytes == 0:
256            raise RuntimeError('should not get here')
257        msg = msg[nbytes:]
258
259#
260#
261#
262
263_forkserver = ForkServer()
264ensure_running = _forkserver.ensure_running
265get_inherited_fds = _forkserver.get_inherited_fds
266connect_to_new_process = _forkserver.connect_to_new_process
267set_forkserver_preload = _forkserver.set_forkserver_preload
268