1# channel.py 2# 3# Support for a message passing channel that can send bytes or pickled 4# Python objects on a stream. Compatible with the Connection class in the 5# multiprocessing module, but rewritten for a purely asynchronous runtime. 6 7__all__ = ['Channel'] 8 9# -- Standard Library 10 11import os 12import pickle 13import struct 14import hmac 15import multiprocessing.connection as mpc 16import logging 17 18log = logging.getLogger(__name__) 19 20# -- Curio 21 22from . import socket 23from .errors import CurioError, TaskTimeout 24from .io import StreamBase, FileStream 25from . import thread 26from .time import timeout_after, sleep 27 28# Authentication parameters (copied from multiprocessing) 29 30AUTH_MESSAGE_LENGTH = mpc.MESSAGE_LENGTH # 20 31CHALLENGE = mpc.CHALLENGE # b'#CHALLENGE#' 32WELCOME = mpc.WELCOME # b'#WELCOME#' 33FAILURE = mpc.FAILURE # b'#FAILURE#' 34 35 36 37class ConnectionError(CurioError): 38 pass 39 40 41class AuthenticationError(ConnectionError): 42 pass 43 44 45class Connection(object): 46 ''' 47 A communication channel for sending size-prefixed messages of bytes 48 or pickled Python objects. Must be passed a pair of reader/writer 49 streams for performing the underlying communication. 50 ''' 51 52 def __init__(self, reader, writer): 53 assert isinstance(reader, StreamBase) and isinstance(writer, StreamBase) 54 self._reader = reader 55 self._writer = writer 56 57 @classmethod 58 def from_Connection(cls, conn): 59 ''' 60 Creates a channel from a multiprocessing Connection. Note: The 61 multiprocessing connection is detached by having its handle set to None. 62 63 This method can be used to make curio talk over Pipes as created by 64 multiprocessing. For example: 65 66 p1, p2 = multiprocessing.Pipe() 67 p1 = Connection.from_Connection(p1) 68 p2 = Connection.from_Connection(p2) 69 70 ''' 71 assert isinstance(conn, mpc._ConnectionBase) 72 reader = FileStream(open(conn._handle, 'rb', buffering=0)) 73 writer = FileStream(open(conn._handle, 'wb', buffering=0, closefd=False)) 74 conn._handle = None 75 return cls(reader, writer) 76 77 async def __aenter__(self): 78 return self 79 80 async def __aexit__(self, *args): 81 await self.close() 82 83 def __enter__(self): 84 return thread.AWAIT(self.__aenter__()) 85 86 def __exit__(self, *args): 87 return thread.AWAIT(self.__aexit__(*args)) 88 89 async def close(self): 90 await self._reader.close() 91 if self._reader != self._writer: 92 await self._writer.close() 93 94 async def send_bytes(self, buf, offset=0, size=None): 95 ''' 96 Send a buffer of bytes as a single message 97 ''' 98 m = memoryview(buf) 99 if m.itemsize > 1: 100 m = memoryview(bytes(m)) 101 n = len(m) 102 if offset < 0: 103 raise ValueError("offset is negative") 104 if n < offset: 105 raise ValueError("buffer length < offset") 106 if size is None: 107 size = n - offset 108 elif size < 0: 109 raise ValueError("size is negative") 110 elif offset + size > n: 111 raise ValueError("buffer length < offset + size") 112 113 header = struct.pack('!i', size) 114 if size >= 16384: 115 await self._writer.write(header) 116 await self._writer.write(m[offset:offset + size]) 117 else: 118 msg = header + bytes(m[offset:offset + size]) 119 await self._writer.write(msg) 120 return size 121 122 async def recv_bytes(self, maxlength=None): 123 ''' 124 Receive a message of bytes as a single message. 125 ''' 126 header = await self._reader.read_exactly(4) 127 size, = struct.unpack('!i', header) 128 if maxlength and size > maxlength: 129 raise IOError("Message too large") 130 msg = await self._reader.read_exactly(size) 131 return msg 132 133 async def recv_bytes_into(self, buf, offset=0): 134 ''' 135 Receive bytes into a writable memory buffer. The buffer must be large enough to 136 hold the message. The number of bytes received in the message is returned. 137 ''' 138 header = await self._reader.read_exactly(4) 139 size, = struct.unpack('!i', header) 140 with memoryview(buf).cast('B') as m: 141 if size > (len(m) - offset): 142 # Message is too large to fit in allotted space 143 # Drain the I/O and raise an error 144 while size > 0: 145 data = await self._reader.read(size) 146 if not data: 147 break 148 size -= len(data) 149 raise IOError('Message is too large to fit') 150 nread = await self._reader.readinto(m[offset:offset+size]) 151 if nread != size: 152 raise EOFError('Expected end of data') 153 return nread 154 155 async def send(self, obj): 156 ''' 157 Send an arbitrary Python object. Uses pickle to serialize. 158 ''' 159 await self.send_bytes(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)) 160 161 async def recv(self): 162 ''' 163 Receive a Python object. Uses pickle to unserialize. 164 ''' 165 msg = await self.recv_bytes() 166 return pickle.loads(msg) 167 168 async def _deliver_challenge(self, authkey): 169 message = os.urandom(AUTH_MESSAGE_LENGTH) 170 await self.send_bytes(CHALLENGE + message) 171 digest = hmac.new(authkey, message, 'md5').digest() 172 response = await self.recv_bytes(maxlength=256) 173 if response == digest: 174 await self.send_bytes(WELCOME) 175 else: 176 await self.send_bytes(FAILURE) 177 raise AuthenticationError('digest received was wrong') 178 179 async def _answer_challenge(self, authkey): 180 message = await self.recv_bytes(maxlength=256) 181 assert message[:len(CHALLENGE)] == CHALLENGE, f'message = {message!r}' 182 message = message[len(CHALLENGE):] 183 digest = hmac.new(authkey, message, 'md5').digest() 184 await self.send_bytes(digest) 185 response = await self.recv_bytes(maxlength=256) 186 187 if response != WELCOME: 188 raise AuthenticationError('digest sent was rejected') 189 190 async def authenticate_server(self, authkey): 191 await self._deliver_challenge(authkey) 192 await self._answer_challenge(authkey) 193 194 async def authenticate_client(self, authkey): 195 await self._answer_challenge(authkey) 196 await self._deliver_challenge(authkey) 197 198class Channel(object): 199 def __init__(self, address, family=socket.AF_INET, check_address=None): 200 self.address = address 201 self.family = family 202 self.sock = None 203 if check_address: 204 self.check_address = check_address 205 206 def __repr__(self): 207 return f'Channel({self.address!r}, {self.family!r})' 208 209 async def __aenter__(self): 210 return self 211 212 async def __aexit__(self, ty, val, tb): 213 await self.close() 214 215 def __getstate__(self): 216 return (self.address, self.family) 217 218 def __setstate__(self, state): 219 self.address, self.family = state 220 self.sock = None 221 222 def bind(self): 223 self.sock = socket.socket(self.family, socket.SOCK_STREAM) 224 self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) 225 self.sock.bind(self.address) 226 self.sock.listen(5) 227 self.address = self.sock.getsockname() 228 229 def check_address(self, addr): 230 return True 231 232 async def accept(self, *, authkey=None): 233 if self.sock is None: 234 self.bind() 235 236 while True: 237 client, addr = await self.sock.accept() 238 if not self.check_address(addr): 239 log.warning('Channel connection from %s rejected', addr) 240 await client.close() 241 del client 242 continue 243 244 client_stream = client.as_stream() 245 c = Connection(client_stream, client_stream) 246 c.address = addr 247 try: 248 async with timeout_after(1): 249 if authkey: 250 await c.authenticate_server(authkey) 251 break 252 except (TaskTimeout, AuthenticationError, EOFError): 253 log.warning('Channel connection from %s failed', addr, exc_info=True) 254 await c.close() 255 del c 256 del client_stream 257 del client 258 return c 259 260 async def connect(self, *, authkey=None): 261 sock = socket.socket(self.family, socket.SOCK_STREAM) 262 await sock.connect(self.address) 263 sock_stream = sock.as_stream() 264 c = Connection(sock_stream, sock_stream) 265 try: 266 async with timeout_after(1): 267 if authkey: 268 await c.authenticate_client(authkey) 269 return c 270 except TaskTimeout: 271 log.warning('Channel connection to %s timed out', self.address) 272 await c.close() 273 del c 274 del sock_stream 275 # Note: Raising an OSError. 276 raise TimeoutError("Connection timed out") 277 278 async def close(self): 279 if self.sock: 280 await self.sock.close() 281 self.sock = None 282