1# SPDX-FileCopyrightText: 2015 Eric Larson
2#
3# SPDX-License-Identifier: Apache-2.0
4
5import types
6import functools
7import zlib
8
9from requests.adapters import HTTPAdapter
10
11from .controller import CacheController, PERMANENT_REDIRECT_STATUSES
12from .cache import DictCache
13from .filewrapper import CallbackFileWrapper
14
15
16class CacheControlAdapter(HTTPAdapter):
17    invalidating_methods = {"PUT", "PATCH", "DELETE"}
18
19    def __init__(
20        self,
21        cache=None,
22        cache_etags=True,
23        controller_class=None,
24        serializer=None,
25        heuristic=None,
26        cacheable_methods=None,
27        *args,
28        **kw
29    ):
30        super(CacheControlAdapter, self).__init__(*args, **kw)
31        self.cache = DictCache() if cache is None else cache
32        self.heuristic = heuristic
33        self.cacheable_methods = cacheable_methods or ("GET",)
34
35        controller_factory = controller_class or CacheController
36        self.controller = controller_factory(
37            self.cache, cache_etags=cache_etags, serializer=serializer
38        )
39
40    def send(self, request, cacheable_methods=None, **kw):
41        """
42        Send a request. Use the request information to see if it
43        exists in the cache and cache the response if we need to and can.
44        """
45        cacheable = cacheable_methods or self.cacheable_methods
46        if request.method in cacheable:
47            try:
48                cached_response = self.controller.cached_request(request)
49            except zlib.error:
50                cached_response = None
51            if cached_response:
52                return self.build_response(request, cached_response, from_cache=True)
53
54            # check for etags and add headers if appropriate
55            request.headers.update(self.controller.conditional_headers(request))
56
57        resp = super(CacheControlAdapter, self).send(request, **kw)
58
59        return resp
60
61    def build_response(
62        self, request, response, from_cache=False, cacheable_methods=None
63    ):
64        """
65        Build a response by making a request or using the cache.
66
67        This will end up calling send and returning a potentially
68        cached response
69        """
70        cacheable = cacheable_methods or self.cacheable_methods
71        if not from_cache and request.method in cacheable:
72            # Check for any heuristics that might update headers
73            # before trying to cache.
74            if self.heuristic:
75                response = self.heuristic.apply(response)
76
77            # apply any expiration heuristics
78            if response.status == 304:
79                # We must have sent an ETag request. This could mean
80                # that we've been expired already or that we simply
81                # have an etag. In either case, we want to try and
82                # update the cache if that is the case.
83                cached_response = self.controller.update_cached_response(
84                    request, response
85                )
86
87                if cached_response is not response:
88                    from_cache = True
89
90                # We are done with the server response, read a
91                # possible response body (compliant servers will
92                # not return one, but we cannot be 100% sure) and
93                # release the connection back to the pool.
94                response.read(decode_content=False)
95                response.release_conn()
96
97                response = cached_response
98
99            # We always cache the 301 responses
100            elif int(response.status) in PERMANENT_REDIRECT_STATUSES:
101                self.controller.cache_response(request, response)
102            else:
103                # Wrap the response file with a wrapper that will cache the
104                #   response when the stream has been consumed.
105                response._fp = CallbackFileWrapper(
106                    response._fp,
107                    functools.partial(
108                        self.controller.cache_response, request, response
109                    ),
110                )
111                if response.chunked:
112                    super_update_chunk_length = response._update_chunk_length
113
114                    def _update_chunk_length(self):
115                        super_update_chunk_length()
116                        if self.chunk_left == 0:
117                            self._fp._close()
118
119                    response._update_chunk_length = types.MethodType(
120                        _update_chunk_length, response
121                    )
122
123        resp = super(CacheControlAdapter, self).build_response(request, response)
124
125        # See if we should invalidate the cache.
126        if request.method in self.invalidating_methods and resp.ok:
127            cache_url = self.controller.cache_url(request.url)
128            self.cache.delete(cache_url)
129
130        # Give the request a from_cache attr to let people use it
131        resp.from_cache = from_cache
132
133        return resp
134
135    def close(self):
136        self.cache.close()
137        super(CacheControlAdapter, self).close()
138