1import errno 2import os 3import selectors 4import signal 5import socket 6import struct 7import sys 8import threading 9 10from . import connection 11from . import process 12from .context import reduction 13from . import semaphore_tracker 14from . import spawn 15from . import util 16 17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', 18 'set_forkserver_preload'] 19 20# 21# 22# 23 24MAXFDS_TO_SEND = 256 25UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t 26 27# 28# Forkserver class 29# 30 31class ForkServer(object): 32 33 def __init__(self): 34 self._forkserver_address = None 35 self._forkserver_alive_fd = None 36 self._forkserver_pid = None 37 self._inherited_fds = None 38 self._lock = threading.Lock() 39 self._preload_modules = ['__main__'] 40 41 def set_forkserver_preload(self, modules_names): 42 '''Set list of module names to try to load in forkserver process.''' 43 if not all(type(mod) is str for mod in self._preload_modules): 44 raise TypeError('module_names must be a list of strings') 45 self._preload_modules = modules_names 46 47 def get_inherited_fds(self): 48 '''Return list of fds inherited from parent process. 49 50 This returns None if the current process was not started by fork 51 server. 52 ''' 53 return self._inherited_fds 54 55 def connect_to_new_process(self, fds): 56 '''Request forkserver to create a child process. 57 58 Returns a pair of fds (status_r, data_w). The calling process can read 59 the child process's pid and (eventually) its returncode from status_r. 60 The calling process should write to data_w the pickled preparation and 61 process data. 62 ''' 63 self.ensure_running() 64 if len(fds) + 4 >= MAXFDS_TO_SEND: 65 raise ValueError('too many fds') 66 with socket.socket(socket.AF_UNIX) as client: 67 client.connect(self._forkserver_address) 68 parent_r, child_w = os.pipe() 69 child_r, parent_w = os.pipe() 70 allfds = [child_r, child_w, self._forkserver_alive_fd, 71 semaphore_tracker.getfd()] 72 allfds += fds 73 try: 74 reduction.sendfds(client, allfds) 75 return parent_r, parent_w 76 except: 77 os.close(parent_r) 78 os.close(parent_w) 79 raise 80 finally: 81 os.close(child_r) 82 os.close(child_w) 83 84 def ensure_running(self): 85 '''Make sure that a fork server is running. 86 87 This can be called from any process. Note that usually a child 88 process will just reuse the forkserver started by its parent, so 89 ensure_running() will do nothing. 90 ''' 91 with self._lock: 92 semaphore_tracker.ensure_running() 93 if self._forkserver_pid is not None: 94 # forkserver was launched before, is it still running? 95 pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) 96 if not pid: 97 # still alive 98 return 99 # dead, launch it again 100 os.close(self._forkserver_alive_fd) 101 self._forkserver_address = None 102 self._forkserver_alive_fd = None 103 self._forkserver_pid = None 104 105 cmd = ('from multiprocess.forkserver import main; ' + 106 'main(%d, %d, %r, **%r)') 107 108 if self._preload_modules: 109 desired_keys = {'main_path', 'sys_path'} 110 data = spawn.get_preparation_data('ignore') 111 data = dict((x,y) for (x,y) in data.items() 112 if x in desired_keys) 113 else: 114 data = {} 115 116 with socket.socket(socket.AF_UNIX) as listener: 117 address = connection.arbitrary_address('AF_UNIX') 118 listener.bind(address) 119 os.chmod(address, 0o600) 120 listener.listen() 121 122 # all client processes own the write end of the "alive" pipe; 123 # when they all terminate the read end becomes ready. 124 alive_r, alive_w = os.pipe() 125 try: 126 fds_to_pass = [listener.fileno(), alive_r] 127 cmd %= (listener.fileno(), alive_r, self._preload_modules, 128 data) 129 exe = spawn.get_executable() 130 args = [exe] + util._args_from_interpreter_flags() 131 args += ['-c', cmd] 132 pid = util.spawnv_passfds(exe, args, fds_to_pass) 133 except: 134 os.close(alive_w) 135 raise 136 finally: 137 os.close(alive_r) 138 self._forkserver_address = address 139 self._forkserver_alive_fd = alive_w 140 self._forkserver_pid = pid 141 142# 143# 144# 145 146def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): 147 '''Run forkserver.''' 148 if preload: 149 if '__main__' in preload and main_path is not None: 150 process.current_process()._inheriting = True 151 try: 152 spawn.import_main_path(main_path) 153 finally: 154 del process.current_process()._inheriting 155 for modname in preload: 156 try: 157 __import__(modname) 158 except ImportError: 159 pass 160 161 util._close_stdin() 162 163 handlers = { 164 # no need to reap zombie processes; 165 signal.SIGCHLD: signal.SIG_IGN, 166 # protect the process from ^C 167 signal.SIGINT: signal.SIG_IGN, 168 } 169 old_handlers = {sig: signal.signal(sig, val) 170 for (sig, val) in handlers.items()} 171 172 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ 173 selectors.DefaultSelector() as selector: 174 _forkserver._forkserver_address = listener.getsockname() 175 176 selector.register(listener, selectors.EVENT_READ) 177 selector.register(alive_r, selectors.EVENT_READ) 178 179 while True: 180 try: 181 while True: 182 rfds = [key.fileobj for (key, events) in selector.select()] 183 if rfds: 184 break 185 186 if alive_r in rfds: 187 # EOF because no more client processes left 188 assert os.read(alive_r, 1) == b'' 189 raise SystemExit 190 191 assert listener in rfds 192 with listener.accept()[0] as s: 193 code = 1 194 if os.fork() == 0: 195 try: 196 _serve_one(s, listener, alive_r, old_handlers) 197 except Exception: 198 sys.excepthook(*sys.exc_info()) 199 sys.stderr.flush() 200 finally: 201 os._exit(code) 202 203 except OSError as e: 204 if e.errno != errno.ECONNABORTED: 205 raise 206 207def _serve_one(s, listener, alive_r, handlers): 208 # close unnecessary stuff and reset signal handlers 209 listener.close() 210 os.close(alive_r) 211 for sig, val in handlers.items(): 212 signal.signal(sig, val) 213 214 # receive fds from parent process 215 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) 216 s.close() 217 assert len(fds) <= MAXFDS_TO_SEND 218 (child_r, child_w, _forkserver._forkserver_alive_fd, 219 stfd, *_forkserver._inherited_fds) = fds 220 semaphore_tracker._semaphore_tracker._fd = stfd 221 222 # send pid to client processes 223 write_unsigned(child_w, os.getpid()) 224 225 # reseed random number generator 226 if 'random' in sys.modules: 227 import random 228 random.seed() 229 230 # run process object received over pipe 231 code = spawn._main(child_r) 232 233 # write the exit code to the pipe 234 write_unsigned(child_w, code) 235 236# 237# Read and write unsigned numbers 238# 239 240def read_unsigned(fd): 241 data = b'' 242 length = UNSIGNED_STRUCT.size 243 while len(data) < length: 244 s = os.read(fd, length - len(data)) 245 if not s: 246 raise EOFError('unexpected EOF') 247 data += s 248 return UNSIGNED_STRUCT.unpack(data)[0] 249 250def write_unsigned(fd, n): 251 msg = UNSIGNED_STRUCT.pack(n) 252 while msg: 253 nbytes = os.write(fd, msg) 254 if nbytes == 0: 255 raise RuntimeError('should not get here') 256 msg = msg[nbytes:] 257 258# 259# 260# 261 262_forkserver = ForkServer() 263ensure_running = _forkserver.ensure_running 264get_inherited_fds = _forkserver.get_inherited_fds 265connect_to_new_process = _forkserver.connect_to_new_process 266set_forkserver_preload = _forkserver.set_forkserver_preload 267