1# Copyright (c) 2018, Neil Booth
2#
3# All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining
8# a copy of this software and associated documentation files (the
9# "Software"), to deal in the Software without restriction, including
10# without limitation the rights to use, copy, modify, merge, publish,
11# distribute, sublicense, and/or sell copies of the Software, and to
12# permit persons to whom the Software is furnished to do so, subject to
13# the following conditions:
14#
15# The above copyright notice and this permission notice shall be
16# included in all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
22# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
24# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25
26__all__ = ('instantiate_coroutine', 'is_valid_hostname', 'classify_host',
27           'validate_port', 'validate_protocol', 'Service', 'ServicePart', 'NetAddress')
28
29
30import asyncio
31from collections import namedtuple
32from enum import IntEnum
33from functools import partial
34import inspect
35from ipaddress import ip_address, IPv4Address, IPv6Address
36import re
37
38
39# See http://stackoverflow.com/questions/2532053/validate-a-hostname-string
40# Note underscores are valid in domain names, but strictly invalid in host
41# names.  We ignore that distinction.
42PROTOCOL_REGEX = re.compile('[A-Za-z][A-Za-z0-9+-.]+$')
43LABEL_REGEX = re.compile('^[a-z0-9_]([a-z0-9-_]{0,61}[a-z0-9_])?$', re.IGNORECASE)
44NUMERIC_REGEX = re.compile('[0-9]+$')
45
46
47def is_valid_hostname(hostname):
48    '''Return True if hostname is valid, otherwise False.'''
49    if not isinstance(hostname, str):
50        raise TypeError('hostname must be a string')
51    # strip exactly one dot from the right, if present
52    if hostname and hostname[-1] == ".":
53        hostname = hostname[:-1]
54    if not hostname or len(hostname) > 253:
55        return False
56    labels = hostname.split('.')
57    # the TLD must be not all-numeric
58    if re.match(NUMERIC_REGEX, labels[-1]):
59        return False
60    return all(LABEL_REGEX.match(label) for label in labels)
61
62
63def classify_host(host):
64    '''Host is an IPv4Address, IPv6Address or a string.
65
66    If an IPv4Address or IPv6Address return it.  Otherwise convert the string to an
67    IPv4Address or IPv6Address object if possible and return it.  Otherwise return the
68    original string if it is a valid hostname.
69
70    Raise ValueError if a string cannot be interpreted as an IP address and it is not
71    a valid hostname.
72    '''
73    if isinstance(host, (IPv4Address, IPv6Address)):
74        return host
75    if is_valid_hostname(host):
76        return host
77    return ip_address(host)
78
79
80def validate_port(port):
81    '''Validate port and return it as an integer.
82
83    A string, or its representation as an integer, is accepted.'''
84    if not isinstance(port, (str, int)):
85        raise TypeError(f'port must be an integer or string: {port}')
86    if isinstance(port, str) and port.isdigit():
87        port = int(port)
88    if isinstance(port, int) and 0 < port <= 65535:
89        return port
90    raise ValueError(f'invalid port: {port}')
91
92
93def validate_protocol(protocol):
94    '''Validate a protocol, a string, and return it.'''
95    if not re.match(PROTOCOL_REGEX, protocol):
96        raise ValueError(f'invalid protocol: {protocol}')
97    return protocol.lower()
98
99
100class ServicePart(IntEnum):
101    PROTOCOL = 0
102    HOST = 1
103    PORT = 2
104
105
106def _split_address(string):
107    if string.startswith('['):
108        end = string.find(']')
109        if end != -1:
110            if len(string) == end + 1:
111                return string[1:end], ''
112            if string[end + 1] == ':':
113                return string[1:end], string[end+2:]
114    colon = string.find(':')
115    if colon == -1:
116        return string, ''
117    return string[:colon], string[colon + 1:]
118
119
120class NetAddress:
121
122    def __init__(self, host, port):
123        '''Construct a NetAddress from a host and a port.
124
125        Host is classified and port is an integer.'''
126        self._host = classify_host(host)
127        self._port = validate_port(port)
128
129    def __eq__(self, other):
130        # pylint: disable=protected-access
131        return self._host == other._host and self._port == other._port
132
133    def __hash__(self):
134        return hash((self._host, self._port))
135
136    @classmethod
137    def from_string(cls, string, *, default_func=None):
138        '''Construct a NetAddress from a string and return a (host, port) pair.
139
140        If either (or both) is missing and default_func is provided, it is called with
141        ServicePart.HOST or ServicePart.PORT to get a default.
142        '''
143        if not isinstance(string, str):
144            raise TypeError(f'address must be a string: {string}')
145        host, port = _split_address(string)
146        if default_func:
147            host = host or default_func(ServicePart.HOST)
148            port = port or default_func(ServicePart.PORT)
149            if not host or not port:
150                raise ValueError(f'invalid address string: {string}')
151        return cls(host, port)
152
153    @property
154    def host(self):
155        return self._host
156
157    @property
158    def port(self):
159        return self._port
160
161    def __str__(self):
162        if isinstance(self._host, IPv6Address):
163            return f'[{self._host}]:{self._port}'
164        return f'{self.host}:{self.port}'
165
166    def __repr__(self):
167        return f'NetAddress({self.host!r}, {self.port})'
168
169    @classmethod
170    def default_host_and_port(cls, host, port):
171        def func(kind):
172            return host if kind == ServicePart.HOST else port
173        return func
174
175    @classmethod
176    def default_host(cls, host):
177        return cls.default_host_and_port(host, None)
178
179    @classmethod
180    def default_port(cls, port):
181        return cls.default_host_and_port(None, port)
182
183
184class Service:
185    '''A validated protocol, address pair.'''
186
187    def __init__(self, protocol, address):
188        '''Construct a service from a protocol string and a NetAddress object,'''
189        self._protocol = validate_protocol(protocol)
190        if not isinstance(address, NetAddress):
191            address = NetAddress.from_string(address)
192        self._address = address
193
194    def __eq__(self, other):
195        # pylint: disable=protected-access
196        return self._protocol == other._protocol and self._address == other._address
197
198    def __hash__(self):
199        return hash((self._protocol, self._address))
200
201    @property
202    def protocol(self):
203        return self._protocol
204
205    @property
206    def address(self):
207        return self._address
208
209    @property
210    def host(self):
211        return self._address.host
212
213    @property
214    def port(self):
215        return self._address.port
216
217    @classmethod
218    def from_string(cls, string, *, default_func=None):
219        '''Construct a Service from a string.
220
221        If default_func is provided and any ServicePart is missing, it is called with
222        default_func(protocol, part) to obtain the missing part.
223        '''
224        if not isinstance(string, str):
225            raise TypeError(f'service must be a string: {string}')
226
227        parts = string.split('://', 1)
228        if len(parts) == 2:
229            protocol, address = parts
230        else:
231            item, = parts
232            protocol = None
233            if default_func:
234                if default_func(item, ServicePart.HOST) and default_func(item, ServicePart.PORT):
235                    protocol, address = item, ''
236                else:
237                    protocol, address = default_func(None, ServicePart.PROTOCOL), item
238            if not protocol:
239                raise ValueError(f'invalid service string: {string}')
240
241        if default_func:
242            default_func = partial(default_func, protocol.lower())
243        address = NetAddress.from_string(address, default_func=default_func)
244        return cls(protocol, address)
245
246    def __str__(self):
247        return f'{self._protocol}://{self._address}'
248
249    def __repr__(self):
250        return f"Service({self._protocol!r}, '{self._address}')"
251
252
253def instantiate_coroutine(corofunc, args):
254    if asyncio.iscoroutine(corofunc):
255        if args != ():
256            raise ValueError('args cannot be passed with a coroutine')
257        return corofunc
258    return corofunc(*args)
259
260
261def is_async_call(func):
262    '''inspect.iscoroutinefunction that looks through partials.'''
263    while isinstance(func, partial):
264        func = func.func
265    return inspect.iscoroutinefunction(func)
266
267
268# other_params: None means cannot be called with keyword arguments only
269# any means any name is good
270SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args '
271                           'required_names other_names')
272
273
274def signature_info(func):
275    params = inspect.signature(func).parameters
276    min_args = max_args = 0
277    required_names = []
278    other_names = []
279    no_names = False
280    for p in params.values():
281        if p.kind == p.POSITIONAL_OR_KEYWORD:
282            max_args += 1
283            if p.default is p.empty:
284                min_args += 1
285                required_names.append(p.name)
286            else:
287                other_names.append(p.name)
288        elif p.kind == p.KEYWORD_ONLY:
289            other_names.append(p.name)
290        elif p.kind == p.VAR_POSITIONAL:
291            max_args = None
292        elif p.kind == p.VAR_KEYWORD:
293            other_names = any
294        elif p.kind == p.POSITIONAL_ONLY:
295            max_args += 1
296            if p.default is p.empty:
297                min_args += 1
298            no_names = True
299
300    if no_names:
301        other_names = None
302
303    return SignatureInfo(min_args, max_args, required_names, other_names)
304
305
306def check_task(logger, task):
307    if not task.cancelled():
308        try:
309            task.result()
310        except Exception:   # pylint: disable=broad-except
311            logger.error('task crashed: %r', task, exc_info=True)
312