1# Copyright 2015, 2016 OpenMarket Ltd
2# Copyright 2018 New Vector Ltd
3# Copyright 2019-2021 Matrix.org Federation C.I.C
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16import logging
17import random
18from typing import (
19    TYPE_CHECKING,
20    Any,
21    Awaitable,
22    Callable,
23    Dict,
24    Iterable,
25    List,
26    Optional,
27    Tuple,
28    Union,
29)
30
31from matrix_common.regex import glob_to_regex
32from prometheus_client import Counter, Gauge, Histogram
33
34from twisted.internet.abstract import isIPAddress
35from twisted.python import failure
36
37from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
38from synapse.api.errors import (
39    AuthError,
40    Codes,
41    FederationError,
42    IncompatibleRoomVersionError,
43    NotFoundError,
44    SynapseError,
45    UnsupportedRoomVersionError,
46)
47from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
48from synapse.crypto.event_signing import compute_event_signature
49from synapse.events import EventBase
50from synapse.events.snapshot import EventContext
51from synapse.federation.federation_base import FederationBase, event_from_pdu_json
52from synapse.federation.persistence import TransactionActions
53from synapse.federation.units import Edu, Transaction
54from synapse.http.servlet import assert_params_in_dict
55from synapse.logging.context import (
56    make_deferred_yieldable,
57    nested_logging_context,
58    run_in_background,
59)
60from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
61from synapse.logging.utils import log_function
62from synapse.metrics.background_process_metrics import wrap_as_background_process
63from synapse.replication.http.federation import (
64    ReplicationFederationSendEduRestServlet,
65    ReplicationGetQueryRestServlet,
66)
67from synapse.storage.databases.main.lock import Lock
68from synapse.types import JsonDict, get_domain_from_id
69from synapse.util import json_decoder, unwrapFirstError
70from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
71from synapse.util.caches.response_cache import ResponseCache
72from synapse.util.stringutils import parse_server_name
73
74if TYPE_CHECKING:
75    from synapse.server import HomeServer
76
77# when processing incoming transactions, we try to handle multiple rooms in
78# parallel, up to this limit.
79TRANSACTION_CONCURRENCY_LIMIT = 10
80
81logger = logging.getLogger(__name__)
82
83received_pdus_counter = Counter("synapse_federation_server_received_pdus", "")
84
85received_edus_counter = Counter("synapse_federation_server_received_edus", "")
86
87received_queries_counter = Counter(
88    "synapse_federation_server_received_queries", "", ["type"]
89)
90
91pdu_process_time = Histogram(
92    "synapse_federation_server_pdu_process_time",
93    "Time taken to process an event",
94)
95
96last_pdu_ts_metric = Gauge(
97    "synapse_federation_last_received_pdu_time",
98    "The timestamp of the last PDU which was successfully received from the given domain",
99    labelnames=("server_name",),
100)
101
102
103# The name of the lock to use when process events in a room received over
104# federation.
105_INBOUND_EVENT_HANDLING_LOCK_NAME = "federation_inbound_pdu"
106
107
108class FederationServer(FederationBase):
109    def __init__(self, hs: "HomeServer"):
110        super().__init__(hs)
111
112        self.handler = hs.get_federation_handler()
113        self.storage = hs.get_storage()
114        self._federation_event_handler = hs.get_federation_event_handler()
115        self.state = hs.get_state_handler()
116        self._event_auth_handler = hs.get_event_auth_handler()
117
118        self.device_handler = hs.get_device_handler()
119
120        # Ensure the following handlers are loaded since they register callbacks
121        # with FederationHandlerRegistry.
122        hs.get_directory_handler()
123
124        self._server_linearizer = Linearizer("fed_server")
125
126        # origins that we are currently processing a transaction from.
127        # a dict from origin to txn id.
128        self._active_transactions: Dict[str, str] = {}
129
130        # We cache results for transaction with the same ID
131        self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
132            hs.get_clock(), "fed_txn_handler", timeout_ms=30000
133        )
134
135        self.transaction_actions = TransactionActions(self.store)
136
137        self.registry = hs.get_federation_registry()
138
139        # We cache responses to state queries, as they take a while and often
140        # come in waves.
141        self._state_resp_cache: ResponseCache[
142            Tuple[str, Optional[str]]
143        ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
144        self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
145            hs.get_clock(), "state_ids_resp", timeout_ms=30000
146        )
147
148        self._federation_metrics_domains = (
149            hs.config.federation.federation_metrics_domains
150        )
151
152        self._room_prejoin_state_types = hs.config.api.room_prejoin_state
153
154        # Whether we have started handling old events in the staging area.
155        self._started_handling_of_staged_events = False
156
157    @wrap_as_background_process("_handle_old_staged_events")
158    async def _handle_old_staged_events(self) -> None:
159        """Handle old staged events by fetching all rooms that have staged
160        events and start the processing of each of those rooms.
161        """
162
163        # Get all the rooms IDs with staged events.
164        room_ids = await self.store.get_all_rooms_with_staged_incoming_events()
165
166        # We then shuffle them so that if there are multiple instances doing
167        # this work they're less likely to collide.
168        random.shuffle(room_ids)
169
170        for room_id in room_ids:
171            room_version = await self.store.get_room_version(room_id)
172
173            # Try and acquire the processing lock for the room, if we get it start a
174            # background process for handling the events in the room.
175            lock = await self.store.try_acquire_lock(
176                _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
177            )
178            if lock:
179                logger.info("Handling old staged inbound events in %s", room_id)
180                self._process_incoming_pdus_in_room_inner(
181                    room_id,
182                    room_version,
183                    lock,
184                )
185
186            # We pause a bit so that we don't start handling all rooms at once.
187            await self._clock.sleep(random.uniform(0, 0.1))
188
189    async def on_backfill_request(
190        self, origin: str, room_id: str, versions: List[str], limit: int
191    ) -> Tuple[int, Dict[str, Any]]:
192        with (await self._server_linearizer.queue((origin, room_id))):
193            origin_host, _ = parse_server_name(origin)
194            await self.check_server_matches_acl(origin_host, room_id)
195
196            pdus = await self.handler.on_backfill_request(
197                origin, room_id, versions, limit
198            )
199
200            res = self._transaction_dict_from_pdus(pdus)
201
202        return 200, res
203
204    async def on_timestamp_to_event_request(
205        self, origin: str, room_id: str, timestamp: int, direction: str
206    ) -> Tuple[int, Dict[str, Any]]:
207        """When we receive a federated `/timestamp_to_event` request,
208        handle all of the logic for validating and fetching the event.
209
210        Args:
211            origin: The server we received the event from
212            room_id: Room to fetch the event from
213            timestamp: The point in time (inclusive) we should navigate from in
214                the given direction to find the closest event.
215            direction: ["f"|"b"] to indicate whether we should navigate forward
216                or backward from the given timestamp to find the closest event.
217
218        Returns:
219            Tuple indicating the response status code and dictionary response
220            body including `event_id`.
221        """
222        with (await self._server_linearizer.queue((origin, room_id))):
223            origin_host, _ = parse_server_name(origin)
224            await self.check_server_matches_acl(origin_host, room_id)
225
226            # We only try to fetch data from the local database
227            event_id = await self.store.get_event_id_for_timestamp(
228                room_id, timestamp, direction
229            )
230            if event_id:
231                event = await self.store.get_event(
232                    event_id, allow_none=False, allow_rejected=False
233                )
234
235                return 200, {
236                    "event_id": event_id,
237                    "origin_server_ts": event.origin_server_ts,
238                }
239
240        raise SynapseError(
241            404,
242            "Unable to find event from %s in direction %s" % (timestamp, direction),
243            errcode=Codes.NOT_FOUND,
244        )
245
246    async def on_incoming_transaction(
247        self,
248        origin: str,
249        transaction_id: str,
250        destination: str,
251        transaction_data: JsonDict,
252    ) -> Tuple[int, JsonDict]:
253        # If we receive a transaction we should make sure that kick off handling
254        # any old events in the staging area.
255        if not self._started_handling_of_staged_events:
256            self._started_handling_of_staged_events = True
257            self._handle_old_staged_events()
258
259            # Start a periodic check for old staged events. This is to handle
260            # the case where locks time out, e.g. if another process gets killed
261            # without dropping its locks.
262            self._clock.looping_call(self._handle_old_staged_events, 60 * 1000)
263
264        # keep this as early as possible to make the calculated origin ts as
265        # accurate as possible.
266        request_time = self._clock.time_msec()
267
268        transaction = Transaction(
269            transaction_id=transaction_id,
270            destination=destination,
271            origin=origin,
272            origin_server_ts=transaction_data.get("origin_server_ts"),  # type: ignore
273            pdus=transaction_data.get("pdus"),  # type: ignore
274            edus=transaction_data.get("edus"),
275        )
276
277        if not transaction_id:
278            raise Exception("Transaction missing transaction_id")
279
280        logger.debug("[%s] Got transaction", transaction_id)
281
282        # Reject malformed transactions early: reject if too many PDUs/EDUs
283        if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
284            logger.info("Transaction PDU or EDU count too large. Returning 400")
285            return 400, {}
286
287        # we only process one transaction from each origin at a time. We need to do
288        # this check here, rather than in _on_incoming_transaction_inner so that we
289        # don't cache the rejection in _transaction_resp_cache (so that if the txn
290        # arrives again later, we can process it).
291        current_transaction = self._active_transactions.get(origin)
292        if current_transaction and current_transaction != transaction_id:
293            logger.warning(
294                "Received another txn %s from %s while still processing %s",
295                transaction_id,
296                origin,
297                current_transaction,
298            )
299            return 429, {
300                "errcode": Codes.UNKNOWN,
301                "error": "Too many concurrent transactions",
302            }
303
304        # CRITICAL SECTION: we must now not await until we populate _active_transactions
305        # in _on_incoming_transaction_inner.
306
307        # We wrap in a ResponseCache so that we de-duplicate retried
308        # transactions.
309        return await self._transaction_resp_cache.wrap(
310            (origin, transaction_id),
311            self._on_incoming_transaction_inner,
312            origin,
313            transaction,
314            request_time,
315        )
316
317    async def _on_incoming_transaction_inner(
318        self, origin: str, transaction: Transaction, request_time: int
319    ) -> Tuple[int, Dict[str, Any]]:
320        # CRITICAL SECTION: the first thing we must do (before awaiting) is
321        # add an entry to _active_transactions.
322        assert origin not in self._active_transactions
323        self._active_transactions[origin] = transaction.transaction_id
324
325        try:
326            result = await self._handle_incoming_transaction(
327                origin, transaction, request_time
328            )
329            return result
330        finally:
331            del self._active_transactions[origin]
332
333    async def _handle_incoming_transaction(
334        self, origin: str, transaction: Transaction, request_time: int
335    ) -> Tuple[int, Dict[str, Any]]:
336        """Process an incoming transaction and return the HTTP response
337
338        Args:
339            origin: the server making the request
340            transaction: incoming transaction
341            request_time: timestamp that the HTTP request arrived at
342
343        Returns:
344            HTTP response code and body
345        """
346        existing_response = await self.transaction_actions.have_responded(
347            origin, transaction
348        )
349
350        if existing_response:
351            logger.debug(
352                "[%s] We've already responded to this request",
353                transaction.transaction_id,
354            )
355            return existing_response
356
357        logger.debug("[%s] Transaction is new", transaction.transaction_id)
358
359        # We process PDUs and EDUs in parallel. This is important as we don't
360        # want to block things like to device messages from reaching clients
361        # behind the potentially expensive handling of PDUs.
362        pdu_results, _ = await make_deferred_yieldable(
363            gather_results(
364                (
365                    run_in_background(
366                        self._handle_pdus_in_txn, origin, transaction, request_time
367                    ),
368                    run_in_background(self._handle_edus_in_txn, origin, transaction),
369                ),
370                consumeErrors=True,
371            ).addErrback(unwrapFirstError)
372        )
373
374        response = {"pdus": pdu_results}
375
376        logger.debug("Returning: %s", str(response))
377
378        await self.transaction_actions.set_response(origin, transaction, 200, response)
379        return 200, response
380
381    async def _handle_pdus_in_txn(
382        self, origin: str, transaction: Transaction, request_time: int
383    ) -> Dict[str, dict]:
384        """Process the PDUs in a received transaction.
385
386        Args:
387            origin: the server making the request
388            transaction: incoming transaction
389            request_time: timestamp that the HTTP request arrived at
390
391        Returns:
392            A map from event ID of a processed PDU to any errors we should
393            report back to the sending server.
394        """
395
396        received_pdus_counter.inc(len(transaction.pdus))
397
398        origin_host, _ = parse_server_name(origin)
399
400        pdus_by_room: Dict[str, List[EventBase]] = {}
401
402        newest_pdu_ts = 0
403
404        for p in transaction.pdus:
405            # FIXME (richardv): I don't think this works:
406            #  https://github.com/matrix-org/synapse/issues/8429
407            if "unsigned" in p:
408                unsigned = p["unsigned"]
409                if "age" in unsigned:
410                    p["age"] = unsigned["age"]
411            if "age" in p:
412                p["age_ts"] = request_time - int(p["age"])
413                del p["age"]
414
415            # We try and pull out an event ID so that if later checks fail we
416            # can log something sensible. We don't mandate an event ID here in
417            # case future event formats get rid of the key.
418            possible_event_id = p.get("event_id", "<Unknown>")
419
420            # Now we get the room ID so that we can check that we know the
421            # version of the room.
422            room_id = p.get("room_id")
423            if not room_id:
424                logger.info(
425                    "Ignoring PDU as does not have a room_id. Event ID: %s",
426                    possible_event_id,
427                )
428                continue
429
430            try:
431                room_version = await self.store.get_room_version(room_id)
432            except NotFoundError:
433                logger.info("Ignoring PDU for unknown room_id: %s", room_id)
434                continue
435            except UnsupportedRoomVersionError as e:
436                # this can happen if support for a given room version is withdrawn,
437                # so that we still get events for said room.
438                logger.info("Ignoring PDU: %s", e)
439                continue
440
441            event = event_from_pdu_json(p, room_version)
442            pdus_by_room.setdefault(room_id, []).append(event)
443
444            if event.origin_server_ts > newest_pdu_ts:
445                newest_pdu_ts = event.origin_server_ts
446
447        pdu_results = {}
448
449        # we can process different rooms in parallel (which is useful if they
450        # require callouts to other servers to fetch missing events), but
451        # impose a limit to avoid going too crazy with ram/cpu.
452
453        async def process_pdus_for_room(room_id: str) -> None:
454            with nested_logging_context(room_id):
455                logger.debug("Processing PDUs for %s", room_id)
456
457                try:
458                    await self.check_server_matches_acl(origin_host, room_id)
459                except AuthError as e:
460                    logger.warning(
461                        "Ignoring PDUs for room %s from banned server", room_id
462                    )
463                    for pdu in pdus_by_room[room_id]:
464                        event_id = pdu.event_id
465                        pdu_results[event_id] = e.error_dict()
466                    return
467
468                for pdu in pdus_by_room[room_id]:
469                    pdu_results[pdu.event_id] = await process_pdu(pdu)
470
471        async def process_pdu(pdu: EventBase) -> JsonDict:
472            event_id = pdu.event_id
473            with nested_logging_context(event_id):
474                try:
475                    await self._handle_received_pdu(origin, pdu)
476                    return {}
477                except FederationError as e:
478                    logger.warning("Error handling PDU %s: %s", event_id, e)
479                    return {"error": str(e)}
480                except Exception as e:
481                    f = failure.Failure()
482                    logger.error(
483                        "Failed to handle PDU %s",
484                        event_id,
485                        exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
486                    )
487                    return {"error": str(e)}
488
489        await concurrently_execute(
490            process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
491        )
492
493        if newest_pdu_ts and origin in self._federation_metrics_domains:
494            last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
495
496        return pdu_results
497
498    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
499        """Process the EDUs in a received transaction."""
500
501        async def _process_edu(edu_dict: JsonDict) -> None:
502            received_edus_counter.inc()
503
504            edu = Edu(
505                origin=origin,
506                destination=self.server_name,
507                edu_type=edu_dict["edu_type"],
508                content=edu_dict["content"],
509            )
510            await self.registry.on_edu(edu.edu_type, origin, edu.content)
511
512        await concurrently_execute(
513            _process_edu,
514            transaction.edus,
515            TRANSACTION_CONCURRENCY_LIMIT,
516        )
517
518    async def on_room_state_request(
519        self, origin: str, room_id: str, event_id: Optional[str]
520    ) -> Tuple[int, JsonDict]:
521        origin_host, _ = parse_server_name(origin)
522        await self.check_server_matches_acl(origin_host, room_id)
523
524        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
525        if not in_room:
526            raise AuthError(403, "Host not in room.")
527
528        # we grab the linearizer to protect ourselves from servers which hammer
529        # us. In theory we might already have the response to this query
530        # in the cache so we could return it without waiting for the linearizer
531        # - but that's non-trivial to get right, and anyway somewhat defeats
532        # the point of the linearizer.
533        with (await self._server_linearizer.queue((origin, room_id))):
534            resp: JsonDict = dict(
535                await self._state_resp_cache.wrap(
536                    (room_id, event_id),
537                    self._on_context_state_request_compute,
538                    room_id,
539                    event_id,
540                )
541            )
542
543        room_version = await self.store.get_room_version_id(room_id)
544        resp["room_version"] = room_version
545
546        return 200, resp
547
548    async def on_state_ids_request(
549        self, origin: str, room_id: str, event_id: str
550    ) -> Tuple[int, JsonDict]:
551        if not event_id:
552            raise NotImplementedError("Specify an event")
553
554        origin_host, _ = parse_server_name(origin)
555        await self.check_server_matches_acl(origin_host, room_id)
556
557        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
558        if not in_room:
559            raise AuthError(403, "Host not in room.")
560
561        resp = await self._state_ids_resp_cache.wrap(
562            (room_id, event_id),
563            self._on_state_ids_request_compute,
564            room_id,
565            event_id,
566        )
567
568        return 200, resp
569
570    async def _on_state_ids_request_compute(
571        self, room_id: str, event_id: str
572    ) -> JsonDict:
573        state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
574        auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
575        return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
576
577    async def _on_context_state_request_compute(
578        self, room_id: str, event_id: Optional[str]
579    ) -> Dict[str, list]:
580        if event_id:
581            pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
582                room_id, event_id
583            )
584        else:
585            pdus = (await self.state.get_current_state(room_id)).values()
586
587        auth_chain = await self.store.get_auth_chain(
588            room_id, [pdu.event_id for pdu in pdus]
589        )
590
591        return {
592            "pdus": [pdu.get_pdu_json() for pdu in pdus],
593            "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
594        }
595
596    async def on_pdu_request(
597        self, origin: str, event_id: str
598    ) -> Tuple[int, Union[JsonDict, str]]:
599        pdu = await self.handler.get_persisted_pdu(origin, event_id)
600
601        if pdu:
602            return 200, self._transaction_dict_from_pdus([pdu])
603        else:
604            return 404, ""
605
606    async def on_query_request(
607        self, query_type: str, args: Dict[str, str]
608    ) -> Tuple[int, Dict[str, Any]]:
609        received_queries_counter.labels(query_type).inc()
610        resp = await self.registry.on_query(query_type, args)
611        return 200, resp
612
613    async def on_make_join_request(
614        self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
615    ) -> Dict[str, Any]:
616        origin_host, _ = parse_server_name(origin)
617        await self.check_server_matches_acl(origin_host, room_id)
618
619        room_version = await self.store.get_room_version_id(room_id)
620        if room_version not in supported_versions:
621            logger.warning(
622                "Room version %s not in %s", room_version, supported_versions
623            )
624            raise IncompatibleRoomVersionError(room_version=room_version)
625
626        pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
627        return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
628
629    async def on_invite_request(
630        self, origin: str, content: JsonDict, room_version_id: str
631    ) -> Dict[str, Any]:
632        room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
633        if not room_version:
634            raise SynapseError(
635                400,
636                "Homeserver does not support this room version",
637                Codes.UNSUPPORTED_ROOM_VERSION,
638            )
639
640        pdu = event_from_pdu_json(content, room_version)
641        origin_host, _ = parse_server_name(origin)
642        await self.check_server_matches_acl(origin_host, pdu.room_id)
643        pdu = await self._check_sigs_and_hash(room_version, pdu)
644        ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
645        time_now = self._clock.time_msec()
646        return {"event": ret_pdu.get_pdu_json(time_now)}
647
648    async def on_send_join_request(
649        self, origin: str, content: JsonDict, room_id: str
650    ) -> Dict[str, Any]:
651        event, context = await self._on_send_membership_event(
652            origin, content, Membership.JOIN, room_id
653        )
654
655        prev_state_ids = await context.get_prev_state_ids()
656        state_ids = list(prev_state_ids.values())
657        auth_chain = await self.store.get_auth_chain(room_id, state_ids)
658        state = await self.store.get_events(state_ids)
659
660        time_now = self._clock.time_msec()
661        event_json = event.get_pdu_json()
662        return {
663            # TODO Remove the unstable prefix when servers have updated.
664            "org.matrix.msc3083.v2.event": event_json,
665            "event": event_json,
666            "state": [p.get_pdu_json(time_now) for p in state.values()],
667            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
668        }
669
670    async def on_make_leave_request(
671        self, origin: str, room_id: str, user_id: str
672    ) -> Dict[str, Any]:
673        origin_host, _ = parse_server_name(origin)
674        await self.check_server_matches_acl(origin_host, room_id)
675        pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
676
677        room_version = await self.store.get_room_version_id(room_id)
678
679        return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
680
681    async def on_send_leave_request(
682        self, origin: str, content: JsonDict, room_id: str
683    ) -> dict:
684        logger.debug("on_send_leave_request: content: %s", content)
685        await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id)
686        return {}
687
688    async def on_make_knock_request(
689        self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
690    ) -> JsonDict:
691        """We've received a /make_knock/ request, so we create a partial knock
692        event for the room and hand that back, along with the room version, to the knocking
693        homeserver. We do *not* persist or process this event until the other server has
694        signed it and sent it back.
695
696        Args:
697            origin: The (verified) server name of the requesting server.
698            room_id: The room to create the knock event in.
699            user_id: The user to create the knock for.
700            supported_versions: The room versions supported by the requesting server.
701
702        Returns:
703            The partial knock event.
704        """
705        origin_host, _ = parse_server_name(origin)
706        await self.check_server_matches_acl(origin_host, room_id)
707
708        room_version = await self.store.get_room_version(room_id)
709
710        # Check that this room version is supported by the remote homeserver
711        if room_version.identifier not in supported_versions:
712            logger.warning(
713                "Room version %s not in %s", room_version.identifier, supported_versions
714            )
715            raise IncompatibleRoomVersionError(room_version=room_version.identifier)
716
717        # Check that this room supports knocking as defined by its room version
718        if not room_version.msc2403_knocking:
719            raise SynapseError(
720                403,
721                "This room version does not support knocking",
722                errcode=Codes.FORBIDDEN,
723            )
724
725        pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
726        return {
727            "event": pdu.get_templated_pdu_json(),
728            "room_version": room_version.identifier,
729        }
730
731    async def on_send_knock_request(
732        self,
733        origin: str,
734        content: JsonDict,
735        room_id: str,
736    ) -> Dict[str, List[JsonDict]]:
737        """
738        We have received a knock event for a room. Verify and send the event into the room
739        on the knocking homeserver's behalf. Then reply with some stripped state from the
740        room for the knockee.
741
742        Args:
743            origin: The remote homeserver of the knocking user.
744            content: The content of the request.
745            room_id: The ID of the room to knock on.
746
747        Returns:
748            The stripped room state.
749        """
750        _, context = await self._on_send_membership_event(
751            origin, content, Membership.KNOCK, room_id
752        )
753
754        # Retrieve stripped state events from the room and send them back to the remote
755        # server. This will allow the remote server's clients to display information
756        # related to the room while the knock request is pending.
757        stripped_room_state = (
758            await self.store.get_stripped_room_state_from_event_context(
759                context, self._room_prejoin_state_types
760            )
761        )
762        return {"knock_state_events": stripped_room_state}
763
764    async def _on_send_membership_event(
765        self, origin: str, content: JsonDict, membership_type: str, room_id: str
766    ) -> Tuple[EventBase, EventContext]:
767        """Handle an on_send_{join,leave,knock} request
768
769        Does some preliminary validation before passing the request on to the
770        federation handler.
771
772        Args:
773            origin: The (authenticated) requesting server
774            content: The body of the send_* request - a complete membership event
775            membership_type: The expected membership type (join or leave, depending
776                on the endpoint)
777            room_id: The room_id from the request, to be validated against the room_id
778                in the event
779
780        Returns:
781            The event and context of the event after inserting it into the room graph.
782
783        Raises:
784            SynapseError if there is a problem with the request, including things like
785               the room_id not matching or the event not being authorized.
786        """
787        assert_params_in_dict(content, ["room_id"])
788        if content["room_id"] != room_id:
789            raise SynapseError(
790                400,
791                "Room ID in body does not match that in request path",
792                Codes.BAD_JSON,
793            )
794
795        room_version = await self.store.get_room_version(room_id)
796
797        if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
798            raise SynapseError(
799                403,
800                "This room version does not support knocking",
801                errcode=Codes.FORBIDDEN,
802            )
803
804        event = event_from_pdu_json(content, room_version)
805
806        if event.type != EventTypes.Member or not event.is_state():
807            raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON)
808
809        if event.content.get("membership") != membership_type:
810            raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON)
811
812        origin_host, _ = parse_server_name(origin)
813        await self.check_server_matches_acl(origin_host, event.room_id)
814
815        logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
816
817        # Sign the event since we're vouching on behalf of the remote server that
818        # the event is valid to be sent into the room. Currently this is only done
819        # if the user is being joined via restricted join rules.
820        if (
821            room_version.msc3083_join_rules
822            and event.membership == Membership.JOIN
823            and EventContentFields.AUTHORISING_USER in event.content
824        ):
825            # We can only authorise our own users.
826            authorising_server = get_domain_from_id(
827                event.content[EventContentFields.AUTHORISING_USER]
828            )
829            if authorising_server != self.server_name:
830                raise SynapseError(
831                    400,
832                    f"Cannot authorise request from resident server: {authorising_server}",
833                )
834
835            event.signatures.update(
836                compute_event_signature(
837                    room_version,
838                    event.get_pdu_json(),
839                    self.hs.hostname,
840                    self.hs.signing_key,
841                )
842            )
843
844        event = await self._check_sigs_and_hash(room_version, event)
845
846        return await self._federation_event_handler.on_send_membership_event(
847            origin, event
848        )
849
850    async def on_event_auth(
851        self, origin: str, room_id: str, event_id: str
852    ) -> Tuple[int, Dict[str, Any]]:
853        with (await self._server_linearizer.queue((origin, room_id))):
854            origin_host, _ = parse_server_name(origin)
855            await self.check_server_matches_acl(origin_host, room_id)
856
857            time_now = self._clock.time_msec()
858            auth_pdus = await self.handler.on_event_auth(event_id)
859            res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
860        return 200, res
861
862    @log_function
863    async def on_query_client_keys(
864        self, origin: str, content: Dict[str, str]
865    ) -> Tuple[int, Dict[str, Any]]:
866        return await self.on_query_request("client_keys", content)
867
868    async def on_query_user_devices(
869        self, origin: str, user_id: str
870    ) -> Tuple[int, Dict[str, Any]]:
871        keys = await self.device_handler.on_federation_query_user_devices(user_id)
872        return 200, keys
873
874    @trace
875    async def on_claim_client_keys(
876        self, origin: str, content: JsonDict
877    ) -> Dict[str, Any]:
878        query = []
879        for user_id, device_keys in content.get("one_time_keys", {}).items():
880            for device_id, algorithm in device_keys.items():
881                query.append((user_id, device_id, algorithm))
882
883        log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
884        results = await self.store.claim_e2e_one_time_keys(query)
885
886        json_result: Dict[str, Dict[str, dict]] = {}
887        for user_id, device_keys in results.items():
888            for device_id, keys in device_keys.items():
889                for key_id, json_str in keys.items():
890                    json_result.setdefault(user_id, {})[device_id] = {
891                        key_id: json_decoder.decode(json_str)
892                    }
893
894        logger.info(
895            "Claimed one-time-keys: %s",
896            ",".join(
897                (
898                    "%s for %s:%s" % (key_id, user_id, device_id)
899                    for user_id, user_keys in json_result.items()
900                    for device_id, device_keys in user_keys.items()
901                    for key_id, _ in device_keys.items()
902                )
903            ),
904        )
905
906        return {"one_time_keys": json_result}
907
908    async def on_get_missing_events(
909        self,
910        origin: str,
911        room_id: str,
912        earliest_events: List[str],
913        latest_events: List[str],
914        limit: int,
915    ) -> Dict[str, list]:
916        with (await self._server_linearizer.queue((origin, room_id))):
917            origin_host, _ = parse_server_name(origin)
918            await self.check_server_matches_acl(origin_host, room_id)
919
920            logger.debug(
921                "on_get_missing_events: earliest_events: %r, latest_events: %r,"
922                " limit: %d",
923                earliest_events,
924                latest_events,
925                limit,
926            )
927
928            missing_events = await self.handler.on_get_missing_events(
929                origin, room_id, earliest_events, latest_events, limit
930            )
931
932            if len(missing_events) < 5:
933                logger.debug(
934                    "Returning %d events: %r", len(missing_events), missing_events
935                )
936            else:
937                logger.debug("Returning %d events", len(missing_events))
938
939            time_now = self._clock.time_msec()
940
941        return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
942
943    @log_function
944    async def on_openid_userinfo(self, token: str) -> Optional[str]:
945        ts_now_ms = self._clock.time_msec()
946        return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
947
948    def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
949        """Returns a new Transaction containing the given PDUs suitable for
950        transmission.
951        """
952        time_now = self._clock.time_msec()
953        pdus = [p.get_pdu_json(time_now) for p in pdu_list]
954        return Transaction(
955            # Just need a dummy transaction ID and destination since it won't be used.
956            transaction_id="",
957            origin=self.server_name,
958            pdus=pdus,
959            origin_server_ts=int(time_now),
960            destination="",
961        ).get_dict()
962
963    async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
964        """Process a PDU received in a federation /send/ transaction.
965
966        If the event is invalid, then this method throws a FederationError.
967        (The error will then be logged and sent back to the sender (which
968        probably won't do anything with it), and other events in the
969        transaction will be processed as normal).
970
971        It is likely that we'll then receive other events which refer to
972        this rejected_event in their prev_events, etc.  When that happens,
973        we'll attempt to fetch the rejected event again, which will presumably
974        fail, so those second-generation events will also get rejected.
975
976        Eventually, we get to the point where there are more than 10 events
977        between any new events and the original rejected event. Since we
978        only try to backfill 10 events deep on received pdu, we then accept the
979        new event, possibly introducing a discontinuity in the DAG, with new
980        forward extremities, so normal service is approximately returned,
981        until we try to backfill across the discontinuity.
982
983        Args:
984            origin: server which sent the pdu
985            pdu: received pdu
986
987        Raises: FederationError if the signatures / hash do not match, or
988            if the event was unacceptable for any other reason (eg, too large,
989            too many prev_events, couldn't find the prev_events)
990        """
991
992        # We've already checked that we know the room version by this point
993        room_version = await self.store.get_room_version(pdu.room_id)
994
995        # Check signature.
996        try:
997            pdu = await self._check_sigs_and_hash(room_version, pdu)
998        except SynapseError as e:
999            raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
1000
1001        # Add the event to our staging area
1002        await self.store.insert_received_event_to_staging(origin, pdu)
1003
1004        # Try and acquire the processing lock for the room, if we get it start a
1005        # background process for handling the events in the room.
1006        lock = await self.store.try_acquire_lock(
1007            _INBOUND_EVENT_HANDLING_LOCK_NAME, pdu.room_id
1008        )
1009        if lock:
1010            self._process_incoming_pdus_in_room_inner(
1011                pdu.room_id, room_version, lock, origin, pdu
1012            )
1013
1014    @wrap_as_background_process("_process_incoming_pdus_in_room_inner")
1015    async def _process_incoming_pdus_in_room_inner(
1016        self,
1017        room_id: str,
1018        room_version: RoomVersion,
1019        lock: Lock,
1020        latest_origin: Optional[str] = None,
1021        latest_event: Optional[EventBase] = None,
1022    ) -> None:
1023        """Process events in the staging area for the given room.
1024
1025        The latest_origin and latest_event args are the latest origin and event
1026        received (or None to simply pull the next event from the database).
1027        """
1028
1029        # The common path is for the event we just received be the only event in
1030        # the room, so instead of pulling the event out of the DB and parsing
1031        # the event we just pull out the next event ID and check if that matches.
1032        if latest_event is not None and latest_origin is not None:
1033            result = await self.store.get_next_staged_event_id_for_room(room_id)
1034            if result is None:
1035                latest_origin = None
1036                latest_event = None
1037            else:
1038                next_origin, next_event_id = result
1039                if (
1040                    next_origin != latest_origin
1041                    or next_event_id != latest_event.event_id
1042                ):
1043                    latest_origin = None
1044                    latest_event = None
1045
1046        if latest_origin is None or latest_event is None:
1047            next = await self.store.get_next_staged_event_for_room(
1048                room_id, room_version
1049            )
1050            if not next:
1051                await lock.release()
1052                return
1053
1054            origin, event = next
1055        else:
1056            origin = latest_origin
1057            event = latest_event
1058
1059        # We loop round until there are no more events in the room in the
1060        # staging area, or we fail to get the lock (which means another process
1061        # has started processing).
1062        while True:
1063            async with lock:
1064                logger.info("handling received PDU: %s", event)
1065                try:
1066                    with nested_logging_context(event.event_id):
1067                        await self._federation_event_handler.on_receive_pdu(
1068                            origin, event
1069                        )
1070                except FederationError as e:
1071                    # XXX: Ideally we'd inform the remote we failed to process
1072                    # the event, but we can't return an error in the transaction
1073                    # response (as we've already responded).
1074                    logger.warning("Error handling PDU %s: %s", event.event_id, e)
1075                except Exception:
1076                    f = failure.Failure()
1077                    logger.error(
1078                        "Failed to handle PDU %s",
1079                        event.event_id,
1080                        exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
1081                    )
1082
1083                received_ts = await self.store.remove_received_event_from_staging(
1084                    origin, event.event_id
1085                )
1086                if received_ts is not None:
1087                    pdu_process_time.observe(
1088                        (self._clock.time_msec() - received_ts) / 1000
1089                    )
1090
1091            # We need to do this check outside the lock to avoid a race between
1092            # a new event being inserted by another instance and it attempting
1093            # to acquire the lock.
1094            next = await self.store.get_next_staged_event_for_room(
1095                room_id, room_version
1096            )
1097            if not next:
1098                break
1099
1100            origin, event = next
1101
1102            # Prune the event queue if it's getting large.
1103            #
1104            # We do this *after* handling the first event as the common case is
1105            # that the queue is empty (/has the single event in), and so there's
1106            # no need to do this check.
1107            pruned = await self.store.prune_staged_events_in_room(room_id, room_version)
1108            if pruned:
1109                # If we have pruned the queue check we need to refetch the next
1110                # event to handle.
1111                next = await self.store.get_next_staged_event_for_room(
1112                    room_id, room_version
1113                )
1114                if not next:
1115                    break
1116
1117                origin, event = next
1118
1119            new_lock = await self.store.try_acquire_lock(
1120                _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
1121            )
1122            if not new_lock:
1123                return
1124            lock = new_lock
1125
1126    def __str__(self) -> str:
1127        return "<ReplicationLayer(%s)>" % self.server_name
1128
1129    async def exchange_third_party_invite(
1130        self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
1131    ) -> None:
1132        await self.handler.exchange_third_party_invite(
1133            sender_user_id, target_user_id, room_id, signed
1134        )
1135
1136    async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
1137        await self.handler.on_exchange_third_party_invite_request(event_dict)
1138
1139    async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
1140        """Check if the given server is allowed by the server ACLs in the room
1141
1142        Args:
1143            server_name: name of server, *without any port part*
1144            room_id: ID of the room to check
1145
1146        Raises:
1147            AuthError if the server does not match the ACL
1148        """
1149        state_ids = await self.store.get_current_state_ids(room_id)
1150        acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
1151
1152        if not acl_event_id:
1153            return
1154
1155        acl_event = await self.store.get_event(acl_event_id)
1156        if server_matches_acl_event(server_name, acl_event):
1157            return
1158
1159        raise AuthError(code=403, msg="Server is banned from room")
1160
1161
1162def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
1163    """Check if the given server is allowed by the ACL event
1164
1165    Args:
1166        server_name: name of server, without any port part
1167        acl_event: m.room.server_acl event
1168
1169    Returns:
1170        True if this server is allowed by the ACLs
1171    """
1172    logger.debug("Checking %s against acl %s", server_name, acl_event.content)
1173
1174    # first of all, check if literal IPs are blocked, and if so, whether the
1175    # server name is a literal IP
1176    allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
1177    if not isinstance(allow_ip_literals, bool):
1178        logger.warning("Ignoring non-bool allow_ip_literals flag")
1179        allow_ip_literals = True
1180    if not allow_ip_literals:
1181        # check for ipv6 literals. These start with '['.
1182        if server_name[0] == "[":
1183            return False
1184
1185        # check for ipv4 literals. We can just lift the routine from twisted.
1186        if isIPAddress(server_name):
1187            return False
1188
1189    # next,  check the deny list
1190    deny = acl_event.content.get("deny", [])
1191    if not isinstance(deny, (list, tuple)):
1192        logger.warning("Ignoring non-list deny ACL %s", deny)
1193        deny = []
1194    for e in deny:
1195        if _acl_entry_matches(server_name, e):
1196            # logger.info("%s matched deny rule %s", server_name, e)
1197            return False
1198
1199    # then the allow list.
1200    allow = acl_event.content.get("allow", [])
1201    if not isinstance(allow, (list, tuple)):
1202        logger.warning("Ignoring non-list allow ACL %s", allow)
1203        allow = []
1204    for e in allow:
1205        if _acl_entry_matches(server_name, e):
1206            # logger.info("%s matched allow rule %s", server_name, e)
1207            return True
1208
1209    # everything else should be rejected.
1210    # logger.info("%s fell through", server_name)
1211    return False
1212
1213
1214def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
1215    if not isinstance(acl_entry, str):
1216        logger.warning(
1217            "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
1218        )
1219        return False
1220    regex = glob_to_regex(acl_entry)
1221    return bool(regex.match(server_name))
1222
1223
1224class FederationHandlerRegistry:
1225    """Allows classes to register themselves as handlers for a given EDU or
1226    query type for incoming federation traffic.
1227    """
1228
1229    def __init__(self, hs: "HomeServer"):
1230        self.config = hs.config
1231        self.clock = hs.get_clock()
1232        self._instance_name = hs.get_instance_name()
1233
1234        # These are safe to load in monolith mode, but will explode if we try
1235        # and use them. However we have guards before we use them to ensure that
1236        # we don't route to ourselves, and in monolith mode that will always be
1237        # the case.
1238        self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
1239        self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
1240
1241        self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
1242        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
1243
1244        # Map from type to instance names that we should route EDU handling to.
1245        # We randomly choose one instance from the list to route to for each new
1246        # EDU received.
1247        self._edu_type_to_instance: Dict[str, List[str]] = {}
1248
1249    def register_edu_handler(
1250        self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
1251    ) -> None:
1252        """Sets the handler callable that will be used to handle an incoming
1253        federation EDU of the given type.
1254
1255        Args:
1256            edu_type: The type of the incoming EDU to register handler for
1257            handler: A callable invoked on incoming EDU
1258                of the given type. The arguments are the origin server name and
1259                the EDU contents.
1260        """
1261        if edu_type in self.edu_handlers:
1262            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
1263
1264        logger.info("Registering federation EDU handler for %r", edu_type)
1265
1266        self.edu_handlers[edu_type] = handler
1267
1268    def register_query_handler(
1269        self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
1270    ) -> None:
1271        """Sets the handler callable that will be used to handle an incoming
1272        federation query of the given type.
1273
1274        Args:
1275            query_type: Category name of the query, which should match
1276                the string used by make_query.
1277            handler: Invoked to handle
1278                incoming queries of this type. The return will be yielded
1279                on and the result used as the response to the query request.
1280        """
1281        if query_type in self.query_handlers:
1282            raise KeyError("Already have a Query handler for %s" % (query_type,))
1283
1284        logger.info("Registering federation query handler for %r", query_type)
1285
1286        self.query_handlers[query_type] = handler
1287
1288    def register_instances_for_edu(
1289        self, edu_type: str, instance_names: List[str]
1290    ) -> None:
1291        """Register that the EDU handler is on multiple instances."""
1292        self._edu_type_to_instance[edu_type] = instance_names
1293
1294    async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
1295        if not self.config.server.use_presence and edu_type == EduTypes.Presence:
1296            return
1297
1298        # Check if we have a handler on this instance
1299        handler = self.edu_handlers.get(edu_type)
1300        if handler:
1301            with start_active_span_from_edu(content, "handle_edu"):
1302                try:
1303                    await handler(origin, content)
1304                except SynapseError as e:
1305                    logger.info("Failed to handle edu %r: %r", edu_type, e)
1306                except Exception:
1307                    logger.exception("Failed to handle edu %r", edu_type)
1308            return
1309
1310        # Check if we can route it somewhere else that isn't us
1311        instances = self._edu_type_to_instance.get(edu_type, ["master"])
1312        if self._instance_name not in instances:
1313            # Pick an instance randomly so that we don't overload one.
1314            route_to = random.choice(instances)
1315
1316            try:
1317                await self._send_edu(
1318                    instance_name=route_to,
1319                    edu_type=edu_type,
1320                    origin=origin,
1321                    content=content,
1322                )
1323            except SynapseError as e:
1324                logger.info("Failed to handle edu %r: %r", edu_type, e)
1325            except Exception:
1326                logger.exception("Failed to handle edu %r", edu_type)
1327            return
1328
1329        # Oh well, let's just log and move on.
1330        logger.warning("No handler registered for EDU type %s", edu_type)
1331
1332    async def on_query(self, query_type: str, args: dict) -> JsonDict:
1333        handler = self.query_handlers.get(query_type)
1334        if handler:
1335            return await handler(args)
1336
1337        # Check if we can route it somewhere else that isn't us
1338        if self._instance_name == "master":
1339            return await self._get_query_client(query_type=query_type, args=args)
1340
1341        # Uh oh, no handler! Let's raise an exception so the request returns an
1342        # error.
1343        logger.warning("No handler registered for query type %s", query_type)
1344        raise NotFoundError("No handler for Query type '%s'" % (query_type,))
1345