1import asyncio 2import itertools 3import logging 4import os 5import threading 6import warnings 7import weakref 8from collections import deque, namedtuple 9 10from tornado.concurrent import Future 11from tornado.ioloop import IOLoop 12 13from ..protocol import nested_deserialize 14from ..utils import get_ip 15from .core import Comm, CommClosedError, Connector, Listener 16from .registry import Backend, backends 17 18logger = logging.getLogger(__name__) 19 20ConnectionRequest = namedtuple( 21 "ConnectionRequest", ("c2s_q", "s2c_q", "c_loop", "c_addr", "conn_event") 22) 23 24 25class Manager: 26 """ 27 An object coordinating listeners and their addresses. 28 """ 29 30 def __init__(self): 31 self.listeners = weakref.WeakValueDictionary() 32 self.addr_suffixes = itertools.count(1) 33 with warnings.catch_warnings(): 34 # Avoid immediate warning for unreachable network 35 # (will still warn for other get_ip() calls when actually used) 36 warnings.simplefilter("ignore") 37 try: 38 self.ip = get_ip() 39 except OSError: 40 self.ip = "127.0.0.1" 41 self.lock = threading.Lock() 42 43 def add_listener(self, addr, listener): 44 with self.lock: 45 if addr in self.listeners: 46 raise RuntimeError(f"already listening on {addr!r}") 47 self.listeners[addr] = listener 48 49 def remove_listener(self, addr): 50 with self.lock: 51 try: 52 del self.listeners[addr] 53 except KeyError: 54 pass 55 56 def get_listener_for(self, addr): 57 with self.lock: 58 self.validate_address(addr) 59 return self.listeners.get(addr) 60 61 def new_address(self): 62 return "%s/%d/%s" % (self.ip, os.getpid(), next(self.addr_suffixes)) 63 64 def validate_address(self, addr): 65 """ 66 Validate the address' IP and pid. 67 """ 68 ip, pid, suffix = addr.split("/") 69 if ip != self.ip or int(pid) != os.getpid(): 70 raise ValueError( 71 "inproc address %r does not match host (%r) or pid (%r)" 72 % (addr, self.ip, os.getpid()) 73 ) 74 75 76global_manager = Manager() 77 78 79def new_address(): 80 """ 81 Generate a new address. 82 """ 83 return "inproc://" + global_manager.new_address() 84 85 86class QueueEmpty(Exception): 87 pass 88 89 90class Queue: 91 """ 92 A single-reader, single-writer, non-threadsafe, peekable queue. 93 """ 94 95 def __init__(self): 96 self._q = deque() 97 self._read_future = None 98 99 def get_nowait(self): 100 q = self._q 101 if not q: 102 raise QueueEmpty 103 return q.popleft() 104 105 def get(self): 106 assert not self._read_future, "Only one reader allowed" 107 fut = Future() 108 q = self._q 109 if q: 110 fut.set_result(q.popleft()) 111 else: 112 self._read_future = fut 113 return fut 114 115 def put_nowait(self, value): 116 q = self._q 117 fut = self._read_future 118 if fut is not None: 119 assert len(q) == 0 120 self._read_future = None 121 fut.set_result(value) 122 else: 123 q.append(value) 124 125 put = put_nowait 126 127 _omitted = object() 128 129 def peek(self, default=_omitted): 130 """ 131 Get the next object in the queue without removing it from the queue. 132 """ 133 q = self._q 134 if q: 135 return q[0] 136 elif default is not self._omitted: 137 return default 138 else: 139 raise QueueEmpty 140 141 142_EOF = object() 143 144 145class InProc(Comm): 146 """ 147 An established communication based on a pair of in-process queues. 148 149 Reminder: a Comm must always be used from a single thread. 150 Its peer Comm can be running in any thread. 151 """ 152 153 _initialized = False 154 155 def __init__( 156 self, 157 local_addr: str, 158 peer_addr: str, 159 read_q, 160 write_q, 161 write_loop, 162 deserialize=True, 163 ): 164 super().__init__() 165 self._local_addr = local_addr 166 self._peer_addr = peer_addr 167 self.deserialize = deserialize 168 self._read_q = read_q 169 self._write_q = write_q 170 self._write_loop = write_loop 171 self._closed = False 172 173 self._finalizer = weakref.finalize(self, self._get_finalizer()) 174 self._finalizer.atexit = False 175 self._initialized = True 176 177 def _get_finalizer(self): 178 def finalize(write_q=self._write_q, write_loop=self._write_loop, r=repr(self)): 179 logger.warning(f"Closing dangling queue in {r}") 180 write_loop.add_callback(write_q.put_nowait, _EOF) 181 182 return finalize 183 184 @property 185 def local_address(self) -> str: 186 return self._local_addr 187 188 @property 189 def peer_address(self) -> str: 190 return self._peer_addr 191 192 async def read(self, deserializers="ignored"): 193 if self._closed: 194 raise CommClosedError() 195 196 msg = await self._read_q.get() 197 if msg is _EOF: 198 self._closed = True 199 self._finalizer.detach() 200 raise CommClosedError() 201 202 if self.deserialize: 203 msg = nested_deserialize(msg) 204 return msg 205 206 async def write(self, msg, serializers=None, on_error=None): 207 if self.closed(): 208 raise CommClosedError() 209 210 # Ensure we feed the queue in the same thread it is read from. 211 self._write_loop.add_callback(self._write_q.put_nowait, msg) 212 213 return 1 214 215 async def close(self): 216 self.abort() 217 218 def abort(self): 219 if not self.closed(): 220 # Putting EOF is cheap enough that we do it on abort() too 221 self._write_loop.add_callback(self._write_q.put_nowait, _EOF) 222 self._read_q.put_nowait(_EOF) 223 self._write_q = self._read_q = None 224 self._closed = True 225 self._finalizer.detach() 226 227 def closed(self): 228 """ 229 Whether this comm is closed. An InProc comm is closed if: 230 1) close() or abort() was called on this comm 231 2) close() or abort() was called on the other end and the 232 read queue is empty 233 """ 234 if self._closed: 235 return True 236 # NOTE: repr() is called by finalize() during __init__()... 237 if self._initialized and self._read_q.peek(None) is _EOF: 238 self._closed = True 239 self._finalizer.detach() 240 return True 241 else: 242 return False 243 244 245class InProcListener(Listener): 246 prefix = "inproc" 247 248 def __init__(self, address, comm_handler, deserialize=True): 249 self.manager = global_manager 250 self.address = address or self.manager.new_address() 251 self.comm_handler = comm_handler 252 self.deserialize = deserialize 253 self.listen_q = Queue() 254 255 async def _listen(self): 256 while True: 257 conn_req = await self.listen_q.get() 258 if conn_req is None: 259 break 260 comm = InProc( 261 local_addr="inproc://" + self.address, 262 peer_addr="inproc://" + conn_req.c_addr, 263 read_q=conn_req.c2s_q, 264 write_q=conn_req.s2c_q, 265 write_loop=conn_req.c_loop, 266 deserialize=self.deserialize, 267 ) 268 # Notify connector 269 conn_req.c_loop.add_callback(conn_req.conn_event.set) 270 try: 271 await self.on_connection(comm) 272 except CommClosedError: 273 logger.debug("Connection closed before handshake completed") 274 return 275 IOLoop.current().add_callback(self.comm_handler, comm) 276 277 def connect_threadsafe(self, conn_req): 278 self.loop.add_callback(self.listen_q.put_nowait, conn_req) 279 280 async def start(self): 281 self.loop = IOLoop.current() 282 self._listen_future = asyncio.ensure_future(self._listen()) 283 self.manager.add_listener(self.address, self) 284 285 def stop(self): 286 self.listen_q.put_nowait(None) 287 self.manager.remove_listener(self.address) 288 289 @property 290 def listen_address(self): 291 return "inproc://" + self.address 292 293 @property 294 def contact_address(self): 295 return "inproc://" + self.address 296 297 298class InProcConnector(Connector): 299 def __init__(self, manager): 300 self.manager = manager 301 302 async def connect(self, address, deserialize=True, **connection_args): 303 listener = self.manager.get_listener_for(address) 304 if listener is None: 305 raise OSError(f"no endpoint for inproc address {address!r}") 306 307 conn_req = ConnectionRequest( 308 c2s_q=Queue(), 309 s2c_q=Queue(), 310 c_loop=IOLoop.current(), 311 c_addr=self.manager.new_address(), 312 conn_event=asyncio.Event(), 313 ) 314 listener.connect_threadsafe(conn_req) 315 # Wait for connection acknowledgement 316 # (do not pretend we're connected if the other comm never gets 317 # created, for example if the listener was stopped in the meantime) 318 await conn_req.conn_event.wait() 319 320 comm = InProc( 321 local_addr="inproc://" + conn_req.c_addr, 322 peer_addr="inproc://" + address, 323 read_q=conn_req.s2c_q, 324 write_q=conn_req.c2s_q, 325 write_loop=listener.loop, 326 deserialize=deserialize, 327 ) 328 return comm 329 330 331class InProcBackend(Backend): 332 manager = global_manager 333 334 # I/O 335 336 def get_connector(self): 337 return InProcConnector(self.manager) 338 339 def get_listener(self, loc, handle_comm, deserialize, **connection_args): 340 return InProcListener(loc, handle_comm, deserialize) 341 342 # Address handling 343 344 def get_address_host(self, loc): 345 self.manager.validate_address(loc) 346 return self.manager.ip 347 348 def resolve_address(self, loc): 349 return loc 350 351 def get_local_address_for(self, loc): 352 self.manager.validate_address(loc) 353 return self.manager.new_address() 354 355 356backends["inproc"] = InProcBackend() 357