1#!/usr/bin/env python
2# Copyright 2016 The LUCI Authors. All rights reserved.
3# Use of this source code is governed under the Apache License, Version 2.0
4# that can be found in the LICENSE file.
5
6import base64
7import collections
8import json
9import logging
10import os
11import re
12import sys
13import threading
14import time
15
16import six
17from six.moves import BaseHTTPServer
18from six.moves import socketserver
19
20# OAuth access token with its expiration time.
21AccessToken = collections.namedtuple('AccessToken', [
22  'access_token',  # urlsafe str with the token
23  'expiry',        # expiration time as unix timestamp in seconds
24])
25
26
27class TokenError(Exception):
28  """Raised by TokenProvider if the token can't be created (fatal error).
29
30  See TokenProvider docs for more info.
31  """
32
33  def __init__(self, code, msg):
34    super(TokenError, self).__init__(msg)
35    self.code = code
36
37
38class RPCError(Exception):
39  """Raised by LocalAuthServer RPC handlers to reply with HTTP error status."""
40
41  def __init__(self, code, msg):
42    super(RPCError, self).__init__(msg)
43    self.code = code
44
45
46# Account describes one logical account.
47Account = collections.namedtuple('Account', ['id', 'email'])
48
49
50class TokenProvider(object):
51  """Interface for an object that can create OAuth tokens on demand.
52
53  Defined as a concrete class only for documentation purposes.
54  """
55
56  def generate_token(self, account_id, scopes):
57    """Generates a new access token with given scopes.
58
59    Will be called from multiple threads (possibly concurrently) whenever
60    LocalAuthServer needs to refresh a token with particular scopes.
61
62    Can rise RPCError exceptions. They will be immediately converted to
63    corresponding RPC error replies (e.g. HTTP 500). This is appropriate for
64    low-level or transient errors.
65
66    Can also raise TokenError. It will be converted to GetOAuthToken reply with
67    non-zero error_code. It will also be cached, so that the provider would
68    never be called again for the same set of scopes. This is appropriate for
69    high-level fatal errors.
70
71    Returns AccessToken on success.
72    """
73    raise NotImplementedError()
74
75
76class LocalAuthServer(object):
77  """LocalAuthServer handles /rpc/LuciLocalAuthService.* requests.
78
79  It exposes an HTTP JSON RPC API that is used by task processes to grab an
80  access token for the service account associated with the task.
81
82  It implements RPC handling details and in-memory cache for the tokens, but
83  defers to the supplied TokenProvider for the actual token generation.
84  """
85
86  def __init__(self):
87    self._lock = threading.Lock() # guards everything below
88    self._accept_thread = None
89    self._cache = {}  # dict ((account_id, scopes) => AccessToken | TokenError).
90    self._token_provider = None
91    self._accounts = frozenset()  # set of Account tuples
92    self._rpc_secret = None
93    self._server = None
94
95  def start(self, token_provider, accounts, default_account_id, port=0):
96    """Starts the local auth RPC server on some 127.0.0.1 port.
97
98    Args:
99      token_provider: instance of TokenProvider to use for making tokens.
100      accounts: a list of Account tuples to allow getting a token for.
101      default_account_id: goes directly into LUCI_CONTEXT['local_auth'].
102      port: local TCP port to bind to, or 0 to bind to any available port.
103
104    Returns:
105      A dict to put into 'local_auth' section of LUCI_CONTEXT.
106    """
107    assert all(isinstance(acc, Account) for acc in accounts), accounts
108
109    # 'default_account_id' is either not set, or one of the supported accounts.
110    assert (
111        not default_account_id or
112        any(default_account_id == acc.id for acc in accounts))
113
114    server = _HTTPServer(self, ('127.0.0.1', port))
115
116    # This secret will be placed in a file on disk accessible only to current
117    # user processes. RPC requests are expected to send this secret verbatim.
118    # That way we authenticate RPCs as coming from current user's processes.
119    rpc_secret = base64.b64encode(os.urandom(48)).decode('ascii')
120
121    with self._lock:
122      assert not self._server, 'Already running'
123      logging.info('Local auth server: http://127.0.0.1:%d', server.server_port)
124      self._token_provider = token_provider
125      self._accounts = frozenset(accounts)
126      self._rpc_secret = rpc_secret
127      self._server = server
128      self._accept_thread = threading.Thread(target=self._server.serve_forever)
129      self._accept_thread.start()
130      local_auth = {
131          'rpc_port':
132              self._server.server_port,
133          'secret':
134              self._rpc_secret,
135          'accounts': [{
136              'id': acc.id,
137              'email': acc.email
138          } for acc in sorted(accounts)],
139      }
140      # TODO(vadimsh): Some clients don't understand 'null' value for
141      # default_account_id, so just omit it completely for now.
142      if default_account_id:
143        local_auth['default_account_id'] = default_account_id
144      return local_auth
145
146  def stop(self):
147    """Stops the server and resets the state."""
148    with self._lock:
149      if not self._server:
150        return
151      server, self._server = self._server, None
152      thread, self._accept_thread = self._accept_thread, None
153      self._token_provider = None
154      self._accounts = frozenset()
155      self._rpc_secret = None
156      self._cache.clear()
157    logging.debug('Stopping the local auth server...')
158    server.shutdown()
159    thread.join()
160    server.server_close()
161    logging.info('The local auth server is stopped')
162
163  def handle_rpc(self, method, request):
164    """Called by _RequestHandler to handle one RPC call.
165
166    Called from internal server thread. May be called even if the server is
167    already stopped (due to BaseHTTPServer.HTTPServer implementation that
168    stupidly leaks handler threads).
169
170    Args:
171      method: name of the invoked RPC method, e.g. "GetOAuthToken".
172      request: JSON dict with the request body.
173
174    Returns:
175      JSON dict with the response body.
176
177    Raises:
178      RPCError to return non-200 HTTP code and an error message as plain text.
179    """
180    if method == 'GetOAuthToken':
181      return self.handle_get_oauth_token(request)
182    raise RPCError(404, 'Unknown RPC method "%s".' % method)
183
184  ### RPC method handlers. Called from internal threads.
185
186  def handle_get_oauth_token(self, request):
187    """Returns an OAuth token representing the task service account.
188
189    The returned token is usable for at least 1 min.
190
191    Request body:
192    {
193      "account_id": <str>,
194      "scopes": [<str scope1>, <str scope2>, ...],
195      "secret": <str from LUCI_CONTEXT.local_auth.secret>,
196    }
197
198    Response body:
199    {
200      "error_code": <int, 0 or missing on success>,
201      "error_message": <str, optional>,
202      "access_token": <str with actual token (on success)>,
203      "expiry": <int with unix timestamp in seconds (on success)>
204    }
205    """
206    # Logical account to get a token for (e.g. "task" or "system").
207    account_id = request.get('account_id')
208    if not account_id:
209      raise RPCError(400, 'Field "account_id" is required.')
210    if not isinstance(account_id, six.string_types):
211      raise RPCError(400, 'Field "account_id" must be a string')
212    account_id = str(account_id)
213
214    # Validate scopes. It is conceptually a set, so remove duplicates.
215    scopes = request.get('scopes')
216    if not scopes:
217      raise RPCError(400, 'Field "scopes" is required.')
218    if (not isinstance(scopes, list) or
219        not all(isinstance(s, six.string_types) for s in scopes)):
220      raise RPCError(400, 'Field "scopes" must be a list of strings.')
221    scopes = tuple(sorted(set(map(str, scopes))))
222
223    # Validate the secret format.
224    secret = request.get('secret')
225    if not secret:
226      raise RPCError(400, 'Field "secret" is required.')
227    if not isinstance(secret, six.string_types):
228      raise RPCError(400, 'Field "secret" must be a string.')
229    secret = str(secret)
230
231    # Grab the state from the lock-guarded state.
232    with self._lock:
233      if not self._server:
234        raise RPCError(503, 'Stopped already.')
235      rpc_secret = self._rpc_secret
236      accounts = self._accounts
237      token_provider = self._token_provider
238
239    # Use constant time check to prevent malicious processes from discovering
240    # the secret byte-by-byte measuring response time.
241    if not constant_time_equals(secret, rpc_secret):
242      raise RPCError(403, 'Invalid "secret".')
243
244    # Make sure we know about the requested account.
245    if not any(account_id == acc.id for acc in accounts):
246      raise RPCError(404, 'Unrecognized account ID %r.' % account_id)
247
248    # Grab the token (or a fatal error) from the memory cache, checks token
249    # expiration time.
250    cache_key = (account_id, scopes)
251    tok_or_err = None
252    need_refresh = False
253    with self._lock:
254      if not self._server:
255        raise RPCError(503, 'Stopped already.')
256      tok_or_err = self._cache.get(cache_key)
257      need_refresh = (
258          not tok_or_err or
259          isinstance(tok_or_err, AccessToken) and should_refresh(tok_or_err))
260
261    # Do the refresh outside of the RPC server lock to unblock other clients
262    # that are hitting the cache. The token provider should implement its own
263    # synchronization.
264    if need_refresh:
265      try:
266        tok_or_err = token_provider.generate_token(account_id, scopes)
267        assert isinstance(tok_or_err, AccessToken), tok_or_err
268      except TokenError as exc:
269        tok_or_err = exc
270      # Cache the token or fatal errors (to avoid useless retry later).
271      with self._lock:
272        if not self._server:
273          raise RPCError(503, 'Stopped already.')
274        self._cache[cache_key] = tok_or_err
275
276    # Done.
277    if isinstance(tok_or_err, AccessToken):
278      return {
279        'access_token': tok_or_err.access_token,
280        'expiry': int(tok_or_err.expiry),
281      }
282    if isinstance(tok_or_err, TokenError):
283      return {
284          'error_code': tok_or_err.code,
285          'error_message': str(tok_or_err) or 'unknown',
286      }
287    raise AssertionError('impossible')
288
289
290def constant_time_equals(a, b):
291  """Compares two strings in constant time regardless of theirs content."""
292  if len(a) != len(b):
293    return False
294  result = 0
295  for x, y in zip(a, b):
296    result |= ord(x) ^ ord(y)
297  return result == 0
298
299
300def should_refresh(tok):
301  """Returns True if the token must be refreshed because it expires soon."""
302  # LUCI_CONTEXT protocol requires that returned tokens are alive for at least
303  # 2.5 min. See LUCI_CONTEXT.md. Add 30 sec extra of leeway.
304  return time.time() > tok.expiry - 3*60
305
306
307class _HTTPServer(socketserver.ThreadingMixIn, BaseHTTPServer.HTTPServer):
308  """Used internally by LocalAuthServer."""
309
310  # How often to poll 'select' in local HTTP server.
311  #
312  # Defines minimal amount of time 'stop' would block. Overridden in tests to
313  # speed them up.
314  poll_interval = 0.5
315
316  # From socketserver.ThreadingMixIn.
317  daemon_threads = True
318  # From BaseHTTPServer.HTTPServer.
319  request_queue_size = 50
320
321  def __init__(self, local_auth_server, addr):
322    BaseHTTPServer.HTTPServer.__init__(self, addr, _RequestHandler)
323    self.local_auth_server = local_auth_server
324
325  def serve_forever(self, poll_interval=None):
326    """Overrides default poll interval."""
327    BaseHTTPServer.HTTPServer.serve_forever(
328        self, poll_interval or self.poll_interval)
329
330  def handle_error(self, request, client_address):
331    """Overrides default handle_error that dumbs stuff to stdout."""
332    logging.exception('local auth server: Exception happened')
333
334
335class _RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
336  """Used internally by LocalAuthServer.
337
338  Parses the request, serializes and write the response.
339  """
340
341  # Buffer the reply, no need to send each line separately.
342  wbufsize = -1
343
344  def log_message(self, fmt, *args):
345    """Overrides default log_message to not abuse stderr."""
346    logging.debug('local auth server: ' + fmt, *args)
347
348  def send_error(self, code, message=None):
349    """Overrides default send_error to send 'text/plain' response."""
350    assert isinstance(message, str), 'unicode is not allowed'
351    logging.warning('local auth server: HTTP %d - %s', code, message)
352    message = (message or '') + '\n'
353    self.send_response(code)
354    self.send_header('Connection', 'close')
355    self.send_header('Content-Length', str(len(message)))
356    self.send_header('Content-Type', 'text/plain')
357    self.end_headers()
358    self.wfile.write(message.encode('utf-8'))
359
360  def do_POST(self):
361    """Implements POST handler."""
362    # Parse URL to extract method name.
363    m = re.match(r'^/rpc/LuciLocalAuthService\.([a-zA-Z0-9_]+)$', self.path)
364    if not m:
365      self.send_error(404, 'Expecting /rpc/LuciLocalAuthService.*')
366      return
367    method = m.group(1)
368
369    # The request body MUST be JSON. Ignore charset, we don't care.
370    ct = self.headers.get('content-type')
371    if not ct or ct.split(';')[0] != 'application/json':
372      self.send_error(
373          400, 'Expecting "application/json" Content-Type, got %r' % ct)
374      return
375
376    # Read the body. Chunked transfer encoding or compression is no supported.
377    try:
378      content_len = int(self.headers['content-length'])
379    except ValueError:
380      self.send_error(400, 'Missing on invalid Content-Length header')
381      return
382    try:
383      req = json.loads(self.rfile.read(content_len))
384    except ValueError as exc:
385      self.send_error(400, 'Not a JSON: %s' % exc)
386      return
387    if not isinstance(req, dict):
388      self.send_error(400, 'Not a JSON dictionary')
389      return
390
391    # Let the LocalAuthServer handle the request. Prepare the response body.
392    try:
393      resp = self.server.local_auth_server.handle_rpc(method, req)
394      response_body = json.dumps(resp) + '\n'
395    except RPCError as exc:
396      self.send_error(exc.code, str(exc))
397      return
398    except Exception as exc:
399      self.send_error(500, 'Internal error: %s' % exc)
400      return
401
402    # Send the response.
403    self.send_response(200)
404    self.send_header('Connection', 'close')
405    self.send_header('Content-Length', str(len(response_body)))
406    self.send_header('Content-Type', 'application/json')
407    self.end_headers()
408    self.wfile.write(response_body.encode('utf-8'))
409
410
411def testing_main():
412  """Launches a local HTTP auth service and waits for Ctrl+C.
413
414  Useful during development and manual testing.
415  """
416  # Don't mess with sys.path outside of adhoc testing.
417  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
418  sys.path.insert(0, ROOT_DIR)
419  from libs import luci_context
420
421  logging.basicConfig(level=logging.DEBUG)
422
423  class DumbProvider(object):
424    def generate_token(self, account_id, scopes):
425      logging.info('generate_token(%r, %r) called', account_id, scopes)
426      return AccessToken('fake_tok_for_%s' % account_id, time.time() + 80)
427
428  server = LocalAuthServer()
429  ctx = server.start(
430      token_provider=DumbProvider(),
431      accounts=[
432          Account('a', 'a@example.com'),
433          Account('b', 'b@example.com'),
434          Account('c', 'c@example.com'),
435      ],
436      default_account_id='a',
437      port=11111)
438  try:
439    with luci_context.write(local_auth=ctx):
440      print('Copy-paste this into another shell:')
441      print('export LUCI_CONTEXT=%s' % os.getenv('LUCI_CONTEXT'))
442      while True:
443        time.sleep(1)
444  except KeyboardInterrupt:
445    pass
446  finally:
447    server.stop()
448
449
450if __name__ == '__main__':
451  testing_main()
452