1# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import logging
17import urllib
18from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
19
20import attr
21from signedjson.key import (
22    decode_verify_key_bytes,
23    encode_verify_key_base64,
24    get_verify_key,
25    is_signing_algorithm_supported,
26)
27from signedjson.sign import (
28    SignatureVerifyException,
29    encode_canonical_json,
30    signature_ids,
31    verify_signed_json,
32)
33from signedjson.types import VerifyKey
34from unpaddedbase64 import decode_base64
35
36from twisted.internet import defer
37
38from synapse.api.errors import (
39    Codes,
40    HttpResponseException,
41    RequestSendFailed,
42    SynapseError,
43)
44from synapse.config.key import TrustedKeyServer
45from synapse.events import EventBase
46from synapse.events.utils import prune_event_dict
47from synapse.logging.context import make_deferred_yieldable, run_in_background
48from synapse.storage.keys import FetchKeyResult
49from synapse.types import JsonDict
50from synapse.util import unwrapFirstError
51from synapse.util.async_helpers import yieldable_gather_results
52from synapse.util.batching_queue import BatchingQueue
53from synapse.util.retryutils import NotRetryingDestination
54
55if TYPE_CHECKING:
56    from synapse.server import HomeServer
57
58logger = logging.getLogger(__name__)
59
60
61@attr.s(slots=True, cmp=False)
62class VerifyJsonRequest:
63    """
64    A request to verify a JSON object.
65
66    Attributes:
67        server_name: The name of the server to verify against.
68
69        get_json_object: A callback to fetch the JSON object to verify.
70            A callback is used to allow deferring the creation of the JSON
71            object to verify until needed, e.g. for events we can defer
72            creating the redacted copy. This reduces the memory usage when
73            there are large numbers of in flight requests.
74
75        minimum_valid_until_ts: time at which we require the signing key to
76            be valid. (0 implies we don't care)
77
78        key_ids: The set of key_ids to that could be used to verify the JSON object
79    """
80
81    server_name = attr.ib(type=str)
82    get_json_object = attr.ib(type=Callable[[], JsonDict])
83    minimum_valid_until_ts = attr.ib(type=int)
84    key_ids = attr.ib(type=List[str])
85
86    @staticmethod
87    def from_json_object(
88        server_name: str,
89        json_object: JsonDict,
90        minimum_valid_until_ms: int,
91    ) -> "VerifyJsonRequest":
92        """Create a VerifyJsonRequest to verify all signatures on a signed JSON
93        object for the given server.
94        """
95        key_ids = signature_ids(json_object, server_name)
96        return VerifyJsonRequest(
97            server_name,
98            lambda: json_object,
99            minimum_valid_until_ms,
100            key_ids=key_ids,
101        )
102
103    @staticmethod
104    def from_event(
105        server_name: str,
106        event: EventBase,
107        minimum_valid_until_ms: int,
108    ) -> "VerifyJsonRequest":
109        """Create a VerifyJsonRequest to verify all signatures on an event
110        object for the given server.
111        """
112        key_ids = list(event.signatures.get(server_name, []))
113        return VerifyJsonRequest(
114            server_name,
115            # We defer creating the redacted json object, as it uses a lot more
116            # memory than the Event object itself.
117            lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
118            minimum_valid_until_ms,
119            key_ids=key_ids,
120        )
121
122
123class KeyLookupError(ValueError):
124    pass
125
126
127@attr.s(slots=True)
128class _FetchKeyRequest:
129    """A request for keys for a given server.
130
131    We will continue to try and fetch until we have all the keys listed under
132    `key_ids` (with an appropriate `valid_until_ts` property) or we run out of
133    places to fetch keys from.
134
135    Attributes:
136        server_name: The name of the server that owns the keys.
137        minimum_valid_until_ts: The timestamp which the keys must be valid until.
138        key_ids: The IDs of the keys to attempt to fetch
139    """
140
141    server_name = attr.ib(type=str)
142    minimum_valid_until_ts = attr.ib(type=int)
143    key_ids = attr.ib(type=List[str])
144
145
146class Keyring:
147    """Handles verifying signed JSON objects and fetching the keys needed to do
148    so.
149    """
150
151    def __init__(
152        self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
153    ):
154        self.clock = hs.get_clock()
155
156        if key_fetchers is None:
157            key_fetchers = (
158                StoreKeyFetcher(hs),
159                PerspectivesKeyFetcher(hs),
160                ServerKeyFetcher(hs),
161            )
162        self._key_fetchers = key_fetchers
163
164        self._server_queue: BatchingQueue[
165            _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
166        ] = BatchingQueue(
167            "keyring_server",
168            clock=hs.get_clock(),
169            process_batch_callback=self._inner_fetch_key_requests,
170        )
171
172        self._hostname = hs.hostname
173
174        # build a FetchKeyResult for each of our own keys, to shortcircuit the
175        # fetcher.
176        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
177        for key_id, key in hs.config.key.old_signing_keys.items():
178            self._local_verify_keys[key_id] = FetchKeyResult(
179                verify_key=key, valid_until_ts=key.expired_ts
180            )
181
182        vk = get_verify_key(hs.signing_key)
183        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
184            verify_key=vk,
185            valid_until_ts=2 ** 63,  # fake future timestamp
186        )
187
188    async def verify_json_for_server(
189        self,
190        server_name: str,
191        json_object: JsonDict,
192        validity_time: int,
193    ) -> None:
194        """Verify that a JSON object has been signed by a given server
195
196        Completes if the the object was correctly signed, otherwise raises.
197
198        Args:
199            server_name: name of the server which must have signed this object
200
201            json_object: object to be checked
202
203            validity_time: timestamp at which we require the signing key to
204                be valid. (0 implies we don't care)
205        """
206
207        request = VerifyJsonRequest.from_json_object(
208            server_name,
209            json_object,
210            validity_time,
211        )
212        return await self.process_request(request)
213
214    def verify_json_objects_for_server(
215        self, server_and_json: Iterable[Tuple[str, dict, int]]
216    ) -> List[defer.Deferred]:
217        """Bulk verifies signatures of json objects, bulk fetching keys as
218        necessary.
219
220        Args:
221            server_and_json:
222                Iterable of (server_name, json_object, validity_time)
223                tuples.
224
225                validity_time is a timestamp at which the signing key must be
226                valid.
227
228        Returns:
229            List<Deferred[None]>: for each input triplet, a deferred indicating success
230                or failure to verify each json object's signature for the given
231                server_name. The deferreds run their callbacks in the sentinel
232                logcontext.
233        """
234        return [
235            run_in_background(
236                self.process_request,
237                VerifyJsonRequest.from_json_object(
238                    server_name,
239                    json_object,
240                    validity_time,
241                ),
242            )
243            for server_name, json_object, validity_time in server_and_json
244        ]
245
246    async def verify_event_for_server(
247        self,
248        server_name: str,
249        event: EventBase,
250        validity_time: int,
251    ) -> None:
252        await self.process_request(
253            VerifyJsonRequest.from_event(
254                server_name,
255                event,
256                validity_time,
257            )
258        )
259
260    async def process_request(self, verify_request: VerifyJsonRequest) -> None:
261        """Processes the `VerifyJsonRequest`. Raises if the object is not signed
262        by the server, the signatures don't match or we failed to fetch the
263        necessary keys.
264        """
265
266        if not verify_request.key_ids:
267            raise SynapseError(
268                400,
269                f"Not signed by {verify_request.server_name}",
270                Codes.UNAUTHORIZED,
271            )
272
273        found_keys: Dict[str, FetchKeyResult] = {}
274
275        # If we are the originating server, short-circuit the key-fetch for any keys
276        # we already have
277        if verify_request.server_name == self._hostname:
278            for key_id in verify_request.key_ids:
279                if key_id in self._local_verify_keys:
280                    found_keys[key_id] = self._local_verify_keys[key_id]
281
282        key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
283        if key_ids_to_find:
284            # Add the keys we need to verify to the queue for retrieval. We queue
285            # up requests for the same server so we don't end up with many in flight
286            # requests for the same keys.
287            key_request = _FetchKeyRequest(
288                server_name=verify_request.server_name,
289                minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
290                key_ids=list(key_ids_to_find),
291            )
292            found_keys_by_server = await self._server_queue.add_to_queue(
293                key_request, key=verify_request.server_name
294            )
295
296            # Since we batch up requests the returned set of keys may contain keys
297            # from other servers, so we pull out only the ones we care about.
298            found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
299
300        # Verify each signature we got valid keys for, raising if we can't
301        # verify any of them.
302        verified = False
303        for key_id in verify_request.key_ids:
304            key_result = found_keys.get(key_id)
305            if not key_result:
306                continue
307
308            if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
309                continue
310
311            await self._process_json(key_result.verify_key, verify_request)
312            verified = True
313
314        if not verified:
315            raise SynapseError(
316                401,
317                f"Failed to find any key to satisfy: {key_request}",
318                Codes.UNAUTHORIZED,
319            )
320
321    async def _process_json(
322        self, verify_key: VerifyKey, verify_request: VerifyJsonRequest
323    ) -> None:
324        """Processes the `VerifyJsonRequest`. Raises if the signature can't be
325        verified.
326        """
327        try:
328            verify_signed_json(
329                verify_request.get_json_object(),
330                verify_request.server_name,
331                verify_key,
332            )
333        except SignatureVerifyException as e:
334            logger.debug(
335                "Error verifying signature for %s:%s:%s with key %s: %s",
336                verify_request.server_name,
337                verify_key.alg,
338                verify_key.version,
339                encode_verify_key_base64(verify_key),
340                str(e),
341            )
342            raise SynapseError(
343                401,
344                "Invalid signature for server %s with key %s:%s: %s"
345                % (
346                    verify_request.server_name,
347                    verify_key.alg,
348                    verify_key.version,
349                    str(e),
350                ),
351                Codes.UNAUTHORIZED,
352            )
353
354    async def _inner_fetch_key_requests(
355        self, requests: List[_FetchKeyRequest]
356    ) -> Dict[str, Dict[str, FetchKeyResult]]:
357        """Processing function for the queue of `_FetchKeyRequest`."""
358
359        logger.debug("Starting fetch for %s", requests)
360
361        # First we need to deduplicate requests for the same key. We do this by
362        # taking the *maximum* requested `minimum_valid_until_ts` for each pair
363        # of server name/key ID.
364        server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
365        for request in requests:
366            by_server = server_to_key_to_ts.setdefault(request.server_name, {})
367            for key_id in request.key_ids:
368                existing_ts = by_server.get(key_id, 0)
369                by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
370
371        deduped_requests = [
372            _FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
373            for server_name, by_server in server_to_key_to_ts.items()
374            for key_id, minimum_valid_ts in by_server.items()
375        ]
376
377        logger.debug("Deduplicated key requests to %s", deduped_requests)
378
379        # For each key we call `_inner_verify_request` which will handle
380        # fetching each key. Note these shouldn't throw if we fail to contact
381        # other servers etc.
382        results_per_request = await yieldable_gather_results(
383            self._inner_fetch_key_request,
384            deduped_requests,
385        )
386
387        # We now convert the returned list of results into a map from server
388        # name to key ID to FetchKeyResult, to return.
389        to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
390        for (request, results) in zip(deduped_requests, results_per_request):
391            to_return_by_server = to_return.setdefault(request.server_name, {})
392            for key_id, key_result in results.items():
393                existing = to_return_by_server.get(key_id)
394                if not existing or existing.valid_until_ts < key_result.valid_until_ts:
395                    to_return_by_server[key_id] = key_result
396
397        return to_return
398
399    async def _inner_fetch_key_request(
400        self, verify_request: _FetchKeyRequest
401    ) -> Dict[str, FetchKeyResult]:
402        """Attempt to fetch the given key by calling each key fetcher one by
403        one.
404        """
405        logger.debug("Starting fetch for %s", verify_request)
406
407        found_keys: Dict[str, FetchKeyResult] = {}
408        missing_key_ids = set(verify_request.key_ids)
409
410        for fetcher in self._key_fetchers:
411            if not missing_key_ids:
412                break
413
414            logger.debug("Getting keys from %s for %s", fetcher, verify_request)
415            keys = await fetcher.get_keys(
416                verify_request.server_name,
417                list(missing_key_ids),
418                verify_request.minimum_valid_until_ts,
419            )
420
421            for key_id, key in keys.items():
422                if not key:
423                    continue
424
425                # If we already have a result for the given key ID we keep the
426                # one with the highest `valid_until_ts`.
427                existing_key = found_keys.get(key_id)
428                if existing_key:
429                    if key.valid_until_ts <= existing_key.valid_until_ts:
430                        continue
431
432                # We always store the returned key even if it doesn't the
433                # `minimum_valid_until_ts` requirement, as some verification
434                # requests may still be able to be satisfied by it.
435                #
436                # We still keep looking for the key from other fetchers in that
437                # case though.
438                found_keys[key_id] = key
439
440                if key.valid_until_ts < verify_request.minimum_valid_until_ts:
441                    continue
442
443                missing_key_ids.discard(key_id)
444
445        return found_keys
446
447
448class KeyFetcher(metaclass=abc.ABCMeta):
449    def __init__(self, hs: "HomeServer"):
450        self._queue = BatchingQueue(
451            self.__class__.__name__, hs.get_clock(), self._fetch_keys
452        )
453
454    async def get_keys(
455        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
456    ) -> Dict[str, FetchKeyResult]:
457        results = await self._queue.add_to_queue(
458            _FetchKeyRequest(
459                server_name=server_name,
460                key_ids=key_ids,
461                minimum_valid_until_ts=minimum_valid_until_ts,
462            )
463        )
464        return results.get(server_name, {})
465
466    @abc.abstractmethod
467    async def _fetch_keys(
468        self, keys_to_fetch: List[_FetchKeyRequest]
469    ) -> Dict[str, Dict[str, FetchKeyResult]]:
470        pass
471
472
473class StoreKeyFetcher(KeyFetcher):
474    """KeyFetcher impl which fetches keys from our data store"""
475
476    def __init__(self, hs: "HomeServer"):
477        super().__init__(hs)
478
479        self.store = hs.get_datastore()
480
481    async def _fetch_keys(
482        self, keys_to_fetch: List[_FetchKeyRequest]
483    ) -> Dict[str, Dict[str, FetchKeyResult]]:
484        key_ids_to_fetch = (
485            (queue_value.server_name, key_id)
486            for queue_value in keys_to_fetch
487            for key_id in queue_value.key_ids
488        )
489
490        res = await self.store.get_server_verify_keys(key_ids_to_fetch)
491        keys: Dict[str, Dict[str, FetchKeyResult]] = {}
492        for (server_name, key_id), key in res.items():
493            keys.setdefault(server_name, {})[key_id] = key
494        return keys
495
496
497class BaseV2KeyFetcher(KeyFetcher):
498    def __init__(self, hs: "HomeServer"):
499        super().__init__(hs)
500
501        self.store = hs.get_datastore()
502        self.config = hs.config
503
504    async def process_v2_response(
505        self, from_server: str, response_json: JsonDict, time_added_ms: int
506    ) -> Dict[str, FetchKeyResult]:
507        """Parse a 'Server Keys' structure from the result of a /key request
508
509        This is used to parse either the entirety of the response from
510        GET /_matrix/key/v2/server, or a single entry from the list returned by
511        POST /_matrix/key/v2/query.
512
513        Checks that each signature in the response that claims to come from the origin
514        server is valid, and that there is at least one such signature.
515
516        Stores the json in server_keys_json so that it can be used for future responses
517        to /_matrix/key/v2/query.
518
519        Args:
520            from_server: the name of the server producing this result: either
521                the origin server for a /_matrix/key/v2/server request, or the notary
522                for a /_matrix/key/v2/query.
523
524            response_json: the json-decoded Server Keys response object
525
526            time_added_ms: the timestamp to record in server_keys_json
527
528        Returns:
529            Map from key_id to result object
530        """
531        ts_valid_until_ms = response_json["valid_until_ts"]
532
533        # start by extracting the keys from the response, since they may be required
534        # to validate the signature on the response.
535        verify_keys = {}
536        for key_id, key_data in response_json["verify_keys"].items():
537            if is_signing_algorithm_supported(key_id):
538                key_base64 = key_data["key"]
539                key_bytes = decode_base64(key_base64)
540                verify_key = decode_verify_key_bytes(key_id, key_bytes)
541                verify_keys[key_id] = FetchKeyResult(
542                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
543                )
544
545        server_name = response_json["server_name"]
546        verified = False
547        for key_id in response_json["signatures"].get(server_name, {}):
548            key = verify_keys.get(key_id)
549            if not key:
550                # the key may not be present in verify_keys if:
551                #  * we got the key from the notary server, and:
552                #  * the key belongs to the notary server, and:
553                #  * the notary server is using a different key to sign notary
554                #    responses.
555                continue
556
557            verify_signed_json(response_json, server_name, key.verify_key)
558            verified = True
559            break
560
561        if not verified:
562            raise KeyLookupError(
563                "Key response for %s is not signed by the origin server"
564                % (server_name,)
565            )
566
567        for key_id, key_data in response_json["old_verify_keys"].items():
568            if is_signing_algorithm_supported(key_id):
569                key_base64 = key_data["key"]
570                key_bytes = decode_base64(key_base64)
571                verify_key = decode_verify_key_bytes(key_id, key_bytes)
572                verify_keys[key_id] = FetchKeyResult(
573                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
574                )
575
576        key_json_bytes = encode_canonical_json(response_json)
577
578        await make_deferred_yieldable(
579            defer.gatherResults(
580                [
581                    run_in_background(
582                        self.store.store_server_keys_json,
583                        server_name=server_name,
584                        key_id=key_id,
585                        from_server=from_server,
586                        ts_now_ms=time_added_ms,
587                        ts_expires_ms=ts_valid_until_ms,
588                        key_json_bytes=key_json_bytes,
589                    )
590                    for key_id in verify_keys
591                ],
592                consumeErrors=True,
593            ).addErrback(unwrapFirstError)
594        )
595
596        return verify_keys
597
598
599class PerspectivesKeyFetcher(BaseV2KeyFetcher):
600    """KeyFetcher impl which fetches keys from the "perspectives" servers"""
601
602    def __init__(self, hs: "HomeServer"):
603        super().__init__(hs)
604        self.clock = hs.get_clock()
605        self.client = hs.get_federation_http_client()
606        self.key_servers = self.config.key.key_servers
607
608    async def _fetch_keys(
609        self, keys_to_fetch: List[_FetchKeyRequest]
610    ) -> Dict[str, Dict[str, FetchKeyResult]]:
611        """see KeyFetcher._fetch_keys"""
612
613        async def get_key(key_server: TrustedKeyServer) -> Dict:
614            try:
615                return await self.get_server_verify_key_v2_indirect(
616                    keys_to_fetch, key_server
617                )
618            except KeyLookupError as e:
619                logger.warning(
620                    "Key lookup failed from %r: %s", key_server.server_name, e
621                )
622            except Exception as e:
623                logger.exception(
624                    "Unable to get key from %r: %s %s",
625                    key_server.server_name,
626                    type(e).__name__,
627                    str(e),
628                )
629
630            return {}
631
632        results = await make_deferred_yieldable(
633            defer.gatherResults(
634                [run_in_background(get_key, server) for server in self.key_servers],
635                consumeErrors=True,
636            ).addErrback(unwrapFirstError)
637        )
638
639        union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
640        for result in results:
641            for server_name, keys in result.items():
642                union_of_keys.setdefault(server_name, {}).update(keys)
643
644        return union_of_keys
645
646    async def get_server_verify_key_v2_indirect(
647        self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
648    ) -> Dict[str, Dict[str, FetchKeyResult]]:
649        """
650        Args:
651            keys_to_fetch:
652                the keys to be fetched.
653
654            key_server: notary server to query for the keys
655
656        Returns:
657            Map from server_name -> key_id -> FetchKeyResult
658
659        Raises:
660            KeyLookupError if there was an error processing the entire response from
661                the server
662        """
663        perspective_name = key_server.server_name
664        logger.info(
665            "Requesting keys %s from notary server %s",
666            keys_to_fetch,
667            perspective_name,
668        )
669
670        request: JsonDict = {}
671        for queue_value in keys_to_fetch:
672            # there may be multiple requests for each server, so we have to merge
673            # them intelligently.
674            request_for_server = {
675                key_id: {
676                    "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
677                }
678                for key_id in queue_value.key_ids
679            }
680            request.setdefault(queue_value.server_name, {}).update(request_for_server)
681
682        logger.debug("Request to notary server %s: %s", perspective_name, request)
683
684        try:
685            query_response = await self.client.post_json(
686                destination=perspective_name,
687                path="/_matrix/key/v2/query",
688                data={"server_keys": request},
689            )
690        except (NotRetryingDestination, RequestSendFailed) as e:
691            # these both have str() representations which we can't really improve upon
692            raise KeyLookupError(str(e))
693        except HttpResponseException as e:
694            raise KeyLookupError("Remote server returned an error: %s" % (e,))
695
696        logger.debug(
697            "Response from notary server %s: %s", perspective_name, query_response
698        )
699
700        keys: Dict[str, Dict[str, FetchKeyResult]] = {}
701        added_keys: List[Tuple[str, str, FetchKeyResult]] = []
702
703        time_now_ms = self.clock.time_msec()
704
705        assert isinstance(query_response, dict)
706        for response in query_response["server_keys"]:
707            # do this first, so that we can give useful errors thereafter
708            server_name = response.get("server_name")
709            if not isinstance(server_name, str):
710                raise KeyLookupError(
711                    "Malformed response from key notary server %s: invalid server_name"
712                    % (perspective_name,)
713                )
714
715            try:
716                self._validate_perspectives_response(key_server, response)
717
718                processed_response = await self.process_v2_response(
719                    perspective_name, response, time_added_ms=time_now_ms
720                )
721            except KeyLookupError as e:
722                logger.warning(
723                    "Error processing response from key notary server %s for origin "
724                    "server %s: %s",
725                    perspective_name,
726                    server_name,
727                    e,
728                )
729                # we continue to process the rest of the response
730                continue
731
732            added_keys.extend(
733                (server_name, key_id, key) for key_id, key in processed_response.items()
734            )
735            keys.setdefault(server_name, {}).update(processed_response)
736
737        await self.store.store_server_verify_keys(
738            perspective_name, time_now_ms, added_keys
739        )
740
741        return keys
742
743    def _validate_perspectives_response(
744        self, key_server: TrustedKeyServer, response: JsonDict
745    ) -> None:
746        """Optionally check the signature on the result of a /key/query request
747
748        Args:
749            key_server: the notary server that produced this result
750
751            response: the json-decoded Server Keys response object
752        """
753        perspective_name = key_server.server_name
754        perspective_keys = key_server.verify_keys
755
756        if perspective_keys is None:
757            # signature checking is disabled on this server
758            return
759
760        if (
761            "signatures" not in response
762            or perspective_name not in response["signatures"]
763        ):
764            raise KeyLookupError("Response not signed by the notary server")
765
766        verified = False
767        for key_id in response["signatures"][perspective_name]:
768            if key_id in perspective_keys:
769                verify_signed_json(response, perspective_name, perspective_keys[key_id])
770                verified = True
771
772        if not verified:
773            raise KeyLookupError(
774                "Response not signed with a known key: signed with: %r, known keys: %r"
775                % (
776                    list(response["signatures"][perspective_name].keys()),
777                    list(perspective_keys.keys()),
778                )
779            )
780
781
782class ServerKeyFetcher(BaseV2KeyFetcher):
783    """KeyFetcher impl which fetches keys from the origin servers"""
784
785    def __init__(self, hs: "HomeServer"):
786        super().__init__(hs)
787        self.clock = hs.get_clock()
788        self.client = hs.get_federation_http_client()
789
790    async def get_keys(
791        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
792    ) -> Dict[str, FetchKeyResult]:
793        results = await self._queue.add_to_queue(
794            _FetchKeyRequest(
795                server_name=server_name,
796                key_ids=key_ids,
797                minimum_valid_until_ts=minimum_valid_until_ts,
798            ),
799            key=server_name,
800        )
801        return results.get(server_name, {})
802
803    async def _fetch_keys(
804        self, keys_to_fetch: List[_FetchKeyRequest]
805    ) -> Dict[str, Dict[str, FetchKeyResult]]:
806        """
807        Args:
808            keys_to_fetch:
809                the keys to be fetched. server_name -> key_ids
810
811        Returns:
812            Map from server_name -> key_id -> FetchKeyResult
813        """
814
815        results = {}
816
817        async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
818            server_name = key_to_fetch_item.server_name
819            key_ids = key_to_fetch_item.key_ids
820
821            try:
822                keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
823                results[server_name] = keys
824            except KeyLookupError as e:
825                logger.warning(
826                    "Error looking up keys %s from %s: %s", key_ids, server_name, e
827                )
828            except Exception:
829                logger.exception("Error getting keys %s from %s", key_ids, server_name)
830
831        await yieldable_gather_results(get_key, keys_to_fetch)
832        return results
833
834    async def get_server_verify_key_v2_direct(
835        self, server_name: str, key_ids: Iterable[str]
836    ) -> Dict[str, FetchKeyResult]:
837        """
838
839        Args:
840            server_name:
841            key_ids:
842
843        Returns:
844            Map from key ID to lookup result
845
846        Raises:
847            KeyLookupError if there was a problem making the lookup
848        """
849        keys: Dict[str, FetchKeyResult] = {}
850
851        for requested_key_id in key_ids:
852            # we may have found this key as a side-effect of asking for another.
853            if requested_key_id in keys:
854                continue
855
856            time_now_ms = self.clock.time_msec()
857            try:
858                response = await self.client.get_json(
859                    destination=server_name,
860                    path="/_matrix/key/v2/server/"
861                    + urllib.parse.quote(requested_key_id),
862                    ignore_backoff=True,
863                    # we only give the remote server 10s to respond. It should be an
864                    # easy request to handle, so if it doesn't reply within 10s, it's
865                    # probably not going to.
866                    #
867                    # Furthermore, when we are acting as a notary server, we cannot
868                    # wait all day for all of the origin servers, as the requesting
869                    # server will otherwise time out before we can respond.
870                    #
871                    # (Note that get_json may make 4 attempts, so this can still take
872                    # almost 45 seconds to fetch the headers, plus up to another 60s to
873                    # read the response).
874                    timeout=10000,
875                )
876            except (NotRetryingDestination, RequestSendFailed) as e:
877                # these both have str() representations which we can't really improve
878                # upon
879                raise KeyLookupError(str(e))
880            except HttpResponseException as e:
881                raise KeyLookupError("Remote server returned an error: %s" % (e,))
882
883            assert isinstance(response, dict)
884            if response["server_name"] != server_name:
885                raise KeyLookupError(
886                    "Expected a response for server %r not %r"
887                    % (server_name, response["server_name"])
888                )
889
890            response_keys = await self.process_v2_response(
891                from_server=server_name,
892                response_json=response,
893                time_added_ms=time_now_ms,
894            )
895            await self.store.store_server_verify_keys(
896                server_name,
897                time_now_ms,
898                ((server_name, key_id, key) for key_id, key in response_keys.items()),
899            )
900            keys.update(response_keys)
901
902        return keys
903