1# Copyright 2018 New Vector Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import abc 16import logging 17import re 18import urllib 19from inspect import signature 20from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple 21 22from prometheus_client import Counter, Gauge 23 24from synapse.api.errors import HttpResponseException, SynapseError 25from synapse.http import RequestTimedOutError 26from synapse.logging import opentracing 27from synapse.logging.opentracing import trace 28from synapse.util.caches.response_cache import ResponseCache 29from synapse.util.stringutils import random_string 30 31if TYPE_CHECKING: 32 from synapse.server import HomeServer 33 34logger = logging.getLogger(__name__) 35 36_pending_outgoing_requests = Gauge( 37 "synapse_pending_outgoing_replication_requests", 38 "Number of active outgoing replication requests, by replication method name", 39 ["name"], 40) 41 42_outgoing_request_counter = Counter( 43 "synapse_outgoing_replication_requests", 44 "Number of outgoing replication requests, by replication method name and result", 45 ["name", "code"], 46) 47 48 49class ReplicationEndpoint(metaclass=abc.ABCMeta): 50 """Helper base class for defining new replication HTTP endpoints. 51 52 This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` 53 (with a `/:txn_id` suffix for cached requests), where NAME is a name, 54 PATH_ARGS are a tuple of parameters to be encoded in the URL. 55 56 For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`, 57 with `CACHE` set to true then this generates an endpoint: 58 59 /_synapse/replication/send_event/:event_id/:txn_id 60 61 For POST/PUT requests the payload is serialized to json and sent as the 62 body, while for GET requests the payload is added as query parameters. See 63 `_serialize_payload` for details. 64 65 Incoming requests are handled by overriding `_handle_request`. Servers 66 must call `register` to register the path with the HTTP server. 67 68 Requests can be sent by calling the client returned by `make_client`. 69 Requests are sent to master process by default, but can be sent to other 70 named processes by specifying an `instance_name` keyword argument. 71 72 Attributes: 73 NAME (str): A name for the endpoint, added to the path as well as used 74 in logging and metrics. 75 PATH_ARGS (tuple[str]): A list of parameters to be added to the path. 76 Adding parameters to the path (rather than payload) can make it 77 easier to follow along in the log files. 78 METHOD (str): The method of the HTTP request, defaults to POST. Can be 79 one of POST, PUT or GET. If GET then the payload is sent as query 80 parameters rather than a JSON body. 81 CACHE (bool): Whether server should cache the result of the request/ 82 If true then transparently adds a txn_id to all requests, and 83 `_handle_request` must return a Deferred. 84 RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504 85 is received. 86 """ 87 88 NAME: str = abc.abstractproperty() # type: ignore 89 PATH_ARGS: Tuple[str, ...] = abc.abstractproperty() # type: ignore 90 METHOD = "POST" 91 CACHE = True 92 RETRY_ON_TIMEOUT = True 93 94 def __init__(self, hs: "HomeServer"): 95 if self.CACHE: 96 self.response_cache: ResponseCache[str] = ResponseCache( 97 hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 98 ) 99 100 # We reserve `instance_name` as a parameter to sending requests, so we 101 # assert here that sub classes don't try and use the name. 102 assert ( 103 "instance_name" not in self.PATH_ARGS 104 ), "`instance_name` is a reserved parameter name" 105 assert ( 106 "instance_name" 107 not in signature(self.__class__._serialize_payload).parameters 108 ), "`instance_name` is a reserved parameter name" 109 110 assert self.METHOD in ("PUT", "POST", "GET") 111 112 self._replication_secret = None 113 if hs.config.worker.worker_replication_secret: 114 self._replication_secret = hs.config.worker.worker_replication_secret 115 116 def _check_auth(self, request) -> None: 117 # Get the authorization header. 118 auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") 119 120 if len(auth_headers) > 1: 121 raise RuntimeError("Too many Authorization headers.") 122 parts = auth_headers[0].split(b" ") 123 if parts[0] == b"Bearer" and len(parts) == 2: 124 received_secret = parts[1].decode("ascii") 125 if self._replication_secret == received_secret: 126 # Success! 127 return 128 129 raise RuntimeError("Invalid Authorization header.") 130 131 @abc.abstractmethod 132 async def _serialize_payload(**kwargs): 133 """Static method that is called when creating a request. 134 135 Concrete implementations should have explicit parameters (rather than 136 kwargs) so that an appropriate exception is raised if the client is 137 called with unexpected parameters. All PATH_ARGS must appear in 138 argument list. 139 140 Returns: 141 dict: If POST/PUT request then dictionary must be JSON serialisable, 142 otherwise must be appropriate for adding as query args. 143 """ 144 return {} 145 146 @abc.abstractmethod 147 async def _handle_request(self, request, **kwargs): 148 """Handle incoming request. 149 150 This is called with the request object and PATH_ARGS. 151 152 Returns: 153 tuple[int, dict]: HTTP status code and a JSON serialisable dict 154 to be used as response body of request. 155 """ 156 pass 157 158 @classmethod 159 def make_client(cls, hs: "HomeServer"): 160 """Create a client that makes requests. 161 162 Returns a callable that accepts the same parameters as 163 `_serialize_payload`, and also accepts an optional `instance_name` 164 parameter to specify which instance to hit (the instance must be in 165 the `instance_map` config). 166 """ 167 clock = hs.get_clock() 168 client = hs.get_simple_http_client() 169 local_instance_name = hs.get_instance_name() 170 171 master_host = hs.config.worker.worker_replication_host 172 master_port = hs.config.worker.worker_replication_http_port 173 174 instance_map = hs.config.worker.instance_map 175 176 outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) 177 178 replication_secret = None 179 if hs.config.worker.worker_replication_secret: 180 replication_secret = hs.config.worker.worker_replication_secret.encode( 181 "ascii" 182 ) 183 184 @trace(opname="outgoing_replication_request") 185 async def send_request(*, instance_name="master", **kwargs): 186 with outgoing_gauge.track_inprogress(): 187 if instance_name == local_instance_name: 188 raise Exception("Trying to send HTTP request to self") 189 if instance_name == "master": 190 host = master_host 191 port = master_port 192 elif instance_name in instance_map: 193 host = instance_map[instance_name].host 194 port = instance_map[instance_name].port 195 else: 196 raise Exception( 197 "Instance %r not in 'instance_map' config" % (instance_name,) 198 ) 199 200 data = await cls._serialize_payload(**kwargs) 201 202 url_args = [ 203 urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS 204 ] 205 206 if cls.CACHE: 207 txn_id = random_string(10) 208 url_args.append(txn_id) 209 210 if cls.METHOD == "POST": 211 request_func: Callable[ 212 ..., Awaitable[Any] 213 ] = client.post_json_get_json 214 elif cls.METHOD == "PUT": 215 request_func = client.put_json 216 elif cls.METHOD == "GET": 217 request_func = client.get_json 218 else: 219 # We have already asserted in the constructor that a 220 # compatible was picked, but lets be paranoid. 221 raise Exception( 222 "Unknown METHOD on %s replication endpoint" % (cls.NAME,) 223 ) 224 225 uri = "http://%s:%s/_synapse/replication/%s/%s" % ( 226 host, 227 port, 228 cls.NAME, 229 "/".join(url_args), 230 ) 231 232 try: 233 # We keep retrying the same request for timeouts. This is so that we 234 # have a good idea that the request has either succeeded or failed 235 # on the master, and so whether we should clean up or not. 236 while True: 237 headers: Dict[bytes, List[bytes]] = {} 238 # Add an authorization header, if configured. 239 if replication_secret: 240 headers[b"Authorization"] = [ 241 b"Bearer " + replication_secret 242 ] 243 opentracing.inject_header_dict(headers, check_destination=False) 244 try: 245 result = await request_func(uri, data, headers=headers) 246 break 247 except RequestTimedOutError: 248 if not cls.RETRY_ON_TIMEOUT: 249 raise 250 251 logger.warning("%s request timed out; retrying", cls.NAME) 252 253 # If we timed out we probably don't need to worry about backing 254 # off too much, but lets just wait a little anyway. 255 await clock.sleep(1) 256 except HttpResponseException as e: 257 # We convert to SynapseError as we know that it was a SynapseError 258 # on the main process that we should send to the client. (And 259 # importantly, not stack traces everywhere) 260 _outgoing_request_counter.labels(cls.NAME, e.code).inc() 261 raise e.to_synapse_error() 262 except Exception as e: 263 _outgoing_request_counter.labels(cls.NAME, "ERR").inc() 264 raise SynapseError(502, "Failed to talk to main process") from e 265 266 _outgoing_request_counter.labels(cls.NAME, 200).inc() 267 return result 268 269 return send_request 270 271 def register(self, http_server): 272 """Called by the server to register this as a handler to the 273 appropriate path. 274 """ 275 276 url_args = list(self.PATH_ARGS) 277 method = self.METHOD 278 279 if self.CACHE: 280 url_args.append("txn_id") 281 282 args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) 283 pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) 284 285 http_server.register_paths( 286 method, 287 [pattern], 288 self._check_auth_and_handle, 289 self.__class__.__name__, 290 ) 291 292 async def _check_auth_and_handle(self, request, **kwargs): 293 """Called on new incoming requests when caching is enabled. Checks 294 if there is a cached response for the request and returns that, 295 otherwise calls `_handle_request` and caches its response. 296 """ 297 # We just use the txn_id here, but we probably also want to use the 298 # other PATH_ARGS as well. 299 300 # Check the authorization headers before handling the request. 301 if self._replication_secret: 302 self._check_auth(request) 303 304 if self.CACHE: 305 txn_id = kwargs.pop("txn_id") 306 307 return await self.response_cache.wrap( 308 txn_id, self._handle_request, request, **kwargs 309 ) 310 311 return await self._handle_request(request, **kwargs) 312