1# Copyright 2016 OpenMarket Ltd
2# Copyright 2018-2019 New Vector Ltd
3# Copyright 2019 The Matrix.org Foundation C.I.C.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import logging
18from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
19
20import attr
21from canonicaljson import encode_canonical_json
22from signedjson.key import VerifyKey, decode_verify_key_bytes
23from signedjson.sign import SignatureVerifyException, verify_signed_json
24from unpaddedbase64 import decode_base64
25
26from twisted.internet import defer
27
28from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
29from synapse.logging.context import make_deferred_yieldable, run_in_background
30from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
31from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
32from synapse.types import (
33    JsonDict,
34    UserID,
35    get_domain_from_id,
36    get_verify_key_from_cross_signing_key,
37)
38from synapse.util import json_decoder, unwrapFirstError
39from synapse.util.async_helpers import Linearizer
40from synapse.util.retryutils import NotRetryingDestination
41
42if TYPE_CHECKING:
43    from synapse.server import HomeServer
44
45logger = logging.getLogger(__name__)
46
47
48class E2eKeysHandler:
49    def __init__(self, hs: "HomeServer"):
50        self.store = hs.get_datastore()
51        self.federation = hs.get_federation_client()
52        self.device_handler = hs.get_device_handler()
53        self.is_mine = hs.is_mine
54        self.clock = hs.get_clock()
55
56        self._edu_updater = SigningKeyEduUpdater(hs, self)
57
58        federation_registry = hs.get_federation_registry()
59
60        self._is_master = hs.config.worker.worker_app is None
61        if not self._is_master:
62            self._user_device_resync_client = (
63                ReplicationUserDevicesResyncRestServlet.make_client(hs)
64            )
65        else:
66            # Only register this edu handler on master as it requires writing
67            # device updates to the db
68            federation_registry.register_edu_handler(
69                "m.signing_key_update",
70                self._edu_updater.incoming_signing_key_update,
71            )
72            # also handle the unstable version
73            # FIXME: remove this when enough servers have upgraded
74            federation_registry.register_edu_handler(
75                "org.matrix.signing_key_update",
76                self._edu_updater.incoming_signing_key_update,
77            )
78
79        # doesn't really work as part of the generic query API, because the
80        # query request requires an object POST, but we abuse the
81        # "query handler" interface.
82        federation_registry.register_query_handler(
83            "client_keys", self.on_federation_query_client_keys
84        )
85
86        # Limit the number of in-flight requests from a single device.
87        self._query_devices_linearizer = Linearizer(
88            name="query_devices",
89            max_count=10,
90        )
91
92    @trace
93    async def query_devices(
94        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
95    ) -> JsonDict:
96        """Handle a device key query from a client
97
98        {
99            "device_keys": {
100                "<user_id>": ["<device_id>"]
101            }
102        }
103        ->
104        {
105            "device_keys": {
106                "<user_id>": {
107                    "<device_id>": {
108                        ...
109                    }
110                }
111            }
112        }
113
114        Args:
115            from_user_id: the user making the query.  This is used when
116                adding cross-signing signatures to limit what signatures users
117                can see.
118            from_device_id: the device making the query. This is used to limit
119                the number of in-flight queries at a time.
120        """
121        with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
122            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
123                "device_keys", {}
124            )
125
126            # separate users by domain.
127            # make a map from domain to user_id to device_ids
128            local_query = {}
129            remote_queries = {}
130
131            for user_id, device_ids in device_keys_query.items():
132                # we use UserID.from_string to catch invalid user ids
133                if self.is_mine(UserID.from_string(user_id)):
134                    local_query[user_id] = device_ids
135                else:
136                    remote_queries[user_id] = device_ids
137
138            set_tag("local_key_query", local_query)
139            set_tag("remote_key_query", remote_queries)
140
141            # First get local devices.
142            # A map of destination -> failure response.
143            failures: Dict[str, JsonDict] = {}
144            results = {}
145            if local_query:
146                local_result = await self.query_local_devices(local_query)
147                for user_id, keys in local_result.items():
148                    if user_id in local_query:
149                        results[user_id] = keys
150
151            # Get cached cross-signing keys
152            cross_signing_keys = await self.get_cross_signing_keys_from_cache(
153                device_keys_query, from_user_id
154            )
155
156            # Now attempt to get any remote devices from our local cache.
157            # A map of destination -> user ID -> device IDs.
158            remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
159            if remote_queries:
160                query_list: List[Tuple[str, Optional[str]]] = []
161                for user_id, device_ids in remote_queries.items():
162                    if device_ids:
163                        query_list.extend(
164                            (user_id, device_id) for device_id in device_ids
165                        )
166                    else:
167                        query_list.append((user_id, None))
168
169                (
170                    user_ids_not_in_cache,
171                    remote_results,
172                ) = await self.store.get_user_devices_from_cache(query_list)
173                for user_id, devices in remote_results.items():
174                    user_devices = results.setdefault(user_id, {})
175                    for device_id, device in devices.items():
176                        keys = device.get("keys", None)
177                        device_display_name = device.get("device_display_name", None)
178                        if keys:
179                            result = dict(keys)
180                            unsigned = result.setdefault("unsigned", {})
181                            if device_display_name:
182                                unsigned["device_display_name"] = device_display_name
183                            user_devices[device_id] = result
184
185                # check for missing cross-signing keys.
186                for user_id in remote_queries.keys():
187                    cached_cross_master = user_id in cross_signing_keys["master_keys"]
188                    cached_cross_selfsigning = (
189                        user_id in cross_signing_keys["self_signing_keys"]
190                    )
191
192                    # check if we are missing only one of cross-signing master or
193                    # self-signing key, but the other one is cached.
194                    # as we need both, this will issue a federation request.
195                    # if we don't have any of the keys, either the user doesn't have
196                    # cross-signing set up, or the cached device list
197                    # is not (yet) updated.
198                    if cached_cross_master ^ cached_cross_selfsigning:
199                        user_ids_not_in_cache.add(user_id)
200
201                # add those users to the list to fetch over federation.
202                for user_id in user_ids_not_in_cache:
203                    domain = get_domain_from_id(user_id)
204                    r = remote_queries_not_in_cache.setdefault(domain, {})
205                    r[user_id] = remote_queries[user_id]
206
207            # Now fetch any devices that we don't have in our cache
208            await make_deferred_yieldable(
209                defer.gatherResults(
210                    [
211                        run_in_background(
212                            self._query_devices_for_destination,
213                            results,
214                            cross_signing_keys,
215                            failures,
216                            destination,
217                            queries,
218                            timeout,
219                        )
220                        for destination, queries in remote_queries_not_in_cache.items()
221                    ],
222                    consumeErrors=True,
223                ).addErrback(unwrapFirstError)
224            )
225
226            ret = {"device_keys": results, "failures": failures}
227
228            ret.update(cross_signing_keys)
229
230            return ret
231
232    @trace
233    async def _query_devices_for_destination(
234        self,
235        results: JsonDict,
236        cross_signing_keys: JsonDict,
237        failures: Dict[str, JsonDict],
238        destination: str,
239        destination_query: Dict[str, Iterable[str]],
240        timeout: int,
241    ) -> None:
242        """This is called when we are querying the device list of a user on
243        a remote homeserver and their device list is not in the device list
244        cache. If we share a room with this user and we're not querying for
245        specific user we will update the cache with their device list.
246
247        Args:
248            results: A map from user ID to their device keys, which gets
249                updated with the newly fetched keys.
250            cross_signing_keys: Map from user ID to their cross signing keys,
251                which gets updated with the newly fetched keys.
252            failures: Map of destinations to failures that have occurred while
253                attempting to fetch keys.
254            destination: The remote server to query
255            destination_query: The query dict of devices to query the remote
256                server for.
257            timeout: The timeout for remote HTTP requests.
258        """
259
260        # We first consider whether we wish to update the device list cache with
261        # the users device list. We want to track a user's devices when the
262        # authenticated user shares a room with the queried user and the query
263        # has not specified a particular device.
264        # If we update the cache for the queried user we remove them from further
265        # queries. We use the more efficient batched query_client_keys for all
266        # remaining users
267        user_ids_updated = []
268        for (user_id, device_list) in destination_query.items():
269            if user_id in user_ids_updated:
270                continue
271
272            if device_list:
273                continue
274
275            room_ids = await self.store.get_rooms_for_user(user_id)
276            if not room_ids:
277                continue
278
279            # We've decided we're sharing a room with this user and should
280            # probably be tracking their device lists. However, we haven't
281            # done an initial sync on the device list so we do it now.
282            try:
283                if self._is_master:
284                    resync_results = await self.device_handler.device_list_updater.user_device_resync(
285                        user_id
286                    )
287                else:
288                    resync_results = await self._user_device_resync_client(
289                        user_id=user_id
290                    )
291
292                # Add the device keys to the results.
293                user_devices = resync_results["devices"]
294                user_results = results.setdefault(user_id, {})
295                for device in user_devices:
296                    user_results[device["device_id"]] = device["keys"]
297                user_ids_updated.append(user_id)
298
299                # Add any cross signing keys to the results.
300                master_key = resync_results.get("master_key")
301                self_signing_key = resync_results.get("self_signing_key")
302
303                if master_key:
304                    cross_signing_keys["master_keys"][user_id] = master_key
305
306                if self_signing_key:
307                    cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
308            except Exception as e:
309                failures[destination] = _exception_to_failure(e)
310
311        if len(destination_query) == len(user_ids_updated):
312            # We've updated all the users in the query and we do not need to
313            # make any further remote calls.
314            return
315
316        # Remove all the users from the query which we have updated
317        for user_id in user_ids_updated:
318            destination_query.pop(user_id)
319
320        try:
321            remote_result = await self.federation.query_client_keys(
322                destination, {"device_keys": destination_query}, timeout=timeout
323            )
324
325            for user_id, keys in remote_result["device_keys"].items():
326                if user_id in destination_query:
327                    results[user_id] = keys
328
329            if "master_keys" in remote_result:
330                for user_id, key in remote_result["master_keys"].items():
331                    if user_id in destination_query:
332                        cross_signing_keys["master_keys"][user_id] = key
333
334            if "self_signing_keys" in remote_result:
335                for user_id, key in remote_result["self_signing_keys"].items():
336                    if user_id in destination_query:
337                        cross_signing_keys["self_signing_keys"][user_id] = key
338
339        except Exception as e:
340            failure = _exception_to_failure(e)
341            failures[destination] = failure
342            set_tag("error", True)
343            set_tag("reason", failure)
344
345        return
346
347    async def get_cross_signing_keys_from_cache(
348        self, query: Iterable[str], from_user_id: Optional[str]
349    ) -> Dict[str, Dict[str, dict]]:
350        """Get cross-signing keys for users from the database
351
352        Args:
353            query: an iterable of user IDs.  A dict whose keys
354                are user IDs satisfies this, so the query format used for
355                query_devices can be used here.
356            from_user_id: the user making the query.  This is used when
357                adding cross-signing signatures to limit what signatures users
358                can see.
359
360        Returns:
361            A map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
362        """
363        master_keys = {}
364        self_signing_keys = {}
365        user_signing_keys = {}
366
367        user_ids = list(query)
368
369        keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
370
371        for user_id, user_info in keys.items():
372            if user_info is None:
373                continue
374            if "master" in user_info:
375                master_keys[user_id] = user_info["master"]
376            if "self_signing" in user_info:
377                self_signing_keys[user_id] = user_info["self_signing"]
378
379        # users can see other users' master and self-signing keys, but can
380        # only see their own user-signing keys
381        if from_user_id:
382            from_user_key = keys.get(from_user_id)
383            if from_user_key and "user_signing" in from_user_key:
384                user_signing_keys[from_user_id] = from_user_key["user_signing"]
385
386        return {
387            "master_keys": master_keys,
388            "self_signing_keys": self_signing_keys,
389            "user_signing_keys": user_signing_keys,
390        }
391
392    @trace
393    async def query_local_devices(
394        self, query: Dict[str, Optional[List[str]]]
395    ) -> Dict[str, Dict[str, dict]]:
396        """Get E2E device keys for local users
397
398        Args:
399            query: map from user_id to a list
400                 of devices to query (None for all devices)
401
402        Returns:
403            A map from user_id -> device_id -> device details
404        """
405        set_tag("local_query", query)
406        local_query: List[Tuple[str, Optional[str]]] = []
407
408        result_dict: Dict[str, Dict[str, dict]] = {}
409        for user_id, device_ids in query.items():
410            # we use UserID.from_string to catch invalid user ids
411            if not self.is_mine(UserID.from_string(user_id)):
412                logger.warning("Request for keys for non-local user %s", user_id)
413                log_kv(
414                    {
415                        "message": "Requested a local key for a user which"
416                        " was not local to the homeserver",
417                        "user_id": user_id,
418                    }
419                )
420                set_tag("error", True)
421                raise SynapseError(400, "Not a user here")
422
423            if not device_ids:
424                local_query.append((user_id, None))
425            else:
426                for device_id in device_ids:
427                    local_query.append((user_id, device_id))
428
429            # make sure that each queried user appears in the result dict
430            result_dict[user_id] = {}
431
432        results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
433
434        # Build the result structure
435        for user_id, device_keys in results.items():
436            for device_id, device_info in device_keys.items():
437                result_dict[user_id][device_id] = device_info
438
439        log_kv(results)
440        return result_dict
441
442    async def on_federation_query_client_keys(
443        self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
444    ) -> JsonDict:
445        """Handle a device key query from a federated server"""
446        device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
447            "device_keys", {}
448        )
449        res = await self.query_local_devices(device_keys_query)
450        ret = {"device_keys": res}
451
452        # add in the cross-signing keys
453        cross_signing_keys = await self.get_cross_signing_keys_from_cache(
454            device_keys_query, None
455        )
456
457        ret.update(cross_signing_keys)
458
459        return ret
460
461    @trace
462    async def claim_one_time_keys(
463        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
464    ) -> JsonDict:
465        local_query: List[Tuple[str, str, str]] = []
466        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
467
468        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
469            # we use UserID.from_string to catch invalid user ids
470            if self.is_mine(UserID.from_string(user_id)):
471                for device_id, algorithm in one_time_keys.items():
472                    local_query.append((user_id, device_id, algorithm))
473            else:
474                domain = get_domain_from_id(user_id)
475                remote_queries.setdefault(domain, {})[user_id] = one_time_keys
476
477        set_tag("local_key_query", local_query)
478        set_tag("remote_key_query", remote_queries)
479
480        results = await self.store.claim_e2e_one_time_keys(local_query)
481
482        # A map of user ID -> device ID -> key ID -> key.
483        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
484        failures: Dict[str, JsonDict] = {}
485        for user_id, device_keys in results.items():
486            for device_id, keys in device_keys.items():
487                for key_id, json_str in keys.items():
488                    json_result.setdefault(user_id, {})[device_id] = {
489                        key_id: json_decoder.decode(json_str)
490                    }
491
492        @trace
493        async def claim_client_keys(destination: str) -> None:
494            set_tag("destination", destination)
495            device_keys = remote_queries[destination]
496            try:
497                remote_result = await self.federation.claim_client_keys(
498                    destination, {"one_time_keys": device_keys}, timeout=timeout
499                )
500                for user_id, keys in remote_result["one_time_keys"].items():
501                    if user_id in device_keys:
502                        json_result[user_id] = keys
503
504            except Exception as e:
505                failure = _exception_to_failure(e)
506                failures[destination] = failure
507                set_tag("error", True)
508                set_tag("reason", failure)
509
510        await make_deferred_yieldable(
511            defer.gatherResults(
512                [
513                    run_in_background(claim_client_keys, destination)
514                    for destination in remote_queries
515                ],
516                consumeErrors=True,
517            )
518        )
519
520        logger.info(
521            "Claimed one-time-keys: %s",
522            ",".join(
523                (
524                    "%s for %s:%s" % (key_id, user_id, device_id)
525                    for user_id, user_keys in json_result.items()
526                    for device_id, device_keys in user_keys.items()
527                    for key_id, _ in device_keys.items()
528                )
529            ),
530        )
531
532        log_kv({"one_time_keys": json_result, "failures": failures})
533        return {"one_time_keys": json_result, "failures": failures}
534
535    @tag_args
536    async def upload_keys_for_user(
537        self, user_id: str, device_id: str, keys: JsonDict
538    ) -> JsonDict:
539
540        time_now = self.clock.time_msec()
541
542        # TODO: Validate the JSON to make sure it has the right keys.
543        device_keys = keys.get("device_keys", None)
544        if device_keys:
545            logger.info(
546                "Updating device_keys for device %r for user %s at %d",
547                device_id,
548                user_id,
549                time_now,
550            )
551            log_kv(
552                {
553                    "message": "Updating device_keys for user.",
554                    "user_id": user_id,
555                    "device_id": device_id,
556                }
557            )
558            # TODO: Sign the JSON with the server key
559            changed = await self.store.set_e2e_device_keys(
560                user_id, device_id, time_now, device_keys
561            )
562            if changed:
563                # Only notify about device updates *if* the keys actually changed
564                await self.device_handler.notify_device_update(user_id, [device_id])
565        else:
566            log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
567        one_time_keys = keys.get("one_time_keys", None)
568        if one_time_keys:
569            log_kv(
570                {
571                    "message": "Updating one_time_keys for device.",
572                    "user_id": user_id,
573                    "device_id": device_id,
574                }
575            )
576            await self._upload_one_time_keys_for_user(
577                user_id, device_id, time_now, one_time_keys
578            )
579        else:
580            log_kv(
581                {"message": "Did not update one_time_keys", "reason": "no keys given"}
582            )
583        fallback_keys = keys.get("fallback_keys") or keys.get(
584            "org.matrix.msc2732.fallback_keys"
585        )
586        if fallback_keys and isinstance(fallback_keys, dict):
587            log_kv(
588                {
589                    "message": "Updating fallback_keys for device.",
590                    "user_id": user_id,
591                    "device_id": device_id,
592                }
593            )
594            await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
595        elif fallback_keys:
596            log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
597        else:
598            log_kv(
599                {"message": "Did not update fallback_keys", "reason": "no keys given"}
600            )
601
602        # the device should have been registered already, but it may have been
603        # deleted due to a race with a DELETE request. Or we may be using an
604        # old access_token without an associated device_id. Either way, we
605        # need to double-check the device is registered to avoid ending up with
606        # keys without a corresponding device.
607        await self.device_handler.check_device_registered(user_id, device_id)
608
609        result = await self.store.count_e2e_one_time_keys(user_id, device_id)
610
611        set_tag("one_time_key_counts", result)
612        return {"one_time_key_counts": result}
613
614    async def _upload_one_time_keys_for_user(
615        self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
616    ) -> None:
617        logger.info(
618            "Adding one_time_keys %r for device %r for user %r at %d",
619            one_time_keys.keys(),
620            device_id,
621            user_id,
622            time_now,
623        )
624
625        # make a list of (alg, id, key) tuples
626        key_list = []
627        for key_id, key_obj in one_time_keys.items():
628            algorithm, key_id = key_id.split(":")
629            key_list.append((algorithm, key_id, key_obj))
630
631        # First we check if we have already persisted any of the keys.
632        existing_key_map = await self.store.get_e2e_one_time_keys(
633            user_id, device_id, [k_id for _, k_id, _ in key_list]
634        )
635
636        new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
637        for algorithm, key_id, key in key_list:
638            ex_json = existing_key_map.get((algorithm, key_id), None)
639            if ex_json:
640                if not _one_time_keys_match(ex_json, key):
641                    raise SynapseError(
642                        400,
643                        (
644                            "One time key %s:%s already exists. "
645                            "Old key: %s; new key: %r"
646                        )
647                        % (algorithm, key_id, ex_json, key),
648                    )
649            else:
650                new_keys.append(
651                    (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
652                )
653
654        log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
655        await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
656
657    async def upload_signing_keys_for_user(
658        self, user_id: str, keys: JsonDict
659    ) -> JsonDict:
660        """Upload signing keys for cross-signing
661
662        Args:
663            user_id: the user uploading the keys
664            keys: the signing keys
665        """
666
667        # if a master key is uploaded, then check it.  Otherwise, load the
668        # stored master key, to check signatures on other keys
669        if "master_key" in keys:
670            master_key = keys["master_key"]
671
672            _check_cross_signing_key(master_key, user_id, "master")
673        else:
674            master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
675
676        # if there is no master key, then we can't do anything, because all the
677        # other cross-signing keys need to be signed by the master key
678        if not master_key:
679            raise SynapseError(400, "No master key available", Codes.MISSING_PARAM)
680
681        try:
682            master_key_id, master_verify_key = get_verify_key_from_cross_signing_key(
683                master_key
684            )
685        except ValueError:
686            if "master_key" in keys:
687                # the invalid key came from the request
688                raise SynapseError(400, "Invalid master key", Codes.INVALID_PARAM)
689            else:
690                # the invalid key came from the database
691                logger.error("Invalid master key found for user %s", user_id)
692                raise SynapseError(500, "Invalid master key")
693
694        # for the other cross-signing keys, make sure that they have valid
695        # signatures from the master key
696        if "self_signing_key" in keys:
697            self_signing_key = keys["self_signing_key"]
698
699            _check_cross_signing_key(
700                self_signing_key, user_id, "self_signing", master_verify_key
701            )
702
703        if "user_signing_key" in keys:
704            user_signing_key = keys["user_signing_key"]
705
706            _check_cross_signing_key(
707                user_signing_key, user_id, "user_signing", master_verify_key
708            )
709
710        # if everything checks out, then store the keys and send notifications
711        deviceids = []
712        if "master_key" in keys:
713            await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
714            deviceids.append(master_verify_key.version)
715        if "self_signing_key" in keys:
716            await self.store.set_e2e_cross_signing_key(
717                user_id, "self_signing", self_signing_key
718            )
719            try:
720                deviceids.append(
721                    get_verify_key_from_cross_signing_key(self_signing_key)[1].version
722                )
723            except ValueError:
724                raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
725        if "user_signing_key" in keys:
726            await self.store.set_e2e_cross_signing_key(
727                user_id, "user_signing", user_signing_key
728            )
729            # the signature stream matches the semantics that we want for
730            # user-signing key updates: only the user themselves is notified of
731            # their own user-signing key updates
732            await self.device_handler.notify_user_signature_update(user_id, [user_id])
733
734        # master key and self-signing key updates match the semantics of device
735        # list updates: all users who share an encrypted room are notified
736        if len(deviceids):
737            await self.device_handler.notify_device_update(user_id, deviceids)
738
739        return {}
740
741    async def upload_signatures_for_device_keys(
742        self, user_id: str, signatures: JsonDict
743    ) -> JsonDict:
744        """Upload device signatures for cross-signing
745
746        Args:
747            user_id: the user uploading the signatures
748            signatures: map of users to devices to signed keys. This is the submission
749            from the user; an exception will be raised if it is malformed.
750        Returns:
751            The response to be sent back to the client.  The response will have
752                a "failures" key, which will be a dict mapping users to devices
753                to errors for the signatures that failed.
754        Raises:
755            SynapseError: if the signatures dict is not valid.
756        """
757        failures = {}
758
759        # signatures to be stored.  Each item will be a SignatureListItem
760        signature_list = []
761
762        # split between checking signatures for own user and signatures for
763        # other users, since we verify them with different keys
764        self_signatures = signatures.get(user_id, {})
765        other_signatures = {k: v for k, v in signatures.items() if k != user_id}
766
767        self_signature_list, self_failures = await self._process_self_signatures(
768            user_id, self_signatures
769        )
770        signature_list.extend(self_signature_list)
771        failures.update(self_failures)
772
773        other_signature_list, other_failures = await self._process_other_signatures(
774            user_id, other_signatures
775        )
776        signature_list.extend(other_signature_list)
777        failures.update(other_failures)
778
779        # store the signature, and send the appropriate notifications for sync
780        logger.debug("upload signature failures: %r", failures)
781        await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
782
783        self_device_ids = [item.target_device_id for item in self_signature_list]
784        if self_device_ids:
785            await self.device_handler.notify_device_update(user_id, self_device_ids)
786        signed_users = [item.target_user_id for item in other_signature_list]
787        if signed_users:
788            await self.device_handler.notify_user_signature_update(
789                user_id, signed_users
790            )
791
792        return {"failures": failures}
793
794    async def _process_self_signatures(
795        self, user_id: str, signatures: JsonDict
796    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
797        """Process uploaded signatures of the user's own keys.
798
799        Signatures of the user's own keys from this API come in two forms:
800        - signatures of the user's devices by the user's self-signing key,
801        - signatures of the user's master key by the user's devices.
802
803        Args:
804            user_id (string): the user uploading the keys
805            signatures (dict[string, dict]): map of devices to signed keys
806
807        Returns:
808            A tuple of a list of signatures to store, and a map of users to
809            devices to failure reasons
810
811        Raises:
812            SynapseError: if the input is malformed
813        """
814        signature_list: List["SignatureListItem"] = []
815        failures: Dict[str, Dict[str, JsonDict]] = {}
816        if not signatures:
817            return signature_list, failures
818
819        if not isinstance(signatures, dict):
820            raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
821
822        try:
823            # get our self-signing key to verify the signatures
824            (
825                _,
826                self_signing_key_id,
827                self_signing_verify_key,
828            ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
829
830            # get our master key, since we may have received a signature of it.
831            # We need to fetch it here so that we know what its key ID is, so
832            # that we can check if a signature that was sent is a signature of
833            # the master key or of a device
834            (
835                master_key,
836                _,
837                master_verify_key,
838            ) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
839
840            # fetch our stored devices.  This is used to 1. verify
841            # signatures on the master key, and 2. to compare with what
842            # was sent if the device was signed
843            devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
844
845            if user_id not in devices:
846                raise NotFoundError("No device keys found")
847
848            devices = devices[user_id]
849        except SynapseError as e:
850            failure = _exception_to_failure(e)
851            failures[user_id] = {device: failure for device in signatures.keys()}
852            return signature_list, failures
853
854        for device_id, device in signatures.items():
855            # make sure submitted data is in the right form
856            if not isinstance(device, dict):
857                raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
858
859            try:
860                if "signatures" not in device or user_id not in device["signatures"]:
861                    # no signature was sent
862                    raise SynapseError(
863                        400, "Invalid signature", Codes.INVALID_SIGNATURE
864                    )
865
866                if device_id == master_verify_key.version:
867                    # The signature is of the master key. This needs to be
868                    # handled differently from signatures of normal devices.
869                    master_key_signature_list = self._check_master_key_signature(
870                        user_id, device_id, device, master_key, devices
871                    )
872                    signature_list.extend(master_key_signature_list)
873                    continue
874
875                # at this point, we have a device that should be signed
876                # by the self-signing key
877                if self_signing_key_id not in device["signatures"][user_id]:
878                    # no signature was sent
879                    raise SynapseError(
880                        400, "Invalid signature", Codes.INVALID_SIGNATURE
881                    )
882
883                try:
884                    stored_device = devices[device_id]
885                except KeyError:
886                    raise NotFoundError("Unknown device")
887                if self_signing_key_id in stored_device.get("signatures", {}).get(
888                    user_id, {}
889                ):
890                    # we already have a signature on this device, so we
891                    # can skip it, since it should be exactly the same
892                    continue
893
894                _check_device_signature(
895                    user_id, self_signing_verify_key, device, stored_device
896                )
897
898                signature = device["signatures"][user_id][self_signing_key_id]
899                signature_list.append(
900                    SignatureListItem(
901                        self_signing_key_id, user_id, device_id, signature
902                    )
903                )
904            except SynapseError as e:
905                failures.setdefault(user_id, {})[device_id] = _exception_to_failure(e)
906
907        return signature_list, failures
908
909    def _check_master_key_signature(
910        self,
911        user_id: str,
912        master_key_id: str,
913        signed_master_key: JsonDict,
914        stored_master_key: JsonDict,
915        devices: Dict[str, Dict[str, JsonDict]],
916    ) -> List["SignatureListItem"]:
917        """Check signatures of a user's master key made by their devices.
918
919        Args:
920            user_id: the user whose master key is being checked
921            master_key_id: the ID of the user's master key
922            signed_master_key: the user's signed master key that was uploaded
923            stored_master_key: our previously-stored copy of the user's master key
924            devices: the user's devices
925
926        Returns:
927            A list of signatures to store
928
929        Raises:
930            SynapseError: if a signature is invalid
931        """
932        # for each device that signed the master key, check the signature.
933        master_key_signature_list = []
934        sigs = signed_master_key["signatures"]
935        for signing_key_id, signature in sigs[user_id].items():
936            _, signing_device_id = signing_key_id.split(":", 1)
937            if (
938                signing_device_id not in devices
939                or signing_key_id not in devices[signing_device_id]["keys"]
940            ):
941                # signed by an unknown device, or the
942                # device does not have the key
943                raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
944
945            # get the key and check the signature
946            pubkey = devices[signing_device_id]["keys"][signing_key_id]
947            verify_key = decode_verify_key_bytes(signing_key_id, decode_base64(pubkey))
948            _check_device_signature(
949                user_id, verify_key, signed_master_key, stored_master_key
950            )
951
952            master_key_signature_list.append(
953                SignatureListItem(signing_key_id, user_id, master_key_id, signature)
954            )
955
956        return master_key_signature_list
957
958    async def _process_other_signatures(
959        self, user_id: str, signatures: Dict[str, dict]
960    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
961        """Process uploaded signatures of other users' keys.  These will be the
962        target user's master keys, signed by the uploading user's user-signing
963        key.
964
965        Args:
966            user_id: the user uploading the keys
967            signatures: map of users to devices to signed keys
968
969        Returns:
970            A list of signatures to store, and a map of users to devices to failure
971            reasons
972
973        Raises:
974            SynapseError: if the input is malformed
975        """
976        signature_list: List["SignatureListItem"] = []
977        failures: Dict[str, Dict[str, JsonDict]] = {}
978        if not signatures:
979            return signature_list, failures
980
981        try:
982            # get our user-signing key to verify the signatures
983            (
984                user_signing_key,
985                user_signing_key_id,
986                user_signing_verify_key,
987            ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
988        except SynapseError as e:
989            failure = _exception_to_failure(e)
990            for user, devicemap in signatures.items():
991                failures[user] = {device_id: failure for device_id in devicemap.keys()}
992            return signature_list, failures
993
994        for target_user, devicemap in signatures.items():
995            # make sure submitted data is in the right form
996            if not isinstance(devicemap, dict):
997                raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
998            for device in devicemap.values():
999                if not isinstance(device, dict):
1000                    raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
1001
1002            device_id = None
1003            try:
1004                # get the target user's master key, to make sure it matches
1005                # what was sent
1006                (
1007                    master_key,
1008                    master_key_id,
1009                    _,
1010                ) = await self._get_e2e_cross_signing_verify_key(
1011                    target_user, "master", user_id
1012                )
1013
1014                # make sure that the target user's master key is the one that
1015                # was signed (and no others)
1016                device_id = master_key_id.split(":", 1)[1]
1017                if device_id not in devicemap:
1018                    logger.debug(
1019                        "upload signature: could not find signature for device %s",
1020                        device_id,
1021                    )
1022                    # set device to None so that the failure gets
1023                    # marked on all the signatures
1024                    device_id = None
1025                    raise NotFoundError("Unknown device")
1026                key = devicemap[device_id]
1027                other_devices = [k for k in devicemap.keys() if k != device_id]
1028                if other_devices:
1029                    # other devices were signed -- mark those as failures
1030                    logger.debug("upload signature: too many devices specified")
1031                    failure = _exception_to_failure(NotFoundError("Unknown device"))
1032                    failures[target_user] = {
1033                        device: failure for device in other_devices
1034                    }
1035
1036                if user_signing_key_id in master_key.get("signatures", {}).get(
1037                    user_id, {}
1038                ):
1039                    # we already have the signature, so we can skip it
1040                    continue
1041
1042                _check_device_signature(
1043                    user_id, user_signing_verify_key, key, master_key
1044                )
1045
1046                signature = key["signatures"][user_id][user_signing_key_id]
1047                signature_list.append(
1048                    SignatureListItem(
1049                        user_signing_key_id, target_user, device_id, signature
1050                    )
1051                )
1052            except SynapseError as e:
1053                failure = _exception_to_failure(e)
1054                if device_id is None:
1055                    failures[target_user] = {
1056                        device_id: failure for device_id in devicemap.keys()
1057                    }
1058                else:
1059                    failures.setdefault(target_user, {})[device_id] = failure
1060
1061        return signature_list, failures
1062
1063    async def _get_e2e_cross_signing_verify_key(
1064        self, user_id: str, key_type: str, from_user_id: Optional[str] = None
1065    ) -> Tuple[JsonDict, str, VerifyKey]:
1066        """Fetch locally or remotely query for a cross-signing public key.
1067
1068        First, attempt to fetch the cross-signing public key from storage.
1069        If that fails, query the keys from the homeserver they belong to
1070        and update our local copy.
1071
1072        Args:
1073            user_id: the user whose key should be fetched
1074            key_type: the type of key to fetch
1075            from_user_id: the user that we are fetching the keys for.
1076                This affects what signatures are fetched.
1077
1078        Returns:
1079            The raw key data, the key ID, and the signedjson verify key
1080
1081        Raises:
1082            NotFoundError: if the key is not found
1083            SynapseError: if `user_id` is invalid
1084        """
1085        user = UserID.from_string(user_id)
1086        key = await self.store.get_e2e_cross_signing_key(
1087            user_id, key_type, from_user_id
1088        )
1089
1090        if key:
1091            # We found a copy of this key in our database. Decode and return it
1092            key_id, verify_key = get_verify_key_from_cross_signing_key(key)
1093            return key, key_id, verify_key
1094
1095        # If we couldn't find the key locally, and we're looking for keys of
1096        # another user then attempt to fetch the missing key from the remote
1097        # user's server.
1098        #
1099        # We may run into this in possible edge cases where a user tries to
1100        # cross-sign a remote user, but does not share any rooms with them yet.
1101        # Thus, we would not have their key list yet. We instead fetch the key,
1102        # store it and notify clients of new, associated device IDs.
1103        if self.is_mine(user) or key_type not in ["master", "self_signing"]:
1104            # Note that master and self_signing keys are the only cross-signing keys we
1105            # can request over federation
1106            raise NotFoundError("No %s key found for %s" % (key_type, user_id))
1107
1108        (
1109            key,
1110            key_id,
1111            verify_key,
1112        ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
1113
1114        if key is None:
1115            raise NotFoundError("No %s key found for %s" % (key_type, user_id))
1116
1117        return key, key_id, verify_key
1118
1119    async def _retrieve_cross_signing_keys_for_remote_user(
1120        self,
1121        user: UserID,
1122        desired_key_type: str,
1123    ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
1124        """Queries cross-signing keys for a remote user and saves them to the database
1125
1126        Only the key specified by `key_type` will be returned, while all retrieved keys
1127        will be saved regardless
1128
1129        Args:
1130            user: The user to query remote keys for
1131            desired_key_type: The type of key to receive. One of "master", "self_signing"
1132
1133        Returns:
1134            A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
1135            If the key cannot be retrieved, all values in the tuple will instead be None.
1136        """
1137        try:
1138            remote_result = await self.federation.query_user_devices(
1139                user.domain, user.to_string()
1140            )
1141        except Exception as e:
1142            logger.warning(
1143                "Unable to query %s for cross-signing keys of user %s: %s %s",
1144                user.domain,
1145                user.to_string(),
1146                type(e),
1147                e,
1148            )
1149            return None, None, None
1150
1151        # Process each of the retrieved cross-signing keys
1152        desired_key = None
1153        desired_key_id = None
1154        desired_verify_key = None
1155        retrieved_device_ids = []
1156        for key_type in ["master", "self_signing"]:
1157            key_content = remote_result.get(key_type + "_key")
1158            if not key_content:
1159                continue
1160
1161            # Ensure these keys belong to the correct user
1162            if "user_id" not in key_content:
1163                logger.warning(
1164                    "Invalid %s key retrieved, missing user_id field: %s",
1165                    key_type,
1166                    key_content,
1167                )
1168                continue
1169            if user.to_string() != key_content["user_id"]:
1170                logger.warning(
1171                    "Found %s key of user %s when querying for keys of user %s",
1172                    key_type,
1173                    key_content["user_id"],
1174                    user.to_string(),
1175                )
1176                continue
1177
1178            # Validate the key contents
1179            try:
1180                # verify_key is a VerifyKey from signedjson, which uses
1181                # .version to denote the portion of the key ID after the
1182                # algorithm and colon, which is the device ID
1183                key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
1184            except ValueError as e:
1185                logger.warning(
1186                    "Invalid %s key retrieved: %s - %s %s",
1187                    key_type,
1188                    key_content,
1189                    type(e),
1190                    e,
1191                )
1192                continue
1193
1194            # Note down the device ID attached to this key
1195            retrieved_device_ids.append(verify_key.version)
1196
1197            # If this is the desired key type, save it and its ID/VerifyKey
1198            if key_type == desired_key_type:
1199                desired_key = key_content
1200                desired_verify_key = verify_key
1201                desired_key_id = key_id
1202
1203            # At the same time, store this key in the db for subsequent queries
1204            await self.store.set_e2e_cross_signing_key(
1205                user.to_string(), key_type, key_content
1206            )
1207
1208        # Notify clients that new devices for this user have been discovered
1209        if retrieved_device_ids:
1210            # XXX is this necessary?
1211            await self.device_handler.notify_device_update(
1212                user.to_string(), retrieved_device_ids
1213            )
1214
1215        return desired_key, desired_key_id, desired_verify_key
1216
1217
1218def _check_cross_signing_key(
1219    key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
1220) -> None:
1221    """Check a cross-signing key uploaded by a user.  Performs some basic sanity
1222    checking, and ensures that it is signed, if a signature is required.
1223
1224    Args:
1225        key: the key data to verify
1226        user_id: the user whose key is being checked
1227        key_type: the type of key that the key should be
1228        signing_key: the signing key that the key should be signed with.  If
1229            omitted, signatures will not be checked.
1230    """
1231    if (
1232        key.get("user_id") != user_id
1233        or key_type not in key.get("usage", [])
1234        or len(key.get("keys", {})) != 1
1235    ):
1236        raise SynapseError(400, ("Invalid %s key" % (key_type,)), Codes.INVALID_PARAM)
1237
1238    if signing_key:
1239        try:
1240            verify_signed_json(key, user_id, signing_key)
1241        except SignatureVerifyException:
1242            raise SynapseError(
1243                400, ("Invalid signature on %s key" % key_type), Codes.INVALID_SIGNATURE
1244            )
1245
1246
1247def _check_device_signature(
1248    user_id: str,
1249    verify_key: VerifyKey,
1250    signed_device: JsonDict,
1251    stored_device: JsonDict,
1252) -> None:
1253    """Check that a signature on a device or cross-signing key is correct and
1254    matches the copy of the device/key that we have stored.  Throws an
1255    exception if an error is detected.
1256
1257    Args:
1258        user_id: the user ID whose signature is being checked
1259        verify_key: the key to verify the device with
1260        signed_device: the uploaded signed device data
1261        stored_device: our previously stored copy of the device
1262
1263    Raises:
1264        SynapseError: if the signature was invalid or the sent device is not the
1265            same as the stored device
1266
1267    """
1268
1269    # make sure that the device submitted matches what we have stored
1270    stripped_signed_device = {
1271        k: v for k, v in signed_device.items() if k not in ["signatures", "unsigned"]
1272    }
1273    stripped_stored_device = {
1274        k: v for k, v in stored_device.items() if k not in ["signatures", "unsigned"]
1275    }
1276    if stripped_signed_device != stripped_stored_device:
1277        logger.debug(
1278            "upload signatures: key does not match %s vs %s",
1279            signed_device,
1280            stored_device,
1281        )
1282        raise SynapseError(400, "Key does not match")
1283
1284    try:
1285        verify_signed_json(signed_device, user_id, verify_key)
1286    except SignatureVerifyException:
1287        logger.debug("invalid signature on key")
1288        raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
1289
1290
1291def _exception_to_failure(e: Exception) -> JsonDict:
1292    if isinstance(e, SynapseError):
1293        return {"status": e.code, "errcode": e.errcode, "message": str(e)}
1294
1295    if isinstance(e, CodeMessageException):
1296        return {"status": e.code, "message": str(e)}
1297
1298    if isinstance(e, NotRetryingDestination):
1299        return {"status": 503, "message": "Not ready for retry"}
1300
1301    # include ConnectionRefused and other errors
1302    #
1303    # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
1304    # give a string for e.message, which json then fails to serialize.
1305    return {"status": 503, "message": str(e)}
1306
1307
1308def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
1309    old_key = json_decoder.decode(old_key_json)
1310
1311    # if either is a string rather than an object, they must match exactly
1312    if not isinstance(old_key, dict) or not isinstance(new_key, dict):
1313        return old_key == new_key
1314
1315    # otherwise, we strip off the 'signatures' if any, because it's legitimate
1316    # for different upload attempts to have different signatures.
1317    old_key.pop("signatures", None)
1318    new_key_copy = dict(new_key)
1319    new_key_copy.pop("signatures", None)
1320
1321    return old_key == new_key_copy
1322
1323
1324@attr.s(slots=True)
1325class SignatureListItem:
1326    """An item in the signature list as used by upload_signatures_for_device_keys."""
1327
1328    signing_key_id = attr.ib(type=str)
1329    target_user_id = attr.ib(type=str)
1330    target_device_id = attr.ib(type=str)
1331    signature = attr.ib(type=JsonDict)
1332
1333
1334class SigningKeyEduUpdater:
1335    """Handles incoming signing key updates from federation and updates the DB"""
1336
1337    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
1338        self.store = hs.get_datastore()
1339        self.federation = hs.get_federation_client()
1340        self.clock = hs.get_clock()
1341        self.e2e_keys_handler = e2e_keys_handler
1342
1343        self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
1344
1345        # user_id -> list of updates waiting to be handled.
1346        self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
1347
1348    async def incoming_signing_key_update(
1349        self, origin: str, edu_content: JsonDict
1350    ) -> None:
1351        """Called on incoming signing key update from federation. Responsible for
1352        parsing the EDU and adding to pending updates list.
1353
1354        Args:
1355            origin: the server that sent the EDU
1356            edu_content: the contents of the EDU
1357        """
1358
1359        user_id = edu_content.pop("user_id")
1360        master_key = edu_content.pop("master_key", None)
1361        self_signing_key = edu_content.pop("self_signing_key", None)
1362
1363        if get_domain_from_id(user_id) != origin:
1364            logger.warning("Got signing key update edu for %r from %r", user_id, origin)
1365            return
1366
1367        room_ids = await self.store.get_rooms_for_user(user_id)
1368        if not room_ids:
1369            # We don't share any rooms with this user. Ignore update, as we
1370            # probably won't get any further updates.
1371            return
1372
1373        self._pending_updates.setdefault(user_id, []).append(
1374            (master_key, self_signing_key)
1375        )
1376
1377        await self._handle_signing_key_updates(user_id)
1378
1379    async def _handle_signing_key_updates(self, user_id: str) -> None:
1380        """Actually handle pending updates.
1381
1382        Args:
1383            user_id: the user whose updates we are processing
1384        """
1385
1386        device_handler = self.e2e_keys_handler.device_handler
1387        device_list_updater = device_handler.device_list_updater
1388
1389        with (await self._remote_edu_linearizer.queue(user_id)):
1390            pending_updates = self._pending_updates.pop(user_id, [])
1391            if not pending_updates:
1392                # This can happen since we batch updates
1393                return
1394
1395            device_ids: List[str] = []
1396
1397            logger.info("pending updates: %r", pending_updates)
1398
1399            for master_key, self_signing_key in pending_updates:
1400                new_device_ids = (
1401                    await device_list_updater.process_cross_signing_key_update(
1402                        user_id,
1403                        master_key,
1404                        self_signing_key,
1405                    )
1406                )
1407                device_ids = device_ids + new_device_ids
1408
1409            await device_handler.notify_device_update(user_id, device_ids)
1410