1# channel.py
2#
3# Support for a message passing channel that can send bytes or pickled
4# Python objects on a stream.  Compatible with the Connection class in the
5# multiprocessing module, but rewritten for a purely asynchronous runtime.
6
7__all__ = ['Channel']
8
9# -- Standard Library
10
11import os
12import pickle
13import struct
14import hmac
15import multiprocessing.connection as mpc
16import logging
17
18log = logging.getLogger(__name__)
19
20# -- Curio
21
22from . import socket
23from .errors import CurioError, TaskTimeout
24from .io import StreamBase, FileStream
25from . import thread
26from .time import timeout_after, sleep
27
28# Authentication parameters (copied from multiprocessing)
29
30AUTH_MESSAGE_LENGTH = mpc.MESSAGE_LENGTH    # 20
31CHALLENGE = mpc.CHALLENGE                   # b'#CHALLENGE#'
32WELCOME = mpc.WELCOME                       # b'#WELCOME#'
33FAILURE = mpc.FAILURE                       # b'#FAILURE#'
34
35
36
37class ConnectionError(CurioError):
38    pass
39
40
41class AuthenticationError(ConnectionError):
42    pass
43
44
45class Connection(object):
46    '''
47    A communication channel for sending size-prefixed messages of bytes
48    or pickled Python objects.  Must be passed a pair of reader/writer
49    streams for performing the underlying communication.
50    '''
51
52    def __init__(self, reader, writer):
53        assert isinstance(reader, StreamBase) and isinstance(writer, StreamBase)
54        self._reader = reader
55        self._writer = writer
56
57    @classmethod
58    def from_Connection(cls, conn):
59        '''
60        Creates a channel from a multiprocessing Connection. Note: The
61        multiprocessing connection is detached by having its handle set to None.
62
63        This method can be used to make curio talk over Pipes as created by
64        multiprocessing.  For example:
65
66              p1, p2 = multiprocessing.Pipe()
67              p1 = Connection.from_Connection(p1)
68              p2 = Connection.from_Connection(p2)
69
70        '''
71        assert isinstance(conn, mpc._ConnectionBase)
72        reader = FileStream(open(conn._handle, 'rb', buffering=0))
73        writer = FileStream(open(conn._handle, 'wb', buffering=0, closefd=False))
74        conn._handle = None
75        return cls(reader, writer)
76
77    async def __aenter__(self):
78        return self
79
80    async def __aexit__(self, *args):
81        await self.close()
82
83    def __enter__(self):
84        return thread.AWAIT(self.__aenter__())
85
86    def __exit__(self, *args):
87        return thread.AWAIT(self.__aexit__(*args))
88
89    async def close(self):
90        await self._reader.close()
91        if self._reader != self._writer:
92            await self._writer.close()
93
94    async def send_bytes(self, buf, offset=0, size=None):
95        '''
96        Send a buffer of bytes as a single message
97        '''
98        m = memoryview(buf)
99        if m.itemsize > 1:
100            m = memoryview(bytes(m))
101        n = len(m)
102        if offset < 0:
103            raise ValueError("offset is negative")
104        if n < offset:
105            raise ValueError("buffer length < offset")
106        if size is None:
107            size = n - offset
108        elif size < 0:
109            raise ValueError("size is negative")
110        elif offset + size > n:
111            raise ValueError("buffer length < offset + size")
112
113        header = struct.pack('!i', size)
114        if size >= 16384:
115            await self._writer.write(header)
116            await self._writer.write(m[offset:offset + size])
117        else:
118            msg = header + bytes(m[offset:offset + size])
119            await self._writer.write(msg)
120        return size
121
122    async def recv_bytes(self, maxlength=None):
123        '''
124        Receive a message of bytes as a single message.
125        '''
126        header = await self._reader.read_exactly(4)
127        size, = struct.unpack('!i', header)
128        if maxlength and size > maxlength:
129            raise IOError("Message too large")
130        msg = await self._reader.read_exactly(size)
131        return msg
132
133    async def recv_bytes_into(self, buf, offset=0):
134        '''
135        Receive bytes into a writable memory buffer.  The buffer must be large enough to
136        hold the message.  The number of bytes received in the message is returned.
137        '''
138        header = await self._reader.read_exactly(4)
139        size, = struct.unpack('!i', header)
140        with memoryview(buf).cast('B') as m:
141            if size > (len(m) - offset):
142                # Message is too large to fit in allotted space
143                # Drain the I/O and raise an error
144                while size > 0:
145                    data = await self._reader.read(size)
146                    if not data:
147                        break
148                    size -= len(data)
149                raise IOError('Message is too large to fit')
150            nread = await self._reader.readinto(m[offset:offset+size])
151            if nread != size:
152                raise EOFError('Expected end of data')
153            return nread
154
155    async def send(self, obj):
156        '''
157        Send an arbitrary Python object. Uses pickle to serialize.
158        '''
159        await self.send_bytes(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL))
160
161    async def recv(self):
162        '''
163        Receive a Python object. Uses pickle to unserialize.
164        '''
165        msg = await self.recv_bytes()
166        return pickle.loads(msg)
167
168    async def _deliver_challenge(self, authkey):
169        message = os.urandom(AUTH_MESSAGE_LENGTH)
170        await self.send_bytes(CHALLENGE + message)
171        digest = hmac.new(authkey, message, 'md5').digest()
172        response = await self.recv_bytes(maxlength=256)
173        if response == digest:
174            await self.send_bytes(WELCOME)
175        else:
176            await self.send_bytes(FAILURE)
177            raise AuthenticationError('digest received was wrong')
178
179    async def _answer_challenge(self, authkey):
180        message = await self.recv_bytes(maxlength=256)
181        assert message[:len(CHALLENGE)] == CHALLENGE, f'message = {message!r}'
182        message = message[len(CHALLENGE):]
183        digest = hmac.new(authkey, message, 'md5').digest()
184        await self.send_bytes(digest)
185        response = await self.recv_bytes(maxlength=256)
186
187        if response != WELCOME:
188            raise AuthenticationError('digest sent was rejected')
189
190    async def authenticate_server(self, authkey):
191        await self._deliver_challenge(authkey)
192        await self._answer_challenge(authkey)
193
194    async def authenticate_client(self, authkey):
195        await self._answer_challenge(authkey)
196        await self._deliver_challenge(authkey)
197
198class Channel(object):
199    def __init__(self, address, family=socket.AF_INET, check_address=None):
200        self.address = address
201        self.family = family
202        self.sock = None
203        if check_address:
204            self.check_address = check_address
205
206    def __repr__(self):
207        return f'Channel({self.address!r}, {self.family!r})'
208
209    async def __aenter__(self):
210        return self
211
212    async def __aexit__(self, ty, val, tb):
213        await self.close()
214
215    def __getstate__(self):
216        return (self.address, self.family)
217
218    def __setstate__(self, state):
219        self.address, self.family = state
220        self.sock = None
221
222    def bind(self):
223        self.sock = socket.socket(self.family, socket.SOCK_STREAM)
224        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
225        self.sock.bind(self.address)
226        self.sock.listen(5)
227        self.address = self.sock.getsockname()
228
229    def check_address(self, addr):
230        return True
231
232    async def accept(self, *, authkey=None):
233        if self.sock is None:
234            self.bind()
235
236        while True:
237            client, addr = await self.sock.accept()
238            if not self.check_address(addr):
239                log.warning('Channel connection from %s rejected', addr)
240                await client.close()
241                del client
242                continue
243
244            client_stream = client.as_stream()
245            c = Connection(client_stream, client_stream)
246            c.address = addr
247            try:
248                async with timeout_after(1):
249                    if authkey:
250                        await c.authenticate_server(authkey)
251                break
252            except (TaskTimeout, AuthenticationError, EOFError):
253                log.warning('Channel connection from %s failed', addr, exc_info=True)
254                await c.close()
255                del c
256                del client_stream
257                del client
258        return c
259
260    async def connect(self, *, authkey=None):
261        sock = socket.socket(self.family, socket.SOCK_STREAM)
262        await sock.connect(self.address)
263        sock_stream = sock.as_stream()
264        c = Connection(sock_stream, sock_stream)
265        try:
266            async with timeout_after(1):
267                if authkey:
268                    await c.authenticate_client(authkey)
269            return c
270        except TaskTimeout:
271            log.warning('Channel connection to %s timed out', self.address)
272            await c.close()
273            del c
274            del sock_stream
275            # Note: Raising an OSError.
276            raise TimeoutError("Connection timed out")
277
278    async def close(self):
279        if self.sock:
280            await self.sock.close()
281        self.sock = None
282