1#-----------------------------------------------------------------------------
2# Copyright (c) 2012 - 2021, Anaconda, Inc., and Bokeh Contributors.
3# All rights reserved.
4#
5# The full license is in the file LICENSE.txt, distributed with this software.
6#-----------------------------------------------------------------------------
7''' Provide some utility functions useful for implementing different
8components in ``bokeh.server``.
9
10'''
11
12#-----------------------------------------------------------------------------
13# Boilerplate
14#-----------------------------------------------------------------------------
15import logging # isort:skip
16log = logging.getLogger(__name__)
17
18#-----------------------------------------------------------------------------
19# Imports
20#-----------------------------------------------------------------------------
21
22# External imports
23from tornado import netutil
24
25#-----------------------------------------------------------------------------
26# Globals and constants
27#-----------------------------------------------------------------------------
28
29__all__ = (
30    'bind_sockets',
31    'check_allowlist',
32    'create_hosts_allowlist',
33    'match_host',
34)
35
36#-----------------------------------------------------------------------------
37# General API
38#-----------------------------------------------------------------------------
39
40def bind_sockets(address, port):
41    ''' Bind a socket to a port on an address.
42
43    Args:
44        address (str) :
45            An address to bind a port on, e.g. ``"localhost"``
46
47        port (int) :
48            A port number to bind.
49
50            Pass 0 to have the OS automatically choose a free port.
51
52    This function returns a 2-tuple with the new socket as the first element,
53    and the port that was bound as the second. (Useful when passing 0 as a port
54    number to bind any free port.)
55
56    Returns:
57        (socket, port)
58
59    '''
60    ss = netutil.bind_sockets(port=port or 0, address=address)
61    assert len(ss)
62    ports = {s.getsockname()[1] for s in ss}
63    assert len(ports) == 1, "Multiple ports assigned??"
64    actual_port = ports.pop()
65    if port:
66        assert actual_port == port
67    return ss, actual_port
68
69def check_allowlist(host, allowlist):
70    ''' Check a given request host against a allowlist.
71
72    Args:
73        host (str) :
74            A host string to compare against a allowlist.
75
76            If the host does not specify a port, then ``":80"`` is implicitly
77            assumed.
78
79        allowlist (seq[str]) :
80            A list of host patterns to match against
81
82    Returns:
83        ``True``, if ``host`` matches any pattern in ``allowlist``, otherwise
84        ``False``
85
86     '''
87    if ':' not in host:
88        host = host + ':80'
89
90    if host in allowlist:
91        return True
92
93    return any(match_host(host, pattern) for pattern in allowlist)
94
95def create_hosts_allowlist(host_list, port):
96    '''
97
98    This allowlist can be used to restrict websocket or other connections to
99    only those explicitly originating from approved hosts.
100
101    Args:
102        host_list (seq[str]) :
103            A list of string `<name>` or `<name>:<port>` values to add to the
104            allowlist.
105
106            If no port is specified in a host string, then ``":80"``  is
107            implicitly assumed.
108
109        port (int) :
110            If ``host_list`` is empty or ``None``, then the allowlist will
111            be the single item list `` [ 'localhost:<port>' ]``
112
113            If ``host_list`` is not empty, this parameter has no effect.
114
115    Returns:
116        list[str]
117
118    Raises:
119        ValueError, if host or port values are invalid
120
121    Note:
122        If any host in ``host_list`` contains a wildcard ``*`` a warning will
123        be logged regarding permissive websocket connections.
124
125    '''
126    if not host_list:
127        return ['localhost:' + str(port)]
128
129    hosts = []
130    for host in host_list:
131        if '*' in host:
132            log.warning(
133                "Host wildcard %r will allow connections originating "
134                "from multiple (or possibly all) hostnames or IPs. Use non-wildcard "
135                "values to restrict access explicitly", host)
136        if host == '*':
137            # do not append the :80 port suffix in that case: any port is
138            # accepted
139            hosts.append(host)
140            continue
141        parts = host.split(':')
142        if len(parts) == 1:
143            if parts[0] == "":
144                raise ValueError("Empty host value")
145            hosts.append(host+":80")
146        elif len(parts) == 2:
147            try:
148                int(parts[1])
149            except ValueError:
150                raise ValueError("Invalid port in host value: %s" % host)
151            if parts[0] == "":
152                raise ValueError("Empty host value")
153            hosts.append(host)
154        else:
155            raise ValueError("Invalid host value: %s" % host)
156    return hosts
157
158def match_host(host, pattern):
159    ''' Match a host string against a pattern
160
161    Args:
162        host (str)
163            A hostname to compare to the given pattern
164
165        pattern (str)
166            A string representing a hostname pattern, possibly including
167            wildcards for ip address octets or ports.
168
169    This function will return ``True`` if the hostname matches the pattern,
170    including any wildcards. If the pattern contains a port, the host string
171    must also contain a matching port.
172
173    Returns:
174        bool
175
176    Examples:
177
178        >>> match_host('192.168.0.1:80', '192.168.0.1:80')
179        True
180        >>> match_host('192.168.0.1:80', '192.168.0.1')
181        True
182        >>> match_host('192.168.0.1:80', '192.168.0.1:8080')
183        False
184        >>> match_host('192.168.0.1', '192.168.0.2')
185        False
186        >>> match_host('192.168.0.1', '192.168.*.*')
187        True
188        >>> match_host('alice', 'alice')
189        True
190        >>> match_host('alice:80', 'alice')
191        True
192        >>> match_host('alice', 'bob')
193        False
194        >>> match_host('foo.example.com', 'foo.example.com.net')
195        False
196        >>> match_host('alice', '*')
197        True
198        >>> match_host('alice', '*:*')
199        True
200        >>> match_host('alice:80', '*')
201        True
202        >>> match_host('alice:80', '*:80')
203        True
204        >>> match_host('alice:8080', '*:80')
205        False
206
207    '''
208    if ':' in host:
209        host, host_port = host.rsplit(':', 1)
210    else:
211        host_port = None
212
213    if ':' in pattern:
214        pattern, pattern_port = pattern.rsplit(':', 1)
215        if pattern_port == '*':
216            pattern_port = None
217    else:
218        pattern_port = None
219
220    if pattern_port is not None and host_port != pattern_port:
221        return False
222
223    host = host.split('.')
224    pattern = pattern.split('.')
225
226    if len(pattern) > len(host):
227        return False
228
229    for h, p in zip(host, pattern):
230        if h == p or p == '*':
231            continue
232        else:
233            return False
234    return True
235
236#-----------------------------------------------------------------------------
237# Dev API
238#-----------------------------------------------------------------------------
239
240#-----------------------------------------------------------------------------
241# Private API
242#-----------------------------------------------------------------------------
243
244#-----------------------------------------------------------------------------
245# Code
246#-----------------------------------------------------------------------------
247