1# Copyright (c) 2019, Neil Booth
2#
3# All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining
8# a copy of this software and associated documentation files (the
9# "Software"), to deal in the Software without restriction, including
10# without limitation the rights to use, copy, modify, merge, publish,
11# distribute, sublicense, and/or sell copies of the Software, and to
12# permit persons to whom the Software is furnished to do so, subject to
13# the following conditions:
14#
15# The above copyright notice and this permission notice shall be
16# included in all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
22# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
24# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25
26'''Asyncio protocol abstraction.'''
27
28__all__ = ('connect_rs', 'serve_rs')
29
30
31import asyncio
32from functools import partial
33
34from aiorpcx.curio import Event, timeout_after, TaskTimeout
35from aiorpcx.session import RPCSession, SessionBase, SessionKind
36from aiorpcx.util import NetAddress
37
38
39class ConnectionLostError(Exception):
40    pass
41
42
43class RSTransport(asyncio.Protocol):
44
45    def __init__(self, session_factory, framer, kind):
46        self.session_factory = session_factory
47        self.loop = asyncio.get_event_loop()
48        self.session = None
49        self.kind = kind
50        self._proxy = None
51        self._asyncio_transport = None
52        self._remote_address = None
53        self._framer = framer
54        # Cleared when the send socket is full
55        self._can_send = Event()
56        self._can_send.set()
57        self._closed_event = Event()
58        self._process_messages_task = None
59
60    async def process_messages(self):
61        try:
62            await self.session.process_messages(self.receive_message)
63        except ConnectionLostError:
64            pass
65        finally:
66            self._closed_event.set()
67            await self.session.connection_lost()
68
69    async def receive_message(self):
70        return await self._framer.receive_message()
71
72    def connection_made(self, transport):
73        '''Called by asyncio when a connection is established.'''
74        self._asyncio_transport = transport
75        # If the Socks proxy was used then _proxy and _remote_address are already set
76        if self._proxy is None:
77            # This would throw if called on a closed SSL transport.  Fixed in asyncio in
78            # Python 3.6.1 and 3.5.4
79            peername = transport.get_extra_info('peername')
80            self._remote_address = NetAddress(peername[0], peername[1])
81        self.session = self.session_factory(self)
82        self._framer = self._framer or self.session.default_framer()
83        self._process_messages_task = self.loop.create_task(self.process_messages())
84
85    def connection_lost(self, exc):
86        '''Called by asyncio when the connection closes.
87
88        Tear down things done in connection_made.'''
89        # If works around a uvloop bug; see https://github.com/MagicStack/uvloop/issues/246
90        if not self._asyncio_transport:
91            return
92        # Release waiting tasks
93        self._can_send.set()
94        self._framer.fail(ConnectionLostError())
95
96    def data_received(self, data):
97        '''Called by asyncio when a message comes in.'''
98        self.session.data_received(data)
99        self._framer.received_bytes(data)
100
101    def pause_writing(self):
102        '''Called by asyncio the send buffer is full.'''
103        if not self.is_closing():
104            self._can_send.clear()
105            self._asyncio_transport.pause_reading()
106
107    def resume_writing(self):
108        '''Called by asyncio the send buffer has room.'''
109        if not self._can_send.is_set():
110            self._can_send.set()
111            self._asyncio_transport.resume_reading()
112
113    # API exposed to session
114    async def write(self, message):
115        await self._can_send.wait()
116        if not self.is_closing():
117            framed_message = self._framer.frame(message)
118            self._asyncio_transport.write(framed_message)
119
120    async def close(self, force_after):
121        '''Close the connection and return when closed.'''
122        if self._asyncio_transport:
123            self._asyncio_transport.close()
124            try:
125                async with timeout_after(force_after):
126                    await self._closed_event.wait()
127            except TaskTimeout:
128                await self.abort()
129                await self._closed_event.wait()
130
131    async def abort(self):
132        if self._asyncio_transport:
133            self._asyncio_transport.abort()
134
135    def is_closing(self):
136        '''Return True if the connection is closing.'''
137        return self._closed_event.is_set() or self._asyncio_transport.is_closing()
138
139    def proxy(self):
140        return self._proxy
141
142    def remote_address(self):
143        return self._remote_address
144
145
146class RSClient:
147
148    def __init__(self, host=None, port=None, proxy=None, *, framer=None, **kwargs):
149        session_factory = kwargs.pop('session_factory', RPCSession)
150        self.protocol_factory = partial(RSTransport, session_factory, framer,
151                                        SessionKind.CLIENT)
152        self.host = host
153        self.port = port
154        self.proxy = proxy
155        self.session = None
156        self.loop = kwargs.get('loop', asyncio.get_event_loop())
157        self.kwargs = kwargs
158
159    async def create_connection(self):
160        '''Initiate a connection.'''
161        connector = self.proxy or self.loop
162        return await connector.create_connection(
163            self.protocol_factory, self.host, self.port, **self.kwargs)
164
165    async def __aenter__(self):
166        _transport, protocol = await self.create_connection()
167        self.session = protocol.session
168        assert isinstance(self.session, SessionBase)
169        return self.session
170
171    async def __aexit__(self, exc_type, exc_value, traceback):
172        await self.session.close()
173
174
175async def serve_rs(session_factory, host=None, port=None, *, framer=None, loop=None, **kwargs):
176    loop = loop or asyncio.get_event_loop()
177    protocol_factory = partial(RSTransport, session_factory, framer, SessionKind.SERVER)
178    return await loop.create_server(protocol_factory, host, port, **kwargs)
179
180
181connect_rs = RSClient
182