1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Transport adapter for Async HTTP (aiohttp).
16
17NOTE: This async support is experimental and marked internal. This surface may
18change in minor releases.
19"""
20
21from __future__ import absolute_import
22
23import asyncio
24import functools
25
26import aiohttp
27import six
28import urllib3
29
30from google.auth import exceptions
31from google.auth import transport
32from google.auth.transport import requests
33
34# Timeout can be re-defined depending on async requirement. Currently made 60s more than
35# sync timeout.
36_DEFAULT_TIMEOUT = 180  # in seconds
37
38
39class _CombinedResponse(transport.Response):
40    """
41    In order to more closely resemble the `requests` interface, where a raw
42    and deflated content could be accessed at once, this class lazily reads the
43    stream in `transport.Response` so both return forms can be used.
44
45    The gzip and deflate transfer-encodings are automatically decoded for you
46    because the default parameter for autodecompress into the ClientSession is set
47    to False, and therefore we add this class to act as a wrapper for a user to be
48    able to access both the raw and decoded response bodies - mirroring the sync
49    implementation.
50    """
51
52    def __init__(self, response):
53        self._response = response
54        self._raw_content = None
55
56    def _is_compressed(self):
57        headers = self._response.headers
58        return "Content-Encoding" in headers and (
59            headers["Content-Encoding"] == "gzip"
60            or headers["Content-Encoding"] == "deflate"
61        )
62
63    @property
64    def status(self):
65        return self._response.status
66
67    @property
68    def headers(self):
69        return self._response.headers
70
71    @property
72    def data(self):
73        return self._response.content
74
75    async def raw_content(self):
76        if self._raw_content is None:
77            self._raw_content = await self._response.content.read()
78        return self._raw_content
79
80    async def content(self):
81        # Load raw_content if necessary
82        await self.raw_content()
83        if self._is_compressed():
84            decoder = urllib3.response.MultiDecoder(
85                self._response.headers["Content-Encoding"]
86            )
87            decompressed = decoder.decompress(self._raw_content)
88            return decompressed
89
90        return self._raw_content
91
92
93class _Response(transport.Response):
94    """
95    Requests transport response adapter.
96
97    Args:
98        response (requests.Response): The raw Requests response.
99    """
100
101    def __init__(self, response):
102        self._response = response
103
104    @property
105    def status(self):
106        return self._response.status
107
108    @property
109    def headers(self):
110        return self._response.headers
111
112    @property
113    def data(self):
114        return self._response.content
115
116
117class Request(transport.Request):
118    """Requests request adapter.
119
120    This class is used internally for making requests using asyncio transports
121    in a consistent way. If you use :class:`AuthorizedSession` you do not need
122    to construct or use this class directly.
123
124    This class can be useful if you want to manually refresh a
125    :class:`~google.auth.credentials.Credentials` instance::
126
127        import google.auth.transport.aiohttp_requests
128
129        request = google.auth.transport.aiohttp_requests.Request()
130
131        credentials.refresh(request)
132
133    Args:
134        session (aiohttp.ClientSession): An instance :class:`aiohttp.ClientSession` used
135            to make HTTP requests. If not specified, a session will be created.
136
137    .. automethod:: __call__
138    """
139
140    def __init__(self, session=None):
141        # TODO: Use auto_decompress property for aiohttp 3.7+
142        if session is not None and session._auto_decompress:
143            raise ValueError(
144                "Client sessions with auto_decompress=True are not supported."
145            )
146        self.session = session
147
148    async def __call__(
149        self,
150        url,
151        method="GET",
152        body=None,
153        headers=None,
154        timeout=_DEFAULT_TIMEOUT,
155        **kwargs,
156    ):
157        """
158        Make an HTTP request using aiohttp.
159
160        Args:
161            url (str): The URL to be requested.
162            method (Optional[str]):
163                The HTTP method to use for the request. Defaults to 'GET'.
164            body (Optional[bytes]):
165                The payload or body in HTTP request.
166            headers (Optional[Mapping[str, str]]):
167                Request headers.
168            timeout (Optional[int]): The number of seconds to wait for a
169                response from the server. If not specified or if None, the
170                requests default timeout will be used.
171            kwargs: Additional arguments passed through to the underlying
172                requests :meth:`requests.Session.request` method.
173
174        Returns:
175            google.auth.transport.Response: The HTTP response.
176
177        Raises:
178            google.auth.exceptions.TransportError: If any exception occurred.
179        """
180
181        try:
182            if self.session is None:  # pragma: NO COVER
183                self.session = aiohttp.ClientSession(
184                    auto_decompress=False
185                )  # pragma: NO COVER
186            requests._LOGGER.debug("Making request: %s %s", method, url)
187            response = await self.session.request(
188                method, url, data=body, headers=headers, timeout=timeout, **kwargs
189            )
190            return _CombinedResponse(response)
191
192        except aiohttp.ClientError as caught_exc:
193            new_exc = exceptions.TransportError(caught_exc)
194            six.raise_from(new_exc, caught_exc)
195
196        except asyncio.TimeoutError as caught_exc:
197            new_exc = exceptions.TransportError(caught_exc)
198            six.raise_from(new_exc, caught_exc)
199
200
201class AuthorizedSession(aiohttp.ClientSession):
202    """This is an async implementation of the Authorized Session class. We utilize an
203    aiohttp transport instance, and the interface mirrors the google.auth.transport.requests
204    Authorized Session class, except for the change in the transport used in the async use case.
205
206    A Requests Session class with credentials.
207
208    This class is used to perform requests to API endpoints that require
209    authorization::
210
211        from google.auth.transport import aiohttp_requests
212
213        async with aiohttp_requests.AuthorizedSession(credentials) as authed_session:
214            response = await authed_session.request(
215                'GET', 'https://www.googleapis.com/storage/v1/b')
216
217    The underlying :meth:`request` implementation handles adding the
218    credentials' headers to the request and refreshing credentials as needed.
219
220    Args:
221        credentials (google.auth._credentials_async.Credentials):
222            The credentials to add to the request.
223        refresh_status_codes (Sequence[int]): Which HTTP status codes indicate
224            that credentials should be refreshed and the request should be
225            retried.
226        max_refresh_attempts (int): The maximum number of times to attempt to
227            refresh the credentials and retry the request.
228        refresh_timeout (Optional[int]): The timeout value in seconds for
229            credential refresh HTTP requests.
230        auth_request (google.auth.transport.aiohttp_requests.Request):
231            (Optional) An instance of
232            :class:`~google.auth.transport.aiohttp_requests.Request` used when
233            refreshing credentials. If not passed,
234            an instance of :class:`~google.auth.transport.aiohttp_requests.Request`
235            is created.
236    """
237
238    def __init__(
239        self,
240        credentials,
241        refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES,
242        max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS,
243        refresh_timeout=None,
244        auth_request=None,
245        auto_decompress=False,
246    ):
247        super(AuthorizedSession, self).__init__()
248        self.credentials = credentials
249        self._refresh_status_codes = refresh_status_codes
250        self._max_refresh_attempts = max_refresh_attempts
251        self._refresh_timeout = refresh_timeout
252        self._is_mtls = False
253        self._auth_request = auth_request
254        self._auth_request_session = None
255        self._loop = asyncio.get_event_loop()
256        self._refresh_lock = asyncio.Lock()
257        self._auto_decompress = auto_decompress
258
259    async def request(
260        self,
261        method,
262        url,
263        data=None,
264        headers=None,
265        max_allowed_time=None,
266        timeout=_DEFAULT_TIMEOUT,
267        auto_decompress=False,
268        **kwargs,
269    ):
270
271        """Implementation of Authorized Session aiohttp request.
272
273        Args:
274            method (str):
275                The http request method used (e.g. GET, PUT, DELETE)
276            url (str):
277                The url at which the http request is sent.
278            data (Optional[dict]): Dictionary, list of tuples, bytes, or file-like
279                object to send in the body of the Request.
280            headers (Optional[dict]): Dictionary of HTTP Headers to send with the
281                Request.
282            timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
283                The amount of time in seconds to wait for the server response
284                with each individual request. Can also be passed as an
285                ``aiohttp.ClientTimeout`` object.
286            max_allowed_time (Optional[float]):
287                If the method runs longer than this, a ``Timeout`` exception is
288                automatically raised. Unlike the ``timeout`` parameter, this
289                value applies to the total method execution time, even if
290                multiple requests are made under the hood.
291
292                Mind that it is not guaranteed that the timeout error is raised
293                at ``max_allowed_time``. It might take longer, for example, if
294                an underlying request takes a lot of time, but the request
295                itself does not timeout, e.g. if a large file is being
296                transmitted. The timout error will be raised after such
297                request completes.
298        """
299        # Headers come in as bytes which isn't expected behavior, the resumable
300        # media libraries in some cases expect a str type for the header values,
301        # but sometimes the operations return these in bytes types.
302        if headers:
303            for key in headers.keys():
304                if type(headers[key]) is bytes:
305                    headers[key] = headers[key].decode("utf-8")
306
307        async with aiohttp.ClientSession(
308            auto_decompress=self._auto_decompress
309        ) as self._auth_request_session:
310            auth_request = Request(self._auth_request_session)
311            self._auth_request = auth_request
312
313            # Use a kwarg for this instead of an attribute to maintain
314            # thread-safety.
315            _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0)
316            # Make a copy of the headers. They will be modified by the credentials
317            # and we want to pass the original headers if we recurse.
318            request_headers = headers.copy() if headers is not None else {}
319
320            # Do not apply the timeout unconditionally in order to not override the
321            # _auth_request's default timeout.
322            auth_request = (
323                self._auth_request
324                if timeout is None
325                else functools.partial(self._auth_request, timeout=timeout)
326            )
327
328            remaining_time = max_allowed_time
329
330            with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
331                await self.credentials.before_request(
332                    auth_request, method, url, request_headers
333                )
334
335            with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
336                response = await super(AuthorizedSession, self).request(
337                    method,
338                    url,
339                    data=data,
340                    headers=request_headers,
341                    timeout=timeout,
342                    **kwargs,
343                )
344
345            remaining_time = guard.remaining_timeout
346
347            if (
348                response.status in self._refresh_status_codes
349                and _credential_refresh_attempt < self._max_refresh_attempts
350            ):
351
352                requests._LOGGER.info(
353                    "Refreshing credentials due to a %s response. Attempt %s/%s.",
354                    response.status,
355                    _credential_refresh_attempt + 1,
356                    self._max_refresh_attempts,
357                )
358
359                # Do not apply the timeout unconditionally in order to not override the
360                # _auth_request's default timeout.
361                auth_request = (
362                    self._auth_request
363                    if timeout is None
364                    else functools.partial(self._auth_request, timeout=timeout)
365                )
366
367                with requests.TimeoutGuard(
368                    remaining_time, asyncio.TimeoutError
369                ) as guard:
370                    async with self._refresh_lock:
371                        await self._loop.run_in_executor(
372                            None, self.credentials.refresh, auth_request
373                        )
374
375                remaining_time = guard.remaining_timeout
376
377                return await self.request(
378                    method,
379                    url,
380                    data=data,
381                    headers=headers,
382                    max_allowed_time=remaining_time,
383                    timeout=timeout,
384                    _credential_refresh_attempt=_credential_refresh_attempt + 1,
385                    **kwargs,
386                )
387
388        return response
389