1# Copyright 2012 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing,
10# software distributed under the License is distributed on an
11# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
12# either express or implied. See the License for the specific
13# language governing permissions and limitations under the License.
14
15"""Base and helper classes for Google RESTful APIs."""
16
17
18
19
20
21__all__ = ['add_sync_methods']
22
23import httplib
24import random
25import time
26
27from . import api_utils
28
29try:
30  from google.appengine.api import app_identity
31  from google.appengine.ext import ndb
32except ImportError:
33  from google.appengine.api import app_identity
34  from google.appengine.ext import ndb
35
36
37def _make_sync_method(name):
38  """Helper to synthesize a synchronous method from an async method name.
39
40  Used by the @add_sync_methods class decorator below.
41
42  Args:
43    name: The name of the synchronous method.
44
45  Returns:
46    A method (with first argument 'self') that retrieves and calls
47    self.<name>, passing its own arguments, expects it to return a
48    Future, and then waits for and returns that Future's result.
49  """
50
51  def sync_wrapper(self, *args, **kwds):
52    method = getattr(self, name)
53    future = method(*args, **kwds)
54    return future.get_result()
55
56  return sync_wrapper
57
58
59def add_sync_methods(cls):
60  """Class decorator to add synchronous methods corresponding to async methods.
61
62  This modifies the class in place, adding additional methods to it.
63  If a synchronous method of a given name already exists it is not
64  replaced.
65
66  Args:
67    cls: A class.
68
69  Returns:
70    The same class, modified in place.
71  """
72  for name in cls.__dict__.keys():
73    if name.endswith('_async'):
74      sync_name = name[:-6]
75      if not hasattr(cls, sync_name):
76        setattr(cls, sync_name, _make_sync_method(name))
77  return cls
78
79
80class _AE_TokenStorage_(ndb.Model):
81  """Entity to store app_identity tokens in memcache."""
82
83  token = ndb.StringProperty()
84  expires = ndb.FloatProperty()
85
86
87@ndb.tasklet
88def _make_token_async(scopes, service_account_id):
89  """Get a fresh authentication token.
90
91  Args:
92    scopes: A list of scopes.
93    service_account_id: Internal-use only.
94
95  Returns:
96    An tuple (token, expiration_time) where expiration_time is
97    seconds since the epoch.
98  """
99  rpc = app_identity.create_rpc()
100  app_identity.make_get_access_token_call(rpc, scopes, service_account_id)
101  token, expires_at = yield rpc
102  raise ndb.Return((token, expires_at))
103
104
105class _RestApi(object):
106  """Base class for REST-based API wrapper classes.
107
108  This class manages authentication tokens and request retries.  All
109  APIs are available as synchronous and async methods; synchronous
110  methods are synthesized from async ones by the add_sync_methods()
111  function in this module.
112
113  WARNING: Do NOT directly use this api. It's an implementation detail
114  and is subject to change at any release.
115  """
116
117  _TOKEN_EXPIRATION_HEADROOM = random.randint(60, 600)
118
119  def __init__(self, scopes, service_account_id=None, token_maker=None,
120               retry_params=None):
121    """Constructor.
122
123    Args:
124      scopes: A scope or a list of scopes.
125      token_maker: An asynchronous function of the form
126        (scopes, service_account_id) -> (token, expires).
127      retry_params: An instance of api_utils.RetryParams. If None, the
128        default for current thread will be used.
129      service_account_id: Internal use only.
130    """
131
132    if isinstance(scopes, basestring):
133      scopes = [scopes]
134    self.scopes = scopes
135    self.service_account_id = service_account_id
136    self.make_token_async = token_maker or _make_token_async
137    self.token = None
138    if not retry_params:
139      retry_params = api_utils._get_default_retry_params()
140    self.retry_params = retry_params
141
142  def __getstate__(self):
143    """Store state as part of serialization/pickling."""
144    return {'token': self.token,
145            'scopes': self.scopes,
146            'id': self.service_account_id,
147            'a_maker': None if self.make_token_async == _make_token_async
148            else self.make_token_async,
149            'retry_params': self.retry_params}
150
151  def __setstate__(self, state):
152    """Restore state as part of deserialization/unpickling."""
153    self.__init__(state['scopes'],
154                  service_account_id=state['id'],
155                  token_maker=state['a_maker'],
156                  retry_params=state['retry_params'])
157    self.token = state['token']
158
159  @ndb.tasklet
160  def do_request_async(self, url, method='GET', headers=None, payload=None,
161                       deadline=None, callback=None):
162    """Issue one HTTP request.
163
164    This is an async wrapper around urlfetch(). It adds an authentication
165    header and retries on a 401 status code. Upon other retriable errors,
166    it performs blocking retries.
167    """
168    headers = {} if headers is None else dict(headers)
169    if self.token is None:
170      self.token = yield self.get_token_async()
171    headers['authorization'] = 'OAuth ' + self.token
172
173    deadline = deadline or self.retry_params.urlfetch_timeout
174
175    retry = False
176    resp = None
177    try:
178      resp = yield self.urlfetch_async(url, payload=payload, method=method,
179                                       headers=headers, follow_redirects=False,
180                                       deadline=deadline, callback=callback)
181      if resp.status_code == httplib.UNAUTHORIZED:
182        self.token = yield self.get_token_async(refresh=True)
183        headers['authorization'] = 'OAuth ' + self.token
184        resp = yield self.urlfetch_async(
185            url, payload=payload, method=method, headers=headers,
186            follow_redirects=False, deadline=deadline, callback=callback)
187    except api_utils._RETRIABLE_EXCEPTIONS:
188      retry = True
189    else:
190      retry = api_utils._should_retry(resp)
191
192    if retry:
193      retry_resp = api_utils._retry_fetch(
194          url, retry_params=self.retry_params, payload=payload, method=method,
195          headers=headers, follow_redirects=False, deadline=deadline)
196      if retry_resp:
197        resp = retry_resp
198      elif not resp:
199        raise
200
201    raise ndb.Return((resp.status_code, resp.headers, resp.content))
202
203  @ndb.tasklet
204  def get_token_async(self, refresh=False):
205    """Get an authentication token.
206
207    The token is cached in memcache, keyed by the scopes argument.
208
209    Args:
210      refresh: If True, ignore a cached token; default False.
211
212    Returns:
213      An authentication token.
214    """
215    if self.token is not None and not refresh:
216      raise ndb.Return(self.token)
217    key = '%s,%s' % (self.service_account_id, ','.join(self.scopes))
218    ts = yield _AE_TokenStorage_.get_by_id_async(
219        key, use_cache=True, use_memcache=True,
220        use_datastore=self.retry_params.save_access_token)
221    if ts is None or ts.expires < (time.time() +
222                                   self._TOKEN_EXPIRATION_HEADROOM):
223      token, expires_at = yield self.make_token_async(
224          self.scopes, self.service_account_id)
225      timeout = int(expires_at - time.time())
226      ts = _AE_TokenStorage_(id=key, token=token, expires=expires_at)
227      if timeout > 0:
228        yield ts.put_async(memcache_timeout=timeout,
229                           use_datastore=self.retry_params.save_access_token,
230                           use_cache=True, use_memcache=True)
231    self.token = ts.token
232    raise ndb.Return(self.token)
233
234  def urlfetch_async(self, url, **kwds):
235    """Make an async urlfetch() call.
236
237    This just passes the url and keyword arguments to NDB's async
238    urlfetch() wrapper in the current context.
239
240    This returns a Future despite not being decorated with @ndb.tasklet!
241    """
242    ctx = ndb.get_context()
243    return ctx.urlfetch(url, **kwds)
244
245
246_RestApi = add_sync_methods(_RestApi)
247