1import errno 2import os 3import selectors 4import signal 5import socket 6import struct 7import sys 8import threading 9 10from . import connection 11from . import process 12from . import reduction 13from . import semaphore_tracker 14from . import spawn 15from . import util 16 17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', 18 'set_forkserver_preload'] 19 20# 21# 22# 23 24MAXFDS_TO_SEND = 256 25UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t 26 27# 28# Forkserver class 29# 30 31class ForkServer(object): 32 33 def __init__(self): 34 self._forkserver_address = None 35 self._forkserver_alive_fd = None 36 self._inherited_fds = None 37 self._lock = threading.Lock() 38 self._preload_modules = ['__main__'] 39 40 def set_forkserver_preload(self, modules_names): 41 '''Set list of module names to try to load in forkserver process.''' 42 if not all(type(mod) is str for mod in self._preload_modules): 43 raise TypeError('module_names must be a list of strings') 44 self._preload_modules = modules_names 45 46 def get_inherited_fds(self): 47 '''Return list of fds inherited from parent process. 48 49 This returns None if the current process was not started by fork 50 server. 51 ''' 52 return self._inherited_fds 53 54 def connect_to_new_process(self, fds): 55 '''Request forkserver to create a child process. 56 57 Returns a pair of fds (status_r, data_w). The calling process can read 58 the child process's pid and (eventually) its returncode from status_r. 59 The calling process should write to data_w the pickled preparation and 60 process data. 61 ''' 62 self.ensure_running() 63 if len(fds) + 4 >= MAXFDS_TO_SEND: 64 raise ValueError('too many fds') 65 with socket.socket(socket.AF_UNIX) as client: 66 client.connect(self._forkserver_address) 67 parent_r, child_w = os.pipe() 68 child_r, parent_w = os.pipe() 69 allfds = [child_r, child_w, self._forkserver_alive_fd, 70 semaphore_tracker.getfd()] 71 allfds += fds 72 try: 73 reduction.sendfds(client, allfds) 74 return parent_r, parent_w 75 except: 76 os.close(parent_r) 77 os.close(parent_w) 78 raise 79 finally: 80 os.close(child_r) 81 os.close(child_w) 82 83 def ensure_running(self): 84 '''Make sure that a fork server is running. 85 86 This can be called from any process. Note that usually a child 87 process will just reuse the forkserver started by its parent, so 88 ensure_running() will do nothing. 89 ''' 90 with self._lock: 91 semaphore_tracker.ensure_running() 92 if self._forkserver_alive_fd is not None: 93 return 94 95 cmd = ('from multiprocess.forkserver import main; ' + 96 'main(%d, %d, %r, **%r)') 97 98 if self._preload_modules: 99 desired_keys = {'main_path', 'sys_path'} 100 data = spawn.get_preparation_data('ignore') 101 data = dict((x,y) for (x,y) in data.items() 102 if x in desired_keys) 103 else: 104 data = {} 105 106 with socket.socket(socket.AF_UNIX) as listener: 107 address = connection.arbitrary_address('AF_UNIX') 108 listener.bind(address) 109 os.chmod(address, 0o600) 110 listener.listen(100) 111 112 # all client processes own the write end of the "alive" pipe; 113 # when they all terminate the read end becomes ready. 114 alive_r, alive_w = os.pipe() 115 try: 116 fds_to_pass = [listener.fileno(), alive_r] 117 cmd %= (listener.fileno(), alive_r, self._preload_modules, 118 data) 119 exe = spawn.get_executable() 120 args = [exe] + util._args_from_interpreter_flags() 121 args += ['-c', cmd] 122 pid = util.spawnv_passfds(exe, args, fds_to_pass) 123 except: 124 os.close(alive_w) 125 raise 126 finally: 127 os.close(alive_r) 128 self._forkserver_address = address 129 self._forkserver_alive_fd = alive_w 130 131# 132# 133# 134 135def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): 136 '''Run forkserver.''' 137 if preload: 138 if '__main__' in preload and main_path is not None: 139 process.current_process()._inheriting = True 140 try: 141 spawn.import_main_path(main_path) 142 finally: 143 del process.current_process()._inheriting 144 for modname in preload: 145 try: 146 __import__(modname) 147 except ImportError: 148 pass 149 150 # close sys.stdin 151 if sys.stdin is not None: 152 try: 153 sys.stdin.close() 154 sys.stdin = open(os.devnull) 155 except (OSError, ValueError): 156 pass 157 158 # ignoring SIGCHLD means no need to reap zombie processes 159 handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) 160 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ 161 selectors.DefaultSelector() as selector: 162 _forkserver._forkserver_address = listener.getsockname() 163 164 selector.register(listener, selectors.EVENT_READ) 165 selector.register(alive_r, selectors.EVENT_READ) 166 167 while True: 168 try: 169 while True: 170 rfds = [key.fileobj for (key, events) in selector.select()] 171 if rfds: 172 break 173 174 if alive_r in rfds: 175 # EOF because no more client processes left 176 assert os.read(alive_r, 1) == b'' 177 raise SystemExit 178 179 assert listener in rfds 180 with listener.accept()[0] as s: 181 code = 1 182 if os.fork() == 0: 183 try: 184 _serve_one(s, listener, alive_r, handler) 185 except Exception: 186 sys.excepthook(*sys.exc_info()) 187 sys.stderr.flush() 188 finally: 189 os._exit(code) 190 191 except InterruptedError: 192 pass 193 except OSError as e: 194 if e.errno != errno.ECONNABORTED: 195 raise 196 197def _serve_one(s, listener, alive_r, handler): 198 # close unnecessary stuff and reset SIGCHLD handler 199 listener.close() 200 os.close(alive_r) 201 signal.signal(signal.SIGCHLD, handler) 202 203 # receive fds from parent process 204 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) 205 s.close() 206 assert len(fds) <= MAXFDS_TO_SEND 207 (child_r, child_w, _forkserver._forkserver_alive_fd, 208 stfd, *_forkserver._inherited_fds) = fds 209 semaphore_tracker._semaphore_tracker._fd = stfd 210 211 # send pid to client processes 212 write_unsigned(child_w, os.getpid()) 213 214 # reseed random number generator 215 if 'random' in sys.modules: 216 import random 217 random.seed() 218 219 # run process object received over pipe 220 code = spawn._main(child_r) 221 222 # write the exit code to the pipe 223 write_unsigned(child_w, code) 224 225# 226# Read and write unsigned numbers 227# 228 229def read_unsigned(fd): 230 data = b'' 231 length = UNSIGNED_STRUCT.size 232 while len(data) < length: 233 while True: 234 try: 235 s = os.read(fd, length - len(data)) 236 except InterruptedError: 237 pass 238 else: 239 break 240 if not s: 241 raise EOFError('unexpected EOF') 242 data += s 243 return UNSIGNED_STRUCT.unpack(data)[0] 244 245def write_unsigned(fd, n): 246 msg = UNSIGNED_STRUCT.pack(n) 247 while msg: 248 while True: 249 try: 250 nbytes = os.write(fd, msg) 251 except InterruptedError: 252 pass 253 else: 254 break 255 if nbytes == 0: 256 raise RuntimeError('should not get here') 257 msg = msg[nbytes:] 258 259# 260# 261# 262 263_forkserver = ForkServer() 264ensure_running = _forkserver.ensure_running 265get_inherited_fds = _forkserver.get_inherited_fds 266connect_to_new_process = _forkserver.connect_to_new_process 267set_forkserver_preload = _forkserver.set_forkserver_preload 268