1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3"""trio async I/O library query support"""
4
5import socket
6import trio
7import trio.socket  # type: ignore
8
9import dns._asyncbackend
10import dns.exception
11import dns.inet
12
13
14def _maybe_timeout(timeout):
15    if timeout:
16        return trio.move_on_after(timeout)
17    else:
18        return dns._asyncbackend.NullContext()
19
20
21# for brevity
22_lltuple = dns.inet.low_level_address_tuple
23
24# pylint: disable=redefined-outer-name
25
26
27class DatagramSocket(dns._asyncbackend.DatagramSocket):
28    def __init__(self, socket):
29        self.socket = socket
30        self.family = socket.family
31
32    async def sendto(self, what, destination, timeout):
33        with _maybe_timeout(timeout):
34            return await self.socket.sendto(what, destination)
35        raise dns.exception.Timeout(timeout=timeout)  # pragma: no cover
36
37    async def recvfrom(self, size, timeout):
38        with _maybe_timeout(timeout):
39            return await self.socket.recvfrom(size)
40        raise dns.exception.Timeout(timeout=timeout)
41
42    async def close(self):
43        self.socket.close()
44
45    async def getpeername(self):
46        return self.socket.getpeername()
47
48    async def getsockname(self):
49        return self.socket.getsockname()
50
51
52class StreamSocket(dns._asyncbackend.DatagramSocket):
53    def __init__(self, family, stream, tls=False):
54        self.family = family
55        self.stream = stream
56        self.tls = tls
57
58    async def sendall(self, what, timeout):
59        with _maybe_timeout(timeout):
60            return await self.stream.send_all(what)
61        raise dns.exception.Timeout(timeout=timeout)
62
63    async def recv(self, size, timeout):
64        with _maybe_timeout(timeout):
65            return await self.stream.receive_some(size)
66        raise dns.exception.Timeout(timeout=timeout)
67
68    async def close(self):
69        await self.stream.aclose()
70
71    async def getpeername(self):
72        if self.tls:
73            return self.stream.transport_stream.socket.getpeername()
74        else:
75            return self.stream.socket.getpeername()
76
77    async def getsockname(self):
78        if self.tls:
79            return self.stream.transport_stream.socket.getsockname()
80        else:
81            return self.stream.socket.getsockname()
82
83
84class Backend(dns._asyncbackend.Backend):
85    def name(self):
86        return 'trio'
87
88    async def make_socket(self, af, socktype, proto=0, source=None,
89                          destination=None, timeout=None,
90                          ssl_context=None, server_hostname=None):
91        s = trio.socket.socket(af, socktype, proto)
92        stream = None
93        try:
94            if source:
95                await s.bind(_lltuple(source, af))
96            if socktype == socket.SOCK_STREAM:
97                with _maybe_timeout(timeout):
98                    await s.connect(_lltuple(destination, af))
99        except Exception:  # pragma: no cover
100            s.close()
101            raise
102        if socktype == socket.SOCK_DGRAM:
103            return DatagramSocket(s)
104        elif socktype == socket.SOCK_STREAM:
105            stream = trio.SocketStream(s)
106            s = None
107            tls = False
108            if ssl_context:
109                tls = True
110                try:
111                    stream = trio.SSLStream(stream, ssl_context,
112                                            server_hostname=server_hostname)
113                except Exception:  # pragma: no cover
114                    await stream.aclose()
115                    raise
116            return StreamSocket(af, stream, tls)
117        raise NotImplementedError('unsupported socket ' +
118                                  f'type {socktype}')    # pragma: no cover
119
120    async def sleep(self, interval):
121        await trio.sleep(interval)
122