1#----------------------------------------------------------------------------- 2# Copyright (c) 2012 - 2021, Anaconda, Inc., and Bokeh Contributors. 3# All rights reserved. 4# 5# The full license is in the file LICENSE.txt, distributed with this software. 6#----------------------------------------------------------------------------- 7''' Provide some utility functions useful for implementing different 8components in ``bokeh.server``. 9 10''' 11 12#----------------------------------------------------------------------------- 13# Boilerplate 14#----------------------------------------------------------------------------- 15import logging # isort:skip 16log = logging.getLogger(__name__) 17 18#----------------------------------------------------------------------------- 19# Imports 20#----------------------------------------------------------------------------- 21 22# External imports 23from tornado import netutil 24 25#----------------------------------------------------------------------------- 26# Globals and constants 27#----------------------------------------------------------------------------- 28 29__all__ = ( 30 'bind_sockets', 31 'check_allowlist', 32 'create_hosts_allowlist', 33 'match_host', 34) 35 36#----------------------------------------------------------------------------- 37# General API 38#----------------------------------------------------------------------------- 39 40def bind_sockets(address, port): 41 ''' Bind a socket to a port on an address. 42 43 Args: 44 address (str) : 45 An address to bind a port on, e.g. ``"localhost"`` 46 47 port (int) : 48 A port number to bind. 49 50 Pass 0 to have the OS automatically choose a free port. 51 52 This function returns a 2-tuple with the new socket as the first element, 53 and the port that was bound as the second. (Useful when passing 0 as a port 54 number to bind any free port.) 55 56 Returns: 57 (socket, port) 58 59 ''' 60 ss = netutil.bind_sockets(port=port or 0, address=address) 61 assert len(ss) 62 ports = {s.getsockname()[1] for s in ss} 63 assert len(ports) == 1, "Multiple ports assigned??" 64 actual_port = ports.pop() 65 if port: 66 assert actual_port == port 67 return ss, actual_port 68 69def check_allowlist(host, allowlist): 70 ''' Check a given request host against a allowlist. 71 72 Args: 73 host (str) : 74 A host string to compare against a allowlist. 75 76 If the host does not specify a port, then ``":80"`` is implicitly 77 assumed. 78 79 allowlist (seq[str]) : 80 A list of host patterns to match against 81 82 Returns: 83 ``True``, if ``host`` matches any pattern in ``allowlist``, otherwise 84 ``False`` 85 86 ''' 87 if ':' not in host: 88 host = host + ':80' 89 90 if host in allowlist: 91 return True 92 93 return any(match_host(host, pattern) for pattern in allowlist) 94 95def create_hosts_allowlist(host_list, port): 96 ''' 97 98 This allowlist can be used to restrict websocket or other connections to 99 only those explicitly originating from approved hosts. 100 101 Args: 102 host_list (seq[str]) : 103 A list of string `<name>` or `<name>:<port>` values to add to the 104 allowlist. 105 106 If no port is specified in a host string, then ``":80"`` is 107 implicitly assumed. 108 109 port (int) : 110 If ``host_list`` is empty or ``None``, then the allowlist will 111 be the single item list `` [ 'localhost:<port>' ]`` 112 113 If ``host_list`` is not empty, this parameter has no effect. 114 115 Returns: 116 list[str] 117 118 Raises: 119 ValueError, if host or port values are invalid 120 121 Note: 122 If any host in ``host_list`` contains a wildcard ``*`` a warning will 123 be logged regarding permissive websocket connections. 124 125 ''' 126 if not host_list: 127 return ['localhost:' + str(port)] 128 129 hosts = [] 130 for host in host_list: 131 if '*' in host: 132 log.warning( 133 "Host wildcard %r will allow connections originating " 134 "from multiple (or possibly all) hostnames or IPs. Use non-wildcard " 135 "values to restrict access explicitly", host) 136 if host == '*': 137 # do not append the :80 port suffix in that case: any port is 138 # accepted 139 hosts.append(host) 140 continue 141 parts = host.split(':') 142 if len(parts) == 1: 143 if parts[0] == "": 144 raise ValueError("Empty host value") 145 hosts.append(host+":80") 146 elif len(parts) == 2: 147 try: 148 int(parts[1]) 149 except ValueError: 150 raise ValueError("Invalid port in host value: %s" % host) 151 if parts[0] == "": 152 raise ValueError("Empty host value") 153 hosts.append(host) 154 else: 155 raise ValueError("Invalid host value: %s" % host) 156 return hosts 157 158def match_host(host, pattern): 159 ''' Match a host string against a pattern 160 161 Args: 162 host (str) 163 A hostname to compare to the given pattern 164 165 pattern (str) 166 A string representing a hostname pattern, possibly including 167 wildcards for ip address octets or ports. 168 169 This function will return ``True`` if the hostname matches the pattern, 170 including any wildcards. If the pattern contains a port, the host string 171 must also contain a matching port. 172 173 Returns: 174 bool 175 176 Examples: 177 178 >>> match_host('192.168.0.1:80', '192.168.0.1:80') 179 True 180 >>> match_host('192.168.0.1:80', '192.168.0.1') 181 True 182 >>> match_host('192.168.0.1:80', '192.168.0.1:8080') 183 False 184 >>> match_host('192.168.0.1', '192.168.0.2') 185 False 186 >>> match_host('192.168.0.1', '192.168.*.*') 187 True 188 >>> match_host('alice', 'alice') 189 True 190 >>> match_host('alice:80', 'alice') 191 True 192 >>> match_host('alice', 'bob') 193 False 194 >>> match_host('foo.example.com', 'foo.example.com.net') 195 False 196 >>> match_host('alice', '*') 197 True 198 >>> match_host('alice', '*:*') 199 True 200 >>> match_host('alice:80', '*') 201 True 202 >>> match_host('alice:80', '*:80') 203 True 204 >>> match_host('alice:8080', '*:80') 205 False 206 207 ''' 208 if ':' in host: 209 host, host_port = host.rsplit(':', 1) 210 else: 211 host_port = None 212 213 if ':' in pattern: 214 pattern, pattern_port = pattern.rsplit(':', 1) 215 if pattern_port == '*': 216 pattern_port = None 217 else: 218 pattern_port = None 219 220 if pattern_port is not None and host_port != pattern_port: 221 return False 222 223 host = host.split('.') 224 pattern = pattern.split('.') 225 226 if len(pattern) > len(host): 227 return False 228 229 for h, p in zip(host, pattern): 230 if h == p or p == '*': 231 continue 232 else: 233 return False 234 return True 235 236#----------------------------------------------------------------------------- 237# Dev API 238#----------------------------------------------------------------------------- 239 240#----------------------------------------------------------------------------- 241# Private API 242#----------------------------------------------------------------------------- 243 244#----------------------------------------------------------------------------- 245# Code 246#----------------------------------------------------------------------------- 247