1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Transport adapter for Async HTTP (aiohttp). 16 17NOTE: This async support is experimental and marked internal. This surface may 18change in minor releases. 19""" 20 21from __future__ import absolute_import 22 23import asyncio 24import functools 25 26import aiohttp 27import six 28import urllib3 29 30from google.auth import exceptions 31from google.auth import transport 32from google.auth.transport import requests 33 34# Timeout can be re-defined depending on async requirement. Currently made 60s more than 35# sync timeout. 36_DEFAULT_TIMEOUT = 180 # in seconds 37 38 39class _CombinedResponse(transport.Response): 40 """ 41 In order to more closely resemble the `requests` interface, where a raw 42 and deflated content could be accessed at once, this class lazily reads the 43 stream in `transport.Response` so both return forms can be used. 44 45 The gzip and deflate transfer-encodings are automatically decoded for you 46 because the default parameter for autodecompress into the ClientSession is set 47 to False, and therefore we add this class to act as a wrapper for a user to be 48 able to access both the raw and decoded response bodies - mirroring the sync 49 implementation. 50 """ 51 52 def __init__(self, response): 53 self._response = response 54 self._raw_content = None 55 56 def _is_compressed(self): 57 headers = self._response.headers 58 return "Content-Encoding" in headers and ( 59 headers["Content-Encoding"] == "gzip" 60 or headers["Content-Encoding"] == "deflate" 61 ) 62 63 @property 64 def status(self): 65 return self._response.status 66 67 @property 68 def headers(self): 69 return self._response.headers 70 71 @property 72 def data(self): 73 return self._response.content 74 75 async def raw_content(self): 76 if self._raw_content is None: 77 self._raw_content = await self._response.content.read() 78 return self._raw_content 79 80 async def content(self): 81 # Load raw_content if necessary 82 await self.raw_content() 83 if self._is_compressed(): 84 decoder = urllib3.response.MultiDecoder( 85 self._response.headers["Content-Encoding"] 86 ) 87 decompressed = decoder.decompress(self._raw_content) 88 return decompressed 89 90 return self._raw_content 91 92 93class _Response(transport.Response): 94 """ 95 Requests transport response adapter. 96 97 Args: 98 response (requests.Response): The raw Requests response. 99 """ 100 101 def __init__(self, response): 102 self._response = response 103 104 @property 105 def status(self): 106 return self._response.status 107 108 @property 109 def headers(self): 110 return self._response.headers 111 112 @property 113 def data(self): 114 return self._response.content 115 116 117class Request(transport.Request): 118 """Requests request adapter. 119 120 This class is used internally for making requests using asyncio transports 121 in a consistent way. If you use :class:`AuthorizedSession` you do not need 122 to construct or use this class directly. 123 124 This class can be useful if you want to manually refresh a 125 :class:`~google.auth.credentials.Credentials` instance:: 126 127 import google.auth.transport.aiohttp_requests 128 129 request = google.auth.transport.aiohttp_requests.Request() 130 131 credentials.refresh(request) 132 133 Args: 134 session (aiohttp.ClientSession): An instance :class:`aiohttp.ClientSession` used 135 to make HTTP requests. If not specified, a session will be created. 136 137 .. automethod:: __call__ 138 """ 139 140 def __init__(self, session=None): 141 # TODO: Use auto_decompress property for aiohttp 3.7+ 142 if session is not None and session._auto_decompress: 143 raise ValueError( 144 "Client sessions with auto_decompress=True are not supported." 145 ) 146 self.session = session 147 148 async def __call__( 149 self, 150 url, 151 method="GET", 152 body=None, 153 headers=None, 154 timeout=_DEFAULT_TIMEOUT, 155 **kwargs, 156 ): 157 """ 158 Make an HTTP request using aiohttp. 159 160 Args: 161 url (str): The URL to be requested. 162 method (Optional[str]): 163 The HTTP method to use for the request. Defaults to 'GET'. 164 body (Optional[bytes]): 165 The payload or body in HTTP request. 166 headers (Optional[Mapping[str, str]]): 167 Request headers. 168 timeout (Optional[int]): The number of seconds to wait for a 169 response from the server. If not specified or if None, the 170 requests default timeout will be used. 171 kwargs: Additional arguments passed through to the underlying 172 requests :meth:`requests.Session.request` method. 173 174 Returns: 175 google.auth.transport.Response: The HTTP response. 176 177 Raises: 178 google.auth.exceptions.TransportError: If any exception occurred. 179 """ 180 181 try: 182 if self.session is None: # pragma: NO COVER 183 self.session = aiohttp.ClientSession( 184 auto_decompress=False 185 ) # pragma: NO COVER 186 requests._LOGGER.debug("Making request: %s %s", method, url) 187 response = await self.session.request( 188 method, url, data=body, headers=headers, timeout=timeout, **kwargs 189 ) 190 return _CombinedResponse(response) 191 192 except aiohttp.ClientError as caught_exc: 193 new_exc = exceptions.TransportError(caught_exc) 194 six.raise_from(new_exc, caught_exc) 195 196 except asyncio.TimeoutError as caught_exc: 197 new_exc = exceptions.TransportError(caught_exc) 198 six.raise_from(new_exc, caught_exc) 199 200 201class AuthorizedSession(aiohttp.ClientSession): 202 """This is an async implementation of the Authorized Session class. We utilize an 203 aiohttp transport instance, and the interface mirrors the google.auth.transport.requests 204 Authorized Session class, except for the change in the transport used in the async use case. 205 206 A Requests Session class with credentials. 207 208 This class is used to perform requests to API endpoints that require 209 authorization:: 210 211 from google.auth.transport import aiohttp_requests 212 213 async with aiohttp_requests.AuthorizedSession(credentials) as authed_session: 214 response = await authed_session.request( 215 'GET', 'https://www.googleapis.com/storage/v1/b') 216 217 The underlying :meth:`request` implementation handles adding the 218 credentials' headers to the request and refreshing credentials as needed. 219 220 Args: 221 credentials (google.auth._credentials_async.Credentials): 222 The credentials to add to the request. 223 refresh_status_codes (Sequence[int]): Which HTTP status codes indicate 224 that credentials should be refreshed and the request should be 225 retried. 226 max_refresh_attempts (int): The maximum number of times to attempt to 227 refresh the credentials and retry the request. 228 refresh_timeout (Optional[int]): The timeout value in seconds for 229 credential refresh HTTP requests. 230 auth_request (google.auth.transport.aiohttp_requests.Request): 231 (Optional) An instance of 232 :class:`~google.auth.transport.aiohttp_requests.Request` used when 233 refreshing credentials. If not passed, 234 an instance of :class:`~google.auth.transport.aiohttp_requests.Request` 235 is created. 236 """ 237 238 def __init__( 239 self, 240 credentials, 241 refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, 242 max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, 243 refresh_timeout=None, 244 auth_request=None, 245 auto_decompress=False, 246 ): 247 super(AuthorizedSession, self).__init__() 248 self.credentials = credentials 249 self._refresh_status_codes = refresh_status_codes 250 self._max_refresh_attempts = max_refresh_attempts 251 self._refresh_timeout = refresh_timeout 252 self._is_mtls = False 253 self._auth_request = auth_request 254 self._auth_request_session = None 255 self._loop = asyncio.get_event_loop() 256 self._refresh_lock = asyncio.Lock() 257 self._auto_decompress = auto_decompress 258 259 async def request( 260 self, 261 method, 262 url, 263 data=None, 264 headers=None, 265 max_allowed_time=None, 266 timeout=_DEFAULT_TIMEOUT, 267 auto_decompress=False, 268 **kwargs, 269 ): 270 271 """Implementation of Authorized Session aiohttp request. 272 273 Args: 274 method (str): 275 The http request method used (e.g. GET, PUT, DELETE) 276 url (str): 277 The url at which the http request is sent. 278 data (Optional[dict]): Dictionary, list of tuples, bytes, or file-like 279 object to send in the body of the Request. 280 headers (Optional[dict]): Dictionary of HTTP Headers to send with the 281 Request. 282 timeout (Optional[Union[float, aiohttp.ClientTimeout]]): 283 The amount of time in seconds to wait for the server response 284 with each individual request. Can also be passed as an 285 ``aiohttp.ClientTimeout`` object. 286 max_allowed_time (Optional[float]): 287 If the method runs longer than this, a ``Timeout`` exception is 288 automatically raised. Unlike the ``timeout`` parameter, this 289 value applies to the total method execution time, even if 290 multiple requests are made under the hood. 291 292 Mind that it is not guaranteed that the timeout error is raised 293 at ``max_allowed_time``. It might take longer, for example, if 294 an underlying request takes a lot of time, but the request 295 itself does not timeout, e.g. if a large file is being 296 transmitted. The timout error will be raised after such 297 request completes. 298 """ 299 # Headers come in as bytes which isn't expected behavior, the resumable 300 # media libraries in some cases expect a str type for the header values, 301 # but sometimes the operations return these in bytes types. 302 if headers: 303 for key in headers.keys(): 304 if type(headers[key]) is bytes: 305 headers[key] = headers[key].decode("utf-8") 306 307 async with aiohttp.ClientSession( 308 auto_decompress=self._auto_decompress 309 ) as self._auth_request_session: 310 auth_request = Request(self._auth_request_session) 311 self._auth_request = auth_request 312 313 # Use a kwarg for this instead of an attribute to maintain 314 # thread-safety. 315 _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) 316 # Make a copy of the headers. They will be modified by the credentials 317 # and we want to pass the original headers if we recurse. 318 request_headers = headers.copy() if headers is not None else {} 319 320 # Do not apply the timeout unconditionally in order to not override the 321 # _auth_request's default timeout. 322 auth_request = ( 323 self._auth_request 324 if timeout is None 325 else functools.partial(self._auth_request, timeout=timeout) 326 ) 327 328 remaining_time = max_allowed_time 329 330 with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: 331 await self.credentials.before_request( 332 auth_request, method, url, request_headers 333 ) 334 335 with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: 336 response = await super(AuthorizedSession, self).request( 337 method, 338 url, 339 data=data, 340 headers=request_headers, 341 timeout=timeout, 342 **kwargs, 343 ) 344 345 remaining_time = guard.remaining_timeout 346 347 if ( 348 response.status in self._refresh_status_codes 349 and _credential_refresh_attempt < self._max_refresh_attempts 350 ): 351 352 requests._LOGGER.info( 353 "Refreshing credentials due to a %s response. Attempt %s/%s.", 354 response.status, 355 _credential_refresh_attempt + 1, 356 self._max_refresh_attempts, 357 ) 358 359 # Do not apply the timeout unconditionally in order to not override the 360 # _auth_request's default timeout. 361 auth_request = ( 362 self._auth_request 363 if timeout is None 364 else functools.partial(self._auth_request, timeout=timeout) 365 ) 366 367 with requests.TimeoutGuard( 368 remaining_time, asyncio.TimeoutError 369 ) as guard: 370 async with self._refresh_lock: 371 await self._loop.run_in_executor( 372 None, self.credentials.refresh, auth_request 373 ) 374 375 remaining_time = guard.remaining_timeout 376 377 return await self.request( 378 method, 379 url, 380 data=data, 381 headers=headers, 382 max_allowed_time=remaining_time, 383 timeout=timeout, 384 _credential_refresh_attempt=_credential_refresh_attempt + 1, 385 **kwargs, 386 ) 387 388 return response 389