1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license 2 3"""trio async I/O library query support""" 4 5import socket 6import trio 7import trio.socket # type: ignore 8 9import dns._asyncbackend 10import dns.exception 11import dns.inet 12 13 14def _maybe_timeout(timeout): 15 if timeout: 16 return trio.move_on_after(timeout) 17 else: 18 return dns._asyncbackend.NullContext() 19 20 21# for brevity 22_lltuple = dns.inet.low_level_address_tuple 23 24# pylint: disable=redefined-outer-name 25 26 27class DatagramSocket(dns._asyncbackend.DatagramSocket): 28 def __init__(self, socket): 29 self.socket = socket 30 self.family = socket.family 31 32 async def sendto(self, what, destination, timeout): 33 with _maybe_timeout(timeout): 34 return await self.socket.sendto(what, destination) 35 raise dns.exception.Timeout(timeout=timeout) # pragma: no cover 36 37 async def recvfrom(self, size, timeout): 38 with _maybe_timeout(timeout): 39 return await self.socket.recvfrom(size) 40 raise dns.exception.Timeout(timeout=timeout) 41 42 async def close(self): 43 self.socket.close() 44 45 async def getpeername(self): 46 return self.socket.getpeername() 47 48 async def getsockname(self): 49 return self.socket.getsockname() 50 51 52class StreamSocket(dns._asyncbackend.DatagramSocket): 53 def __init__(self, family, stream, tls=False): 54 self.family = family 55 self.stream = stream 56 self.tls = tls 57 58 async def sendall(self, what, timeout): 59 with _maybe_timeout(timeout): 60 return await self.stream.send_all(what) 61 raise dns.exception.Timeout(timeout=timeout) 62 63 async def recv(self, size, timeout): 64 with _maybe_timeout(timeout): 65 return await self.stream.receive_some(size) 66 raise dns.exception.Timeout(timeout=timeout) 67 68 async def close(self): 69 await self.stream.aclose() 70 71 async def getpeername(self): 72 if self.tls: 73 return self.stream.transport_stream.socket.getpeername() 74 else: 75 return self.stream.socket.getpeername() 76 77 async def getsockname(self): 78 if self.tls: 79 return self.stream.transport_stream.socket.getsockname() 80 else: 81 return self.stream.socket.getsockname() 82 83 84class Backend(dns._asyncbackend.Backend): 85 def name(self): 86 return 'trio' 87 88 async def make_socket(self, af, socktype, proto=0, source=None, 89 destination=None, timeout=None, 90 ssl_context=None, server_hostname=None): 91 s = trio.socket.socket(af, socktype, proto) 92 stream = None 93 try: 94 if source: 95 await s.bind(_lltuple(source, af)) 96 if socktype == socket.SOCK_STREAM: 97 with _maybe_timeout(timeout): 98 await s.connect(_lltuple(destination, af)) 99 except Exception: # pragma: no cover 100 s.close() 101 raise 102 if socktype == socket.SOCK_DGRAM: 103 return DatagramSocket(s) 104 elif socktype == socket.SOCK_STREAM: 105 stream = trio.SocketStream(s) 106 s = None 107 tls = False 108 if ssl_context: 109 tls = True 110 try: 111 stream = trio.SSLStream(stream, ssl_context, 112 server_hostname=server_hostname) 113 except Exception: # pragma: no cover 114 await stream.aclose() 115 raise 116 return StreamSocket(af, stream, tls) 117 raise NotImplementedError('unsupported socket ' + 118 f'type {socktype}') # pragma: no cover 119 120 async def sleep(self, interval): 121 await trio.sleep(interval) 122