1# Copyright (c) 2018, Neil Booth 2# 3# All rights reserved. 4# 5# The MIT License (MIT) 6# 7# Permission is hereby granted, free of charge, to any person obtaining 8# a copy of this software and associated documentation files (the 9# "Software"), to deal in the Software without restriction, including 10# without limitation the rights to use, copy, modify, merge, publish, 11# distribute, sublicense, and/or sell copies of the Software, and to 12# permit persons to whom the Software is furnished to do so, subject to 13# the following conditions: 14# 15# The above copyright notice and this permission notice shall be 16# included in all copies or substantial portions of the Software. 17# 18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 19# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 20# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 21# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 22# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 24# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 25 26__all__ = ('instantiate_coroutine', 'is_valid_hostname', 'classify_host', 27 'validate_port', 'validate_protocol', 'Service', 'ServicePart', 'NetAddress') 28 29 30import asyncio 31from collections import namedtuple 32from enum import IntEnum 33from functools import partial 34import inspect 35from ipaddress import ip_address, IPv4Address, IPv6Address 36import re 37 38 39# See http://stackoverflow.com/questions/2532053/validate-a-hostname-string 40# Note underscores are valid in domain names, but strictly invalid in host 41# names. We ignore that distinction. 42PROTOCOL_REGEX = re.compile('[A-Za-z][A-Za-z0-9+-.]+$') 43LABEL_REGEX = re.compile('^[a-z0-9_]([a-z0-9-_]{0,61}[a-z0-9_])?$', re.IGNORECASE) 44NUMERIC_REGEX = re.compile('[0-9]+$') 45 46 47def is_valid_hostname(hostname): 48 '''Return True if hostname is valid, otherwise False.''' 49 if not isinstance(hostname, str): 50 raise TypeError('hostname must be a string') 51 # strip exactly one dot from the right, if present 52 if hostname and hostname[-1] == ".": 53 hostname = hostname[:-1] 54 if not hostname or len(hostname) > 253: 55 return False 56 labels = hostname.split('.') 57 # the TLD must be not all-numeric 58 if re.match(NUMERIC_REGEX, labels[-1]): 59 return False 60 return all(LABEL_REGEX.match(label) for label in labels) 61 62 63def classify_host(host): 64 '''Host is an IPv4Address, IPv6Address or a string. 65 66 If an IPv4Address or IPv6Address return it. Otherwise convert the string to an 67 IPv4Address or IPv6Address object if possible and return it. Otherwise return the 68 original string if it is a valid hostname. 69 70 Raise ValueError if a string cannot be interpreted as an IP address and it is not 71 a valid hostname. 72 ''' 73 if isinstance(host, (IPv4Address, IPv6Address)): 74 return host 75 if is_valid_hostname(host): 76 return host 77 return ip_address(host) 78 79 80def validate_port(port): 81 '''Validate port and return it as an integer. 82 83 A string, or its representation as an integer, is accepted.''' 84 if not isinstance(port, (str, int)): 85 raise TypeError(f'port must be an integer or string: {port}') 86 if isinstance(port, str) and port.isdigit(): 87 port = int(port) 88 if isinstance(port, int) and 0 < port <= 65535: 89 return port 90 raise ValueError(f'invalid port: {port}') 91 92 93def validate_protocol(protocol): 94 '''Validate a protocol, a string, and return it.''' 95 if not re.match(PROTOCOL_REGEX, protocol): 96 raise ValueError(f'invalid protocol: {protocol}') 97 return protocol.lower() 98 99 100class ServicePart(IntEnum): 101 PROTOCOL = 0 102 HOST = 1 103 PORT = 2 104 105 106def _split_address(string): 107 if string.startswith('['): 108 end = string.find(']') 109 if end != -1: 110 if len(string) == end + 1: 111 return string[1:end], '' 112 if string[end + 1] == ':': 113 return string[1:end], string[end+2:] 114 colon = string.find(':') 115 if colon == -1: 116 return string, '' 117 return string[:colon], string[colon + 1:] 118 119 120class NetAddress: 121 122 def __init__(self, host, port): 123 '''Construct a NetAddress from a host and a port. 124 125 Host is classified and port is an integer.''' 126 self._host = classify_host(host) 127 self._port = validate_port(port) 128 129 def __eq__(self, other): 130 # pylint: disable=protected-access 131 return self._host == other._host and self._port == other._port 132 133 def __hash__(self): 134 return hash((self._host, self._port)) 135 136 @classmethod 137 def from_string(cls, string, *, default_func=None): 138 '''Construct a NetAddress from a string and return a (host, port) pair. 139 140 If either (or both) is missing and default_func is provided, it is called with 141 ServicePart.HOST or ServicePart.PORT to get a default. 142 ''' 143 if not isinstance(string, str): 144 raise TypeError(f'address must be a string: {string}') 145 host, port = _split_address(string) 146 if default_func: 147 host = host or default_func(ServicePart.HOST) 148 port = port or default_func(ServicePart.PORT) 149 if not host or not port: 150 raise ValueError(f'invalid address string: {string}') 151 return cls(host, port) 152 153 @property 154 def host(self): 155 return self._host 156 157 @property 158 def port(self): 159 return self._port 160 161 def __str__(self): 162 if isinstance(self._host, IPv6Address): 163 return f'[{self._host}]:{self._port}' 164 return f'{self.host}:{self.port}' 165 166 def __repr__(self): 167 return f'NetAddress({self.host!r}, {self.port})' 168 169 @classmethod 170 def default_host_and_port(cls, host, port): 171 def func(kind): 172 return host if kind == ServicePart.HOST else port 173 return func 174 175 @classmethod 176 def default_host(cls, host): 177 return cls.default_host_and_port(host, None) 178 179 @classmethod 180 def default_port(cls, port): 181 return cls.default_host_and_port(None, port) 182 183 184class Service: 185 '''A validated protocol, address pair.''' 186 187 def __init__(self, protocol, address): 188 '''Construct a service from a protocol string and a NetAddress object,''' 189 self._protocol = validate_protocol(protocol) 190 if not isinstance(address, NetAddress): 191 address = NetAddress.from_string(address) 192 self._address = address 193 194 def __eq__(self, other): 195 # pylint: disable=protected-access 196 return self._protocol == other._protocol and self._address == other._address 197 198 def __hash__(self): 199 return hash((self._protocol, self._address)) 200 201 @property 202 def protocol(self): 203 return self._protocol 204 205 @property 206 def address(self): 207 return self._address 208 209 @property 210 def host(self): 211 return self._address.host 212 213 @property 214 def port(self): 215 return self._address.port 216 217 @classmethod 218 def from_string(cls, string, *, default_func=None): 219 '''Construct a Service from a string. 220 221 If default_func is provided and any ServicePart is missing, it is called with 222 default_func(protocol, part) to obtain the missing part. 223 ''' 224 if not isinstance(string, str): 225 raise TypeError(f'service must be a string: {string}') 226 227 parts = string.split('://', 1) 228 if len(parts) == 2: 229 protocol, address = parts 230 else: 231 item, = parts 232 protocol = None 233 if default_func: 234 if default_func(item, ServicePart.HOST) and default_func(item, ServicePart.PORT): 235 protocol, address = item, '' 236 else: 237 protocol, address = default_func(None, ServicePart.PROTOCOL), item 238 if not protocol: 239 raise ValueError(f'invalid service string: {string}') 240 241 if default_func: 242 default_func = partial(default_func, protocol.lower()) 243 address = NetAddress.from_string(address, default_func=default_func) 244 return cls(protocol, address) 245 246 def __str__(self): 247 return f'{self._protocol}://{self._address}' 248 249 def __repr__(self): 250 return f"Service({self._protocol!r}, '{self._address}')" 251 252 253def instantiate_coroutine(corofunc, args): 254 if asyncio.iscoroutine(corofunc): 255 if args != (): 256 raise ValueError('args cannot be passed with a coroutine') 257 return corofunc 258 return corofunc(*args) 259 260 261def is_async_call(func): 262 '''inspect.iscoroutinefunction that looks through partials.''' 263 while isinstance(func, partial): 264 func = func.func 265 return inspect.iscoroutinefunction(func) 266 267 268# other_params: None means cannot be called with keyword arguments only 269# any means any name is good 270SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args ' 271 'required_names other_names') 272 273 274def signature_info(func): 275 params = inspect.signature(func).parameters 276 min_args = max_args = 0 277 required_names = [] 278 other_names = [] 279 no_names = False 280 for p in params.values(): 281 if p.kind == p.POSITIONAL_OR_KEYWORD: 282 max_args += 1 283 if p.default is p.empty: 284 min_args += 1 285 required_names.append(p.name) 286 else: 287 other_names.append(p.name) 288 elif p.kind == p.KEYWORD_ONLY: 289 other_names.append(p.name) 290 elif p.kind == p.VAR_POSITIONAL: 291 max_args = None 292 elif p.kind == p.VAR_KEYWORD: 293 other_names = any 294 elif p.kind == p.POSITIONAL_ONLY: 295 max_args += 1 296 if p.default is p.empty: 297 min_args += 1 298 no_names = True 299 300 if no_names: 301 other_names = None 302 303 return SignatureInfo(min_args, max_args, required_names, other_names) 304 305 306def check_task(logger, task): 307 if not task.cancelled(): 308 try: 309 task.result() 310 except Exception: # pylint: disable=broad-except 311 logger.error('task crashed: %r', task, exc_info=True) 312