1# Copyright (c) 2019, Neil Booth 2# 3# All rights reserved. 4# 5# The MIT License (MIT) 6# 7# Permission is hereby granted, free of charge, to any person obtaining 8# a copy of this software and associated documentation files (the 9# "Software"), to deal in the Software without restriction, including 10# without limitation the rights to use, copy, modify, merge, publish, 11# distribute, sublicense, and/or sell copies of the Software, and to 12# permit persons to whom the Software is furnished to do so, subject to 13# the following conditions: 14# 15# The above copyright notice and this permission notice shall be 16# included in all copies or substantial portions of the Software. 17# 18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 19# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 20# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 21# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 22# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 24# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 25 26'''Asyncio protocol abstraction.''' 27 28__all__ = ('connect_rs', 'serve_rs') 29 30 31import asyncio 32from functools import partial 33 34from aiorpcx.curio import Event, timeout_after, TaskTimeout 35from aiorpcx.session import RPCSession, SessionBase, SessionKind 36from aiorpcx.util import NetAddress 37 38 39class ConnectionLostError(Exception): 40 pass 41 42 43class RSTransport(asyncio.Protocol): 44 45 def __init__(self, session_factory, framer, kind): 46 self.session_factory = session_factory 47 self.loop = asyncio.get_event_loop() 48 self.session = None 49 self.kind = kind 50 self._proxy = None 51 self._asyncio_transport = None 52 self._remote_address = None 53 self._framer = framer 54 # Cleared when the send socket is full 55 self._can_send = Event() 56 self._can_send.set() 57 self._closed_event = Event() 58 self._process_messages_task = None 59 60 async def process_messages(self): 61 try: 62 await self.session.process_messages(self.receive_message) 63 except ConnectionLostError: 64 pass 65 finally: 66 self._closed_event.set() 67 await self.session.connection_lost() 68 69 async def receive_message(self): 70 return await self._framer.receive_message() 71 72 def connection_made(self, transport): 73 '''Called by asyncio when a connection is established.''' 74 self._asyncio_transport = transport 75 # If the Socks proxy was used then _proxy and _remote_address are already set 76 if self._proxy is None: 77 # This would throw if called on a closed SSL transport. Fixed in asyncio in 78 # Python 3.6.1 and 3.5.4 79 peername = transport.get_extra_info('peername') 80 self._remote_address = NetAddress(peername[0], peername[1]) 81 self.session = self.session_factory(self) 82 self._framer = self._framer or self.session.default_framer() 83 self._process_messages_task = self.loop.create_task(self.process_messages()) 84 85 def connection_lost(self, exc): 86 '''Called by asyncio when the connection closes. 87 88 Tear down things done in connection_made.''' 89 # If works around a uvloop bug; see https://github.com/MagicStack/uvloop/issues/246 90 if not self._asyncio_transport: 91 return 92 # Release waiting tasks 93 self._can_send.set() 94 self._framer.fail(ConnectionLostError()) 95 96 def data_received(self, data): 97 '''Called by asyncio when a message comes in.''' 98 self.session.data_received(data) 99 self._framer.received_bytes(data) 100 101 def pause_writing(self): 102 '''Called by asyncio the send buffer is full.''' 103 if not self.is_closing(): 104 self._can_send.clear() 105 self._asyncio_transport.pause_reading() 106 107 def resume_writing(self): 108 '''Called by asyncio the send buffer has room.''' 109 if not self._can_send.is_set(): 110 self._can_send.set() 111 self._asyncio_transport.resume_reading() 112 113 # API exposed to session 114 async def write(self, message): 115 await self._can_send.wait() 116 if not self.is_closing(): 117 framed_message = self._framer.frame(message) 118 self._asyncio_transport.write(framed_message) 119 120 async def close(self, force_after): 121 '''Close the connection and return when closed.''' 122 if self._asyncio_transport: 123 self._asyncio_transport.close() 124 try: 125 async with timeout_after(force_after): 126 await self._closed_event.wait() 127 except TaskTimeout: 128 await self.abort() 129 await self._closed_event.wait() 130 131 async def abort(self): 132 if self._asyncio_transport: 133 self._asyncio_transport.abort() 134 135 def is_closing(self): 136 '''Return True if the connection is closing.''' 137 return self._closed_event.is_set() or self._asyncio_transport.is_closing() 138 139 def proxy(self): 140 return self._proxy 141 142 def remote_address(self): 143 return self._remote_address 144 145 146class RSClient: 147 148 def __init__(self, host=None, port=None, proxy=None, *, framer=None, **kwargs): 149 session_factory = kwargs.pop('session_factory', RPCSession) 150 self.protocol_factory = partial(RSTransport, session_factory, framer, 151 SessionKind.CLIENT) 152 self.host = host 153 self.port = port 154 self.proxy = proxy 155 self.session = None 156 self.loop = kwargs.get('loop', asyncio.get_event_loop()) 157 self.kwargs = kwargs 158 159 async def create_connection(self): 160 '''Initiate a connection.''' 161 connector = self.proxy or self.loop 162 return await connector.create_connection( 163 self.protocol_factory, self.host, self.port, **self.kwargs) 164 165 async def __aenter__(self): 166 _transport, protocol = await self.create_connection() 167 self.session = protocol.session 168 assert isinstance(self.session, SessionBase) 169 return self.session 170 171 async def __aexit__(self, exc_type, exc_value, traceback): 172 await self.session.close() 173 174 175async def serve_rs(session_factory, host=None, port=None, *, framer=None, loop=None, **kwargs): 176 loop = loop or asyncio.get_event_loop() 177 protocol_factory = partial(RSTransport, session_factory, framer, SessionKind.SERVER) 178 return await loop.create_server(protocol_factory, host, port, **kwargs) 179 180 181connect_rs = RSClient 182