1# Copyright 2018 New Vector Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import logging
17import re
18import urllib
19from inspect import signature
20from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
21
22from prometheus_client import Counter, Gauge
23
24from synapse.api.errors import HttpResponseException, SynapseError
25from synapse.http import RequestTimedOutError
26from synapse.logging import opentracing
27from synapse.logging.opentracing import trace
28from synapse.util.caches.response_cache import ResponseCache
29from synapse.util.stringutils import random_string
30
31if TYPE_CHECKING:
32    from synapse.server import HomeServer
33
34logger = logging.getLogger(__name__)
35
36_pending_outgoing_requests = Gauge(
37    "synapse_pending_outgoing_replication_requests",
38    "Number of active outgoing replication requests, by replication method name",
39    ["name"],
40)
41
42_outgoing_request_counter = Counter(
43    "synapse_outgoing_replication_requests",
44    "Number of outgoing replication requests, by replication method name and result",
45    ["name", "code"],
46)
47
48
49class ReplicationEndpoint(metaclass=abc.ABCMeta):
50    """Helper base class for defining new replication HTTP endpoints.
51
52    This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
53    (with a `/:txn_id` suffix for cached requests), where NAME is a name,
54    PATH_ARGS are a tuple of parameters to be encoded in the URL.
55
56    For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
57    with `CACHE` set to true then this generates an endpoint:
58
59        /_synapse/replication/send_event/:event_id/:txn_id
60
61    For POST/PUT requests the payload is serialized to json and sent as the
62    body, while for GET requests the payload is added as query parameters. See
63    `_serialize_payload` for details.
64
65    Incoming requests are handled by overriding `_handle_request`. Servers
66    must call `register` to register the path with the HTTP server.
67
68    Requests can be sent by calling the client returned by `make_client`.
69    Requests are sent to master process by default, but can be sent to other
70    named processes by specifying an `instance_name` keyword argument.
71
72    Attributes:
73        NAME (str): A name for the endpoint, added to the path as well as used
74            in logging and metrics.
75        PATH_ARGS (tuple[str]): A list of parameters to be added to the path.
76            Adding parameters to the path (rather than payload) can make it
77            easier to follow along in the log files.
78        METHOD (str): The method of the HTTP request, defaults to POST. Can be
79            one of POST, PUT or GET. If GET then the payload is sent as query
80            parameters rather than a JSON body.
81        CACHE (bool): Whether server should cache the result of the request/
82            If true then transparently adds a txn_id to all requests, and
83            `_handle_request` must return a Deferred.
84        RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504
85            is received.
86    """
87
88    NAME: str = abc.abstractproperty()  # type: ignore
89    PATH_ARGS: Tuple[str, ...] = abc.abstractproperty()  # type: ignore
90    METHOD = "POST"
91    CACHE = True
92    RETRY_ON_TIMEOUT = True
93
94    def __init__(self, hs: "HomeServer"):
95        if self.CACHE:
96            self.response_cache: ResponseCache[str] = ResponseCache(
97                hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
98            )
99
100        # We reserve `instance_name` as a parameter to sending requests, so we
101        # assert here that sub classes don't try and use the name.
102        assert (
103            "instance_name" not in self.PATH_ARGS
104        ), "`instance_name` is a reserved parameter name"
105        assert (
106            "instance_name"
107            not in signature(self.__class__._serialize_payload).parameters
108        ), "`instance_name` is a reserved parameter name"
109
110        assert self.METHOD in ("PUT", "POST", "GET")
111
112        self._replication_secret = None
113        if hs.config.worker.worker_replication_secret:
114            self._replication_secret = hs.config.worker.worker_replication_secret
115
116    def _check_auth(self, request) -> None:
117        # Get the authorization header.
118        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
119
120        if len(auth_headers) > 1:
121            raise RuntimeError("Too many Authorization headers.")
122        parts = auth_headers[0].split(b" ")
123        if parts[0] == b"Bearer" and len(parts) == 2:
124            received_secret = parts[1].decode("ascii")
125            if self._replication_secret == received_secret:
126                # Success!
127                return
128
129        raise RuntimeError("Invalid Authorization header.")
130
131    @abc.abstractmethod
132    async def _serialize_payload(**kwargs):
133        """Static method that is called when creating a request.
134
135        Concrete implementations should have explicit parameters (rather than
136        kwargs) so that an appropriate exception is raised if the client is
137        called with unexpected parameters. All PATH_ARGS must appear in
138        argument list.
139
140        Returns:
141            dict: If POST/PUT request then dictionary must be JSON serialisable,
142            otherwise must be appropriate for adding as query args.
143        """
144        return {}
145
146    @abc.abstractmethod
147    async def _handle_request(self, request, **kwargs):
148        """Handle incoming request.
149
150        This is called with the request object and PATH_ARGS.
151
152        Returns:
153            tuple[int, dict]: HTTP status code and a JSON serialisable dict
154            to be used as response body of request.
155        """
156        pass
157
158    @classmethod
159    def make_client(cls, hs: "HomeServer"):
160        """Create a client that makes requests.
161
162        Returns a callable that accepts the same parameters as
163        `_serialize_payload`, and also accepts an optional `instance_name`
164        parameter to specify which instance to hit (the instance must be in
165        the `instance_map` config).
166        """
167        clock = hs.get_clock()
168        client = hs.get_simple_http_client()
169        local_instance_name = hs.get_instance_name()
170
171        master_host = hs.config.worker.worker_replication_host
172        master_port = hs.config.worker.worker_replication_http_port
173
174        instance_map = hs.config.worker.instance_map
175
176        outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
177
178        replication_secret = None
179        if hs.config.worker.worker_replication_secret:
180            replication_secret = hs.config.worker.worker_replication_secret.encode(
181                "ascii"
182            )
183
184        @trace(opname="outgoing_replication_request")
185        async def send_request(*, instance_name="master", **kwargs):
186            with outgoing_gauge.track_inprogress():
187                if instance_name == local_instance_name:
188                    raise Exception("Trying to send HTTP request to self")
189                if instance_name == "master":
190                    host = master_host
191                    port = master_port
192                elif instance_name in instance_map:
193                    host = instance_map[instance_name].host
194                    port = instance_map[instance_name].port
195                else:
196                    raise Exception(
197                        "Instance %r not in 'instance_map' config" % (instance_name,)
198                    )
199
200                data = await cls._serialize_payload(**kwargs)
201
202                url_args = [
203                    urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
204                ]
205
206                if cls.CACHE:
207                    txn_id = random_string(10)
208                    url_args.append(txn_id)
209
210                if cls.METHOD == "POST":
211                    request_func: Callable[
212                        ..., Awaitable[Any]
213                    ] = client.post_json_get_json
214                elif cls.METHOD == "PUT":
215                    request_func = client.put_json
216                elif cls.METHOD == "GET":
217                    request_func = client.get_json
218                else:
219                    # We have already asserted in the constructor that a
220                    # compatible was picked, but lets be paranoid.
221                    raise Exception(
222                        "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
223                    )
224
225                uri = "http://%s:%s/_synapse/replication/%s/%s" % (
226                    host,
227                    port,
228                    cls.NAME,
229                    "/".join(url_args),
230                )
231
232                try:
233                    # We keep retrying the same request for timeouts. This is so that we
234                    # have a good idea that the request has either succeeded or failed
235                    # on the master, and so whether we should clean up or not.
236                    while True:
237                        headers: Dict[bytes, List[bytes]] = {}
238                        # Add an authorization header, if configured.
239                        if replication_secret:
240                            headers[b"Authorization"] = [
241                                b"Bearer " + replication_secret
242                            ]
243                        opentracing.inject_header_dict(headers, check_destination=False)
244                        try:
245                            result = await request_func(uri, data, headers=headers)
246                            break
247                        except RequestTimedOutError:
248                            if not cls.RETRY_ON_TIMEOUT:
249                                raise
250
251                        logger.warning("%s request timed out; retrying", cls.NAME)
252
253                        # If we timed out we probably don't need to worry about backing
254                        # off too much, but lets just wait a little anyway.
255                        await clock.sleep(1)
256                except HttpResponseException as e:
257                    # We convert to SynapseError as we know that it was a SynapseError
258                    # on the main process that we should send to the client. (And
259                    # importantly, not stack traces everywhere)
260                    _outgoing_request_counter.labels(cls.NAME, e.code).inc()
261                    raise e.to_synapse_error()
262                except Exception as e:
263                    _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
264                    raise SynapseError(502, "Failed to talk to main process") from e
265
266                _outgoing_request_counter.labels(cls.NAME, 200).inc()
267                return result
268
269        return send_request
270
271    def register(self, http_server):
272        """Called by the server to register this as a handler to the
273        appropriate path.
274        """
275
276        url_args = list(self.PATH_ARGS)
277        method = self.METHOD
278
279        if self.CACHE:
280            url_args.append("txn_id")
281
282        args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
283        pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
284
285        http_server.register_paths(
286            method,
287            [pattern],
288            self._check_auth_and_handle,
289            self.__class__.__name__,
290        )
291
292    async def _check_auth_and_handle(self, request, **kwargs):
293        """Called on new incoming requests when caching is enabled. Checks
294        if there is a cached response for the request and returns that,
295        otherwise calls `_handle_request` and caches its response.
296        """
297        # We just use the txn_id here, but we probably also want to use the
298        # other PATH_ARGS as well.
299
300        # Check the authorization headers before handling the request.
301        if self._replication_secret:
302            self._check_auth(request)
303
304        if self.CACHE:
305            txn_id = kwargs.pop("txn_id")
306
307            return await self.response_cache.wrap(
308                txn_id, self._handle_request, request, **kwargs
309            )
310
311        return await self._handle_request(request, **kwargs)
312