1#!/usr/bin/env python 2# Copyright 2016 The LUCI Authors. All rights reserved. 3# Use of this source code is governed under the Apache License, Version 2.0 4# that can be found in the LICENSE file. 5 6import base64 7import collections 8import json 9import logging 10import os 11import re 12import sys 13import threading 14import time 15 16import six 17from six.moves import BaseHTTPServer 18from six.moves import socketserver 19 20# OAuth access token with its expiration time. 21AccessToken = collections.namedtuple('AccessToken', [ 22 'access_token', # urlsafe str with the token 23 'expiry', # expiration time as unix timestamp in seconds 24]) 25 26 27class TokenError(Exception): 28 """Raised by TokenProvider if the token can't be created (fatal error). 29 30 See TokenProvider docs for more info. 31 """ 32 33 def __init__(self, code, msg): 34 super(TokenError, self).__init__(msg) 35 self.code = code 36 37 38class RPCError(Exception): 39 """Raised by LocalAuthServer RPC handlers to reply with HTTP error status.""" 40 41 def __init__(self, code, msg): 42 super(RPCError, self).__init__(msg) 43 self.code = code 44 45 46# Account describes one logical account. 47Account = collections.namedtuple('Account', ['id', 'email']) 48 49 50class TokenProvider(object): 51 """Interface for an object that can create OAuth tokens on demand. 52 53 Defined as a concrete class only for documentation purposes. 54 """ 55 56 def generate_token(self, account_id, scopes): 57 """Generates a new access token with given scopes. 58 59 Will be called from multiple threads (possibly concurrently) whenever 60 LocalAuthServer needs to refresh a token with particular scopes. 61 62 Can rise RPCError exceptions. They will be immediately converted to 63 corresponding RPC error replies (e.g. HTTP 500). This is appropriate for 64 low-level or transient errors. 65 66 Can also raise TokenError. It will be converted to GetOAuthToken reply with 67 non-zero error_code. It will also be cached, so that the provider would 68 never be called again for the same set of scopes. This is appropriate for 69 high-level fatal errors. 70 71 Returns AccessToken on success. 72 """ 73 raise NotImplementedError() 74 75 76class LocalAuthServer(object): 77 """LocalAuthServer handles /rpc/LuciLocalAuthService.* requests. 78 79 It exposes an HTTP JSON RPC API that is used by task processes to grab an 80 access token for the service account associated with the task. 81 82 It implements RPC handling details and in-memory cache for the tokens, but 83 defers to the supplied TokenProvider for the actual token generation. 84 """ 85 86 def __init__(self): 87 self._lock = threading.Lock() # guards everything below 88 self._accept_thread = None 89 self._cache = {} # dict ((account_id, scopes) => AccessToken | TokenError). 90 self._token_provider = None 91 self._accounts = frozenset() # set of Account tuples 92 self._rpc_secret = None 93 self._server = None 94 95 def start(self, token_provider, accounts, default_account_id, port=0): 96 """Starts the local auth RPC server on some 127.0.0.1 port. 97 98 Args: 99 token_provider: instance of TokenProvider to use for making tokens. 100 accounts: a list of Account tuples to allow getting a token for. 101 default_account_id: goes directly into LUCI_CONTEXT['local_auth']. 102 port: local TCP port to bind to, or 0 to bind to any available port. 103 104 Returns: 105 A dict to put into 'local_auth' section of LUCI_CONTEXT. 106 """ 107 assert all(isinstance(acc, Account) for acc in accounts), accounts 108 109 # 'default_account_id' is either not set, or one of the supported accounts. 110 assert ( 111 not default_account_id or 112 any(default_account_id == acc.id for acc in accounts)) 113 114 server = _HTTPServer(self, ('127.0.0.1', port)) 115 116 # This secret will be placed in a file on disk accessible only to current 117 # user processes. RPC requests are expected to send this secret verbatim. 118 # That way we authenticate RPCs as coming from current user's processes. 119 rpc_secret = base64.b64encode(os.urandom(48)).decode('ascii') 120 121 with self._lock: 122 assert not self._server, 'Already running' 123 logging.info('Local auth server: http://127.0.0.1:%d', server.server_port) 124 self._token_provider = token_provider 125 self._accounts = frozenset(accounts) 126 self._rpc_secret = rpc_secret 127 self._server = server 128 self._accept_thread = threading.Thread(target=self._server.serve_forever) 129 self._accept_thread.start() 130 local_auth = { 131 'rpc_port': 132 self._server.server_port, 133 'secret': 134 self._rpc_secret, 135 'accounts': [{ 136 'id': acc.id, 137 'email': acc.email 138 } for acc in sorted(accounts)], 139 } 140 # TODO(vadimsh): Some clients don't understand 'null' value for 141 # default_account_id, so just omit it completely for now. 142 if default_account_id: 143 local_auth['default_account_id'] = default_account_id 144 return local_auth 145 146 def stop(self): 147 """Stops the server and resets the state.""" 148 with self._lock: 149 if not self._server: 150 return 151 server, self._server = self._server, None 152 thread, self._accept_thread = self._accept_thread, None 153 self._token_provider = None 154 self._accounts = frozenset() 155 self._rpc_secret = None 156 self._cache.clear() 157 logging.debug('Stopping the local auth server...') 158 server.shutdown() 159 thread.join() 160 server.server_close() 161 logging.info('The local auth server is stopped') 162 163 def handle_rpc(self, method, request): 164 """Called by _RequestHandler to handle one RPC call. 165 166 Called from internal server thread. May be called even if the server is 167 already stopped (due to BaseHTTPServer.HTTPServer implementation that 168 stupidly leaks handler threads). 169 170 Args: 171 method: name of the invoked RPC method, e.g. "GetOAuthToken". 172 request: JSON dict with the request body. 173 174 Returns: 175 JSON dict with the response body. 176 177 Raises: 178 RPCError to return non-200 HTTP code and an error message as plain text. 179 """ 180 if method == 'GetOAuthToken': 181 return self.handle_get_oauth_token(request) 182 raise RPCError(404, 'Unknown RPC method "%s".' % method) 183 184 ### RPC method handlers. Called from internal threads. 185 186 def handle_get_oauth_token(self, request): 187 """Returns an OAuth token representing the task service account. 188 189 The returned token is usable for at least 1 min. 190 191 Request body: 192 { 193 "account_id": <str>, 194 "scopes": [<str scope1>, <str scope2>, ...], 195 "secret": <str from LUCI_CONTEXT.local_auth.secret>, 196 } 197 198 Response body: 199 { 200 "error_code": <int, 0 or missing on success>, 201 "error_message": <str, optional>, 202 "access_token": <str with actual token (on success)>, 203 "expiry": <int with unix timestamp in seconds (on success)> 204 } 205 """ 206 # Logical account to get a token for (e.g. "task" or "system"). 207 account_id = request.get('account_id') 208 if not account_id: 209 raise RPCError(400, 'Field "account_id" is required.') 210 if not isinstance(account_id, six.string_types): 211 raise RPCError(400, 'Field "account_id" must be a string') 212 account_id = str(account_id) 213 214 # Validate scopes. It is conceptually a set, so remove duplicates. 215 scopes = request.get('scopes') 216 if not scopes: 217 raise RPCError(400, 'Field "scopes" is required.') 218 if (not isinstance(scopes, list) or 219 not all(isinstance(s, six.string_types) for s in scopes)): 220 raise RPCError(400, 'Field "scopes" must be a list of strings.') 221 scopes = tuple(sorted(set(map(str, scopes)))) 222 223 # Validate the secret format. 224 secret = request.get('secret') 225 if not secret: 226 raise RPCError(400, 'Field "secret" is required.') 227 if not isinstance(secret, six.string_types): 228 raise RPCError(400, 'Field "secret" must be a string.') 229 secret = str(secret) 230 231 # Grab the state from the lock-guarded state. 232 with self._lock: 233 if not self._server: 234 raise RPCError(503, 'Stopped already.') 235 rpc_secret = self._rpc_secret 236 accounts = self._accounts 237 token_provider = self._token_provider 238 239 # Use constant time check to prevent malicious processes from discovering 240 # the secret byte-by-byte measuring response time. 241 if not constant_time_equals(secret, rpc_secret): 242 raise RPCError(403, 'Invalid "secret".') 243 244 # Make sure we know about the requested account. 245 if not any(account_id == acc.id for acc in accounts): 246 raise RPCError(404, 'Unrecognized account ID %r.' % account_id) 247 248 # Grab the token (or a fatal error) from the memory cache, checks token 249 # expiration time. 250 cache_key = (account_id, scopes) 251 tok_or_err = None 252 need_refresh = False 253 with self._lock: 254 if not self._server: 255 raise RPCError(503, 'Stopped already.') 256 tok_or_err = self._cache.get(cache_key) 257 need_refresh = ( 258 not tok_or_err or 259 isinstance(tok_or_err, AccessToken) and should_refresh(tok_or_err)) 260 261 # Do the refresh outside of the RPC server lock to unblock other clients 262 # that are hitting the cache. The token provider should implement its own 263 # synchronization. 264 if need_refresh: 265 try: 266 tok_or_err = token_provider.generate_token(account_id, scopes) 267 assert isinstance(tok_or_err, AccessToken), tok_or_err 268 except TokenError as exc: 269 tok_or_err = exc 270 # Cache the token or fatal errors (to avoid useless retry later). 271 with self._lock: 272 if not self._server: 273 raise RPCError(503, 'Stopped already.') 274 self._cache[cache_key] = tok_or_err 275 276 # Done. 277 if isinstance(tok_or_err, AccessToken): 278 return { 279 'access_token': tok_or_err.access_token, 280 'expiry': int(tok_or_err.expiry), 281 } 282 if isinstance(tok_or_err, TokenError): 283 return { 284 'error_code': tok_or_err.code, 285 'error_message': str(tok_or_err) or 'unknown', 286 } 287 raise AssertionError('impossible') 288 289 290def constant_time_equals(a, b): 291 """Compares two strings in constant time regardless of theirs content.""" 292 if len(a) != len(b): 293 return False 294 result = 0 295 for x, y in zip(a, b): 296 result |= ord(x) ^ ord(y) 297 return result == 0 298 299 300def should_refresh(tok): 301 """Returns True if the token must be refreshed because it expires soon.""" 302 # LUCI_CONTEXT protocol requires that returned tokens are alive for at least 303 # 2.5 min. See LUCI_CONTEXT.md. Add 30 sec extra of leeway. 304 return time.time() > tok.expiry - 3*60 305 306 307class _HTTPServer(socketserver.ThreadingMixIn, BaseHTTPServer.HTTPServer): 308 """Used internally by LocalAuthServer.""" 309 310 # How often to poll 'select' in local HTTP server. 311 # 312 # Defines minimal amount of time 'stop' would block. Overridden in tests to 313 # speed them up. 314 poll_interval = 0.5 315 316 # From socketserver.ThreadingMixIn. 317 daemon_threads = True 318 # From BaseHTTPServer.HTTPServer. 319 request_queue_size = 50 320 321 def __init__(self, local_auth_server, addr): 322 BaseHTTPServer.HTTPServer.__init__(self, addr, _RequestHandler) 323 self.local_auth_server = local_auth_server 324 325 def serve_forever(self, poll_interval=None): 326 """Overrides default poll interval.""" 327 BaseHTTPServer.HTTPServer.serve_forever( 328 self, poll_interval or self.poll_interval) 329 330 def handle_error(self, request, client_address): 331 """Overrides default handle_error that dumbs stuff to stdout.""" 332 logging.exception('local auth server: Exception happened') 333 334 335class _RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): 336 """Used internally by LocalAuthServer. 337 338 Parses the request, serializes and write the response. 339 """ 340 341 # Buffer the reply, no need to send each line separately. 342 wbufsize = -1 343 344 def log_message(self, fmt, *args): 345 """Overrides default log_message to not abuse stderr.""" 346 logging.debug('local auth server: ' + fmt, *args) 347 348 def send_error(self, code, message=None): 349 """Overrides default send_error to send 'text/plain' response.""" 350 assert isinstance(message, str), 'unicode is not allowed' 351 logging.warning('local auth server: HTTP %d - %s', code, message) 352 message = (message or '') + '\n' 353 self.send_response(code) 354 self.send_header('Connection', 'close') 355 self.send_header('Content-Length', str(len(message))) 356 self.send_header('Content-Type', 'text/plain') 357 self.end_headers() 358 self.wfile.write(message.encode('utf-8')) 359 360 def do_POST(self): 361 """Implements POST handler.""" 362 # Parse URL to extract method name. 363 m = re.match(r'^/rpc/LuciLocalAuthService\.([a-zA-Z0-9_]+)$', self.path) 364 if not m: 365 self.send_error(404, 'Expecting /rpc/LuciLocalAuthService.*') 366 return 367 method = m.group(1) 368 369 # The request body MUST be JSON. Ignore charset, we don't care. 370 ct = self.headers.get('content-type') 371 if not ct or ct.split(';')[0] != 'application/json': 372 self.send_error( 373 400, 'Expecting "application/json" Content-Type, got %r' % ct) 374 return 375 376 # Read the body. Chunked transfer encoding or compression is no supported. 377 try: 378 content_len = int(self.headers['content-length']) 379 except ValueError: 380 self.send_error(400, 'Missing on invalid Content-Length header') 381 return 382 try: 383 req = json.loads(self.rfile.read(content_len)) 384 except ValueError as exc: 385 self.send_error(400, 'Not a JSON: %s' % exc) 386 return 387 if not isinstance(req, dict): 388 self.send_error(400, 'Not a JSON dictionary') 389 return 390 391 # Let the LocalAuthServer handle the request. Prepare the response body. 392 try: 393 resp = self.server.local_auth_server.handle_rpc(method, req) 394 response_body = json.dumps(resp) + '\n' 395 except RPCError as exc: 396 self.send_error(exc.code, str(exc)) 397 return 398 except Exception as exc: 399 self.send_error(500, 'Internal error: %s' % exc) 400 return 401 402 # Send the response. 403 self.send_response(200) 404 self.send_header('Connection', 'close') 405 self.send_header('Content-Length', str(len(response_body))) 406 self.send_header('Content-Type', 'application/json') 407 self.end_headers() 408 self.wfile.write(response_body.encode('utf-8')) 409 410 411def testing_main(): 412 """Launches a local HTTP auth service and waits for Ctrl+C. 413 414 Useful during development and manual testing. 415 """ 416 # Don't mess with sys.path outside of adhoc testing. 417 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 418 sys.path.insert(0, ROOT_DIR) 419 from libs import luci_context 420 421 logging.basicConfig(level=logging.DEBUG) 422 423 class DumbProvider(object): 424 def generate_token(self, account_id, scopes): 425 logging.info('generate_token(%r, %r) called', account_id, scopes) 426 return AccessToken('fake_tok_for_%s' % account_id, time.time() + 80) 427 428 server = LocalAuthServer() 429 ctx = server.start( 430 token_provider=DumbProvider(), 431 accounts=[ 432 Account('a', 'a@example.com'), 433 Account('b', 'b@example.com'), 434 Account('c', 'c@example.com'), 435 ], 436 default_account_id='a', 437 port=11111) 438 try: 439 with luci_context.write(local_auth=ctx): 440 print('Copy-paste this into another shell:') 441 print('export LUCI_CONTEXT=%s' % os.getenv('LUCI_CONTEXT')) 442 while True: 443 time.sleep(1) 444 except KeyboardInterrupt: 445 pass 446 finally: 447 server.stop() 448 449 450if __name__ == '__main__': 451 testing_main() 452