1import asyncio
2import itertools
3import logging
4import os
5import threading
6import warnings
7import weakref
8from collections import deque, namedtuple
9
10from tornado.concurrent import Future
11from tornado.ioloop import IOLoop
12
13from ..protocol import nested_deserialize
14from ..utils import get_ip
15from .core import Comm, CommClosedError, Connector, Listener
16from .registry import Backend, backends
17
18logger = logging.getLogger(__name__)
19
20ConnectionRequest = namedtuple(
21    "ConnectionRequest", ("c2s_q", "s2c_q", "c_loop", "c_addr", "conn_event")
22)
23
24
25class Manager:
26    """
27    An object coordinating listeners and their addresses.
28    """
29
30    def __init__(self):
31        self.listeners = weakref.WeakValueDictionary()
32        self.addr_suffixes = itertools.count(1)
33        with warnings.catch_warnings():
34            # Avoid immediate warning for unreachable network
35            # (will still warn for other get_ip() calls when actually used)
36            warnings.simplefilter("ignore")
37            try:
38                self.ip = get_ip()
39            except OSError:
40                self.ip = "127.0.0.1"
41        self.lock = threading.Lock()
42
43    def add_listener(self, addr, listener):
44        with self.lock:
45            if addr in self.listeners:
46                raise RuntimeError(f"already listening on {addr!r}")
47            self.listeners[addr] = listener
48
49    def remove_listener(self, addr):
50        with self.lock:
51            try:
52                del self.listeners[addr]
53            except KeyError:
54                pass
55
56    def get_listener_for(self, addr):
57        with self.lock:
58            self.validate_address(addr)
59            return self.listeners.get(addr)
60
61    def new_address(self):
62        return "%s/%d/%s" % (self.ip, os.getpid(), next(self.addr_suffixes))
63
64    def validate_address(self, addr):
65        """
66        Validate the address' IP and pid.
67        """
68        ip, pid, suffix = addr.split("/")
69        if ip != self.ip or int(pid) != os.getpid():
70            raise ValueError(
71                "inproc address %r does not match host (%r) or pid (%r)"
72                % (addr, self.ip, os.getpid())
73            )
74
75
76global_manager = Manager()
77
78
79def new_address():
80    """
81    Generate a new address.
82    """
83    return "inproc://" + global_manager.new_address()
84
85
86class QueueEmpty(Exception):
87    pass
88
89
90class Queue:
91    """
92    A single-reader, single-writer, non-threadsafe, peekable queue.
93    """
94
95    def __init__(self):
96        self._q = deque()
97        self._read_future = None
98
99    def get_nowait(self):
100        q = self._q
101        if not q:
102            raise QueueEmpty
103        return q.popleft()
104
105    def get(self):
106        assert not self._read_future, "Only one reader allowed"
107        fut = Future()
108        q = self._q
109        if q:
110            fut.set_result(q.popleft())
111        else:
112            self._read_future = fut
113        return fut
114
115    def put_nowait(self, value):
116        q = self._q
117        fut = self._read_future
118        if fut is not None:
119            assert len(q) == 0
120            self._read_future = None
121            fut.set_result(value)
122        else:
123            q.append(value)
124
125    put = put_nowait
126
127    _omitted = object()
128
129    def peek(self, default=_omitted):
130        """
131        Get the next object in the queue without removing it from the queue.
132        """
133        q = self._q
134        if q:
135            return q[0]
136        elif default is not self._omitted:
137            return default
138        else:
139            raise QueueEmpty
140
141
142_EOF = object()
143
144
145class InProc(Comm):
146    """
147    An established communication based on a pair of in-process queues.
148
149    Reminder: a Comm must always be used from a single thread.
150    Its peer Comm can be running in any thread.
151    """
152
153    _initialized = False
154
155    def __init__(
156        self,
157        local_addr: str,
158        peer_addr: str,
159        read_q,
160        write_q,
161        write_loop,
162        deserialize=True,
163    ):
164        super().__init__()
165        self._local_addr = local_addr
166        self._peer_addr = peer_addr
167        self.deserialize = deserialize
168        self._read_q = read_q
169        self._write_q = write_q
170        self._write_loop = write_loop
171        self._closed = False
172
173        self._finalizer = weakref.finalize(self, self._get_finalizer())
174        self._finalizer.atexit = False
175        self._initialized = True
176
177    def _get_finalizer(self):
178        def finalize(write_q=self._write_q, write_loop=self._write_loop, r=repr(self)):
179            logger.warning(f"Closing dangling queue in {r}")
180            write_loop.add_callback(write_q.put_nowait, _EOF)
181
182        return finalize
183
184    @property
185    def local_address(self) -> str:
186        return self._local_addr
187
188    @property
189    def peer_address(self) -> str:
190        return self._peer_addr
191
192    async def read(self, deserializers="ignored"):
193        if self._closed:
194            raise CommClosedError()
195
196        msg = await self._read_q.get()
197        if msg is _EOF:
198            self._closed = True
199            self._finalizer.detach()
200            raise CommClosedError()
201
202        if self.deserialize:
203            msg = nested_deserialize(msg)
204        return msg
205
206    async def write(self, msg, serializers=None, on_error=None):
207        if self.closed():
208            raise CommClosedError()
209
210        # Ensure we feed the queue in the same thread it is read from.
211        self._write_loop.add_callback(self._write_q.put_nowait, msg)
212
213        return 1
214
215    async def close(self):
216        self.abort()
217
218    def abort(self):
219        if not self.closed():
220            # Putting EOF is cheap enough that we do it on abort() too
221            self._write_loop.add_callback(self._write_q.put_nowait, _EOF)
222            self._read_q.put_nowait(_EOF)
223            self._write_q = self._read_q = None
224            self._closed = True
225            self._finalizer.detach()
226
227    def closed(self):
228        """
229        Whether this comm is closed.  An InProc comm is closed if:
230            1) close() or abort() was called on this comm
231            2) close() or abort() was called on the other end and the
232               read queue is empty
233        """
234        if self._closed:
235            return True
236        # NOTE: repr() is called by finalize() during __init__()...
237        if self._initialized and self._read_q.peek(None) is _EOF:
238            self._closed = True
239            self._finalizer.detach()
240            return True
241        else:
242            return False
243
244
245class InProcListener(Listener):
246    prefix = "inproc"
247
248    def __init__(self, address, comm_handler, deserialize=True):
249        self.manager = global_manager
250        self.address = address or self.manager.new_address()
251        self.comm_handler = comm_handler
252        self.deserialize = deserialize
253        self.listen_q = Queue()
254
255    async def _listen(self):
256        while True:
257            conn_req = await self.listen_q.get()
258            if conn_req is None:
259                break
260            comm = InProc(
261                local_addr="inproc://" + self.address,
262                peer_addr="inproc://" + conn_req.c_addr,
263                read_q=conn_req.c2s_q,
264                write_q=conn_req.s2c_q,
265                write_loop=conn_req.c_loop,
266                deserialize=self.deserialize,
267            )
268            # Notify connector
269            conn_req.c_loop.add_callback(conn_req.conn_event.set)
270            try:
271                await self.on_connection(comm)
272            except CommClosedError:
273                logger.debug("Connection closed before handshake completed")
274                return
275            IOLoop.current().add_callback(self.comm_handler, comm)
276
277    def connect_threadsafe(self, conn_req):
278        self.loop.add_callback(self.listen_q.put_nowait, conn_req)
279
280    async def start(self):
281        self.loop = IOLoop.current()
282        self._listen_future = asyncio.ensure_future(self._listen())
283        self.manager.add_listener(self.address, self)
284
285    def stop(self):
286        self.listen_q.put_nowait(None)
287        self.manager.remove_listener(self.address)
288
289    @property
290    def listen_address(self):
291        return "inproc://" + self.address
292
293    @property
294    def contact_address(self):
295        return "inproc://" + self.address
296
297
298class InProcConnector(Connector):
299    def __init__(self, manager):
300        self.manager = manager
301
302    async def connect(self, address, deserialize=True, **connection_args):
303        listener = self.manager.get_listener_for(address)
304        if listener is None:
305            raise OSError(f"no endpoint for inproc address {address!r}")
306
307        conn_req = ConnectionRequest(
308            c2s_q=Queue(),
309            s2c_q=Queue(),
310            c_loop=IOLoop.current(),
311            c_addr=self.manager.new_address(),
312            conn_event=asyncio.Event(),
313        )
314        listener.connect_threadsafe(conn_req)
315        # Wait for connection acknowledgement
316        # (do not pretend we're connected if the other comm never gets
317        #  created, for example if the listener was stopped in the meantime)
318        await conn_req.conn_event.wait()
319
320        comm = InProc(
321            local_addr="inproc://" + conn_req.c_addr,
322            peer_addr="inproc://" + address,
323            read_q=conn_req.s2c_q,
324            write_q=conn_req.c2s_q,
325            write_loop=listener.loop,
326            deserialize=deserialize,
327        )
328        return comm
329
330
331class InProcBackend(Backend):
332    manager = global_manager
333
334    # I/O
335
336    def get_connector(self):
337        return InProcConnector(self.manager)
338
339    def get_listener(self, loc, handle_comm, deserialize, **connection_args):
340        return InProcListener(loc, handle_comm, deserialize)
341
342    # Address handling
343
344    def get_address_host(self, loc):
345        self.manager.validate_address(loc)
346        return self.manager.ip
347
348    def resolve_address(self, loc):
349        return loc
350
351    def get_local_address_for(self, loc):
352        self.manager.validate_address(loc)
353        return self.manager.new_address()
354
355
356backends["inproc"] = InProcBackend()
357